Skip to content

data_actions

data_actions

Extensible data action framework for pre/post-processing, validation, and generation.

Defines BaseAction and its subclasses (GenerateAction, ColAction, ValidationAction) which encapsulate data transformations applied at different pipeline phases. ActionExecutor orchestrates running the registered actions in order.

Classes:

Name Description
ProcessFn

Callable that transforms a DataFrame in-place during a processing phase.

ValidateBatchFn

Callable that splits a batch into valid and rejected DataFrames.

ProcessPhase

Pipeline phases that apply DataFrame-to-DataFrame transformations.

ValidateBatchPhase

Pipeline phase for batch validation.

BaseAction

Abstract base class for all data actions in the pipeline.

GenerateAction

Action that generates net-new data for the DataFrame.

GenExpression
GenRawExpression

Low-level action that passes raw transforms_v2 update payloads.

ReplaceDataSource
GenDatetimeDistribution

Generate a datetime from a provided datetime distribution.

DateConstraint
ColAction

Action that operates on a single named column.

DatetimeCol
ActionExecutor

Orchestrate a sequence of BaseAction instances across pipeline phases.

Functions:

Name Description
data_actions_fn

Applies an action executor to a dataframe.

ProcessFn

Bases: Protocol

Callable that transforms a DataFrame in-place during a processing phase.

ValidateBatchFn

Bases: Protocol

Callable that splits a batch into valid and rejected DataFrames.

ProcessPhase

Bases: str, Enum

Pipeline phases that apply DataFrame-to-DataFrame transformations.

ValidateBatchPhase

Bases: str, Enum

Pipeline phase for batch validation.

BaseAction pydantic-model

Bases: BaseModel, ABC

Abstract base class for all data actions in the pipeline.

Subclasses implement one or more phase methods (preprocess, postprocess, validate_batch, generate) to transform data at the corresponding pipeline stage. The functions method introspects which methods were actually overridden, so only non-default actions run.

State can be shared across phases via set_state / get_state, which persist to the ActionCtx.state dictionary keyed by the action's hash.

Config:

  • alias_generator: type_alias_fn

Validators:

preprocess(df)

Transform the input dataset before training.

Override to modify the shape or contents of the data (e.g., encoding datetimes, dropping columns). The default implementation is a no-op.

Source code in src/nemo_safe_synthesizer/data_processing/actions/data_actions.py
def preprocess(self, df: pd.DataFrame) -> pd.DataFrame:
    """Transform the input dataset before training.

    Override to modify the shape or contents of the data (e.g., encoding
    datetimes, dropping columns). The default implementation is a no-op.
    """
    return df

postprocess(df)

Transform generated data after generation, often reverting preprocessing.

The default implementation is a no-op.

Source code in src/nemo_safe_synthesizer/data_processing/actions/data_actions.py
def postprocess(self, df: pd.DataFrame) -> pd.DataFrame:
    """Transform generated data after generation, often reverting preprocessing.

    The default implementation is a no-op.
    """
    return df

validate_batch(batch, df)

Split a generated batch into valid and rejected rows.

Parameters:

Name Type Description Default
batch DataFrame

Newly generated data to validate.

required
df DataFrame

Reference dataset providing context for validation.

required

Returns:

Type Description
tuple[DataFrame, DataFrame]

A tuple of (valid_rows, rejected_rows) DataFrames.

Source code in src/nemo_safe_synthesizer/data_processing/actions/data_actions.py
def validate_batch(self, batch: pd.DataFrame, df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
    """Split a generated batch into valid and rejected rows.

    Args:
        batch: Newly generated data to validate.
        df: Reference dataset providing context for validation.

    Returns:
        A tuple of (valid_rows, rejected_rows) DataFrames.
    """
    batch_copy = batch.copy()
    valid_mask = self._validate_batch(batch, df)
    return batch_copy[valid_mask], batch_copy[~valid_mask]

generate(df)

Generate new data and merge it into the DataFrame.

Override to create net-new columns or rows. The default implementation is a no-op.

Source code in src/nemo_safe_synthesizer/data_processing/actions/data_actions.py
def generate(self, df: pd.DataFrame) -> pd.DataFrame:
    """Generate new data and merge it into the DataFrame.

    Override to create net-new columns or rows. The default implementation
    is a no-op.
    """
    return df

functions()

Return a Functions bundle containing only the overridden phase methods.

Methods that were not overridden from BaseAction are excluded so that only actions with real work appear during debugging.

Source code in src/nemo_safe_synthesizer/data_processing/actions/data_actions.py
def functions(self) -> Functions:
    """Return a ``Functions`` bundle containing only the overridden phase methods.

    Methods that were not overridden from ``BaseAction`` are excluded so
    that only actions with real work appear during debugging.
    """

    # We use FunctionType annotation for method rather than Callable
    # because Callable does not guarantee a __name__ attribute, so type
    # checking with ty will fail. See
    # https://docs.astral.sh/ty/reference/typing-faq/#why-does-ty-say-callable-has-no-attribute-__name__
    def _method_if_overridden(method: FunctionType) -> FunctionType | None:
        method_fn = getattr(method, "__func__", None)
        class_fn = getattr(BaseAction, method.__name__)
        if method_fn is not class_fn:
            return method
        else:
            return None

    return Functions(
        preprocess=_method_if_overridden(self.preprocess),
        postprocess=_method_if_overridden(self.postprocess),
        validate_batch=(self.validate_batch if _method_if_overridden(self._validate_batch) else None),
        generate=_method_if_overridden(self.generate),
    )

add_ctx(info) pydantic-validator

Inject ActionCtx from pydantic's validation context, if provided.

Source code in src/nemo_safe_synthesizer/data_processing/actions/data_actions.py
@model_validator(mode="after")
def add_ctx(self, info: ValidationInfo) -> "BaseAction":
    """Inject ``ActionCtx`` from pydantic's validation context, if provided."""
    self._ctx = DEFAULT_ACTION_CTX
    if pydantic_ctx := info.context:
        if action_ctx := pydantic_ctx.get("action_ctx"):
            self._ctx = action_ctx

    return self

with_ctx(ctx)

Attach an ActionCtx and return self for chaining.

Source code in src/nemo_safe_synthesizer/data_processing/actions/data_actions.py
def with_ctx(self, ctx: ActionCtx) -> "BaseAction":
    """Attach an ``ActionCtx`` and return self for chaining."""
    self._ctx = ctx
    return self

get_type()

Return the discriminator type_ value, or "unknown" if unset.

Works around the fact that type_ cannot be an abstract property on BaseAction due to pydantic discriminator constraints.

Source code in src/nemo_safe_synthesizer/data_processing/actions/data_actions.py
def get_type(self) -> str:
    """Return the discriminator ``type_`` value, or ``"unknown"`` if unset.

    Works around the fact that ``type_`` cannot be an abstract property
    on ``BaseAction`` due to pydantic discriminator constraints.
    """
    return getattr(self, "type_", "unknown")

hash()

Deterministic key for storing per-action state in ActionCtx.state.

Source code in src/nemo_safe_synthesizer/data_processing/actions/data_actions.py
def hash(self) -> str:
    """Deterministic key for storing per-action state in ``ActionCtx.state``."""
    return str(tuple(sorted(self.model_dump().items())))

set_state(state_obj)

Persist a Pydantic model as JSON in ActionCtx.state.

Source code in src/nemo_safe_synthesizer/data_processing/actions/data_actions.py
def set_state(self, state_obj: BaseModel) -> None:
    """Persist a Pydantic model as JSON in ``ActionCtx.state``."""
    self._ctx.state[self.hash()] = state_obj.model_dump_json()

get_state(state_obj_type)

Retrieve and deserialize a previously persisted state object.

Source code in src/nemo_safe_synthesizer/data_processing/actions/data_actions.py
def get_state(self, state_obj_type: type[BaseModelT]) -> BaseModelT:
    """Retrieve and deserialize a previously persisted state object."""
    state_obj_json = self._ctx.state[self.hash()]
    return state_obj_type.model_validate(json.loads(state_obj_json))

GenerateAction pydantic-model

Bases: BaseAction, ABC

Action that generates net-new data for the DataFrame.

GenerateAction subclasses must implement generate. The phase field controls when generate runs:

  • GENERATE (default) -- after training, during synthetic data creation.
  • PREPROCESS -- before training.
  • POSTPROCESS -- after generation, for cleanup.

Create a new GenerateAction when you need to synthesize a column based on other columns, fill in faker data, etc.

Fields:

Validators:

functions()

Route generate to the correct phase slot based on self.phase.

Source code in src/nemo_safe_synthesizer/data_processing/actions/data_actions.py
def functions(self) -> Functions:
    """Route ``generate`` to the correct phase slot based on ``self.phase``."""
    fns = Functions()
    if self.phase == ProcessPhase.PREPROCESS:
        fns.preprocess = self.generate
    elif self.phase == ProcessPhase.POSTPROCESS:
        fns.postprocess = self.generate
    elif self.phase == ProcessPhase.GENERATE:
        fns.generate = self.generate

    return fns

generate(df) abstractmethod

Generate new data based on the existing data in the DataFrame.

Source code in src/nemo_safe_synthesizer/data_processing/actions/data_actions.py
@abstractmethod
def generate(self, df: pd.DataFrame) -> pd.DataFrame:
    """Generate new data based on the existing data in the DataFrame."""
    ...

generate_records(num_records)

Generate records without an existing DataFrame.

Creates an empty DataFrame with num_records rows, runs generate, and returns the result as a list of dicts.

Source code in src/nemo_safe_synthesizer/data_processing/actions/data_actions.py
def generate_records(self, num_records: int) -> list[dict[Hashable, Any]]:
    """Generate records without an existing DataFrame.

    Creates an empty DataFrame with ``num_records`` rows, runs ``generate``,
    and returns the result as a list of dicts.
    """
    df = pd.DataFrame(index=range(num_records))
    return self.generate(df).to_dict("records")

GenExpression pydantic-model

Bases: GenerateAction

Fields:

Validators:

expression = None pydantic-field

A jinja transforms_v2 expression that specifies the value of the column.

expressions = None pydantic-field

Similar to expression, but allows you to specify multiple statements that'll be processed in sequence to transforms_v2. This might be useful if you have a more complex set of expressions.

dtype = None pydantic-field

If specified, the column will be cast as this dtype after generation.

GenRawExpression pydantic-model

Bases: GenerateAction

Low-level action that passes raw transforms_v2 update payloads.

Unlike GenExpression which targets a single column, this action accepts a full list of TransformsUpdate steps. Prefer GenExpression for simpler use cases.

Fields:

Validators:

ReplaceDataSource(**data) pydantic-model

Bases: BaseAction

Fields:

  • type_ (Literal['replace_datasource'])
  • col (str)
  • data_source (DataSourceT)

Validators:

Source code in src/nemo_safe_synthesizer/data_processing/actions/data_actions.py
def __init__(self, /, **data: Any) -> None:
    super().__init__(**data)
    self.data_source = self.data_source.with_ctx(self._ctx)

State pydantic-model

Bases: BaseModel

Fields:

column_index pydantic-field

The index in which col was before preprocessing dropped it. If None, then that means col was not in the original df.

GenDatetimeDistribution pydantic-model

Bases: GenerateAction

Generate a datetime from a provided datetime distribution.

Fields:

  • phase (ProcessPhase)
  • type_ (Literal['gen_datetime_distribution'])
  • col (str)
  • distribution (DatetimeDistributionT)

Validators:

DateConstraint pydantic-model

Bases: BaseAction

Fields:

  • type_ (Literal['date_constraint'])
  • colA (str)
  • colB (str)
  • operator (Literal['gt', 'ge', 'lt', 'le'])

Validators:

ColAction pydantic-model

Bases: BaseAction, ABC

Action that operates on a single named column.

Useful for defining serialization/deserialization rules (e.g., datetime formatting, categorical validation) applied before training or after generation.

Fields:

  • name (str)

Validators:

DatetimeCol pydantic-model

Bases: ColAction

Fields:

  • name (str)
  • type_ (Literal['datetime'])
  • format (Optional[str])

Validators:

format = None pydantic-field

Human-readable format of the datetime (see strftime in stdlib). If not specified, we will attempt to autodetect.

ActionExecutor(**data) pydantic-model

Bases: BaseModel

Orchestrate a sequence of BaseAction instances across pipeline phases.

Groups each action's overridden methods by phase (preprocess, postprocess, validate_batch, generate) and runs them in order. Postprocess functions run in reverse order to properly unwind preprocessing transformations.

Fields:

  • actions (list[ActionT])
  • ctx (Optional[ActionCtx])
  • _phase_to_functions (dict[FunctionPhase, list[Callable]])
Source code in src/nemo_safe_synthesizer/data_processing/actions/data_actions.py
def __init__(self, /, **data: Any) -> None:
    super().__init__(**data)

    if self.ctx is None:
        self.ctx = ActionCtx()

    # Rebuild the actions with the `ctx` properly attached. This ensures that each
    # action (and it's corresponding functions) properly reference the same
    # `ctx` during runtime.
    self.actions = [a.with_ctx(self.ctx) for a in self.actions]
    self._phase_to_functions = defaultdict(list)

    for action in self.actions:
        fns = action.functions()
        if fn := fns.preprocess:
            self._phase_to_functions[ProcessPhase.PREPROCESS].append(fn)
        if fn := fns.postprocess:
            # postprocess should go in reverse order to properly unravel preprocess
            self._phase_to_functions[ProcessPhase.POSTPROCESS].insert(0, fn)
        if fn := fns.validate_batch:
            self._phase_to_functions[ValidateBatchPhase.VALIDATE_BATCH].append(fn)
        if fn := fns.generate:
            self._phase_to_functions[ProcessPhase.GENERATE].append(fn)

data_actions_fn(action_executor)

Applies an action executor to a dataframe.

Source code in src/nemo_safe_synthesizer/data_processing/actions/data_actions.py
def data_actions_fn(
    action_executor: ActionExecutor,
) -> utils.DataActionsFn:
    """Applies an action executor to a dataframe."""

    def fn(batch: pd.DataFrame, df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
        logger.info("Applying data_config postprocessing")

        logger.debug(f"Before postprocess: {utils.debug_fmt(batch)}")
        batch = action_executor.postprocess(batch)
        logger.debug(f"After postprocess: {utils.debug_fmt(batch)}")
        logger.info(
            f"Applying data_config validation on output batch of size [{len(batch)}]",
        )
        valid_df, rejected_df = action_executor.validate_batch(batch, df)
        logger.debug(f"valid_df after validate: {utils.debug_fmt(valid_df)}")
        logger.debug(f"rejected_df after validate: {utils.debug_fmt(rejected_df)}")

        logger.info(
            f"After data_config validation, output batch size is [{len(valid_df)}]",
        )
        return valid_df, rejected_df

    return fn