diff --git a/pipeline/utils.py b/pipeline/utils.py index e7a7065..9a180df 100644 --- a/pipeline/utils.py +++ b/pipeline/utils.py @@ -179,6 +179,9 @@ class Interval(): return self.end < other.end return self.start < other.start + def __eq__(self, other): + return self.start == other.start and self.end == other.end + def to_json(self): """Return a dict representation of the interval for JSON encoding @@ -268,7 +271,7 @@ class Feature(): feature_extractor=feature_extractor, score=score) def __repr__(self): - return f"Feature({self.interval}, {self.source}, {self.score})" + return f"Feature({self.interval}, {self.source}, {self.feature_extractor}, {self.score})" def __lt__(self, other): """Sort based on interval, then feature_extractor, then score""" @@ -278,6 +281,12 @@ class Feature(): return self.feature_extractor < other.feature_extractor return self.interval < other.interval + def __eq__(self, other): + return self.interval == other.interval \ + and self.source == other.source \ + and self.feature_extractor == other.feature_extractor \ + and self.score == other.score + def to_json(self): """Return a dict representation of the feature for JSON encoding