# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Public LLM types for NeMo Guardrails.
This module defines the stable, public dataclasses and Protocols that make
up the LLM interop surface for NeMo Guardrails. The types are implemented
as plain Python dataclasses (not Pydantic) so they remain lightweight and
introduce no additional runtime dependencies for downstream integrators.
The public surface defined here is stable across minor versions: breaking
changes are reserved for major version bumps. Custom ``LLMModel`` and
``LLMFramework`` implementations should import the relevant types from
this module so they stay aligned with the canonical definitions.
"""
import json
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, AsyncIterator, Dict, List, Literal, Optional, Protocol, Union, runtime_checkable
__all__ = [
"ChatMessage",
"FinishReason",
"LLMFramework",
"LLMModel",
"LLMResponse",
"LLMResponseChunk",
"Role",
"ToolCall",
"ToolCallFunction",
"UsageInfo",
]
class Role(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
TOOL = "tool"
[docs]
@dataclass
class UsageInfo:
input_tokens: int = 0
output_tokens: int = 0
total_tokens: int = 0
reasoning_tokens: Optional[int] = None
cached_tokens: Optional[int] = None
FinishReason = Literal["stop", "length", "tool_calls", "content_filter", "error", "other"]
_STANDARD_MESSAGE_KEYS = {"role", "content", "tool_calls", "tool_call_id", "name", "provider_metadata"}
_ROLE_ALIASES = {
"bot": Role.ASSISTANT,
"assistant": Role.ASSISTANT,
"human": Role.USER,
"user": Role.USER,
"developer": Role.SYSTEM,
"system": Role.SYSTEM,
"tool": Role.TOOL,
}
[docs]
@dataclass
class ChatMessage:
role: Role
content: Optional[Union[str, List[Dict[str, Any]]]] = None
tool_calls: Optional[List[ToolCall]] = None
tool_call_id: Optional[str] = None
name: Optional[str] = None
provider_metadata: Dict[str, Any] = field(default_factory=dict)
[docs]
@classmethod
def from_user(cls, content: str, **kwargs) -> "ChatMessage":
return cls(role=Role.USER, content=content, **kwargs)
[docs]
@classmethod
def from_assistant(cls, content: str, **kwargs) -> "ChatMessage":
return cls(role=Role.ASSISTANT, content=content, **kwargs)
[docs]
@classmethod
def from_system(cls, content: str, **kwargs) -> "ChatMessage":
return cls(role=Role.SYSTEM, content=content, **kwargs)
[docs]
def to_dict(self) -> Dict[str, Any]:
payload: Dict[str, Any] = {"role": self.role.value}
if self.content is not None:
payload["content"] = self.content
if self.tool_calls is not None:
payload["tool_calls"] = [tc.to_dict() for tc in self.tool_calls]
if self.tool_call_id is not None:
payload["tool_call_id"] = self.tool_call_id
if self.name is not None:
payload["name"] = self.name
if self.provider_metadata:
payload["provider_metadata"] = self.provider_metadata
return payload
[docs]
@classmethod
def from_dict(cls, d: Dict[str, Any]) -> "ChatMessage":
"""Create a ChatMessage from a dict.
Accepts both the canonical nested tool call format
(``{"function": {"name": ..., "arguments": ...}}``) and the legacy
flat format (``{"name": ..., "args": ...}``). JSON string arguments
are parsed automatically. Role aliases like "bot", "human", and
"developer" are mapped to canonical Role values. Unknown keys are
captured into ``provider_metadata``.
"""
raw_role = d.get("role") or d.get("type")
if raw_role is None:
raise ValueError("Missing required key: 'role'")
role = _ROLE_ALIASES.get(raw_role)
if role is None:
raise ValueError(f"Unknown role: {raw_role}")
tool_calls = None
raw_tool_calls = d.get("tool_calls")
if raw_tool_calls is not None:
tool_calls = []
for tc in raw_tool_calls:
func_data = tc.get("function")
if func_data is not None:
raw_args = func_data.get("arguments", {})
else:
raw_args = tc.get("args", {})
func_data = {"name": tc.get("name", "")}
if isinstance(raw_args, str):
try:
args_dict = json.loads(raw_args)
except json.JSONDecodeError:
raise ValueError(f"Tool call arguments are not valid JSON: {raw_args!r}")
if not isinstance(args_dict, dict):
raise ValueError(
f"Tool call arguments must be a JSON object, got {type(args_dict).__name__}: {raw_args!r}"
)
else:
if not isinstance(raw_args, dict):
raise ValueError(
f"Tool call arguments must be a dict, got {type(raw_args).__name__}: {raw_args!r}"
)
args_dict = raw_args
tool_calls.append(
ToolCall(
id=tc.get("id", ""),
type=tc.get("type", "function"),
function=ToolCallFunction(
name=func_data.get("name", ""),
arguments=args_dict,
),
)
)
extra = {k: v for k, v in d.items() if k not in _STANDARD_MESSAGE_KEYS}
provider_metadata = {**extra, **d.get("provider_metadata", {})}
return cls(
role=role,
content=d.get("content"),
tool_calls=tool_calls,
tool_call_id=d.get("tool_call_id"),
name=d.get("name"),
provider_metadata=provider_metadata,
)
[docs]
@dataclass
class LLMResponse:
content: str
reasoning: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = None
model: Optional[str] = None
finish_reason: Optional[FinishReason] = None
stop_sequence: Optional[str] = None
request_id: Optional[str] = None
usage: Optional[UsageInfo] = None
provider_metadata: Optional[Dict[str, Any]] = None
[docs]
@dataclass
class LLMResponseChunk:
delta_content: Optional[str] = None
delta_reasoning: Optional[str] = None
delta_tool_calls: Optional[List[ToolCall]] = None
model: Optional[str] = None
finish_reason: Optional[FinishReason] = None
request_id: Optional[str] = None
usage: Optional[UsageInfo] = None
provider_metadata: Optional[Dict[str, Any]] = None
[docs]
@runtime_checkable
class LLMModel(Protocol):
"""Protocol that all LLM backends must implement.
Adapters wrap provider-specific SDKs (LangChain, LiteLLM, OpenAI, etc.)
behind this interface so the core pipeline remains framework-agnostic.
``prompt`` accepts either a plain string or a list of ``ChatMessage``
objects. Adapters convert ``ChatMessage`` to whatever their SDK expects.
``**kwargs`` are forwarded to the underlying SDK (e.g. temperature,
max_tokens).
"""
[docs]
async def generate_async(
self,
prompt: Union[str, List["ChatMessage"]],
*,
stop: Optional[List[str]] = None,
**kwargs,
) -> "LLMResponse": ...
[docs]
def stream_async(
self,
prompt: Union[str, List["ChatMessage"]],
*,
stop: Optional[List[str]] = None,
**kwargs,
) -> AsyncIterator["LLMResponseChunk"]:
"""Implementations must be async generator functions (use ``yield``)."""
...
@property
def model_name(self) -> str: ...
@property
def provider_name(self) -> Optional[str]: ...
@property
def provider_url(self) -> Optional[str]: ...
[docs]
@runtime_checkable
class LLMFramework(Protocol):
"""Protocol for pluggable LLM framework backends.
Each framework (LangChain, LiteLLM, etc.) implements this protocol to
provide a factory for creating ``LLMModel`` instances and managing
its own set of providers.
``model_kwargs`` carries all provider-specific configuration. Framework
implementations extract what they need (e.g. LangChain pops ``mode``
to choose between chat and text completion models).
"""
self,
model_name: str,
provider_name: str,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> LLMModel: ...
[docs]
def register_provider(self, name: str, provider_cls: Any) -> None: ...
[docs]
def get_provider_names(self) -> List[str]: ...
[docs]
async def reset(self) -> None:
"""Release all framework-owned resources and clear all registered state.
Implementations should close any pooled connections, clear registered
providers, and return the framework to its initial state. Callers can
continue using the instance after reset; new resources will be created
on demand.
"""
...