Selaa lähdekoodia

feat: add RandomFeatureExtractor

Simple feature extractor that picks random times, used to exercise other
parts of the pipeline
main
Rob Hallam 4 kuukautta sitten
vanhempi
commit
246aefd34c
1 muutettua tiedostoa jossa 51 lisäystä ja 0 poistoa
  1. +51
    -0
      pipeline/feature_extractors.py

+ 51
- 0
pipeline/feature_extractors.py Näytä tiedosto

@@ -1,5 +1,10 @@
from abc import ABC
import logging
import random
import subprocess
from ast import literal_eval
from pipeline.utils import SourceMedia, Feature, Interval

logger = logging.getLogger(__name__)

class FeatureExtractor(ABC):
@@ -114,3 +119,49 @@ class LaughterFeatureExtractor(FeatureExtractor):
def teardown(self):
pass

class RandomFeatureExtractor(FeatureExtractor):
"""Feature extractor for random feature generation.

This class is responsible for generating random features for testing purposes.

Here:

setup() is used to validate input files & config

run() is used to generate random features

teardown() is used to clean up any temporary files created during setup according to the config
"""
NUM_FEATURES = 5
MAX_DURATION = 20.0

def __init__(self, input_files=None, config=None):
"""It is expected that input_files is a SourceMedia object"""
self.input_files = input_files
self.config = config
self.features = []

def setup(self):
"""Setup the random feature extractor -- validate input files & config"""
logger.debug("RandomFeatureExtractor setup")

# Validate input files
if not self.input_files:
raise ValueError("No input files provided")

def run(self):
"""Generate random features for each input file"""
# check self.input_files is of type SourceMedia
if not self.input_files or not isinstance(self.input_files, SourceMedia):
raise ValueError("No input files provided")

for file in self.input_files:
for _ in range(self.NUM_FEATURES):
# round to 3 decimal places
duration = random.random() * self.MAX_DURATION
start = random.random() * file.duration() - duration
self.features.append(Feature(interval=Interval(start=start, duration=duration),
source="random", path=file.path))

def teardown(self):
pass

Ladataan…
Peruuta
Tallenna