Skip to content

stopping

stopping

Patience-based stopping condition for the generation loop.

Classes:

Name Description
GenerationStopCondition

Stopping conditions for the generation process.

GenerationStopCondition(invalid_fraction_threshold, patience)

Stopping conditions for the generation process.

Empirically, the fraction of invalid records generated is a good indicator of the model's performance. This class implements a condition for stopping generation (and potentially training) based on the invalid fraction of records generated.

Parameters:

Name Type Description Default
invalid_fraction_threshold float

Stop generation if the invalid fraction exceeds this threshold for more than the number of consecutive batches specified by the patience parameter.

required
patience int

Number of consecutive batches to wait before stopping.

required

Methods:

Name Description
has_been_reached

Check whether the invalid-fraction threshold has been exceeded for patience consecutive batches.

Source code in src/nemo_safe_synthesizer/generation/stopping.py
def __init__(self, invalid_fraction_threshold: float, patience: int):
    self.counter = 0
    self.patience = patience
    self.invalid_fraction_threshold = invalid_fraction_threshold
    self.last_value = None

has_been_reached(invalid_fraction)

Check whether the invalid-fraction threshold has been exceeded for patience consecutive batches.

Parameters:

Name Type Description Default
invalid_fraction float

Running average of the invalid-record fraction for the most recent batch.

required

Returns:

Type Description
bool

True if the threshold was exceeded for patience

bool

consecutive batches, False otherwise.

Source code in src/nemo_safe_synthesizer/generation/stopping.py
def has_been_reached(self, invalid_fraction: float) -> bool:
    """Check whether the invalid-fraction threshold has been exceeded for ``patience`` consecutive batches.

    Args:
        invalid_fraction: Running average of the invalid-record
            fraction for the most recent batch.

    Returns:
        ``True`` if the threshold was exceeded for ``patience``
        consecutive batches, ``False`` otherwise.
    """
    is_reached = False
    self.last_value = invalid_fraction
    if invalid_fraction >= self.invalid_fraction_threshold:
        self.counter += 1
        if self.counter >= self.patience:
            is_reached = True
            logger.info(
                f"🛑 Stopping condition reached: {invalid_fraction = :.2} > "
                f"{self.invalid_fraction_threshold} for {self.counter} "
                "consecutive batches.",
            )
    else:
        self.counter = 0
    return is_reached