Source code for logflow.relationsdiscover.StoppingCondition

# Copyright 2020 BULL SAS All rights reserved #
import datetime

[docs]class StoppingCondition: """Condition to stop the learning. Three conditions can be selected: 1) Increase of the F1 value needs to be less than 0.005 during more than 3 steps to stop the learning process. 2) A timer. If the duration of the training is longer than the timer, the training step is stopped 3) Number of epochs. Do the learning step during a fixed number of epochs. Please note that the timer excludes the duration of the testing step. Args: method (str, optional): 3 options: "earlystopping", "timer", "epoch". Earlystopping uses the increase of the macro f1 value accros multiples steps, timer uses a timer, and epoch uses a nb of epoch. Defaults to "earlystopping". condition_value (float, optional): value of the increase. Defaults to 0.005. condition_step (int, optional): number of steps. Defaults to 3. duration (int, optional): duration of the learning step in seconde. Defaults to 60. condition_epoch (int, optional): number of epochs to be done. Defaults to 3. """ def __init__(self, method="earlystopping", condition_value=0.005, condition_step=3, duration=60, condition_epoch=3): self.method = method self.metric_value = 0 if self.method == "earlystopping": self.condition_value = condition_value self.condition_step = condition_step self.nb_step = 0 elif self.method == "timer": self.start = datetime.datetime.now() self.end = self.start + datetime.timedelta(seconds=duration) elif self.method == "epoch": self.nb_epoch = 0 self.condition_epoch = condition_epoch self.stopped = False self.last_increased = -1
[docs] def stop(self) -> bool: """Compute the condition Returns: bool: If the stopping condition is reached return True, else return False """ if self.stopped: return True if self.method == "earlystopping": if self.nb_step == self.condition_step: self.stopped = True return True else: return False elif self.method == "timer": if datetime.datetime.now() > self.end: self.stopped = True return True else: return False elif self.method == "epoch": if self.nb_epoch == self.condition_epoch: self.stopped = True return True else: return False
[docs] def update(self, metric=0.1): """Update the new value of the metric and compute the number of increase steps. Args: metric (optional, float): value of the metric. Should only be used with the earlystopping method. """ if self.method == "earlystopping": if metric - self.metric_value < self.condition_value: self.nb_step += 1 else: self.nb_step = 0 self.last_increased = metric - self.metric_value self.metric_value = metric elif self.method == "epoch": self.nb_epoch += 1
def __str__(self) -> str: """Return the string representation of the condition Returns: str: string representation """ if self.method == "earlystopping": if self.stopped: return "Condition is reached, last increase is: " + str(self.last_increased) else: return "Condition is not reached, last increase is: " + str(self.last_increased) + " number of steps: " + str(self.nb_step) elif self.method == "timer": if self.stopped: return "Condition is reached, expected end: " + str(self.end) + " now: " + str(datetime.datetime.now()) else: return "Condition is not reached, expected end: " + str(self.end) + " remaining: " + str(self.end - datetime.datetime.now()) + " now: " + str(datetime.datetime.now()) elif self.method == "epoch": if self.stopped: return "Condition is reached, epoch is : " + str(self.nb_epoch) + " end at: " + str(self.condition_epoch) else: return "Condition is not reached, epoch is : " + str(self.nb_epoch) + " end at: " + str(self.condition_epoch)