nearest_neighbors
nearest_neighbors
¶
Nearest neighbor search abstraction with GPU acceleration support.
This module provides a unified interface for nearest neighbor search that: - Uses PyTorch (CUDA) when available for GPU-accelerated search - Falls back to sklearn NearestNeighbors for CPU-only environments
Usage
from nemo_safe_synthesizer.evaluation.nearest_neighbors import NearestNeighborSearch
Create search index¶
nn = NearestNeighborSearch(n_neighbors=5) nn.fit(data) # numpy array of shape (n_samples, n_features)
Query¶
distances, indices = nn.kneighbors(queries) # input is numpy array of shape (n_queries, n_features)
Classes:
| Name | Description |
|---|---|
NearestNeighborSearch |
Unified nearest neighbor search with GPU acceleration support. |
NearestNeighborSearch(n_neighbors=5)
¶
Unified nearest neighbor search with GPU acceleration support.
Uses PyTorch (CUDA) when available, falls back to sklearn otherwise. Both backends compute exact brute-force L2 distance for consistency.
Attributes:
| Name | Type | Description |
|---|---|---|
n_neighbors |
Number of neighbors to return in queries. |
|
use_gpu |
Whether GPU acceleration is being used. |
Initialize the nearest neighbor search.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_neighbors
|
int
|
Number of neighbors to find in queries. |
5
|
Methods:
| Name | Description |
|---|---|
fit |
Build the search index from data. |
kneighbors |
Find k nearest neighbors for query points. |
Source code in src/nemo_safe_synthesizer/evaluation/nearest_neighbors.py
backend_name
property
¶
Return the name of the backend being used.
fit(data)
¶
Build the search index from data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
ndarray
|
Array of shape (n_samples, n_features) with float32 values. |
required |
Returns:
| Type | Description |
|---|---|
NearestNeighborSearch
|
Self for method chaining. |
Source code in src/nemo_safe_synthesizer/evaluation/nearest_neighbors.py
kneighbors(queries, n_neighbors=None)
¶
Find k nearest neighbors for query points.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
queries
|
ndarray
|
Array of shape (n_queries, n_features) with query points. |
required |
n_neighbors
|
int | None
|
Number of neighbors to return. If None, uses self.n_neighbors. |
None
|
Returns:
| Type | Description |
|---|---|
ndarray
|
Tuple of (distances, indices) where: |
ndarray
|
|
tuple[ndarray, ndarray]
|
|