Creating Custom Actions#
This section describes how to create custom actions in the actions.py file.
The @action Decorator#
Use the @action decorator from nemoguardrails.actions to define custom actions:
from nemoguardrails.actions import action
@action()
async def my_custom_action():
"""A simple custom action."""
return "result"
Decorator Parameters#
Parameter |
Type |
Description |
Default |
|---|---|---|---|
|
|
Custom name for the action |
Function name |
|
|
Always run locally, bypassing the actions server |
|
|
|
Don’t block event processing while the action runs (Colang 2.x only) |
|
|
|
Function to interpret the action result for blocking decisions |
|
Custom Action Name#
Override the default action name:
@action(name="validate_user_input")
async def check_input(text: str):
"""Validates user input."""
return len(text) > 0
Call from Colang:
$is_valid = execute validate_user_input(text=$user_message)
System Actions#
When is_system_action=True, the action always runs locally, even when an actions_server_url is configured. This is important for actions that need access to special parameters like context, llm, config, and events, which are only injected for locally-run actions.
Note
When no actions_server_url is configured, all actions run locally and receive special parameters regardless of the is_system_action setting. The flag only affects behavior when an actions server is in use.
@action(is_system_action=True)
async def check_policy_compliance(context: Optional[dict] = None):
"""Check if message complies with policy."""
message = context.get("last_user_message", "")
# Validation logic
return True
Async Execution#
When execute_async=True, the event processing loop does not wait for the action to complete before continuing. The action runs in the background and the result is picked up later via polling. This is useful for long-running operations where you don’t need the result immediately.
Note
This flag is only supported in the Colang 2.x runtime. In the Colang 1.0 runtime, it is stored in metadata but has no effect.
@action(execute_async=True)
async def call_external_api(endpoint: str):
"""Call an external API without blocking event processing."""
response = await http_client.get(endpoint)
return response.json()
Output Mapping#
The output_mapping parameter controls how the action’s return value is interpreted to determine if output should be blocked. It accepts a callable that takes the return value and returns True if the output is not safe (should be blocked).
When no output_mapping is provided, the default behavior is:
Boolean results:
Truemeans allowed,Falsemeans blockedNumeric results: Values below
0.5are blockedOther types: Allowed by default
@action(output_mapping=lambda value: value)
async def check_hallucination(context: Optional[dict] = None):
"""Return True if hallucination detected (blocked), False if safe."""
return detect_hallucination(context.get("bot_message", ""))
@action(is_system_action=True, output_mapping=lambda value: not value)
async def check_output_safety(context: Optional[dict] = None):
"""Return True if safe (allowed), mapped to not-blocked."""
return is_safe(context.get("bot_message", ""))
You can also define a custom mapping function for more complex logic:
def my_custom_mapping(result):
if isinstance(result, dict):
return result.get("score", 1.0) < 0.7
return False
@action(output_mapping=my_custom_mapping)
async def score_safety(context: Optional[dict] = None):
"""Return a dict with a safety score."""
return {"score": compute_score(context.get("bot_message", ""))}
Function Parameters#
Actions can accept parameters of the following types:
Type |
Example |
|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
Basic Parameters#
@action()
async def greet_user(name: str, formal: bool = False):
"""Generate a greeting."""
if formal:
return f"Good day, {name}."
return f"Hello, {name}!"
Call from Colang:
$greeting = execute greet_user(name="Alice", formal=True)
Optional Parameters with Defaults#
@action()
async def search_documents(
query: str,
max_results: int = 10,
include_metadata: bool = False
):
"""Search documents with optional parameters."""
results = perform_search(query, limit=max_results)
if include_metadata:
return {"results": results, "count": len(results)}
return results
Return Values#
Actions can return various types:
Simple Return#
@action()
async def get_status():
return "active"
Dictionary Return#
@action()
async def get_user_info(user_id: str):
return {
"id": user_id,
"name": "John Doe",
"role": "admin"
}
Boolean Return (for validation)#
@action(is_system_action=True)
async def is_safe_content(context: Optional[dict] = None):
content = context.get("bot_message", "")
# Returns True if safe, False if blocked
return not contains_harmful_content(content)
Error Handling#
Handle errors gracefully within actions:
@action()
async def fetch_data(url: str):
"""Fetch data with error handling."""
try:
response = await http_client.get(url)
response.raise_for_status()
return response.json()
except Exception as e:
# Log the error
print(f"Error fetching data: {e}")
# Return a safe default or raise
return None
Example Actions#
Input Validation Action#
from typing import Optional
from nemoguardrails.actions import action
@action(is_system_action=True)
async def check_input_length(context: Optional[dict] = None):
"""Ensure user input is not too long."""
user_message = context.get("last_user_message", "")
max_length = 1000
if len(user_message) > max_length:
return False # Block the input
return True # Allow the input
Output Filtering Action#
@action(is_system_action=True)
async def filter_sensitive_data(context: Optional[dict] = None):
"""Check for sensitive data in bot response."""
bot_response = context.get("bot_message", "")
sensitive_patterns = [
r"\b\d{3}-\d{2}-\d{4}\b", # SSN pattern
r"\b\d{16}\b", # Credit card pattern
]
import re
for pattern in sensitive_patterns:
if re.search(pattern, bot_response):
return True # Contains sensitive data
return False # No sensitive data found
External API Action#
import aiohttp
@action(execute_async=True)
async def query_knowledge_base(query: str, top_k: int = 5):
"""Query an external knowledge base API."""
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.example.com/search",
json={"query": query, "limit": top_k}
) as response:
data = await response.json()
return data.get("results", [])