You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

168 line
6.3 KiB

  1. from abc import ABC
  2. import logging
  3. import random
  4. import subprocess
  5. from ast import literal_eval
  6. from pipeline.utils import SourceMedia, Feature, Interval
  7. logger = logging.getLogger(__name__)
  8. class FeatureExtractor(ABC):
  9. """Feature extractor interface."""
  10. # TODO: #API -- decide if .features will be a member variable
  11. def setup(self):
  12. pass
  13. def run(self):
  14. pass
  15. def teardown(self):
  16. pass
  17. class LaughterFeatureExtractor(FeatureExtractor):
  18. """Feature extractor for laughter detection.
  19. This class is responsible for extracting features corresponding to laughter in media files.
  20. Here:
  21. setup() is used to validate input files & config, which may involve processing video files to extract audio
  22. run() is used to extract features from the audio using jrgillick's laughter-detection
  23. teardown() is used to clean up any temporary files created during setup according to the config
  24. See: https://github.com/jrgillick/laughter-detection for the laughter-detection library
  25. """
  26. def __init__(self, input_files=None, config=None):
  27. """It is expected that input_files is a SourceMedia object"""
  28. self.input_files = input_files
  29. self.config = config
  30. self.features = []
  31. def _laughdetect(self, audio_file):
  32. """Run laughter detection on the audio file"""
  33. laugh_detector_dir = "/home/robert/mounts/980data/code/laughter-detection/"
  34. laugh_detector_script = "segment_laughter.py"
  35. # fake output for testing
  36. # laugh_detector_path = "tests/fake_segment_laughter.py"
  37. laugh_detector_cmd = ["python", f"{laugh_detector_dir}{laugh_detector_script}",
  38. f"--input_audio_file={audio_file}"]
  39. # run command, capture output, ignore exit status
  40. laugh_output = subprocess.run(laugh_detector_cmd,
  41. stdout=subprocess.PIPE,
  42. cwd=laugh_detector_dir).stdout.decode("utf-8")
  43. # ↑ have to include cwd to keep laughter-detection imports happy
  44. # also, it isn't happy if no output dir is specified but we get laughs so it's grand
  45. # laughs are lines in stdout that start with "instance:", followed by a space and a 2-tuple of floats
  46. # so jump to the 10th character and evaluate the rest of the line
  47. return [literal_eval(instance[10:])
  48. for instance in laugh_output.splitlines()
  49. if instance.startswith("instance: ")]
  50. def _adjust_features(self):
  51. """Adjust features according to config
  52. Generically, this ensures features conform to config - min/max feature length, etc.
  53. In the context of LaughterFeatureExtractor, there is some secret sauce: things that
  54. cause a laugh generally /precede/ the laugh, so we want more team before the detected start
  55. than at the end. For example, for a minimum feature length of 15s, we might prepend 10 seconds,
  56. and append 5 seconds (for example), or 12s and 3s. We may wish to do this pre/post adjustment
  57. for all laughter features found, regardless of length.
  58. TODO: figure out how we're going to handle length adjustments
  59. TODO: config for length adjustments per design doc
  60. TODO: play with numbers more to see what works best
  61. """
  62. PREPEND = 7.0
  63. APPEND = 3.0
  64. for feature in self.features:
  65. # do the pre & post adjustment
  66. feature.interval.move_start(-PREPEND, relative=True)
  67. feature.interval.move_end(APPEND, relative=True)
  68. def setup(self):
  69. """Setup the laughter feature extractor -- validate input files & config
  70. jrgillick's laughter-detection library can work with AV files directly
  71. TODO: validate input files
  72. TODO: handle config
  73. """
  74. logger.debug("LaughterFeatureExtractor setup")
  75. # Validate input files
  76. if not self.input_files:
  77. raise ValueError("No input files provided")
  78. # TODO: convert video to audio if needed
  79. def run(self):
  80. """Extract laughter features for each input file"""
  81. if self.input_files:
  82. for file in self.input_files:
  83. laughs = self._laughdetect(file.path)
  84. for laugh in laughs:
  85. start, end = laugh
  86. self.features.append(Feature(interval=Interval(start=start, end=end),
  87. source="laughter", path=file.path))
  88. # TODO: implement options eg minimum feature length
  89. # adjust features
  90. self._adjust_features()
  91. def teardown(self):
  92. pass
  93. class RandomFeatureExtractor(FeatureExtractor):
  94. """Feature extractor for random feature generation.
  95. This class is responsible for generating random features for testing purposes.
  96. Here:
  97. setup() is used to validate input files & config
  98. run() is used to generate random features
  99. teardown() is used to clean up any temporary files created during setup according to the config
  100. """
  101. NUM_FEATURES = 5
  102. MAX_DURATION = 20.0
  103. def __init__(self, input_files=None, config=None):
  104. """It is expected that input_files is a SourceMedia object"""
  105. self.input_files = input_files
  106. self.config = config
  107. self.features = []
  108. def setup(self):
  109. """Setup the random feature extractor -- validate input files & config"""
  110. logger.debug("RandomFeatureExtractor setup")
  111. # Validate input files
  112. if not self.input_files:
  113. raise ValueError("No input files provided")
  114. def run(self):
  115. """Generate random features for each input file"""
  116. # check self.input_files is of type SourceMedia
  117. if not self.input_files or not isinstance(self.input_files, SourceMedia):
  118. raise ValueError("No input files provided")
  119. for file in self.input_files:
  120. for _ in range(self.NUM_FEATURES):
  121. # round to 3 decimal places
  122. duration = random.random() * self.MAX_DURATION
  123. start = random.random() * file.duration() - duration
  124. self.features.append(Feature(interval=Interval(start=start, duration=duration),
  125. source="random", path=file.path))
  126. def teardown(self):
  127. pass