瀏覽代碼

test: [utils] add TestFeature fixtures

Gets utils coverage up to 100% (for now)
main
Rob Hallam 2 月之前
父節點
當前提交
6cb0d723ac
共有 1 個檔案被更改,包括 111 行新增0 行删除
  1. +111
    -0
      test/test_utils.py

+ 111
- 0
test/test_utils.py 查看文件

@@ -245,6 +245,117 @@ class TestInterval(unittest.TestCase):
with self.assertRaises(ValueError):
utils.Interval(end=0, duration=-10)

class TestFeature(unittest.TestCase):
"""Test the Feature class"""

# happy path tests

def test_init(self):
"""Test creation of a Feature object.

Needs: interval, source, feature_extractor and score"""
source = MockSource("source")
interval = MockInterval(0, 10)
feature_extractor = "feature_extractor"
score = 0.5
feature = utils.Feature(interval, source, feature_extractor, score)
# NOTE: LSP complains about assign MockSource to source, but it's fine
self.assertEqual(feature.interval, interval)
self.assertEqual(feature.source, source)
self.assertEqual(feature.feature_extractor, feature_extractor)
self.assertEqual(feature.score, score)

def test_repr(self):
source = MockSource("source")
interval = MockInterval(0, 10)
feature_extractor = "test"
score = 0.5
feature = utils.Feature(interval, source, feature_extractor, score)
self.assertEqual(repr(feature), f"Feature({interval}, {source}, {feature_extractor}, {score})")

def test_to_json(self):
"""test Feature.to_json method"""
source = MockSource("source")
interval = MockInterval(0, 10)
feature_extractor = "test"
score = 0.5
feature = utils.Feature(interval, source, feature_extractor, score)
self.assertEqual(feature.to_json(), {"interval": interval.to_json(), "source": source.to_json(),
"feature_extractor": feature_extractor, "score": score})

def test_classmethod_from_start(self):
"""Test the @classmethod for creating Feature with Interval.from_start"""
source = MockSource("source")
feature_extractor = "test"
score = 0.5
feature = utils.Feature.from_start(start=0, source=source, feature_extractor=feature_extractor, score=score)
self.assertEqual(feature.interval.start, 0)
self.assertEqual(feature.interval.end, 5)
self.assertEqual(feature.interval.duration, 5)
self.assertEqual(feature.source, source)
self.assertEqual(feature.feature_extractor, feature_extractor)
self.assertEqual(feature.score, score)

def test_classmethod_from_end(self):
"""Test the @classmethod for creating Feature with Interval.from_end"""
source = MockSource("source")
feature_extractor = "test"
score = 0.5
feature = utils.Feature.from_end(end=10, source=source, feature_extractor=feature_extractor, score=score)
self.assertEqual(feature.interval.start, 5)
self.assertEqual(feature.interval.end, 10)
self.assertEqual(feature.interval.duration, 5)
self.assertEqual(feature.source, source)
self.assertEqual(feature.feature_extractor, feature_extractor)
self.assertEqual(feature.score, score)

def test_lt(self):
"""test __lt__ method for sorting Feature objects"""
# TODO: ensure goood coverage of corresponding __lt__ method of Interval
feature1 = utils.Feature(utils.Interval(0, 10), MockSource("source"), "test_1", 0.5)
feature2 = utils.Feature(utils.Interval(0, 15), MockSource("source"), "test_1", 0.5)
feature3 = utils.Feature(utils.Interval(5, 10), MockSource("source"), "test_1", 0.5)
feature4 = utils.Feature(utils.Interval(5, 15), MockSource("source"), "test_1", 0.5)
feature5 = utils.Feature(utils.Interval(0, 10), MockSource("source"), "test_1", 0.6) # score
feature6 = utils.Feature(utils.Interval(0, 10), MockSource("source"), "test_2", 0.5) # feature_extractor

# test sorting based on Interval
self.assertTrue(feature1 < feature2)
self.assertTrue(feature1 < feature3)
self.assertTrue(feature3 < feature4)

# test sorting based on feature_extractor
self.assertTrue(feature1 < feature6)

# test sorting based on score
self.assertTrue(feature1 < feature5)

def test_eq(self):
"""test __eq__ method for comparing Feature objects"""
feature1 = utils.Feature(utils.Interval(0, 10), MockSource("source"), "test_1", 0.5)
feature2 = utils.Feature(utils.Interval(0, 10), MockSource("source"), "test_1", 0.5)
feature3 = utils.Feature(utils.Interval(0, 10), MockSource("source"), "test_1", 0.6)
self.assertEqual(feature1, feature2)
self.assertNotEqual(feature1, feature3)

# unhappy path tests

def test_init_unhappy(self):
"""Test init with pats missing"""
with self.assertRaises(ValueError):
utils.Feature(None, None, None, None)

# missing interval
with self.assertRaises(ValueError):
utils.Feature(None, MockSource("source"), "feature_extractor", 0.5)

# missing source
with self.assertRaises(ValueError):
utils.Feature(MockInterval(0, 10), None, "feature_extractor", 0.5)

# missing feature_extractor
feature = utils.Feature(MockInterval(0, 10), MockSource("source"), None, 0.5)
self.assertEqual(feature.feature_extractor, "unknown")

if __name__ == "__main__":
unittest.main()

Loading…
取消
儲存