You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

393 lines
16 KiB

  1. """Test cases for utils"""
  2. import unittest
  3. import pipeline.utils as utils
  4. class MockSource():
  5. """Mock Source object for testing Feature"""
  6. def __init__(self, source):
  7. self.source = source
  8. def __str__(self):
  9. return self.source
  10. def to_json(self):
  11. return {"source": self.source}
  12. def __eq__(self, other):
  13. return self.source == other.source
  14. class MockInterval():
  15. """Mock Interval object for testing Feature"""
  16. def __init__(self, start, end):
  17. self.start = start
  18. self.end = end
  19. def to_json(self):
  20. return {"start": self.start, "end": self.end}
  21. class TestSource(unittest.TestCase):
  22. """Source is a container for source, path, and provider of a media file
  23. source -- the source of the media file (eg, a URL or a local path)
  24. path -- the path to the media file
  25. provider -- the provider of the media file (eg, "FileInputJSON")
  26. Accessing the object should return the path to the media file.
  27. Methods:
  28. duration() -- return the duration of the media file (uses ffprobe, result is cached)
  29. """
  30. # Happy path tests
  31. def setUp(self):
  32. self.source = "audio_clips/testclip-5min.aac"
  33. self.path = "audio_clips/testclip-5min.aac"
  34. self.provider = "FileInputJSON"
  35. self._duration = 306.121214 # duration of testclip-5min.aac
  36. def test_init(self):
  37. source = utils.Source(self.source, self.path, self.provider)
  38. self.assertEqual(source.source, self.source)
  39. self.assertEqual(source.path, self.path)
  40. self.assertEqual(source.provider, self.provider)
  41. def test_str(self):
  42. """Accessing the object should return the path to the media file"""
  43. source = utils.Source(self.source, self.path, self.provider)
  44. self.assertEqual(str(source), self.path)
  45. def test_repr(self):
  46. source = utils.Source(self.source, self.path, self.provider)
  47. self.assertEqual(repr(source), f"Source({self.source}, {self.path}, {self.provider})")
  48. def test_duration(self):
  49. """Use a mock duration of 306.121214 for the test clip"""
  50. source = utils.Source(self.source, self.path, self.provider)
  51. self.assertEqual(source.duration(), self._duration)
  52. def test_to_json(self):
  53. source = utils.Source(self.source, self.path, self.provider)
  54. self.assertEqual(source.to_json(), {"source": self.source, "path": self.path, "provider": self.provider})
  55. # Sad path tests
  56. def test_init_no_source(self):
  57. with self.assertRaises(ValueError):
  58. utils.Source("", self.path, self.provider)
  59. def test_init_no_path(self):
  60. with self.assertRaises(ValueError):
  61. utils.Source(self.source, "", self.provider)
  62. def test_init_no_provider(self):
  63. with self.assertRaises(ValueError):
  64. utils.Source(self.source, self.path, "")
  65. def test_duration_no_file(self):
  66. """Test that duration raises FileNotFoundError if the file does not exist"""
  67. source = utils.Source(self.source, "fakepath-noexist!😇", self.provider)
  68. # if that file actually exists I'll eat my hat
  69. with self.assertRaises(FileNotFoundError):
  70. source.duration()
  71. class TestSourceMedia(unittest.TestCase):
  72. """SourceMedia is a container for Source objects"""
  73. class FakeSource():
  74. """Fake source object for testing SourceMedia
  75. Since SourceMedia doesn't actually access any of the attributes of Source objects,
  76. we can use a very empty class for testing.
  77. """
  78. def setUp(self):
  79. self.source1 = self.FakeSource()
  80. self.source2 = self.FakeSource()
  81. self.source3 = self.FakeSource()
  82. self.sources = [self.source1, self.source2, self.source3]
  83. def test_init(self):
  84. source_media = utils.SourceMedia(sources=self.sources)
  85. self.assertEqual(source_media.sources, self.sources)
  86. def test_iter(self):
  87. source_media = utils.SourceMedia(sources=self.sources)
  88. self.assertEqual(list(source_media), self.sources)
  89. class TestInterval(unittest.TestCase):
  90. """Interval is a container for start, end and duration attributes"""
  91. def setUp(self):
  92. self.start = 0
  93. self.end = 10
  94. self.duration = 10
  95. # Happy path tests
  96. def test_init_start_end(self):
  97. """Create an Interval using start and end attributes"""
  98. interval = utils.Interval(start=self.start, end=self.end)
  99. self.assertEqual(interval.start, self.start)
  100. self.assertEqual(interval.end, self.end)
  101. self.assertEqual(interval.duration, self.duration)
  102. def test_init_start_duration(self):
  103. """Create an Interval using start and duration attributes"""
  104. interval = utils.Interval(start=self.start, duration=self.duration)
  105. self.assertEqual(interval.start, self.start)
  106. self.assertEqual(interval.end, self.end)
  107. self.assertEqual(interval.duration, self.duration)
  108. def test_init_end_duration(self):
  109. """Create an Interval using end and duration attributes"""
  110. interval = utils.Interval(end=self.end, duration=self.duration)
  111. self.assertEqual(interval.start, self.start)
  112. self.assertEqual(interval.end, self.end)
  113. self.assertEqual(interval.duration, self.duration)
  114. def test_from_start_classmethod(self):
  115. """Create an Interval using the from_start classmethod (ie start attribute only- uses default duration)"""
  116. interval = utils.Interval.from_start(start=self.start)
  117. self.assertEqual(interval.start, self.start)
  118. self.assertEqual(interval.end, interval.start + utils.Interval.DEFAULT_DURATION)
  119. self.assertEqual(interval.duration, utils.Interval.DEFAULT_DURATION)
  120. def test_from_end_classmethod(self):
  121. """Create an Interval using the from_end classmethod (ie end attribute only- uses default duration)"""
  122. interval = utils.Interval.from_end(end=self.end)
  123. self.assertEqual(interval.start, interval.end - utils.Interval.DEFAULT_DURATION)
  124. self.assertEqual(interval.end, self.end)
  125. self.assertEqual(interval.duration, utils.Interval.DEFAULT_DURATION)
  126. def test_repr(self):
  127. interval = utils.Interval(start=self.start, end=self.end)
  128. self.assertEqual(repr(interval), f"Interval({self.start}, {self.end}, {self.duration})")
  129. def test_lt(self):
  130. """Test the __lt__ method used for sorting Interval objects based on start time
  131. If the start times are equal, the interval with the smaller end time is considered smaller
  132. """
  133. interval1 = utils.Interval(start=0, end=10)
  134. interval2 = utils.Interval(start=0, end=15)
  135. interval3 = utils.Interval(start=5, end=10)
  136. interval4 = utils.Interval(start=5, end=15)
  137. self.assertTrue(interval1 < interval2) # same start, interval1 has smaller end
  138. self.assertTrue(interval1 < interval3) # interval1 start is smaller
  139. self.assertTrue(interval3 > interval2) # interval3 start is larger
  140. self.assertTrue(interval3 < interval4) # same start, interval3 has smaller end
  141. def test_eq(self):
  142. """Test the __eq__ method for comparing Interval objects"""
  143. interval1 = utils.Interval(start=0, end=10)
  144. interval2 = utils.Interval(start=0, end=10)
  145. interval3 = utils.Interval(start=0, end=15)
  146. self.assertEqual(interval1, interval2)
  147. self.assertNotEqual(interval1, interval3)
  148. def test_to_json(self):
  149. """Test to_json method"""
  150. interval = utils.Interval(start=self.start, end=self.end)
  151. self.assertEqual(interval.to_json(), {"start": self.start, "end": self.end, "duration": self.duration})
  152. def test_move_start(self):
  153. """Test the move_start method - changes start time to time specified, keeps end time constant"""
  154. interval = utils.Interval(start=self.start, end=self.end)
  155. interval.move_start(5)
  156. self.assertEqual(interval.start, 5)
  157. self.assertEqual(interval.end, 10)
  158. self.assertEqual(interval.duration, 5)
  159. def test_move_start_relative(self):
  160. """Test the move_start method with relative=True - changes start time by a relative amount"""
  161. interval = utils.Interval(start=self.start, end=self.end)
  162. interval.move_start(2, relative=True)
  163. self.assertEqual(interval.start, 2)
  164. self.assertEqual(interval.end, 10)
  165. self.assertEqual(interval.duration, 8)
  166. def test_move_end(self):
  167. """Test the move_end method - changes end time to time specified, keeps start time constant"""
  168. with self.subTest("non-relative"):
  169. interval = utils.Interval(start=self.start, end=self.end)
  170. interval.move_end(15)
  171. self.assertEqual(interval.start, 0)
  172. self.assertEqual(interval.end, 15)
  173. self.assertEqual(interval.duration, 15)
  174. with self.subTest("relative"):
  175. interval = utils.Interval(start=self.start, end=self.end)
  176. interval.move_end(5, relative=True)
  177. self.assertEqual(interval.start, 0)
  178. self.assertEqual(interval.end, 15)
  179. self.assertEqual(interval.duration, 15)
  180. def test_update_duration(self):
  181. """Test the update_duration method - changes duration to time specified, keeps start time constant"""
  182. interval = utils.Interval(start=self.start, end=self.end)
  183. interval.update_duration(15)
  184. self.assertEqual(interval.start, 0)
  185. self.assertEqual(interval.end, 15)
  186. self.assertEqual(interval.duration, 15)
  187. # test with relative=True
  188. interval.update_duration(5, relative=True)
  189. self.assertEqual(interval.start, 0)
  190. def test_overlaps(self):
  191. """Test the overlaps method - returns True if the interval overlaps with another interval"""
  192. interval1 = utils.Interval(start=0, end=10)
  193. interval2 = utils.Interval(start=5, end=15) # overlaps with interval1
  194. interval3 = utils.Interval(start=15, end=20) # does not overlap with interval1
  195. interval4 = utils.Interval(start=10, end=15) # touch overlap with interval1
  196. # test with overlapping interval
  197. self.assertTrue(interval1.overlaps(interval2))
  198. self.assertTrue(interval2.overlaps(interval1))
  199. # test with non-overlapping interval
  200. self.assertFalse(interval1.overlaps(interval3))
  201. self.assertFalse(interval3.overlaps(interval1))
  202. # test with touching interval
  203. self.assertTrue(interval1.overlaps(interval4))
  204. self.assertTrue(interval4.overlaps(interval1))
  205. # Unhappy path tests
  206. def test_init_no_start_end(self):
  207. with self.assertRaises(ValueError):
  208. utils.Interval()
  209. def test_init_start_end_duration(self):
  210. with self.assertRaises(ValueError):
  211. utils.Interval(start=self.start, end=self.end, duration=self.duration)
  212. def test_init_start_after_end(self):
  213. with self.assertRaises(ValueError):
  214. utils.Interval(start=10, end=0)
  215. def test_init_negative_duration(self):
  216. with self.assertRaises(ValueError):
  217. utils.Interval(start=0, duration=-10)
  218. with self.assertRaises(ValueError):
  219. utils.Interval(end=0, duration=-10)
  220. class TestFeature(unittest.TestCase):
  221. """Test the Feature class"""
  222. # happy path tests
  223. def test_init(self):
  224. """Test creation of a Feature object.
  225. Needs: interval, source, feature_extractor and score"""
  226. source = MockSource("source")
  227. interval = MockInterval(0, 10)
  228. feature_extractor = "feature_extractor"
  229. score = 0.5
  230. feature = utils.Feature(interval, source, feature_extractor, score)
  231. # NOTE: LSP complains about assign MockSource to source, but it's fine
  232. self.assertEqual(feature.interval, interval)
  233. self.assertEqual(feature.source, source)
  234. self.assertEqual(feature.feature_extractor, feature_extractor)
  235. self.assertEqual(feature.score, score)
  236. def test_repr(self):
  237. source = MockSource("source")
  238. interval = MockInterval(0, 10)
  239. feature_extractor = "test"
  240. score = 0.5
  241. feature = utils.Feature(interval, source, feature_extractor, score)
  242. self.assertEqual(repr(feature), f"Feature({interval}, {source}, {feature_extractor}, {score})")
  243. def test_to_json(self):
  244. """test Feature.to_json method"""
  245. source = MockSource("source")
  246. interval = MockInterval(0, 10)
  247. feature_extractor = "test"
  248. score = 0.5
  249. feature = utils.Feature(interval, source, feature_extractor, score)
  250. self.assertEqual(feature.to_json(), {"interval": interval.to_json(), "source": source.to_json(),
  251. "feature_extractor": feature_extractor, "score": score})
  252. def test_classmethod_from_start(self):
  253. """Test the @classmethod for creating Feature with Interval.from_start"""
  254. source = MockSource("source")
  255. feature_extractor = "test"
  256. score = 0.5
  257. feature = utils.Feature.from_start(start=0, source=source, feature_extractor=feature_extractor, score=score)
  258. self.assertEqual(feature.interval.start, 0)
  259. self.assertEqual(feature.interval.end, 5)
  260. self.assertEqual(feature.interval.duration, 5)
  261. self.assertEqual(feature.source, source)
  262. self.assertEqual(feature.feature_extractor, feature_extractor)
  263. self.assertEqual(feature.score, score)
  264. def test_classmethod_from_end(self):
  265. """Test the @classmethod for creating Feature with Interval.from_end"""
  266. source = MockSource("source")
  267. feature_extractor = "test"
  268. score = 0.5
  269. feature = utils.Feature.from_end(end=10, source=source, feature_extractor=feature_extractor, score=score)
  270. self.assertEqual(feature.interval.start, 5)
  271. self.assertEqual(feature.interval.end, 10)
  272. self.assertEqual(feature.interval.duration, 5)
  273. self.assertEqual(feature.source, source)
  274. self.assertEqual(feature.feature_extractor, feature_extractor)
  275. self.assertEqual(feature.score, score)
  276. def test_lt(self):
  277. """test __lt__ method for sorting Feature objects"""
  278. # TODO: ensure goood coverage of corresponding __lt__ method of Interval
  279. feature1 = utils.Feature(utils.Interval(0, 10), MockSource("source"), "test_1", 0.5)
  280. feature2 = utils.Feature(utils.Interval(0, 15), MockSource("source"), "test_1", 0.5)
  281. feature3 = utils.Feature(utils.Interval(5, 10), MockSource("source"), "test_1", 0.5)
  282. feature4 = utils.Feature(utils.Interval(5, 15), MockSource("source"), "test_1", 0.5)
  283. feature5 = utils.Feature(utils.Interval(0, 10), MockSource("source"), "test_1", 0.6) # score
  284. feature6 = utils.Feature(utils.Interval(0, 10), MockSource("source"), "test_2", 0.5) # feature_extractor
  285. # test sorting based on Interval
  286. self.assertTrue(feature1 < feature2)
  287. self.assertTrue(feature1 < feature3)
  288. self.assertTrue(feature3 < feature4)
  289. # test sorting based on feature_extractor
  290. self.assertTrue(feature1 < feature6)
  291. # test sorting based on score
  292. self.assertTrue(feature1 < feature5)
  293. def test_eq(self):
  294. """test __eq__ method for comparing Feature objects"""
  295. feature1 = utils.Feature(utils.Interval(0, 10), MockSource("source"), "test_1", 0.5)
  296. feature2 = utils.Feature(utils.Interval(0, 10), MockSource("source"), "test_1", 0.5)
  297. feature3 = utils.Feature(utils.Interval(0, 10), MockSource("source"), "test_1", 0.6)
  298. self.assertEqual(feature1, feature2)
  299. self.assertNotEqual(feature1, feature3)
  300. def test_feature_no_score(self):
  301. """Test creating a Feature without a score"""
  302. source = MockSource("source")
  303. interval = MockInterval(0, 10)
  304. feature_extractor = "test"
  305. feature = utils.Feature(interval, source, feature_extractor)
  306. self.assertEqual(feature.score, 0.0)
  307. # unhappy path tests
  308. def test_init_unhappy(self):
  309. """Test init with pats missing"""
  310. with self.assertRaises(ValueError):
  311. utils.Feature(None, None, None, None)
  312. # missing interval
  313. with self.assertRaises(ValueError):
  314. utils.Feature(None, MockSource("source"), "feature_extractor", 0.5)
  315. # missing source
  316. with self.assertRaises(ValueError):
  317. utils.Feature(MockInterval(0, 10), None, "feature_extractor", 0.5)
  318. # missing feature_extractor
  319. feature = utils.Feature(MockInterval(0, 10), MockSource("source"), None, 0.5)
  320. self.assertEqual(feature.feature_extractor, "unknown")
  321. if __name__ == "__main__":
  322. unittest.main()