Custom Initialization#
If present, the config.py module is loaded before initializing the LLMRails instance.
If the config.py module contains an init function, it gets called as part of the initialization of the LLMRails instance. For example, you can use the init function to initialize the connection to a database and register it as a custom action parameter using the register_action_param(...) function:
from nemoguardrails import LLMRails
def init(app: LLMRails):
# Initialize the database connection
db = ...
# Register the action parameter
app.register_action_param("db", db)
Custom action parameters are passed on to the custom actions when they are invoked.
Custom Data Access#
If you need to pass additional configuration data to any custom component for your configuration, you can use the custom_data field in your config.yml:
custom_data:
custom_config_field: "some_value"
For example, you can access the custom configuration inside the init function in your config.py:
def init(app: LLMRails):
config = app.config
# Do something with config.custom_data
Custom LLM Provider Registration#
NeMo Guardrails supports two types of custom LLM providers:
Text Completion Models (
BaseLLM) - For models that work with string promptsChat Models (
BaseChatModel) - For models that work with message-based conversations
Custom Text Completion LLM (BaseLLM)#
To register a custom text completion LLM provider, create a class that inherits from BaseLLM and register it using register_llm_provider.
Required methods:
_call- Synchronous text completion_llm_type- Returns the LLM type identifier
Optional methods:
_acall- Asynchronous text completion (recommended)_stream- Streaming text completion_astream- Async streaming text completion_identifying_params- Returns parameters for model identification
from typing import Any, Iterator, List, Optional
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseLLM
from langchain_core.outputs import GenerationChunk
from nemoguardrails.llm.providers import register_llm_provider
class MyCustomTextLLM(BaseLLM):
"""Custom text completion LLM."""
@property
def _llm_type(self) -> str:
return "custom_text_llm"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Synchronous text completion."""
# Your implementation here
return "Generated text response"
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Asynchronous text completion (recommended)."""
# Your async implementation here
return "Generated text response"
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Optional: Streaming text completion."""
# Yield chunks of text
yield GenerationChunk(text="chunk1")
yield GenerationChunk(text="chunk2")
register_llm_provider("custom_text_llm", MyCustomTextLLM)
Custom Chat Model (BaseChatModel)#
To register a custom chat model, create a class that inherits from BaseChatModel and register it using register_chat_provider.
Required methods:
_generate- Synchronous chat completion_llm_type- Returns the LLM type identifier
Optional methods:
_agenerate- Asynchronous chat completion (recommended)_stream- Streaming chat completion_astream- Async streaming chat completion
from typing import Any, Iterator, List, Optional
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from nemoguardrails.llm.providers import register_chat_provider
class MyCustomChatModel(BaseChatModel):
"""Custom chat model."""
@property
def _llm_type(self) -> str:
return "custom_chat_model"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Synchronous chat completion."""
# Convert messages to your model's format and generate response
response_text = "Generated chat response"
message = AIMessage(content=response_text)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Asynchronous chat completion (recommended)."""
# Your async implementation
response_text = "Generated chat response"
message = AIMessage(content=response_text)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Optional: Streaming chat completion."""
# Yield chunks
chunk = ChatGenerationChunk(message=AIMessageChunk(content="chunk1"))
yield chunk
register_chat_provider("custom_chat_model", MyCustomChatModel)
Using Custom LLM Providers#
After registering your custom provider, you can use it in your configuration:
models:
- type: main
engine: custom_text_llm # or custom_chat_model
Important Notes#
Import from langchain-core: Always import base classes from
langchain_core.language_models:from langchain_core.language_models import BaseLLM, BaseChatModel
Implement async methods: For better performance, always implement
_acall(for BaseLLM) or_agenerate(for BaseChatModel).Choose the right base class:
Use
BaseLLMfor text completion models (prompt → text)Use
BaseChatModelfor chat models (messages → message)
Registration functions:
Use
register_llm_provider()forBaseLLMsubclassesUse
register_chat_provider()forBaseChatModelsubclasses
Custom Embedding Provider Registration#
You can also register a custom embedding provider by using the LLMRails.register_embedding_provider function.
To register a custom embedding provider, create a class that inherits from EmbeddingModel and register it in your config.py.
from typing import List
from nemoguardrails.embeddings.providers.base import EmbeddingModel
from nemoguardrails import LLMRails
class CustomEmbeddingModel(EmbeddingModel):
"""An implementation of a custom embedding provider."""
engine_name = "CustomEmbeddingModel"
def __init__(self, embedding_model: str):
# Initialize the model
...
async def encode_async(self, documents: List[str]) -> List[List[float]]:
"""Encode the provided documents into embeddings.
Args:
documents (List[str]): The list of documents for which embeddings should be created.
Returns:
List[List[float]]: The list of embeddings corresponding to the input documents.
"""
...
def encode(self, documents: List[str]) -> List[List[float]]:
"""Encode the provided documents into embeddings.
Args:
documents (List[str]): The list of documents for which embeddings should be created.
Returns:
List[List[float]]: The list of embeddings corresponding to the input documents.
"""
...
def init(app: LLMRails):
"""Initialization function in your config.py."""
app.register_embedding_provider(CustomEmbeddingModel, "CustomEmbeddingModel")
You can then use the custom embedding provider in your configuration:
models:
# ...
- type: embeddings
engine: SomeCustomName
model: SomeModelName # supported by the provider.