Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

168 rindas
6.3 KiB

  1. from abc import ABC
  2. import logging
  3. import random
  4. import subprocess
  5. from ast import literal_eval
  6. from pipeline.utils import SourceMedia, Feature, Interval
  7. logger = logging.getLogger(__name__)
  8. class FeatureExtractor(ABC):
  9. """Feature extractor interface."""
  10. # TODO: #API -- decide if .features will be a member variable
  11. def setup(self):
  12. pass
  13. def run(self):
  14. pass
  15. def teardown(self):
  16. pass
  17. class LaughterFeatureExtractor(FeatureExtractor):
  18. """Feature extractor for laughter detection.
  19. This class is responsible for extracting features corresponding to laughter in media files.
  20. Here:
  21. setup() is used to validate input files & config, which may involve processing video files to extract audio
  22. run() is used to extract features from the audio using jrgillick's laughter-detection
  23. teardown() is used to clean up any temporary files created during setup according to the config
  24. See: https://github.com/jrgillick/laughter-detection for the laughter-detection library
  25. """
  26. def __init__(self, input_files=None, config=None):
  27. """It is expected that input_files is a SourceMedia object"""
  28. self.input_files = input_files
  29. self.config = config
  30. self.features = []
  31. def _laughdetect(self, audio_file):
  32. """Run laughter detection on the audio file"""
  33. laugh_detector_dir = "/home/robert/mounts/980data/code/laughter-detection/"
  34. laugh_detector_script = "segment_laughter.py"
  35. # fake output for testing
  36. # laugh_detector_path = "tests/fake_segment_laughter.py"
  37. laugh_detector_cmd = ["python", f"{laugh_detector_dir}{laugh_detector_script}",
  38. f"--input_audio_file={audio_file}"]
  39. # run command, capture output, ignore exit status
  40. laugh_output = subprocess.run(laugh_detector_cmd,
  41. stdout=subprocess.PIPE,
  42. cwd=laugh_detector_dir).stdout.decode("utf-8")
  43. # ↑ have to include cwd to keep laughter-detection imports happy
  44. # also, it isn't happy if no output dir is specified but we get laughs so it's grand
  45. # laughs are lines in stdout that start with "instance:", followed by a space and a 2-tuple of floats
  46. # so jump to the 10th character and evaluate the rest of the line
  47. return [literal_eval(instance[10:])
  48. for instance in laugh_output.splitlines()
  49. if instance.startswith("instance: ")]
  50. def _adjust_features(self):
  51. """Adjust features according to config
  52. Generically, this ensures features conform to config - min/max feature length, etc.
  53. In the context of LaughterFeatureExtractor, there is some secret sauce: things that
  54. cause a laugh generally /precede/ the laugh, so we want more team before the detected start
  55. than at the end. For example, for a minimum feature length of 15s, we might prepend 10 seconds,
  56. and append 5 seconds (for example), or 12s and 3s. We may wish to do this pre/post adjustment
  57. for all laughter features found, regardless of length.
  58. TODO: figure out how we're going to handle length adjustments
  59. TODO: config for length adjustments per design doc
  60. TODO: play with numbers more to see what works best
  61. """
  62. PREPEND = 7.0
  63. APPEND = 3.0
  64. for feature in self.features:
  65. # do the pre & post adjustment
  66. feature.interval.move_start(-PREPEND, relative=True)
  67. feature.interval.move_end(APPEND, relative=True)
  68. def setup(self):
  69. """Setup the laughter feature extractor -- validate input files & config
  70. jrgillick's laughter-detection library can work with AV files directly
  71. TODO: validate input files
  72. TODO: handle config
  73. """
  74. logger.debug("LaughterFeatureExtractor setup")
  75. # Validate input files
  76. if not self.input_files:
  77. raise ValueError("No input files provided")
  78. # TODO: convert video to audio if needed
  79. def run(self):
  80. """Extract laughter features for each input file"""
  81. if self.input_files:
  82. for file in self.input_files:
  83. laughs = self._laughdetect(file.path)
  84. for laugh in laughs:
  85. start, end = laugh
  86. self.features.append(Feature(interval=Interval(start=start, end=end),
  87. source="laughter", path=file.path))
  88. # TODO: implement options eg minimum feature length
  89. # adjust features
  90. self._adjust_features()
  91. def teardown(self):
  92. pass
  93. class RandomFeatureExtractor(FeatureExtractor):
  94. """Feature extractor for random feature generation.
  95. This class is responsible for generating random features for testing purposes.
  96. Here:
  97. setup() is used to validate input files & config
  98. run() is used to generate random features
  99. teardown() is used to clean up any temporary files created during setup according to the config
  100. """
  101. NUM_FEATURES = 5
  102. MAX_DURATION = 20.0
  103. def __init__(self, input_files=None, config=None):
  104. """It is expected that input_files is a SourceMedia object"""
  105. self.input_files = input_files
  106. self.config = config
  107. self.features = []
  108. def setup(self):
  109. """Setup the random feature extractor -- validate input files & config"""
  110. logger.debug("RandomFeatureExtractor setup")
  111. # Validate input files
  112. if not self.input_files:
  113. raise ValueError("No input files provided")
  114. def run(self):
  115. """Generate random features for each input file"""
  116. # check self.input_files is of type SourceMedia
  117. if not self.input_files or not isinstance(self.input_files, SourceMedia):
  118. raise ValueError("No input files provided")
  119. for file in self.input_files:
  120. for _ in range(self.NUM_FEATURES):
  121. # round to 3 decimal places
  122. duration = random.random() * self.MAX_DURATION
  123. start = random.random() * file.duration() - duration
  124. self.features.append(Feature(interval=Interval(start=start, duration=duration),
  125. source="random", path=file.path))
  126. def teardown(self):
  127. pass