Skip to content

vllm_backend

vllm_backend

vLLM-based generation backend for tabular data synthesis.

Classes:

Name Description
VllmBackend

Generation backend using vLLM for high-throughput inference.

VllmBackend(config, model_metadata, workdir, **kwargs)

Bases: GeneratorBackend

Generation backend using vLLM for high-throughput inference.

Loads the base model with a LoRA adapter via vLLM and generates synthetic records in batches. Supports optional structured generation (regex or JSON schema) to constrain outputs.

Parameters:

Name Type Description Default
config SafeSynthesizerParameters

Pipeline configuration.

required
model_metadata ModelMetadata

Model metadata (prompt template, adapter path, sequence length, etc.).

required
workdir Workdir

Working directory containing the adapter and schema.

required
**kwargs

Additional options. use_detailed_logs (bool) enables verbose error messages (disabled by default to avoid leaking sensitive data).

{}

Methods:

Name Description
teardown

Release GPU memory and distributed resources. Idempotent -- safe to call multiple times.

initialize

Initialize and load the model into memory.

prepare_params

Parse parameters and configure the generation method.

generate

Generate synthetic tabular data in batches until the target count is reached.

Source code in src/nemo_safe_synthesizer/generation/vllm_backend.py
def __init__(
    self,
    config: SafeSynthesizerParameters,
    model_metadata: ModelMetadata,
    workdir: Workdir,
    **kwargs,
):
    self.model_metadata = model_metadata
    self.config = config
    self.remote = False
    self.workdir = workdir
    self.schema = load_json(self.workdir.schema_file)
    self.columns = list(self.schema["properties"].keys())
    self.prompt = utils.create_schema_prompt(
        self.columns,
        instruction=self.model_metadata.instruction,
        prompt_template=self.model_metadata.prompt_config.template,
    )
    self.llm: vLLM | None = None
    self.logits_processors = []

    # Do not generate detailed error messages in production to avoid leaking sensitive data.
    self.use_detailed_logs = kwargs.pop("use_detailed_logs", False)
    self.gen_method: partial | None = None
    self._gen_method: partial | None = None
    self.processor = create_processor(self.schema, self.model_metadata, self.config)
    adapter_path = self.workdir.adapter_path if self.workdir.adapter_path else self.model_metadata.adapter_path
    self.lora_req = LoRARequest("lora", 1, str(adapter_path)) if adapter_path else None
    self._torn_down = False

teardown()

Release GPU memory and distributed resources. Idempotent -- safe to call multiple times.

Source code in src/nemo_safe_synthesizer/generation/vllm_backend.py
def teardown(self) -> None:
    """Release GPU memory and distributed resources. Idempotent -- safe to call multiple times."""
    if self._torn_down:
        return
    self._torn_down = True

    try:
        cleanup_dist_env_and_memory()
    except Exception:
        logger.debug("cleanup_dist_env_and_memory failed during teardown", exc_info=True)

    self.llm = None
    self._gen_method = None
    self.gen_method = None

    try:
        cleanup_memory()
    except Exception:
        logger.debug("cleanup_memory failed during teardown", exc_info=True)

initialize(**kwargs)

Initialize and load the model into memory.

Source code in src/nemo_safe_synthesizer/generation/vllm_backend.py
def initialize(self, **kwargs) -> None:
    """Initialize and load the model into memory."""
    self._torn_down = False

    # vLLM 0.12+ accepts attention_config as a constructor arg (replaces the
    # VLLM_ATTENTION_BACKEND env var used in 0.11.x).
    attn_backend = self.config.generation.attention_backend
    attention_config = {"backend": attn_backend} if attn_backend not in (None, "auto") else None

    max_vram = get_max_vram()
    # note this only works for single GPU setups
    max_vram = max_vram.get(0, 0.8)

    # vllm requires this "config" to set the backend ahead of time.
    structured_outputs_config = StructuredOutputsConfig(
        backend=self.config.generation.structured_generation_backend,
        disable_fallback=True,
    )
    # Unsloth patches model attention forward functions with torch.compiler.disable().
    # vLLM compiles TransformersForCausalLM with fullgraph=True via @support_torch_compile.
    # PyTorch >= 2.9.1 changed fullgraph=True to raise immediately on torch.compiler.disable()
    # rather than silently breaking the graph (pytorch#8e83e24). This combination produces:
    #   torch._dynamo.exc.Unsupported: Skip inlining `torch.compiler.disable()`d function
    # Passing enforce_eager=True skips vLLM's torch.compile pipeline entirely for these runs.
    # check this when updating unsloth in the future.
    enforce_eager = self.config.training.use_unsloth is True

    with heartbeat("Model loading", logger_name=__name__, model=self.config.training.pretrained_model):
        self.llm = vLLM(
            model=self.config.training.pretrained_model,
            gpu_memory_utilization=max_vram,
            enable_lora=True,
            max_lora_rank=self.config.training.lora_r,
            structured_outputs_config=structured_outputs_config,
            enforce_eager=enforce_eager,
            attention_config=attention_config,
        )

prepare_params(**kwargs)

Parse parameters and configure the generation method.

Parses a dictionary of parameters into SamplingParameters, applying necessary transformations from our API to vLLM's API.

Parameters:

Name Type Description Default
**kwargs

Sampling parameters to configure.

{}
Source code in src/nemo_safe_synthesizer/generation/vllm_backend.py
def prepare_params(self, **kwargs) -> None:
    """Parse parameters and configure the generation method.

    Parses a dictionary of parameters into SamplingParameters,
    applying necessary transformations from our API to vLLM's API.

    Args:
        **kwargs: Sampling parameters to configure.
    """
    structured_output_params = self._build_structured_output_params()
    kwargs |= {"structured_outputs": structured_output_params}

    resolved_temperature = self._resolve_temperature(kwargs)
    api_mapping = self._get_api_param_mapping(resolved_temperature)
    sampling_params = self._transform_kwargs_to_sampling_params(kwargs, api_mapping)

    real_params = SamplingParams(**sampling_params)
    logger.debug(f"SamplingParams: {real_params!r}")

    # Create a partially parametrized version of the underlying vllm.LLM.generate
    # method that is immediately callable downstream.
    if TYPE_CHECKING:
        assert self.llm is not None
    self._gen_method = partial(
        self.llm.generate,
        sampling_params=real_params,
        lora_request=self.lora_req,
        # Show vLLM's tqdm progress bar only when debug logging is enabled.
        use_tqdm=logger.isEnabledFor(logging.DEBUG),
    )

generate(data_actions_fn=None)

Generate synthetic tabular data in batches until the target count is reached.

Iterates over generation batches, applying the processor to each LLM output, until the configured num_records target is met or a stopping condition fires.

Parameters:

Name Type Description Default
data_actions_fn DataActionsFn | None

Optional post-processing / validation function applied to each batch of generated records.

None

Returns:

Type Description
GenerateJobResults

Results containing the generated DataFrame and statistics.

Source code in src/nemo_safe_synthesizer/generation/vllm_backend.py
def generate(
    self,
    data_actions_fn: utils.DataActionsFn | None = None,
) -> GenerateJobResults:
    """Generate synthetic tabular data in batches until the target count is reached.

    Iterates over generation batches, applying the processor to each
    LLM output, until the configured ``num_records`` target is met or
    a stopping condition fires.

    Args:
        data_actions_fn: Optional post-processing / validation function
            applied to each batch of generated records.

    Returns:
        Results containing the generated DataFrame and statistics.
    """
    generation_start = time.monotonic()

    need_special_token_outputs = not isinstance(self.processor, TabularDataProcessor)
    sampling_kwargs = dict(
        temperature=self.config.generation.temperature,
        repetition_penalty=self.config.generation.repetition_penalty,
        top_p=self.config.generation.top_p,
        top_k=FIXED_RUNTIME_GENERATE_ARGS["top_k"],
        min_p=FIXED_RUNTIME_GENERATE_ARGS["min_p"],
        logits_processors=self.logits_processors,
        max_tokens=self.model_metadata.max_seq_length,
        skip_special_tokens=not need_special_token_outputs,
        include_stop_str_in_output=need_special_token_outputs,
        ignore_eos=False,
    )

    self.prepare_params(**sampling_kwargs)

    # The batches object collects batches and keeps track of the stopping condition.
    batches = GenerationBatches(
        target_num_records=self.config.generation.num_records,
        invalid_fraction_threshold=self.config.generation.invalid_fraction_threshold,
        patience=self.config.generation.patience,
        data_actions_fn=data_actions_fn,
    )

    with heartbeat(
        "Generation",
        logger_name=__name__,
        target_records=self.config.generation.num_records,
        progress_note=("Long stretches with no new records are normal."),
    ):
        while batches.num_valid_records < self.config.generation.num_records:
            # Generate a batch from prompts and process the responses.
            num_prompts = batches.get_next_num_prompts()
            start_time = time.perf_counter()
            batch: Batch = self._generate_batch(
                num_prompts_per_batch=num_prompts,
                batch=Batch(processor=self.processor),
                **sampling_kwargs,
            )
            duration = time.perf_counter() - start_time
            batches.add_batch(batch)

            # Log generation summary and progress.
            batch.log_summary(detailed_errors=self.use_detailed_logs)
            self._log_batch_timing_and_progress(
                batch=batch,
                duration=duration,
                num_records=self.config.generation.num_records,
                num_valid_records=batches.num_valid_records,
                batches=batches,
            )
            # Check if the generation job should stop.
            if batches.status in [
                GenerationStatus.STOP_NO_RECORDS,
                GenerationStatus.STOP_METRIC_REACHED,
            ]:
                break

    batches.job_complete()
    batches.log_status()

    max_num_records = (
        self.config.generation.num_records
        if self.config.data.group_training_examples_by is None and batches.status == GenerationStatus.COMPLETE
        else None
    )

    generation_time_sec = time.monotonic() - generation_start
    self.elapsed_time = generation_time_sec
    self.gen_results = GenerateJobResults.from_batches(
        batches=batches,
        columns=self.columns,
        max_num_records=max_num_records,
        elapsed_time=self.elapsed_time,
    )

    return self.gen_results