Source code for logflow.relationsdiscover.Saver

# Copyright 2020 BULL SAS All rights reserved #
from logflow.relationsdiscover.Model import LSTMLayer
from logflow.relationsdiscover.Result import Result
import pickle
import os
from loguru import logger
from time import sleep
import random
import numpy as np # type: ignore

[docs]class Saver: """Save and load the model from a file. The file is saved as follow : file["LSTM"][cardinality] = model Args: name_model (str): name of the dataset path_model (str): path of the model to save cardinality (int, optional): cardinality to save. Defaults to -1. lock (int, optional): lock for the file. Defaults to -1. """ def __init__(self, name_model :str, path_model : str, cardinality=-1, lock=-1): self.path = path_model + name_model + "_model.lf" self.cardinality = cardinality self.lock = lock
[docs] def save(self, model : LSTMLayer, result : Result, condition="Test"): """Save the model Args: model (LSTMLayer): model to save result (Result): result to save condition (str): Test or train results to save """ dict_cardinalities_model = {} self.lock.acquire() if os.path.isfile(self.path): with open(self.path, "rb") as output_file: logger.info("["+str(self.cardinality)+"] Loading: " + self.path) dict_cardinalities_model = pickle.load(output_file) try: dict_cardinalities_model["LSTM"][self.cardinality] = model.state_dict() except: dict_cardinalities_model["LSTM"] = {} dict_cardinalities_model["LSTM"][self.cardinality] = model.state_dict() # Keep only the latest version of the results try: dict_cardinalities_model["Result"] except: dict_cardinalities_model["Result"] = {} try: dict_cardinalities_model["Result"][self.cardinality] except: dict_cardinalities_model["Result"][self.cardinality] = {} if condition != "temp": dict_cardinalities_model["Result"][self.cardinality][condition] = result with open(self.path, "wb") as output_file: pickle.dump(dict_cardinalities_model, output_file) logger.info("["+str(self.cardinality)+"] Saving: " + self.path) self.lock.release()
[docs] def load(self, model : LSTMLayer) -> LSTMLayer: """Load the model. Note that the model must be created before. This function loads only the parameters inside the model. Args: model (LSTMLayer): object to use for loading the model. Raises: Exception: the file is not found Returns: LSTMLayer: the loaded model """ if os.path.isfile(self.path): self.lock.acquire() with open(self.path, "rb") as output_file: dict_cardinalities_model = pickle.load(output_file) model.load_state_dict(dict_cardinalities_model["LSTM"][self.cardinality]) self.lock.release() return model else: logger.critical("Trying to load an unknown file: " + str(self.path)) raise Exception("Trying to load an unknown file")