Creating Custom Actions#

This section describes how to create custom actions in the actions.py file.

The @action Decorator#

Use the @action decorator from nemoguardrails.actions to define custom actions:

from nemoguardrails.actions import action

@action()
async def my_custom_action():
    """A simple custom action."""
    return "result"

Decorator Parameters#

Parameter

Type

Description

Default

name

str

Custom name for the action

Function name

is_system_action

bool

Mark as system action (runs in guardrails context)

False

execute_async

bool

Execute asynchronously without blocking

False

Custom Action Name#

Override the default action name:

@action(name="validate_user_input")
async def check_input(text: str):
    """Validates user input."""
    return len(text) > 0

Call from Colang:

$is_valid = execute validate_user_input(text=$user_message)

System Actions#

System actions have access to the guardrails context and are typically used for input/output validation:

@action(is_system_action=True)
async def check_policy_compliance(context: Optional[dict] = None):
    """Check if message complies with policy."""
    message = context.get("last_user_message", "")
    # Validation logic
    return True

Async Execution#

For long-running operations, use execute_async=True to prevent blocking:

@action(execute_async=True)
async def call_external_api(endpoint: str):
    """Call an external API without blocking."""
    response = await http_client.get(endpoint)
    return response.json()

Function Parameters#

Actions can accept parameters of the following types:

Type

Example

str

"hello"

int

42

float

3.14

bool

True

list

["a", "b", "c"]

dict

{"key": "value"}

Basic Parameters#

@action()
async def greet_user(name: str, formal: bool = False):
    """Generate a greeting."""
    if formal:
        return f"Good day, {name}."
    return f"Hello, {name}!"

Call from Colang:

$greeting = execute greet_user(name="Alice", formal=True)

Optional Parameters with Defaults#

@action()
async def search_documents(
    query: str,
    max_results: int = 10,
    include_metadata: bool = False
):
    """Search documents with optional parameters."""
    results = perform_search(query, limit=max_results)
    if include_metadata:
        return {"results": results, "count": len(results)}
    return results

Return Values#

Actions can return various types:

Simple Return#

@action()
async def get_status():
    return "active"

Dictionary Return#

@action()
async def get_user_info(user_id: str):
    return {
        "id": user_id,
        "name": "John Doe",
        "role": "admin"
    }

Boolean Return (for validation)#

@action(is_system_action=True)
async def is_safe_content(context: Optional[dict] = None):
    content = context.get("bot_message", "")
    # Returns True if safe, False if blocked
    return not contains_harmful_content(content)

Error Handling#

Handle errors gracefully within actions:

@action()
async def fetch_data(url: str):
    """Fetch data with error handling."""
    try:
        response = await http_client.get(url)
        response.raise_for_status()
        return response.json()
    except Exception as e:
        # Log the error
        print(f"Error fetching data: {e}")
        # Return a safe default or raise
        return None

Example Actions#

Input Validation Action#

from typing import Optional
from nemoguardrails.actions import action

@action(is_system_action=True)
async def check_input_length(context: Optional[dict] = None):
    """Ensure user input is not too long."""
    user_message = context.get("last_user_message", "")
    max_length = 1000

    if len(user_message) > max_length:
        return False  # Block the input

    return True  # Allow the input

Output Filtering Action#

@action(is_system_action=True)
async def filter_sensitive_data(context: Optional[dict] = None):
    """Check for sensitive data in bot response."""
    bot_response = context.get("bot_message", "")

    sensitive_patterns = [
        r"\b\d{3}-\d{2}-\d{4}\b",  # SSN pattern
        r"\b\d{16}\b",              # Credit card pattern
    ]

    import re
    for pattern in sensitive_patterns:
        if re.search(pattern, bot_response):
            return True  # Contains sensitive data

    return False  # No sensitive data found

External API Action#

import aiohttp

@action(execute_async=True)
async def query_knowledge_base(query: str, top_k: int = 5):
    """Query an external knowledge base API."""
    async with aiohttp.ClientSession() as session:
        async with session.post(
            "https://api.example.com/search",
            json={"query": query, "limit": top_k}
        ) as response:
            data = await response.json()
            return data.get("results", [])