# SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""LLM Rails entry point."""
import asyncio
import importlib.util
import json
import logging
import os
import re
import threading
import time
import warnings
from functools import partial
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
Type,
Union,
cast,
overload,
)
from langchain_core.language_models import BaseChatModel, BaseLLM
from typing_extensions import Self
from nemoguardrails.actions.llm.generation import LLMGenerationActions
from nemoguardrails.actions.llm.utils import (
extract_bot_thinking_from_events,
extract_tool_calls_from_events,
get_and_clear_response_metadata_contextvar,
get_colang_history,
)
from nemoguardrails.actions.output_mapping import is_output_blocked
from nemoguardrails.actions.v2_x.generation import LLMGenerationActionsV2dotx
from nemoguardrails.colang import parse_colang_file
from nemoguardrails.colang.v1_0.runtime.flows import _normalize_flow_id, compute_context
from nemoguardrails.colang.v1_0.runtime.runtime import Runtime, RuntimeV1_0
from nemoguardrails.colang.v2_x.runtime.flows import Action, State
from nemoguardrails.colang.v2_x.runtime.runtime import RuntimeV2_x
from nemoguardrails.colang.v2_x.runtime.serialization import (
json_to_state,
state_to_json,
)
from nemoguardrails.context import (
explain_info_var,
generation_options_var,
llm_stats_var,
raw_llm_request,
streaming_handler_var,
)
from nemoguardrails.embeddings.index import EmbeddingsIndex
from nemoguardrails.embeddings.providers import register_embedding_provider
from nemoguardrails.embeddings.providers.base import EmbeddingModel
from nemoguardrails.exceptions import (
InvalidModelConfigurationError,
InvalidRailsConfigurationError,
StreamingNotSupportedError,
)
from nemoguardrails.kb.kb import KnowledgeBase
from nemoguardrails.llm.cache import CacheInterface, LFUCache
from nemoguardrails.llm.models.initializer import (
ModelInitializationError,
init_llm_model,
)
from nemoguardrails.logging.explain import ExplainInfo
from nemoguardrails.logging.processing_log import compute_generation_log
from nemoguardrails.logging.stats import LLMStats
from nemoguardrails.logging.verbose import set_verbose
from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop
from nemoguardrails.rails.llm.buffer import get_buffer_strategy
from nemoguardrails.rails.llm.config import (
EmbeddingSearchProvider,
OutputRailsStreamingConfig,
RailsConfig,
)
from nemoguardrails.rails.llm.options import (
GenerationLog,
GenerationOptions,
GenerationResponse,
RailsResult,
RailStatus,
RailType,
)
from nemoguardrails.rails.llm.utils import (
get_action_details_from_flow_id,
get_history_cache_key,
)
from nemoguardrails.streaming import END_OF_STREAM, StreamingHandler
from nemoguardrails.utils import (
extract_error_json,
get_or_create_event_loop,
new_event_dict,
new_uuid,
)
log = logging.getLogger(__name__)
process_events_semaphore = asyncio.Semaphore(1)
[docs]
class LLMRails:
"""Rails based on a given configuration."""
config: RailsConfig
llm: Optional[Union[BaseLLM, BaseChatModel]]
runtime: Runtime
[docs]
def __init__(
self,
config: RailsConfig,
llm: Optional[Union[BaseLLM, BaseChatModel]] = None,
verbose: bool = False,
):
"""Initializes the LLMRails instance.
Args:
config: A rails configuration.
llm: An optional LLM engine to use. If provided, this will be used as the main LLM
and will take precedence over any main LLM specified in the config.
verbose: Whether the logging should be verbose or not.
"""
self.config = config
self.llm = llm
self.verbose = verbose
if self.verbose:
set_verbose(True, llm_calls=True)
# We allow the user to register additional embedding search providers, so we keep
# an index of them.
self.embedding_search_providers = {}
# The default embeddings model is using FastEmbed
self.default_embedding_model = "all-MiniLM-L6-v2"
self.default_embedding_engine = "FastEmbed"
self.default_embedding_params = {}
# We keep a cache of the events history associated with a sequence of user messages.
# TODO: when we update the interface to allow to return a "state object", this
# should be removed
self.events_history_cache = {}
# We also load the default flows from the `default_flows.yml` file in the current folder.
# But only for version 1.0.
# TODO: decide on the default flows for 2.x.
if config.colang_version == "1.0":
# We also load the default flows from the `llm_flows.co` file in the current folder.
current_folder = os.path.dirname(__file__)
default_flows_file = "llm_flows.co"
default_flows_path = os.path.join(current_folder, default_flows_file)
with open(default_flows_path, "r") as f:
default_flows_content = f.read()
default_flows = parse_colang_file(default_flows_file, default_flows_content)["flows"]
# We mark all the default flows as system flows.
for flow_config in default_flows:
flow_config["is_system_flow"] = True
# We add the default flows to the config.
self.config.flows.extend(default_flows)
# We also need to load the content from the components library.
library_path = os.path.join(os.path.dirname(__file__), "../../library")
for root, dirs, files in os.walk(library_path):
for file in files:
# Extract the full path for the file
full_path = os.path.join(root, file)
if file.endswith(".co"):
log.debug(f"Loading file: {full_path}")
with open(full_path, "r", encoding="utf-8") as f:
content = parse_colang_file(file, content=f.read(), version=config.colang_version)
if not content:
continue
# We mark all the flows coming from the guardrails library as system flows.
for flow_config in content["flows"]:
flow_config["is_system_flow"] = True
# We load all the flows
self.config.flows.extend(content["flows"])
# And all the messages as well, if they have not been overwritten
for message_id, utterances in content.get("bot_messages", {}).items():
if message_id not in self.config.bot_messages:
self.config.bot_messages[message_id] = utterances
# Last but not least, we mark all the flows that are used in any of the rails
# as system flows (so they don't end up in the prompt).
rail_flow_ids = config.rails.input.flows + config.rails.output.flows + config.rails.retrieval.flows
for flow_config in self.config.flows:
if flow_config.get("id") in rail_flow_ids:
flow_config["is_system_flow"] = True
# We also mark them as subflows by default, to simplify the syntax
flow_config["is_subflow"] = True
# We check if the configuration or any of the imported ones have config.py modules.
config_modules = []
for _path in list(self.config.imported_paths.values() if self.config.imported_paths else []) + [
self.config.config_path
]:
if _path:
filepath = os.path.join(_path, "config.py")
if os.path.exists(filepath):
filename = os.path.basename(filepath)
spec = importlib.util.spec_from_file_location(filename, filepath)
if spec and spec.loader:
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
config_modules.append(config_module)
colang_version_to_runtime: Dict[str, Type[Runtime]] = {
"1.0": RuntimeV1_0,
"2.x": RuntimeV2_x,
}
if config.colang_version not in colang_version_to_runtime:
raise InvalidRailsConfigurationError(
f"Unsupported colang version: {config.colang_version}. Supported versions: {list(colang_version_to_runtime.keys())}"
)
# First, we initialize the runtime.
self.runtime = colang_version_to_runtime[config.colang_version](config=config, verbose=verbose)
# If we have a config_modules with an `init` function, we call it.
# We need to call this here because the `init` might register additional
# LLM providers.
for config_module in config_modules:
if hasattr(config_module, "init"):
config_module.init(self)
# If we have a customized embedding model, we'll use it.
for model in self.config.models:
if model.type == "embeddings":
self.default_embedding_model = model.model
self.default_embedding_engine = model.engine
self.default_embedding_params = model.parameters or {}
break
# InteractionLogAdapters used for tracing
# We ensure that it is used after config.py is loaded
if config.tracing:
from nemoguardrails.tracing import create_log_adapters
self._log_adapters = create_log_adapters(config.tracing)
else:
self._log_adapters = None
# We run some additional checks on the config
self._validate_config()
# Next, we initialize the LLM engines (main engine and action engines if specified).
self._init_llms()
# Next, we initialize the LLM Generate actions and register them.
llm_generation_actions_class = (
LLMGenerationActions if config.colang_version == "1.0" else LLMGenerationActionsV2dotx
)
self.llm_generation_actions = llm_generation_actions_class(
config=config,
llm=self.llm,
llm_task_manager=self.runtime.llm_task_manager,
get_embedding_search_provider_instance=self._get_embeddings_search_provider_instance,
verbose=verbose,
)
# If there's already an action registered, we don't override.
self.runtime.register_actions(self.llm_generation_actions, override=False)
# Next, we initialize the Knowledge Base
# There are still some edge cases not covered by nest_asyncio.
# Using a separate thread always for now.
loop = get_or_create_event_loop()
if True or check_sync_call_from_async_loop():
t = threading.Thread(target=asyncio.run, args=(self._init_kb(),))
t.start()
t.join()
else:
loop.run_until_complete(self._init_kb())
# We also register the kb as a parameter that can be passed to actions.
self.runtime.register_action_param("kb", self.kb)
# Reference to the general ExplainInfo object.
self.explain_info = None
[docs]
def update_llm(self, llm):
"""Replace the main LLM with the provided one.
Arguments:
llm: The new LLM that should be used.
"""
self.llm = llm
self.llm_generation_actions.llm = llm
self.runtime.register_action_param("llm", llm)
def _validate_config(self):
"""Runs additional validation checks on the config."""
if self.config.colang_version == "1.0":
existing_flows_names = set([flow.get("id") for flow in self.config.flows])
else:
existing_flows_names = set([flow.get("name") for flow in self.config.flows])
for flow_name in self.config.rails.input.flows:
# content safety check input/output flows are special as they have parameters
flow_name = _normalize_flow_id(flow_name)
if flow_name not in existing_flows_names:
raise InvalidRailsConfigurationError(f"The provided input rail flow `{flow_name}` does not exist")
for flow_name in self.config.rails.output.flows:
flow_name = _normalize_flow_id(flow_name)
if flow_name not in existing_flows_names:
raise InvalidRailsConfigurationError(f"The provided output rail flow `{flow_name}` does not exist")
for flow_name in self.config.rails.retrieval.flows:
if flow_name not in existing_flows_names:
raise InvalidRailsConfigurationError(f"The provided retrieval rail flow `{flow_name}` does not exist")
# If both passthrough mode and single call mode are specified, we raise an exception.
if self.config.passthrough and self.config.rails.dialog.single_call.enabled:
raise InvalidRailsConfigurationError(
"The passthrough mode and the single call dialog rails mode can't be used at the same time. "
"The single call mode needs to use an altered prompt when prompting the LLM. "
)
async def _init_kb(self):
"""Initializes the knowledge base."""
self.kb = None
if not self.config.docs:
return
documents = [doc.content for doc in self.config.docs]
self.kb = KnowledgeBase(
documents=documents,
config=self.config.knowledge_base,
get_embedding_search_provider_instance=self._get_embeddings_search_provider_instance,
)
self.kb.init()
await self.kb.build()
def _prepare_model_kwargs(self, model_config):
"""
Prepare kwargs for model initialization, including API key from environment variable.
Args:
model_config: The model configuration object
Returns:
dict: The prepared kwargs for model initialization
"""
kwargs = model_config.parameters or {}
# If the optional API Key Environment Variable is set, add it to kwargs
if model_config.api_key_env_var:
api_key = os.environ.get(model_config.api_key_env_var)
if api_key:
kwargs["api_key"] = api_key
return kwargs
def _init_llms(self):
"""
Initializes the right LLM engines based on the configuration.
There can be multiple LLM engines and types that can be specified in the config.
The main LLM engine is the one that will be used for all the core guardrails generations.
Other LLM engines can be specified for use in specific actions.
The reason we provide an option for decoupling the main LLM engine from the action LLM
is to allow for flexibility in using specialized LLM engines for specific actions.
Raises:
ModelInitializationError: If any model initialization fails
"""
# If the user supplied an already-constructed LLM via the constructor we
# treat it as the *main* model, but **still** iterate through the
# configuration to load any additional models (e.g. `content_safety`).
if self.llm:
# If an LLM was provided via constructor, use it as the main LLM
# Log a warning if a main LLM is also specified in the config
if any(model.type == "main" for model in self.config.models):
log.warning(
"Both an LLM was provided via constructor and a main LLM is specified in the config. "
"The LLM provided via constructor will be used and the main LLM from config will be ignored."
)
self.runtime.register_action_param("llm", self.llm)
else:
# Otherwise, initialize the main LLM from the config
main_model = next((model for model in self.config.models if model.type == "main"), None)
if main_model and main_model.model:
kwargs = self._prepare_model_kwargs(main_model)
self.llm = init_llm_model(
model_name=main_model.model,
provider_name=main_model.engine,
mode="chat",
kwargs=kwargs,
)
self.runtime.register_action_param("llm", self.llm)
else:
log.warning("No main LLM specified in the config and no LLM provided via constructor.")
llms = dict()
for llm_config in self.config.models:
if llm_config.type in ["embeddings", "jailbreak_detection"]:
continue
# If a constructor LLM is provided, skip initializing any 'main' model from config
if self.llm and llm_config.type == "main":
continue
try:
model_name = llm_config.model
if not model_name:
raise InvalidModelConfigurationError(
f"`model` field must be set in model configuration: {llm_config.model_dump_json()}"
)
provider_name = llm_config.engine
kwargs = self._prepare_model_kwargs(llm_config)
mode = llm_config.mode
llm_model = init_llm_model(
model_name=model_name,
provider_name=provider_name,
mode=mode,
kwargs=kwargs,
)
# Configure the model based on its type
if llm_config.type == "main":
# If a main LLM was already injected, skip creating another
# one. Otherwise, create and register it.
if not self.llm:
self.llm = llm_model
self.runtime.register_action_param("llm", self.llm)
else:
model_name = f"{llm_config.type}_llm"
if not hasattr(self, model_name):
setattr(self, model_name, llm_model)
self.runtime.register_action_param(model_name, getattr(self, model_name))
# this is used for content safety and topic control
llms[llm_config.type] = getattr(self, model_name)
except ModelInitializationError as e:
log.error("Failed to initialize model: %s", str(e))
raise
except Exception as e:
log.error("Unexpected error initializing model: %s", str(e))
raise
self.runtime.register_action_param("llms", llms)
self._initialize_model_caches()
def _create_model_cache(self, model) -> LFUCache:
"""
Create cache instance for a model based on its configuration.
Args:
model: The model configuration object
Returns:
LFUCache: The cache instance
"""
if model.cache.maxsize <= 0:
raise ValueError(
f"Invalid cache maxsize for model '{model.type}': {model.cache.maxsize}. "
"Capacity must be greater than 0. Skipping cache creation."
)
stats_logging_interval = None
if model.cache.stats.enabled and model.cache.stats.log_interval is not None:
stats_logging_interval = model.cache.stats.log_interval
cache = LFUCache(
maxsize=model.cache.maxsize,
track_stats=model.cache.stats.enabled,
stats_logging_interval=stats_logging_interval,
)
log.info(f"Created cache for model '{model.type}' with maxsize {model.cache.maxsize}")
return cache
def _initialize_model_caches(self) -> None:
"""Initialize caches for configured models."""
model_caches: Optional[Dict[str, CacheInterface]] = dict()
for model in self.config.models:
if model.type in ["main", "embeddings"]:
continue
if model.cache and model.cache.enabled:
cache = self._create_model_cache(model)
model_caches[model.type] = cache
log.info(
f"Initialized model '{model.type}' with cache %s",
"enabled" if cache else "disabled",
)
if model_caches:
self.runtime.register_action_param("model_caches", model_caches)
def _get_embeddings_search_provider_instance(
self, esp_config: Optional[EmbeddingSearchProvider] = None
) -> EmbeddingsIndex:
if esp_config is None:
esp_config = EmbeddingSearchProvider()
if esp_config.name == "default":
from nemoguardrails.embeddings.basic import BasicEmbeddingsIndex
return BasicEmbeddingsIndex(
embedding_model=esp_config.parameters.get("embedding_model", self.default_embedding_model),
embedding_engine=esp_config.parameters.get("embedding_engine", self.default_embedding_engine),
embedding_params=esp_config.parameters.get("embedding_parameters", self.default_embedding_params),
cache_config=esp_config.cache,
# We make sure we also pass additional relevant params.
**{
k: v
for k, v in esp_config.parameters.items()
if k
in [
"use_batching",
"max_batch_size",
"matx_batch_hold",
"search_threshold",
]
and v is not None
},
)
else:
if esp_config.name not in self.embedding_search_providers:
raise Exception(f"Unknown embedding search provider: {esp_config.name}")
else:
kwargs = esp_config.parameters
return self.embedding_search_providers[esp_config.name](**kwargs)
def _get_events_for_messages(self, messages: List[dict], state: Any):
"""Return the list of events corresponding to the provided messages.
Tries to find a prefix of messages for which we have already a list of events
in the cache. For the rest, they are converted as is.
The reason this cache exists is that we want to benefit from events generated in
previous turns, which can't be computed again because it would be expensive (e.g.,
involving multiple LLM calls).
When an explicit state object will be added, this mechanism can be removed.
Args:
messages: The list of messages.
Returns:
A list of events.
"""
events = []
if self.config.colang_version == "1.0":
# We try to find the longest prefix of messages for which we have a cache
# of events.
p = len(messages) - 1
while p > 0:
cache_key = get_history_cache_key(messages[0:p])
if cache_key in self.events_history_cache:
events = self.events_history_cache[cache_key].copy()
break
p -= 1
# For the rest of the messages, we transform them directly into events.
# TODO: Move this to separate function once more types of messages are supported.
for idx in range(p, len(messages)):
msg = messages[idx]
if msg["role"] == "user":
events.append(
{
"type": "UtteranceUserActionFinished",
"final_transcript": msg["content"],
}
)
# If it's not the last message, we also need to add the `UserMessage` event
if idx != len(messages) - 1:
events.append(
{
"type": "UserMessage",
"text": msg["content"],
}
)
elif msg["role"] == "assistant":
if msg.get("tool_calls"):
events.append({"type": "BotToolCalls", "tool_calls": msg["tool_calls"]})
else:
action_uid = new_uuid()
start_event = new_event_dict(
"StartUtteranceBotAction",
script=msg["content"],
action_uid=action_uid,
)
finished_event = new_event_dict(
"UtteranceBotActionFinished",
final_script=msg["content"],
is_success=True,
action_uid=action_uid,
)
events.extend([start_event, finished_event])
elif msg["role"] == "context":
events.append({"type": "ContextUpdate", "data": msg["content"]})
elif msg["role"] == "event":
events.append(msg["event"])
elif msg["role"] == "system":
# Handle system messages - convert them to SystemMessage events
events.append({"type": "SystemMessage", "content": msg["content"]})
elif msg["role"] == "tool":
# For the last tool message, create grouped tool event and synthetic UserMessage
if idx == len(messages) - 1:
# Find the original user message for response generation
user_message = None
for prev_msg in reversed(messages[:idx]):
if prev_msg["role"] == "user":
user_message = prev_msg["content"]
break
if user_message:
# If tool input rails are configured, group all tool messages
if self.config.rails.tool_input.flows:
# Collect all tool messages for grouped processing
tool_messages = []
for tool_idx in range(len(messages)):
if messages[tool_idx]["role"] == "tool":
tool_messages.append(
{
"content": messages[tool_idx]["content"],
"name": messages[tool_idx].get("name", "unknown"),
"tool_call_id": messages[tool_idx].get("tool_call_id", ""),
}
)
events.append(
{
"type": "UserToolMessages",
"tool_messages": tool_messages,
}
)
else:
events.append({"type": "UserMessage", "text": user_message})
else:
for idx in range(len(messages)):
msg = messages[idx]
if msg["role"] == "user":
events.append(
{
"type": "UtteranceUserActionFinished",
"final_transcript": msg["content"],
}
)
elif msg["role"] == "assistant":
raise ValueError(
"Providing `assistant` messages as input is not supported for Colang 2.0 configurations."
)
elif msg["role"] == "context":
events.append({"type": "ContextUpdate", "data": msg["content"]})
elif msg["role"] == "event":
events.append(msg["event"])
elif msg["role"] == "system":
# Handle system messages - convert them to SystemMessage events
events.append({"type": "SystemMessage", "content": msg["content"]})
elif msg["role"] == "tool":
action_uid = msg["tool_call_id"]
return_value = msg["content"]
action: Action = state.actions[action_uid]
events.append(
new_event_dict(
f"{action.name}Finished",
action_uid=action_uid,
action_name=action.name,
status="success",
is_success=True,
return_value=return_value,
events=[],
)
)
return events
@staticmethod
def _ensure_explain_info() -> ExplainInfo:
"""Ensure that the ExplainInfo variable is present in the current context
Returns:
A ExplainInfo class containing the llm calls' statistics
"""
explain_info = explain_info_var.get()
if explain_info is None:
explain_info = ExplainInfo()
explain_info_var.set(explain_info)
return explain_info
[docs]
async def generate_async(
self,
prompt: Optional[str] = None,
messages: Optional[List[dict]] = None,
options: Optional[Union[dict, GenerationOptions]] = None,
state: Optional[Union[dict, State]] = None,
streaming_handler: Optional[StreamingHandler] = None,
) -> Union[str, dict, GenerationResponse, Tuple[dict, dict]]:
"""Generate a completion or a next message.
The format for messages is the following:
```python
[
{"role": "context", "content": {"user_name": "John"}},
{"role": "user", "content": "Hello! How are you?"},
{"role": "assistant", "content": "I am fine, thank you!"},
{"role": "event", "event": {"type": "UserSilent"}},
...
]
```
Args:
prompt: The prompt to be used for completion.
messages: The history of messages to be used to generate the next message.
options: Options specific for the generation.
state: The state object that should be used as the starting point.
streaming_handler: If specified, and the config supports streaming, the
provided handler will be used for streaming.
Returns:
The completion (when a prompt is provided) or the next message.
System messages are not yet supported."""
# convert options to gen_options of type GenerationOptions
gen_options: Optional[GenerationOptions] = None
if prompt is None and messages is None:
raise ValueError("Either prompt or messages must be provided.")
if prompt is not None and messages is not None:
raise ValueError("Only one of prompt or messages can be provided.")
if prompt is not None:
# Currently, we transform the prompt request into a single turn conversation
messages = [{"role": "user", "content": prompt}]
# If a state object is specified, then we switch to "generation options" mode.
# This is because we want the output to be a GenerationResponse which will contain
# the output state.
if state is not None:
# We deserialize the state if needed.
if isinstance(state, dict) and state.get("version", "1.0") == "2.x":
state = json_to_state(state["state"])
if options is None:
gen_options = GenerationOptions()
elif isinstance(options, dict):
gen_options = GenerationOptions(**options)
else:
gen_options = options
else:
# We allow options to be specified both as a dict and as an object.
if options and isinstance(options, dict):
gen_options = GenerationOptions(**options)
elif isinstance(options, GenerationOptions):
gen_options = options
elif options is None:
gen_options = None
else:
raise TypeError("options must be a dict or GenerationOptions")
# Save the generation options in the current async context.
# At this point, gen_options is either None or GenerationOptions
generation_options_var.set(gen_options)
if streaming_handler:
streaming_handler_var.set(streaming_handler)
# Initialize the object with additional explanation information.
# We allow this to also be set externally. This is useful when multiple parallel
# requests are made.
self.explain_info = self._ensure_explain_info()
raw_llm_request.set(messages)
# If we have generation options, we also add them to the context
if gen_options:
messages = [
{
"role": "context",
"content": {"generation_options": gen_options.model_dump()},
}
] + (messages or [])
# If the last message is from the assistant, rather than the user, then
# we move that to the `$bot_message` variable. This is to enable a more
# convenient interface. (only when dialog rails are disabled)
if messages and messages[-1]["role"] == "assistant" and gen_options and gen_options.rails.dialog is False:
# We already have the first message with a context update, so we use that
messages[0]["content"]["bot_message"] = messages[-1]["content"]
messages = messages[0:-1]
# TODO: Add support to load back history of events, next to history of messages
# This is important as without it, the LLM prediction is not as good.
t0 = time.time()
# Initialize the LLM stats
llm_stats = LLMStats()
llm_stats_var.set(llm_stats)
processing_log = []
# The array of events corresponding to the provided sequence of messages.
events = self._get_events_for_messages(messages, state) # type: ignore
if self.config.colang_version == "1.0":
# If we had a state object, we also need to prepend the events from the state.
state_events = []
if state:
assert isinstance(state, dict)
state_events = state["events"]
new_events = []
# Compute the new events.
try:
new_events = await self.runtime.generate_events(state_events + events, processing_log=processing_log)
output_state = None
except Exception as e:
log.error("Error in generate_async: %s", e, exc_info=True)
streaming_handler = streaming_handler_var.get()
if streaming_handler:
# Push an error chunk instead of None.
error_message = str(e)
error_dict = extract_error_json(error_message)
error_payload: str = json.dumps(error_dict)
await streaming_handler.push_chunk(error_payload)
# push a termination signal
await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore
# Re-raise the exact exception
raise
else:
# In generation mode, by default the bot response is an instant action.
instant_actions = ["UtteranceBotAction"]
if self.config.rails.actions.instant_actions is not None:
instant_actions = self.config.rails.actions.instant_actions
# Cast this explicitly to avoid certain warnings
runtime: RuntimeV2_x = cast(RuntimeV2_x, self.runtime)
# Compute the new events.
# In generation mode, the processing is always blocking, i.e., it waits for
# all local actions (sync and async).
new_events, output_state = await runtime.process_events(
events, state=state, instant_actions=instant_actions, blocking=True
)
# We also encode the output state as a JSON
output_state = {"state": state_to_json(output_state), "version": "2.x"}
# Extract and join all the messages from StartUtteranceBotAction events as the response.
responses = []
response_tool_calls = []
response_events = []
new_extra_events = []
exception = None
# The processing is different for Colang 1.0 and 2.0
if self.config.colang_version == "1.0":
for event in new_events:
if event["type"] == "StartUtteranceBotAction":
# Check if we need to remove a message
if event["script"] == "(remove last message)":
responses = responses[0:-1]
else:
responses.append(event["script"])
elif event["type"].endswith("Exception"):
exception = event
else:
for event in new_events:
start_action_match = re.match(r"Start(.*Action)", event["type"])
if start_action_match:
action_name = start_action_match[1]
# TODO: is there an elegant way to extract just the arguments?
arguments = {
k: v
for k, v in event.items()
if k != "type"
and k != "uid"
and k != "event_created_at"
and k != "source_uid"
and k != "action_uid"
}
response_tool_calls.append(
{
"id": event["action_uid"],
"type": "function",
"function": {"name": action_name, "arguments": arguments},
}
)
elif event["type"] == "UtteranceBotActionFinished":
responses.append(event["final_script"])
else:
# We just append the event
response_events.append(event)
if exception:
new_message: dict = {"role": "exception", "content": exception}
else:
# Ensure all items in responses are strings
responses = [str(response) if not isinstance(response, str) else response for response in responses]
new_message: dict = {"role": "assistant", "content": "\n".join(responses)}
if response_tool_calls:
new_message["tool_calls"] = response_tool_calls
if response_events:
new_message["events"] = response_events
if self.config.colang_version == "1.0":
events.extend(new_events)
events.extend(new_extra_events)
# If a state object is not used, then we use the implicit caching
if state is None:
# Save the new events in the history and update the cache
cache_key = get_history_cache_key((messages) + [new_message]) # type: ignore
self.events_history_cache[cache_key] = events
else:
output_state = {"events": events}
# If logging is enabled, we log the conversation
# TODO: add support for logging flag
self.explain_info.colang_history = get_colang_history(events)
if self.verbose:
log.info(f"Conversation history so far: \n{self.explain_info.colang_history}")
total_time = time.time() - t0
log.info("--- :: Total processing took %.2f seconds. LLM Stats: %s" % (total_time, llm_stats))
# If there is a streaming handler, we make sure we close it now
streaming_handler = streaming_handler_var.get()
if streaming_handler:
# print("Closing the stream handler explicitly")
await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore
# IF tracing is enabled we need to set GenerationLog attrs
original_log_options = None
if self.config.tracing.enabled:
if gen_options is None:
gen_options = GenerationOptions()
else:
# create a copy of the gen_options to avoid modifying the original
gen_options = gen_options.model_copy(deep=True)
original_log_options = gen_options.log.model_copy(deep=True)
# enable log options
# it is aggressive, but these are required for tracing
if (
not gen_options.log.activated_rails
or not gen_options.log.llm_calls
or not gen_options.log.internal_events
):
gen_options.log.activated_rails = True
gen_options.log.llm_calls = True
gen_options.log.internal_events = True
tool_calls = extract_tool_calls_from_events(new_events)
llm_metadata = get_and_clear_response_metadata_contextvar()
reasoning_content = extract_bot_thinking_from_events(new_events)
# If we have generation options, we prepare a GenerationResponse instance.
if gen_options:
# If a prompt was used, we only need to return the content of the message.
if prompt:
res = GenerationResponse(response=new_message["content"])
else:
res = GenerationResponse(response=[new_message])
if reasoning_content:
res.reasoning_content = reasoning_content
if tool_calls:
res.tool_calls = tool_calls
if llm_metadata:
res.llm_metadata = llm_metadata
if self.config.colang_version == "1.0":
# If output variables are specified, we extract their values
if gen_options and gen_options.output_vars:
context = compute_context(events)
output_vars = gen_options.output_vars
if isinstance(output_vars, list):
# If we have only a selection of keys, we filter to only that.
res.output_data = {k: context.get(k) for k in output_vars}
else:
# Otherwise, we return the full context
res.output_data = context
_log = compute_generation_log(processing_log)
# Include information about activated rails and LLM calls if requested
log_options = gen_options.log if gen_options else None
if log_options and (log_options.activated_rails or log_options.llm_calls):
res.log = GenerationLog()
# We always include the stats
res.log.stats = _log.stats
if log_options.activated_rails:
res.log.activated_rails = _log.activated_rails
if log_options.llm_calls:
res.log.llm_calls = []
for activated_rail in _log.activated_rails:
for executed_action in activated_rail.executed_actions:
res.log.llm_calls.extend(executed_action.llm_calls)
# Include internal events if requested
if log_options and log_options.internal_events:
if res.log is None:
res.log = GenerationLog()
res.log.internal_events = new_events
# Include the Colang history if requested
if log_options and log_options.colang_history:
if res.log is None:
res.log = GenerationLog()
res.log.colang_history = get_colang_history(events)
# Include the raw llm output if requested
if gen_options and gen_options.llm_output:
# Currently, we include the output from the generation LLM calls.
for activated_rail in _log.activated_rails:
if activated_rail.type == "generation":
for executed_action in activated_rail.executed_actions:
for llm_call in executed_action.llm_calls:
res.llm_output = llm_call.raw_response
else:
if gen_options and gen_options.output_vars:
raise ValueError("The `output_vars` option is not supported for Colang 2.0 configurations.")
log_options = gen_options.log if gen_options else None
if log_options and (
log_options.activated_rails
or log_options.llm_calls
or log_options.internal_events
or log_options.colang_history
):
raise ValueError("The `log` option is not supported for Colang 2.0 configurations.")
if gen_options and gen_options.llm_output:
raise ValueError("The `llm_output` option is not supported for Colang 2.0 configurations.")
# Include the state
if state is not None:
res.state = output_state
if self.config.tracing.enabled:
# TODO: move it to the top once resolved circular dependency of eval
# lazy import to avoid circular dependency
from nemoguardrails.tracing import Tracer
span_format = getattr(self.config.tracing, "span_format", "opentelemetry")
enable_content_capture = getattr(self.config.tracing, "enable_content_capture", False)
# Create a Tracer instance with instantiated adapters and span configuration
tracer = Tracer(
input=messages,
response=res,
adapters=self._log_adapters,
span_format=span_format,
enable_content_capture=enable_content_capture,
)
await tracer.export_async()
# respect original log specification, if tracing added information to the output
if original_log_options:
if not any(
(
original_log_options.internal_events,
original_log_options.activated_rails,
original_log_options.llm_calls,
original_log_options.colang_history,
)
):
res.log = None
else:
# Ensure res.log exists before setting attributes
if res.log is not None:
if not original_log_options.internal_events:
res.log.internal_events = []
if not original_log_options.activated_rails:
res.log.activated_rails = []
if not original_log_options.llm_calls:
res.log.llm_calls = []
return res
else:
# If a prompt is used, we only return the content of the message.
if reasoning_content:
thinking_trace = f"<think>{reasoning_content}</think>\n"
new_message["content"] = thinking_trace + new_message["content"]
if prompt:
return new_message["content"]
else:
if tool_calls:
new_message["tool_calls"] = tool_calls
return new_message
def _validate_streaming_with_output_rails(self) -> None:
if len(self.config.rails.output.flows) > 0 and (
not self.config.rails.output.streaming or not self.config.rails.output.streaming.enabled
):
raise StreamingNotSupportedError(
"stream_async() cannot be used when output rails are configured but "
"rails.output.streaming.enabled is False. Either set "
"rails.output.streaming.enabled to True in your configuration, or use "
"generate_async() instead of stream_async()."
)
@overload
def stream_async(
self,
prompt: Optional[str] = None,
messages: Optional[List[dict]] = None,
options: Optional[Union[dict, GenerationOptions]] = None,
state: Optional[Union[dict, State]] = None,
include_metadata: Literal[False] = False,
generator: Optional[AsyncIterator[str]] = None,
include_generation_metadata: Optional[bool] = None,
) -> AsyncIterator[str]: ...
@overload
def stream_async(
self,
prompt: Optional[str] = None,
messages: Optional[List[dict]] = None,
options: Optional[Union[dict, GenerationOptions]] = None,
state: Optional[Union[dict, State]] = None,
include_metadata: Literal[True] = ...,
generator: Optional[AsyncIterator[str]] = None,
include_generation_metadata: Optional[bool] = None,
) -> AsyncIterator[Union[str, dict]]: ...
[docs]
def stream_async(
self,
prompt: Optional[str] = None,
messages: Optional[List[dict]] = None,
options: Optional[Union[dict, GenerationOptions]] = None,
state: Optional[Union[dict, State]] = None,
include_metadata: Optional[bool] = False,
generator: Optional[AsyncIterator[str]] = None,
include_generation_metadata: Optional[bool] = None,
) -> AsyncIterator[Union[str, dict]]:
"""Simplified interface for getting directly the streamed tokens from the LLM."""
if include_generation_metadata is not None:
warnings.warn(
"include_generation_metadata is deprecated, use include_metadata instead. "
"It will be removed in version 0.22.0.",
DeprecationWarning,
stacklevel=2,
)
include_metadata = include_generation_metadata
self._validate_streaming_with_output_rails()
# if an external generator is provided, use it directly
if generator:
if self.config.rails.output.streaming and self.config.rails.output.streaming.enabled:
return self._run_output_rails_in_streaming(
streaming_handler=generator,
output_rails_streaming_config=self.config.rails.output.streaming,
messages=messages,
prompt=prompt,
)
else:
return generator
self.explain_info = self._ensure_explain_info()
streaming_handler = StreamingHandler(include_metadata=include_metadata)
# Create a properly managed task with exception handling
async def _generation_task():
try:
await self.generate_async(
prompt=prompt,
messages=messages,
streaming_handler=streaming_handler,
options=options,
state=state,
)
except Exception as e:
# If an exception occurs during generation, push it to the streaming handler as a json string
# This ensures the streaming pipeline is properly terminated
log.error(f"Error in generation task: {e}", exc_info=True)
error_message = str(e)
error_dict = extract_error_json(error_message)
error_payload = json.dumps(error_dict)
await streaming_handler.push_chunk(error_payload)
await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore
task = asyncio.create_task(_generation_task())
# Store task reference to prevent garbage collection and ensure proper cleanup
if not hasattr(self, "_active_tasks"):
self._active_tasks = set()
self._active_tasks.add(task)
# Clean up task when it's done
def task_done_callback(task):
self._active_tasks.discard(task)
task.add_done_callback(task_done_callback)
# when we have output rails we wrap the streaming handler
# if len(self.config.rails.output.flows) > 0:
#
if self.config.rails.output.streaming and self.config.rails.output.streaming.enabled:
base_iterator = self._run_output_rails_in_streaming(
streaming_handler=streaming_handler,
output_rails_streaming_config=self.config.rails.output.streaming,
messages=messages,
prompt=prompt,
)
else:
base_iterator = streaming_handler
async def wrapped_iterator():
try:
async for chunk in base_iterator:
if chunk is not None:
yield chunk
finally:
await task
return wrapped_iterator()
[docs]
def generate(
self,
prompt: Optional[str] = None,
messages: Optional[List[dict]] = None,
options: Optional[Union[dict, GenerationOptions]] = None,
state: Optional[dict] = None,
):
"""Synchronous version of generate_async."""
if check_sync_call_from_async_loop():
raise RuntimeError(
"You are using the sync `generate` inside async code. "
"You should replace with `await generate_async(...)` or use `nest_asyncio.apply()`."
)
loop = get_or_create_event_loop()
return loop.run_until_complete(
self.generate_async(
prompt=prompt,
messages=messages,
options=options,
state=state,
)
)
[docs]
async def generate_events_async(
self,
events: List[dict],
) -> List[dict]:
"""Generate the next events based on the provided history.
The format for events is the following:
```python
[
{"type": "...", ...},
...
]
```
Args:
events: The history of events to be used to generate the next events.
options: The options to be used for the generation.
Returns:
The newly generate event(s).
"""
t0 = time.time()
# Initialize the LLM stats
llm_stats = LLMStats()
llm_stats_var.set(llm_stats)
# Compute the new events.
processing_log = []
new_events = await self.runtime.generate_events(events, processing_log=processing_log)
# If logging is enabled, we log the conversation
# TODO: add support for logging flag
if self.verbose:
history = get_colang_history(events)
log.info(f"Conversation history so far: \n{history}")
log.info("--- :: Total processing took %.2f seconds." % (time.time() - t0))
log.info("--- :: Stats: %s" % llm_stats)
return new_events
[docs]
def generate_events(
self,
events: List[dict],
) -> List[dict]:
"""Synchronous version of `LLMRails.generate_events_async`."""
if check_sync_call_from_async_loop():
raise RuntimeError(
"You are using the sync `generate_events` inside async code. "
"You should replace with `await generate_events_async(...)` or use `nest_asyncio.apply()`."
)
loop = get_or_create_event_loop()
return loop.run_until_complete(self.generate_events_async(events=events))
[docs]
async def process_events_async(
self,
events: List[dict],
state: Union[Optional[dict], State] = None,
blocking: bool = False,
) -> Tuple[List[dict], Union[dict, State]]:
"""Process a sequence of events in a given state.
The events will be processed one by one, in the input order.
Args:
events: A sequence of events that needs to be processed.
state: The state that should be used as the starting point. If not provided,
a clean state will be used.
Returns:
(output_events, output_state) Returns a sequence of output events and an output
state.
"""
t0 = time.time()
llm_stats = LLMStats()
llm_stats_var.set(llm_stats)
# Compute the new events.
# We need to protect 'process_events' to be called only once at a time
# TODO (cschueller): Why is this?
async with process_events_semaphore:
output_events, output_state = await self.runtime.process_events(events, state, blocking)
took = time.time() - t0
# Small tweak, disable this when there were no events (or it was just too fast).
if took > 0.1:
log.info("--- :: Total processing took %.2f seconds." % took)
log.info("--- :: Stats: %s" % llm_stats)
return output_events, output_state
[docs]
def process_events(
self,
events: List[dict],
state: Union[Optional[dict], State] = None,
blocking: bool = False,
) -> Tuple[List[dict], Union[dict, State]]:
"""Synchronous version of `LLMRails.process_events_async`."""
if check_sync_call_from_async_loop():
raise RuntimeError(
"You are using the sync `generate_events` inside async code. "
"You should replace with `await generate_events_async(...)."
)
loop = get_or_create_event_loop()
return loop.run_until_complete(self.process_events_async(events, state, blocking))
[docs]
async def check_async(
self,
messages: List[dict],
rail_types: Optional[List[RailType]] = None,
) -> RailsResult:
"""Run rails on messages based on their content (asynchronous).
When ``rail_types`` is not provided, automatically determines which rails
to run based on message roles:
- Only user messages: runs input rails
- Only assistant messages: runs output rails
- Both user and assistant messages: runs both input and output rails
- No user/assistant messages: logs warning and returns passing result
When ``rail_types`` is provided, runs exactly the specified rail types,
skipping the auto-detection logic.
Args:
messages: List of message dicts with 'role' and 'content' fields.
Messages can contain any roles, but only user/assistant roles
determine which rails execute when ``rail_types`` is not provided.
rails: Optional list of rail types to run, e.g.
``[RailType.INPUT]`` or ``[RailType.OUTPUT]``.
When provided, overrides automatic detection.
Returns:
RailsResult containing:
- status: PASSED, MODIFIED, or BLOCKED
- content: The final content after rails processing
- rail: Name of the rail that blocked (if blocked)
Examples:
Check user input (auto-detected):
result = await rails.check_async([{"role": "user", "content": "Hello!"}])
if result.status == RailStatus.BLOCKED:
print(f"Blocked by: {result.rail}")
Check bot output with context (auto-detected):
result = await rails.check_async([
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi there!"}
])
Run only input rails explicitly:
result = await rails.check_async(messages, rail_types=[RailType.INPUT])
"""
if rail_types is not None:
options: Optional[dict] = {"rails": [r.value for r in rail_types]}
else:
options = _determine_rails_from_messages(messages)
if options is None:
last_content = messages[-1].get("content", "") if messages else ""
return RailsResult(status=RailStatus.PASSED, content=last_content)
rails_to_run = options["rails"]
if "output" in rails_to_run:
original_content = _get_last_content_by_role(messages, "assistant")
else:
original_content = _get_last_content_by_role(messages, "user")
messages = _normalize_messages_for_rails(messages, rails_to_run)
options["log"] = {"activated_rails": True}
response = await self.generate_async(messages=messages, options=options)
if not isinstance(response, GenerationResponse):
raise RuntimeError(f"Expected GenerationResponse, got {type(response).__name__}")
blocking_rail = _get_blocking_rail(response)
result_content = _get_last_response_content(response)
if blocking_rail:
return RailsResult(status=RailStatus.BLOCKED, content=result_content, rail=blocking_rail)
if result_content != original_content:
return RailsResult(status=RailStatus.MODIFIED, content=result_content)
return RailsResult(status=RailStatus.PASSED, content=result_content)
[docs]
def check(
self,
messages: List[dict],
rail_types: Optional[List[RailType]] = None,
) -> RailsResult:
"""Run rails on messages based on their content (synchronous).
This is a synchronous wrapper around check_async().
Args:
messages: List of message dicts with 'role' and 'content' fields.
rails: Optional list of rail types to run. See check_async() for details.
Returns:
RailsResult containing status, content, and optional blocking rail name.
"""
if check_sync_call_from_async_loop():
raise RuntimeError(
"You are using the sync `check` inside async code. You should replace with `await check_async(...)`."
)
loop = get_or_create_event_loop()
return loop.run_until_complete(self.check_async(messages, rail_types=rail_types))
[docs]
def register_action(self, action: Callable, name: Optional[str] = None) -> Self:
"""Register a custom action for the rails configuration."""
self.runtime.register_action(action, name)
return self
[docs]
def register_action_param(self, name: str, value: Any) -> Self:
"""Registers a custom action parameter."""
self.runtime.register_action_param(name, value)
return self
[docs]
def register_filter(self, filter_fn: Callable, name: Optional[str] = None) -> Self:
"""Register a custom filter for the rails configuration."""
self.runtime.llm_task_manager.register_filter(filter_fn, name)
return self
[docs]
def register_output_parser(self, output_parser: Callable, name: str) -> Self:
"""Register a custom output parser for the rails configuration."""
self.runtime.llm_task_manager.register_output_parser(output_parser, name)
return self
[docs]
def register_prompt_context(self, name: str, value_or_fn: Any) -> Self:
"""Register a value to be included in the prompt context.
:name: The name of the variable or function that will be used.
:value_or_fn: The value or function that will be used to generate the value.
"""
self.runtime.llm_task_manager.register_prompt_context(name, value_or_fn)
return self
[docs]
def register_embedding_search_provider(self, name: str, cls: Type[EmbeddingsIndex]) -> Self:
"""Register a new embedding search provider.
Args:
name: The name of the embedding search provider that will be used.
cls: The class that will be used to generate and search embedding
"""
self.embedding_search_providers[name] = cls
return self
[docs]
def register_embedding_provider(self, cls: Type[EmbeddingModel], name: Optional[str] = None) -> Self:
"""Register a custom embedding provider.
Args:
model (Type[EmbeddingModel]): The embedding model class.
name (str): The name of the embedding engine. If available in the model, it will be used.
Raises:
ValueError: If the engine name is not provided and the model does not have an engine name.
ValueError: If the model does not have 'encode' or 'encode_async' methods.
"""
register_embedding_provider(engine_name=name, model=cls)
return self
[docs]
def explain(self) -> ExplainInfo:
"""Helper function to return the latest ExplainInfo object."""
if self.explain_info is None:
self.explain_info = self._ensure_explain_info()
return self.explain_info
def __getstate__(self):
return {"config": self.config}
def __setstate__(self, state):
if state["config"].config_path:
config = RailsConfig.from_path(state["config"].config_path)
else:
config = state["config"]
self.__init__(config=config, verbose=False)
async def _run_output_rails_in_streaming(
self,
streaming_handler: AsyncIterator[str],
output_rails_streaming_config: OutputRailsStreamingConfig,
prompt: Optional[str] = None,
messages: Optional[List[dict]] = None,
stream_first: Optional[bool] = None,
) -> AsyncIterator[str]:
"""
1. Buffers tokens from 'streaming_handler' via BufferStrategy.
2. Runs sequential (parallel for colang 2.0 in future) flows for each chunk.
3. Yields the chunk if not blocked, or STOP if blocked.
"""
def _get_last_context_message(
messages: Optional[List[dict]] = None,
) -> dict:
if messages is None:
return {}
for message in reversed(messages):
if message.get("role") == "context":
return message
return {}
def _get_latest_user_message(
messages: Optional[List[dict]] = None,
) -> dict:
if messages is None:
return {}
for message in reversed(messages):
if message.get("role") == "user":
return message
return {}
def _prepare_context_for_parallel_rails(
chunk_str: str,
prompt: Optional[str] = None,
messages: Optional[List[dict]] = None,
) -> dict:
"""Prepare context for parallel rails execution."""
context_message = _get_last_context_message(messages)
user_message = prompt or _get_latest_user_message(messages)
context = {
"user_message": user_message,
"bot_message": chunk_str,
}
if context_message:
context.update(context_message["content"])
return context
def _create_events_for_chunk(chunk_str: str, context: dict) -> List[dict]:
"""Create events for running output rails on a chunk."""
return [
{"type": "ContextUpdate", "data": context},
{"type": "BotMessage", "text": chunk_str},
]
def _prepare_params(
flow_id: str,
action_name: str,
bot_response_chunk: str,
prompt: Optional[str] = None,
messages: Optional[List[dict]] = None,
action_params: Dict[str, Any] = {},
):
context_message = _get_last_context_message(messages)
user_message = prompt or _get_latest_user_message(messages)
context = {
"user_message": user_message,
"bot_message": bot_response_chunk,
}
if context_message:
context.update(context_message["content"])
model_name = flow_id.split("$")[-1].split("=")[-1].strip('"')
# we pass action params that are defined in the flow
# caveate, e.g. prmpt_security uses bot_response=$bot_message
# to resolve replace placeholders in action_params
for key, value in action_params.items():
if value == "$bot_message":
action_params[key] = bot_response_chunk
elif value == "$user_message":
action_params[key] = user_message
return {
# TODO:: are there other context variables that need to be passed?
# passing events to compute context was not successful
# context var failed due to different context
"context": context,
"llm_task_manager": self.runtime.llm_task_manager,
"config": self.config,
"model_name": model_name,
"llms": self.runtime.registered_action_params.get("llms", {}),
"llm": self.runtime.registered_action_params.get(f"{action_name}_llm", self.llm),
**action_params,
}
buffer_strategy = get_buffer_strategy(output_rails_streaming_config)
output_rails_flows_id = self.config.rails.output.flows
stream_first = stream_first or output_rails_streaming_config.stream_first
get_action_details = partial(get_action_details_from_flow_id, flows=self.config.flows)
parallel_mode = getattr(self.config.rails.output, "parallel", False)
async for chunk_batch in buffer_strategy(streaming_handler):
user_output_chunks = chunk_batch.user_output_chunks
# format processing_context for output rails processing (needs full context)
bot_response_chunk = buffer_strategy.format_chunks(chunk_batch.processing_context)
# check if user_output_chunks is a list of individual chunks
# or if it's a JSON string, by convention this means an error occurred and the error dict is stored as a JSON
if not isinstance(user_output_chunks, list):
try:
json.loads(user_output_chunks)
yield user_output_chunks
return
except (json.JSONDecodeError, TypeError):
# if it's not JSON, treat it as empty list
user_output_chunks = []
if stream_first:
# yield the individual chunks directly from the buffer strategy
for chunk in user_output_chunks:
yield chunk
if parallel_mode:
try:
context = _prepare_context_for_parallel_rails(bot_response_chunk, prompt, messages)
events = _create_events_for_chunk(bot_response_chunk, context)
flows_with_params = {}
for flow_id in output_rails_flows_id:
action_name, action_params = get_action_details(flow_id)
params = _prepare_params(
flow_id=flow_id,
action_name=action_name,
bot_response_chunk=bot_response_chunk,
prompt=prompt,
messages=messages,
action_params=action_params,
)
flows_with_params[flow_id] = {
"action_name": action_name,
"params": params,
}
result_tuple = await self.runtime.action_dispatcher.execute_action(
"run_output_rails_in_parallel_streaming",
{
"flows_with_params": flows_with_params,
"events": events,
},
)
# ActionDispatcher.execute_action always returns (result, status)
result, status = result_tuple
if status != "success":
log.error(f"Parallel rails execution failed with status: {status}")
# continue processing the chunk even if rails fail
pass
else:
# if there are any stop events, content was blocked or internal error occurred
result_events = getattr(result, "events", None)
if result_events:
# extract the flow info from the first stop event
stop_event = result_events[0]
blocked_flow = stop_event.get("flow_id", "output rails")
error_type = stop_event.get("error_type")
if error_type == "internal_error":
error_message = stop_event.get("error_message", "Unknown error")
reason = f"Internal error in {blocked_flow} rail: {error_message}"
error_code = "rail_execution_failure"
error_type = "internal_error"
else:
reason = f"Blocked by {blocked_flow} rails."
error_code = "content_blocked"
error_type = "guardrails_violation"
error_data = {
"error": {
"message": reason,
"type": error_type,
"param": blocked_flow,
"code": error_code,
}
}
yield json.dumps(error_data)
return
except Exception as e:
log.error(f"Error in parallel rail execution: {e}")
# don't block the stream for rail execution errors
# continue processing the chunk
pass
# update explain info for parallel mode
self.explain_info = self._ensure_explain_info()
else:
for flow_id in output_rails_flows_id:
action_name, action_params = get_action_details(flow_id)
params = _prepare_params(
flow_id=flow_id,
action_name=action_name,
bot_response_chunk=bot_response_chunk,
prompt=prompt,
messages=messages,
action_params=action_params,
)
result = await self.runtime.action_dispatcher.execute_action(action_name, params)
self.explain_info = self._ensure_explain_info()
action_func = self.runtime.action_dispatcher.get_action(action_name)
# Use the mapping to decide if the result indicates blocked content.
if is_output_blocked(result, action_func):
reason = f"Blocked by {flow_id} rails."
# return the error as a plain JSON string (not in SSE format)
# NOTE: When integrating with the OpenAI Python client, the server code should:
# 1. detect this JSON error object in the stream
# 2. terminate the stream
# 3. format the error following OpenAI's SSE format
# the OpenAI client will then properly raise an APIError with this error message
error_data = {
"error": {
"message": reason,
"type": "guardrails_violation",
"param": flow_id,
"code": "content_blocked",
}
}
# return as plain JSON: the server should detect this JSON and convert it to an HTTP error
yield json.dumps(error_data)
return
if not stream_first:
# yield the individual chunks directly from the buffer strategy
for chunk in user_output_chunks:
yield chunk
def _determine_rails_from_messages(messages: List[dict]) -> Optional[dict]:
roles = {msg.get("role") for msg in reversed(messages)}
has_user = "user" in roles
has_assistant = "assistant" in roles
if not has_user and not has_assistant:
log.warning(
"check() called with no user or assistant messages. "
"Only system, context, or tool messages found. "
"Returning passing result without running rails."
)
return None
if has_user and has_assistant:
return {"rails": ["input", "output"]}
if has_user:
return {"rails": ["input"]}
return {"rails": ["output"]}
def _normalize_messages_for_rails(
messages: List[dict],
rails: List[str],
) -> List[dict]:
if rails == ["output"]:
has_user = any(msg.get("role") == "user" for msg in messages)
if not has_user:
return [{"role": "user", "content": ""}] + messages
return messages
def _get_last_content_by_role(messages: List[dict], role: str) -> str:
for msg in reversed(messages):
if msg.get("role") == role:
return msg.get("content", "")
return ""
def _get_blocking_rail(response: "GenerationResponse") -> Optional[str]:
if response.log and response.log.activated_rails:
for rail in response.log.activated_rails:
if rail.stop:
return rail.name
return None
def _get_last_response_content(response: "GenerationResponse") -> str:
if isinstance(response.response, list) and response.response:
return response.response[-1].get("content", "")
if isinstance(response.response, str):
return response.response
return ""