Non puoi selezionare più di 25 argomenti Gli argomenti devono iniziare con una lettera o un numero, possono includere trattini ('-') e possono essere lunghi fino a 35 caratteri.

test_utils.py 15 KiB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. """Test cases for utils"""
  2. import unittest
  3. import pipeline.utils as utils
  4. class MockSource():
  5. """Mock Source object for testing Feature"""
  6. def __init__(self, source):
  7. self.source = source
  8. def __str__(self):
  9. return self.source
  10. def to_json(self):
  11. return {"source": self.source}
  12. def __eq__(self, other):
  13. return self.source == other.source
  14. class MockInterval():
  15. """Mock Interval object for testing Feature"""
  16. def __init__(self, start, end):
  17. self.start = start
  18. self.end = end
  19. def to_json(self):
  20. return {"start": self.start, "end": self.end}
  21. class TestSource(unittest.TestCase):
  22. """Source is a container for source, path, and provider of a media file
  23. source -- the source of the media file (eg, a URL or a local path)
  24. path -- the path to the media file
  25. provider -- the provider of the media file (eg, "FileInputJSON")
  26. Accessing the object should return the path to the media file.
  27. Methods:
  28. duration() -- return the duration of the media file (uses ffprobe, result is cached)
  29. """
  30. # Happy path tests
  31. def setUp(self):
  32. self.source = "audio_clips/testclip-5min.aac"
  33. self.path = "audio_clips/testclip-5min.aac"
  34. self.provider = "FileInputJSON"
  35. self._duration = 306.121214 # duration of testclip-5min.aac
  36. def test_init(self):
  37. source = utils.Source(self.source, self.path, self.provider)
  38. self.assertEqual(source.source, self.source)
  39. self.assertEqual(source.path, self.path)
  40. self.assertEqual(source.provider, self.provider)
  41. def test_str(self):
  42. """Accessing the object should return the path to the media file"""
  43. source = utils.Source(self.source, self.path, self.provider)
  44. self.assertEqual(str(source), self.path)
  45. def test_repr(self):
  46. source = utils.Source(self.source, self.path, self.provider)
  47. self.assertEqual(repr(source), f"Source({self.source}, {self.path}, {self.provider})")
  48. def test_duration(self):
  49. """Use a mock duration of 306.121214 for the test clip"""
  50. source = utils.Source(self.source, self.path, self.provider)
  51. self.assertEqual(source.duration(), self._duration)
  52. def test_to_json(self):
  53. source = utils.Source(self.source, self.path, self.provider)
  54. self.assertEqual(source.to_json(), {"source": self.source, "path": self.path, "provider": self.provider})
  55. # Sad path tests
  56. def test_init_no_source(self):
  57. with self.assertRaises(ValueError):
  58. utils.Source("", self.path, self.provider)
  59. def test_init_no_path(self):
  60. with self.assertRaises(ValueError):
  61. utils.Source(self.source, "", self.provider)
  62. def test_init_no_provider(self):
  63. with self.assertRaises(ValueError):
  64. utils.Source(self.source, self.path, "")
  65. def test_duration_no_file(self):
  66. """Test that duration raises FileNotFoundError if the file does not exist"""
  67. source = utils.Source(self.source, "fakepath-noexist!😇", self.provider)
  68. # if that file actually exists I'll eat my hat
  69. with self.assertRaises(FileNotFoundError):
  70. source.duration()
  71. class TestSourceMedia(unittest.TestCase):
  72. """SourceMedia is a container for Source objects"""
  73. class FakeSource():
  74. """Fake source object for testing SourceMedia
  75. Since SourceMedia doesn't actually access any of the attributes of Source objects,
  76. we can use a very empty class for testing.
  77. """
  78. def setUp(self):
  79. self.source1 = self.FakeSource()
  80. self.source2 = self.FakeSource()
  81. self.source3 = self.FakeSource()
  82. self.sources = [self.source1, self.source2, self.source3]
  83. def test_init(self):
  84. source_media = utils.SourceMedia(sources=self.sources)
  85. self.assertEqual(source_media.sources, self.sources)
  86. def test_iter(self):
  87. source_media = utils.SourceMedia(sources=self.sources)
  88. self.assertEqual(list(source_media), self.sources)
  89. class TestInterval(unittest.TestCase):
  90. """Interval is a container for start, end and duration attributes"""
  91. def setUp(self):
  92. self.start = 0
  93. self.end = 10
  94. self.duration = 10
  95. # Happy path tests
  96. def test_init_start_end(self):
  97. """Create an Interval using start and end attributes"""
  98. interval = utils.Interval(start=self.start, end=self.end)
  99. self.assertEqual(interval.start, self.start)
  100. self.assertEqual(interval.end, self.end)
  101. self.assertEqual(interval.duration, self.duration)
  102. def test_init_start_duration(self):
  103. """Create an Interval using start and duration attributes"""
  104. interval = utils.Interval(start=self.start, duration=self.duration)
  105. self.assertEqual(interval.start, self.start)
  106. self.assertEqual(interval.end, self.end)
  107. self.assertEqual(interval.duration, self.duration)
  108. def test_init_end_duration(self):
  109. """Create an Interval using end and duration attributes"""
  110. interval = utils.Interval(end=self.end, duration=self.duration)
  111. self.assertEqual(interval.start, self.start)
  112. self.assertEqual(interval.end, self.end)
  113. self.assertEqual(interval.duration, self.duration)
  114. def test_from_start_classmethod(self):
  115. """Create an Interval using the from_start classmethod (ie start attribute only- uses default duration)"""
  116. interval = utils.Interval.from_start(start=self.start)
  117. self.assertEqual(interval.start, self.start)
  118. self.assertEqual(interval.end, interval.start + utils.Interval.DEFAULT_DURATION)
  119. self.assertEqual(interval.duration, utils.Interval.DEFAULT_DURATION)
  120. def test_from_end_classmethod(self):
  121. """Create an Interval using the from_end classmethod (ie end attribute only- uses default duration)"""
  122. interval = utils.Interval.from_end(end=self.end)
  123. self.assertEqual(interval.start, interval.end - utils.Interval.DEFAULT_DURATION)
  124. self.assertEqual(interval.end, self.end)
  125. self.assertEqual(interval.duration, utils.Interval.DEFAULT_DURATION)
  126. def test_repr(self):
  127. interval = utils.Interval(start=self.start, end=self.end)
  128. self.assertEqual(repr(interval), f"Interval({self.start}, {self.end}, {self.duration})")
  129. def test_lt(self):
  130. """Test the __lt__ method used for sorting Interval objects based on start time
  131. If the start times are equal, the interval with the smaller end time is considered smaller
  132. """
  133. interval1 = utils.Interval(start=0, end=10)
  134. interval2 = utils.Interval(start=0, end=15)
  135. interval3 = utils.Interval(start=5, end=10)
  136. interval4 = utils.Interval(start=5, end=15)
  137. self.assertTrue(interval1 < interval2) # same start, interval1 has smaller end
  138. self.assertTrue(interval1 < interval3) # interval1 start is smaller
  139. self.assertTrue(interval3 > interval2) # interval3 start is larger
  140. self.assertTrue(interval3 < interval4) # same start, interval3 has smaller end
  141. def test_eq(self):
  142. """Test the __eq__ method for comparing Interval objects"""
  143. interval1 = utils.Interval(start=0, end=10)
  144. interval2 = utils.Interval(start=0, end=10)
  145. interval3 = utils.Interval(start=0, end=15)
  146. self.assertEqual(interval1, interval2)
  147. self.assertNotEqual(interval1, interval3)
  148. def test_to_json(self):
  149. """Test to_json method"""
  150. interval = utils.Interval(start=self.start, end=self.end)
  151. self.assertEqual(interval.to_json(), {"start": self.start, "end": self.end, "duration": self.duration})
  152. def test_move_start(self):
  153. """Test the move_start method - changes start time to time specified, keeps end time constant"""
  154. interval = utils.Interval(start=self.start, end=self.end)
  155. interval.move_start(5)
  156. self.assertEqual(interval.start, 5)
  157. self.assertEqual(interval.end, 10)
  158. self.assertEqual(interval.duration, 5)
  159. def test_move_start_relative(self):
  160. """Test the move_start method with relative=True - changes start time by a relative amount"""
  161. interval = utils.Interval(start=self.start, end=self.end)
  162. interval.move_start(2, relative=True)
  163. self.assertEqual(interval.start, 2)
  164. self.assertEqual(interval.end, 10)
  165. self.assertEqual(interval.duration, 8)
  166. def test_move_end(self):
  167. """Test the move_end method - changes end time to time specified, keeps start time constant"""
  168. interval = utils.Interval(start=self.start, end=self.end)
  169. interval.move_end(15)
  170. self.assertEqual(interval.start, 0)
  171. self.assertEqual(interval.end, 15)
  172. self.assertEqual(interval.duration, 15)
  173. def test_update_duration(self):
  174. """Test the update_duration method - changes duration to time specified, keeps start time constant"""
  175. interval = utils.Interval(start=self.start, end=self.end)
  176. interval.update_duration(15)
  177. self.assertEqual(interval.start, 0)
  178. self.assertEqual(interval.end, 15)
  179. self.assertEqual(interval.duration, 15)
  180. # test with relative=True
  181. interval.update_duration(5, relative=True)
  182. self.assertEqual(interval.start, 0)
  183. # Unhappy path tests
  184. def test_init_no_start_end(self):
  185. with self.assertRaises(ValueError):
  186. utils.Interval()
  187. def test_init_start_end_duration(self):
  188. with self.assertRaises(ValueError):
  189. utils.Interval(start=self.start, end=self.end, duration=self.duration)
  190. def test_init_start_after_end(self):
  191. with self.assertRaises(ValueError):
  192. utils.Interval(start=10, end=0)
  193. def test_init_negative_duration(self):
  194. with self.assertRaises(ValueError):
  195. utils.Interval(start=0, duration=-10)
  196. with self.assertRaises(ValueError):
  197. utils.Interval(end=0, duration=-10)
  198. class TestFeature(unittest.TestCase):
  199. """Test the Feature class"""
  200. # happy path tests
  201. def test_init(self):
  202. """Test creation of a Feature object.
  203. Needs: interval, source, feature_extractor and score"""
  204. source = MockSource("source")
  205. interval = MockInterval(0, 10)
  206. feature_extractor = "feature_extractor"
  207. score = 0.5
  208. feature = utils.Feature(interval, source, feature_extractor, score)
  209. # NOTE: LSP complains about assign MockSource to source, but it's fine
  210. self.assertEqual(feature.interval, interval)
  211. self.assertEqual(feature.source, source)
  212. self.assertEqual(feature.feature_extractor, feature_extractor)
  213. self.assertEqual(feature.score, score)
  214. def test_repr(self):
  215. source = MockSource("source")
  216. interval = MockInterval(0, 10)
  217. feature_extractor = "test"
  218. score = 0.5
  219. feature = utils.Feature(interval, source, feature_extractor, score)
  220. self.assertEqual(repr(feature), f"Feature({interval}, {source}, {feature_extractor}, {score})")
  221. def test_to_json(self):
  222. """test Feature.to_json method"""
  223. source = MockSource("source")
  224. interval = MockInterval(0, 10)
  225. feature_extractor = "test"
  226. score = 0.5
  227. feature = utils.Feature(interval, source, feature_extractor, score)
  228. self.assertEqual(feature.to_json(), {"interval": interval.to_json(), "source": source.to_json(),
  229. "feature_extractor": feature_extractor, "score": score})
  230. def test_classmethod_from_start(self):
  231. """Test the @classmethod for creating Feature with Interval.from_start"""
  232. source = MockSource("source")
  233. feature_extractor = "test"
  234. score = 0.5
  235. feature = utils.Feature.from_start(start=0, source=source, feature_extractor=feature_extractor, score=score)
  236. self.assertEqual(feature.interval.start, 0)
  237. self.assertEqual(feature.interval.end, 5)
  238. self.assertEqual(feature.interval.duration, 5)
  239. self.assertEqual(feature.source, source)
  240. self.assertEqual(feature.feature_extractor, feature_extractor)
  241. self.assertEqual(feature.score, score)
  242. def test_classmethod_from_end(self):
  243. """Test the @classmethod for creating Feature with Interval.from_end"""
  244. source = MockSource("source")
  245. feature_extractor = "test"
  246. score = 0.5
  247. feature = utils.Feature.from_end(end=10, source=source, feature_extractor=feature_extractor, score=score)
  248. self.assertEqual(feature.interval.start, 5)
  249. self.assertEqual(feature.interval.end, 10)
  250. self.assertEqual(feature.interval.duration, 5)
  251. self.assertEqual(feature.source, source)
  252. self.assertEqual(feature.feature_extractor, feature_extractor)
  253. self.assertEqual(feature.score, score)
  254. def test_lt(self):
  255. """test __lt__ method for sorting Feature objects"""
  256. # TODO: ensure goood coverage of corresponding __lt__ method of Interval
  257. feature1 = utils.Feature(utils.Interval(0, 10), MockSource("source"), "test_1", 0.5)
  258. feature2 = utils.Feature(utils.Interval(0, 15), MockSource("source"), "test_1", 0.5)
  259. feature3 = utils.Feature(utils.Interval(5, 10), MockSource("source"), "test_1", 0.5)
  260. feature4 = utils.Feature(utils.Interval(5, 15), MockSource("source"), "test_1", 0.5)
  261. feature5 = utils.Feature(utils.Interval(0, 10), MockSource("source"), "test_1", 0.6) # score
  262. feature6 = utils.Feature(utils.Interval(0, 10), MockSource("source"), "test_2", 0.5) # feature_extractor
  263. # test sorting based on Interval
  264. self.assertTrue(feature1 < feature2)
  265. self.assertTrue(feature1 < feature3)
  266. self.assertTrue(feature3 < feature4)
  267. # test sorting based on feature_extractor
  268. self.assertTrue(feature1 < feature6)
  269. # test sorting based on score
  270. self.assertTrue(feature1 < feature5)
  271. def test_eq(self):
  272. """test __eq__ method for comparing Feature objects"""
  273. feature1 = utils.Feature(utils.Interval(0, 10), MockSource("source"), "test_1", 0.5)
  274. feature2 = utils.Feature(utils.Interval(0, 10), MockSource("source"), "test_1", 0.5)
  275. feature3 = utils.Feature(utils.Interval(0, 10), MockSource("source"), "test_1", 0.6)
  276. self.assertEqual(feature1, feature2)
  277. self.assertNotEqual(feature1, feature3)
  278. # unhappy path tests
  279. def test_init_unhappy(self):
  280. """Test init with pats missing"""
  281. with self.assertRaises(ValueError):
  282. utils.Feature(None, None, None, None)
  283. # missing interval
  284. with self.assertRaises(ValueError):
  285. utils.Feature(None, MockSource("source"), "feature_extractor", 0.5)
  286. # missing source
  287. with self.assertRaises(ValueError):
  288. utils.Feature(MockInterval(0, 10), None, "feature_extractor", 0.5)
  289. # missing feature_extractor
  290. feature = utils.Feature(MockInterval(0, 10), MockSource("source"), None, 0.5)
  291. self.assertEqual(feature.feature_extractor, "unknown")
  292. if __name__ == "__main__":
  293. unittest.main()