Skip to content

correlation

correlation

Classes:

Name Description
Correlation

Column Correlation Stability metric.

Correlation pydantic-model

Bases: Component

Column Correlation Stability metric.

Computes per-column-pair correlations (Pearson, Theil's U, Correlation Ratio) for both reference and output dataframes, then scores the mean absolute difference.

Config:

  • arbitrary_types_allowed: True

Fields:

reference_correlation = None pydantic-field

Correlation matrix for the reference data.

output_correlation = None pydantic-field

Correlation matrix for the output data.

correlation_difference = None pydantic-field

Element-wise absolute difference of the two matrices.

jinja_context cached property

Template context with combined correlation heatmap figure.

from_evaluation_dataset(evaluation_dataset, config=None) staticmethod

Compute correlation matrices and the correlation stability score.

Source code in src/nemo_safe_synthesizer/evaluation/components/correlation.py
@staticmethod
def from_evaluation_dataset(
    evaluation_dataset: EvaluationDataset, config: SafeSynthesizerParameters | None = None
) -> Correlation:
    """Compute correlation matrices and the correlation stability score."""
    # We only want to use these types for correlation.
    tabular_columns = evaluation_dataset.get_tabular_columns()
    # We use different calculations (Theil's U) for nominal columns.
    nominal_columns = evaluation_dataset.get_nominal_columns()

    (
        reference_correlation,
        output_correlation,
        correlation_difference,
        mean_absolute_error,
    ) = Correlation._get_correlation_calculations(
        reference=evaluation_dataset.reference[tabular_columns],  # ty: ignore[invalid-argument-type]
        output=evaluation_dataset.output[tabular_columns],  # ty: ignore[invalid-argument-type]
        nominal_columns=nominal_columns,
        fields=evaluation_dataset.evaluation_fields,
    )
    evaluation_score = Correlation._get_field_correlation_stability(mean_absolute_error)
    return Correlation(
        reference_correlation=reference_correlation,
        output_correlation=output_correlation,
        correlation_difference=correlation_difference,
        score=evaluation_score,
    )