ソースを参照

feat: [TTA] add algo for dropping Features in adjust()

Tries to drop the lowest-scoring Features until the target time (range) is
reached. This is not optimised and a relatively naïve approach- there are many
inputs which would result in a non-ideal pruning.
main
Rob Hallam 2ヶ月前
コミット
f253250239
1個のファイルの変更34行の追加2行の削除
  1. +34
    -2
      pipeline/adjusters.py

+ 34
- 2
pipeline/adjusters.py ファイルの表示

@@ -84,7 +84,7 @@ class TargetTimeAdjuster(Adjuster):
self.strategy = strategy

def adjust(self) -> list:
"""Drop Features until the target time within the margin is reached.
"""Drop Features until the target time within the margin is reached. Prioritise dropping lower scoring Features.

Approach:

@@ -111,4 +111,36 @@ class TargetTimeAdjuster(Adjuster):
return self.features

# sort list of Features by score (primary) and by duration (secondary)
sorted_features = sorted(self.features, key=lambda x: (x.score, x.time))
sorted_features = self._sort_by_score_time(self.features)
drop_indices = [] # indices of Features to drop

# first pass- drop lowest scoring Features until we are within the target time
for i in range(len(sorted_features)):
# check if dropping this Feature would put us in the target range:
# if so, drop it and return
if (total_time - sorted_features[i].interval.duration >= target_time_min and
total_time - sorted_features[i].interval.duration <= target_time_max):
drop_indices.append(i)
break

elif (total_time - sorted_features[i].interval.duration > target_time_max):
drop_indices.append(i)
total_time -= sorted_features[i].interval.duration

for i in drop_indices:
self.features.remove(sorted_features[i])

# if we are now within the target time, return the Features
total_time = self._features_total_time(features=self.features)
if total_time <= target_time_max:
return self.features

# else: we are still over the target time
# so drop the lowest scoring Features until we are UNDER the target time
for i in range(len(sorted_features)):
self.features.remove(sorted_features[i])
total_time -= sorted_features[i].interval.duration
if total_time <= target_time_max:
break

return self.features

読み込み中…
キャンセル
保存