|
|
@@ -10,6 +10,10 @@ import tempfile |
|
|
|
# for visualisations: |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
|
|
# for encoding as JSON |
|
|
|
from pipeline.utils import Feature |
|
|
|
|
|
|
|
|
|
|
|
class Producer(ABC): |
|
|
|
"""Generic producer interface.""" |
|
|
|
def __init__(self, features): |
|
|
@@ -141,6 +145,13 @@ class VisualisationProducer(Producer): |
|
|
|
plt.savefig("/tmp/visualisation.png") |
|
|
|
plt.close() |
|
|
|
|
|
|
|
class PipelineJSONEncoder(json.JSONEncoder): |
|
|
|
def default(self, obj): |
|
|
|
if hasattr(obj, 'to_json'): |
|
|
|
return obj.to_json() |
|
|
|
else: |
|
|
|
return json.JSONEncoder.default(self, obj) |
|
|
|
|
|
|
|
class JSONProducer(Producer): |
|
|
|
"""Produce JSON output""" |
|
|
|
def __init__(self, features): |
|
|
|