diff --git a/pipeline/feature_extractors.py b/pipeline/feature_extractors.py index 0aed58e..65825ad 100644 --- a/pipeline/feature_extractors.py +++ b/pipeline/feature_extractors.py @@ -407,6 +407,29 @@ class JSONFeatureExtractor(FeatureExtractor): class WordFeatureExtractor(FeatureExtractor): """Feature extractor for specific word detection (uses Whisper)""" + # set defaults for whisper settings + DEFAULT_MODEL_SIZE = "medium" + DEFAULT_DEVICE = "cpu" + DEFAULT_COMPUTE_TYPE = "int8" + DEFAULT_BEAM_SIZE = 5 + DEFAULT_BATCH_SIZE = 16 + DEFAULT_PIPELINE_TYPE = "batched" # or "stream" + + words = [] + + def _transcribe(self, model, file, **kwargs): + """Defined here to allow for mocking in tests""" + return model.transcribe(file, **kwargs) + + def _whispermodel(self, model_size=DEFAULT_MODEL_SIZE, + device=DEFAULT_DEVICE, compute_type=DEFAULT_COMPUTE_TYPE): + """Defined here to allow for mocking out in tests""" + return WhisperModel(model_size, device=device, compute_type=compute_type) + + def _batched_inference_pipeline(self, model): + """Defined here to allow for mocking out in tests""" + return BatchedInferencePipeline(model=model) + def __init__(self, input_files=None, config=None): if not input_files: raise ValueError("No input files provided!") diff --git a/test/test_feature_extractors.py b/test/test_feature_extractors.py index edf24d3..796c4d9 100644 --- a/test/test_feature_extractors.py +++ b/test/test_feature_extractors.py @@ -1,4 +1,8 @@ """test_feature_extractors.py - test pipeline feature extractors""" +import sys +from unittest.mock import patch, mock_open, MagicMock # +sys.modules["faster_whisper"] = MagicMock() # mock faster_whisper as it is a slow import + import unittest import os import random @@ -6,7 +10,6 @@ import pytest import pipeline.feature_extractors as extractors from pipeline.utils import Source, SourceMedia # technically makes this an integration test, but... -from unittest.mock import patch, mock_open class TestSource(): """Provide utils.Source for testing""" @@ -282,3 +285,70 @@ class TestJSONFeatureExtractor(unittest.TestCase): m = unittest.mock.mock_open(read_data='[{"foo": "bar"}]') with unittest.mock.patch("builtins.open", m): test_extractor._read_json_from_file("foo.json") + + +class TestWordFeatureExtractor(unittest.TestCase): + """Test WordFeatureExtractor""" + + @classmethod + def setUpClass(cls): + sys.modules["faster_whisper"] = MagicMock() + + _MOCK_SENTENCE = "the quick brown fox jumps over the lazy dog".split() + class MockSegment(): + """Mock Segment -- has starte, end and text attributes""" + def __init__(self, start, end, text): + self.start = start + self.end = end + self.text = text + + def mock_transcribe(self, *args, **kwargs): + """Mock for WhisperModel.model.transcribe + + returns a 2-tuple: + - list of segments + + segment = start, end, text + - info = language, language_probability + + We will mock the segments- this provides 9 segments for the sentence: + "the quick brown fox jumps over the lazy dog" + """ + segments = [] + for i in range(len(self._MOCK_SENTENCE)): + segments.append(self.MockSegment(i, i+1, self._MOCK_SENTENCE[i])) + return segments, {"language": "en", "language_probability": 0.9} + + def test_basic_init(self): + video_source = TestSourceMedia().one_colour_silent_audio() + test_extractor = extractors.WordFeatureExtractor(input_files=video_source) + self.assertTrue(test_extractor) + + def test_init_no_input_videos(self): + """test init - no input files""" + with self.assertRaises(ValueError): + test_extractor = extractors.WordFeatureExtractor() + + def test_extract_no_words_supplied(self): + """Test extract with basic input file but no words specirfied returns zero features""" + video_source = TestSourceMedia().one_colour_silent_audio() + test_extractor = extractors.WordFeatureExtractor(input_files=video_source) + test_extractor.setup() + test_extractor.run() + test_extractor.teardown() + self.assertEqual(test_extractor.features, []) + + def test_extract_mocked_transcribe(self): + """Mock out the actual call to transcribe""" + video_source = TestSourceMedia().one_colour_silent_audio() + test_extractor = extractors.WordFeatureExtractor(input_files=video_source) + # mock _transcribe and mock out model and batched pipeline for speed + test_extractor._transcribe = self.mock_transcribe + test_extractor._model = MagicMock() + test_extractor._batched_model = MagicMock() + # set up and run the extractor + test_extractor.setup(words=self._MOCK_SENTENCE) + test_extractor.run() + test_extractor.teardown() + + self.assertEqual(len(test_extractor.features), 9) +