improvement
This commit is contained in:
137
ReadMe.md
137
ReadMe.md
@@ -11,6 +11,9 @@ A library focused purely on gesture indexing mathematics for DHT-based path comp
|
||||
4. **Spectral Embeddings** - Laplacian eigenvalues for gesture signature
|
||||
5. **Dimensionality Reduction** - PCA for feature compression
|
||||
6. **Spatial Indexing** - Morton/Z-order curve for integer keys
|
||||
7. **Multiple Hashing Strategies** - Moment, spectral, hybrid, and global vector addressing
|
||||
8. **Multi-Probe Hashing** - Query neighboring buckets for improved recall
|
||||
9. **HNSW Index** - Approximate nearest neighbor search for fast similarity
|
||||
|
||||
## Usage Example
|
||||
|
||||
@@ -47,7 +50,7 @@ fn main() {
|
||||
}
|
||||
```
|
||||
|
||||
### Similarity Search
|
||||
### Similarity Search with HNSW
|
||||
|
||||
```rust
|
||||
use redoal::*;
|
||||
@@ -82,6 +85,76 @@ fn find_similar_gestures(query: &[Point], database: &[(&str, Vec<Point>)]) -> Ve
|
||||
}
|
||||
```
|
||||
|
||||
### Multi-Probe Hashing for Improved Recall
|
||||
|
||||
```rust
|
||||
use redoal::*;
|
||||
|
||||
fn multi_probe_query(points: &[Point]) {
|
||||
// Use moment hash for stable partitioning
|
||||
let moment_key = hu_moment_hash(points, 10);
|
||||
|
||||
// Generate neighboring keys for multi-probe
|
||||
let neighbors = neighboring_keys(moment_key, 3);
|
||||
|
||||
println!("Querying {} buckets", neighbors.len());
|
||||
for key in neighbors {
|
||||
println!("Bucket: {}", key);
|
||||
}
|
||||
|
||||
// Or use adaptive probing
|
||||
let keys = adaptive_probe(
|
||||
points,
|
||||
10,
|
||||
5,
|
||||
15,
|
||||
HashStrategy::Hybrid,
|
||||
);
|
||||
|
||||
println!("Adaptive probe found {} keys", keys.len());
|
||||
}
|
||||
```
|
||||
|
||||
### HNSW Index for Fast Local Search
|
||||
|
||||
```rust
|
||||
use redoal::*;
|
||||
|
||||
fn hnsw_example() {
|
||||
// Create HNSW index
|
||||
let mut index = HnswIndex::new(HnswConfig::default());
|
||||
|
||||
// Add gestures to index
|
||||
let gesture1 = vec![
|
||||
Point::new(0.0, 0.0),
|
||||
Point::new(1.0, 0.0),
|
||||
Point::new(0.5, 1.0),
|
||||
];
|
||||
|
||||
let norm1 = normalize(&gesture1);
|
||||
let resamp1 = resample(&norm1, 64);
|
||||
let embedding1 = spectral_signature(&resamp1, 4);
|
||||
index.add(&embedding1, "triangle");
|
||||
|
||||
// Query the index
|
||||
let query = vec![
|
||||
Point::new(0.1, 0.1),
|
||||
Point::new(1.1, 0.1),
|
||||
Point::new(0.6, 1.1),
|
||||
];
|
||||
|
||||
let norm = normalize(&query);
|
||||
let resamp = resample(&norm, 64);
|
||||
let embedding = spectral_signature(&resamp, 4);
|
||||
|
||||
let results = index.search(&embedding, 5);
|
||||
|
||||
for (label, distance) in results {
|
||||
println!("Found {} at distance {}", label, distance);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Mathematical Operations
|
||||
|
||||
| Module | Function | Purpose |
|
||||
@@ -93,6 +166,66 @@ fn find_similar_gestures(query: &[Point], database: &[(&str, Vec<Point>)]) -> Ve
|
||||
| `spectral` | `spectral_signature(points, k)` | Compute k Laplacian eigenvalues |
|
||||
| `pca` | `pca(data, k)` | Dimensionality reduction to k principal components |
|
||||
| `morton` | `morton2(x, y)` | Convert 2D coordinates to 64-bit Morton code |
|
||||
| `hashing` | `hu_moment_hash()` | Moment-based hashing for DHT |
|
||||
| `hashing` | `spectral_hash()` | Spectral-based hashing |
|
||||
| `hashing` | `hybrid_hash()` | Combined moment+spectral hashing |
|
||||
| `hashing` | `vector_to_dht_key()` | Global vector addressing |
|
||||
| `hashing` | `neighboring_keys()` | Multi-probe hashing |
|
||||
| `hashing` | `adaptive_probe()` | Adaptive query planning |
|
||||
| `hnsw` | `HnswIndex` | Approximate nearest neighbor search |
|
||||
|
||||
## Hashing Strategies
|
||||
|
||||
### Moment Hash (Stable)
|
||||
- Uses Hu invariant moments
|
||||
- Translation and scale invariant
|
||||
- Good for broad gesture categories
|
||||
- Coarse partitioning
|
||||
|
||||
### Spectral Hash (Precise)
|
||||
- Uses Laplacian eigenvalues
|
||||
- More sensitive to small changes
|
||||
- Better for fine-grained similarity
|
||||
- Requires multi-probe for robustness
|
||||
|
||||
### Hybrid Hash (Balanced)
|
||||
- Combines moment and spectral features
|
||||
- Weighted fusion for optimal balance
|
||||
- Good default choice
|
||||
|
||||
### Global Vector Addressing
|
||||
- Directly maps embeddings to DHT keys
|
||||
- No intermediate hashing
|
||||
- Most precise but requires careful quantization
|
||||
|
||||
## Multi-Probe Hashing
|
||||
|
||||
To handle hash instability and improve recall:
|
||||
|
||||
```rust
|
||||
// Basic multi-probe
|
||||
let key = hu_moment_hash(points, 10);
|
||||
let neighbors = neighboring_keys(key, 3); // Query 3 neighboring buckets
|
||||
|
||||
// Adaptive probing
|
||||
let keys = adaptive_probe(
|
||||
points,
|
||||
10, // quantization bits
|
||||
5, // min peers
|
||||
15, // max peers
|
||||
HashStrategy::Hybrid, // hash strategy
|
||||
);
|
||||
```
|
||||
|
||||
## HNSW Index
|
||||
|
||||
For fast local similarity search on peers:
|
||||
|
||||
```rust
|
||||
let mut index = HnswIndex::new(HnswConfig::default());
|
||||
index.add(&embedding, "gesture_label");
|
||||
let results = index.search(&query_embedding, 10);
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
@@ -108,4 +241,4 @@ Run tests with:
|
||||
cargo test
|
||||
```
|
||||
|
||||
All tests pass, demonstrating correct implementation of gesture indexing mathematics.
|
||||
All tests pass, demonstrating correct implementation of gesture indexing mathematics and distributed search capabilities.
|
||||
|
||||
375
src/hashing.rs
Normal file
375
src/hashing.rs
Normal file
@@ -0,0 +1,375 @@
|
||||
/// Gesture hashing strategies for DHT-based similarity search
|
||||
///
|
||||
/// This module provides multiple hashing techniques:
|
||||
/// - Moment-based hashing (stable but coarse)
|
||||
/// - Spectral-based hashing (precise but sensitive)
|
||||
/// - Hybrid hashing (balanced approach)
|
||||
/// - Global vector addressing (direct embedding to DHT key)
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use redoal::point::Point;
|
||||
/// use redoal::{hu_moment_hash, spectral_hash, hybrid_hash, vector_to_dht_key};
|
||||
///
|
||||
/// let points = vec![
|
||||
/// Point::new(0.0, 0.0),
|
||||
/// Point::new(1.0, 0.0),
|
||||
/// Point::new(0.5, 1.0),
|
||||
/// ];
|
||||
///
|
||||
/// // Moment-based hash (stable)
|
||||
/// let moment_key = hu_moment_hash(&points, 10);
|
||||
///
|
||||
/// // Spectral-based hash (precise)
|
||||
/// let spectral_key = spectral_hash(&points, 10);
|
||||
///
|
||||
/// // Hybrid hash (balanced)
|
||||
/// let hybrid_key = hybrid_hash(&points, 0.7, 0.3, 10);
|
||||
///
|
||||
/// // Global vector addressing
|
||||
/// let spectral = redoal::spectral_signature(&points, 4);
|
||||
/// let vector_key = vector_to_dht_key(&spectral, 12);
|
||||
/// ```
|
||||
use crate::Point;
|
||||
use crate::hu_moments;
|
||||
use crate::spectral_signature;
|
||||
|
||||
/// Moment-based hash using Hu moments
|
||||
///
|
||||
/// This is the most stable hash but provides coarse partitioning.
|
||||
/// Good for broad gesture categories (circles, lines, spirals).
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `points` - Gesture points
|
||||
/// * `quantization_bits` - Number of bits for quantization (higher = more buckets)
|
||||
///
|
||||
/// # Returns
|
||||
/// 64-bit hash key
|
||||
pub fn hu_moment_hash(points: &[Point], quantization_bits: u32) -> u64 {
|
||||
let moments = hu_moments(points);
|
||||
|
||||
// Quantize moments to integer range
|
||||
let quantize = |m: f64| {
|
||||
let scale = (1u64 << quantization_bits) as f64;
|
||||
((m * scale).rem_euclid(scale)) as u64
|
||||
};
|
||||
|
||||
// Combine first two moments into a single key
|
||||
let m1 = quantize(moments[0]);
|
||||
let m2 = quantize(moments[1]);
|
||||
|
||||
// Simple combination: interleave bits
|
||||
let mut key = 0u64;
|
||||
for i in 0..quantization_bits {
|
||||
let bit1 = (m1 >> i) & 1;
|
||||
let bit2 = (m2 >> i) & 1;
|
||||
key |= (bit1 << (2 * i));
|
||||
key |= (bit2 << (2 * i + 1));
|
||||
}
|
||||
|
||||
key
|
||||
}
|
||||
|
||||
/// Spectral-based hash using Laplacian eigenvalues
|
||||
///
|
||||
/// More precise than moment hash but sensitive to small changes.
|
||||
/// Good for fine-grained similarity but requires multi-probe.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `points` - Gesture points
|
||||
/// * `quantization_bits` - Number of bits for quantization
|
||||
///
|
||||
/// # Returns
|
||||
/// 64-bit hash key
|
||||
pub fn spectral_hash(points: &[Point], quantization_bits: u32) -> u64 {
|
||||
let spectral = spectral_signature(points, 4);
|
||||
|
||||
if spectral.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Quantize spectral values
|
||||
let quantize = |s: f64| {
|
||||
let scale = (1u64 << quantization_bits) as f64;
|
||||
((s * scale).rem_euclid(scale)) as u64
|
||||
};
|
||||
|
||||
let s1 = quantize(spectral[0]);
|
||||
let s2 = quantize(spectral[1]);
|
||||
let s3 = quantize(spectral[2]);
|
||||
let s4 = quantize(spectral[3]);
|
||||
|
||||
// Combine all four spectral values
|
||||
let mut key = 0u64;
|
||||
for i in 0..quantization_bits {
|
||||
let bit1 = (s1 >> i) & 1;
|
||||
let bit2 = (s2 >> i) & 1;
|
||||
let bit3 = (s3 >> i) & 1;
|
||||
let bit4 = (s4 >> i) & 1;
|
||||
key |= (bit1 << (4 * i));
|
||||
key |= (bit2 << (4 * i + 1));
|
||||
key |= (bit3 << (4 * i + 2));
|
||||
key |= (bit4 << (4 * i + 3));
|
||||
}
|
||||
|
||||
key
|
||||
}
|
||||
|
||||
/// Hybrid hash combining moment and spectral features
|
||||
///
|
||||
/// Balances stability and precision by combining both approaches.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `points` - Gesture points
|
||||
/// * `moment_weight` - Weight for moment features (0.0-1.0)
|
||||
/// * `spectral_weight` - Weight for spectral features (0.0-1.0)
|
||||
/// * `quantization_bits` - Number of bits for quantization
|
||||
///
|
||||
/// # Returns
|
||||
/// 64-bit hash key
|
||||
pub fn hybrid_hash(
|
||||
points: &[Point],
|
||||
moment_weight: f64,
|
||||
spectral_weight: f64,
|
||||
quantization_bits: u32,
|
||||
) -> u64 {
|
||||
let moment_key = hu_moment_hash(points, quantization_bits);
|
||||
let spectral_key = spectral_hash(points, quantization_bits);
|
||||
|
||||
// Weighted combination
|
||||
let total_weight = moment_weight + spectral_weight;
|
||||
let moment_contrib = (moment_weight / total_weight) as f64 * (1u64 << 63) as f64;
|
||||
let spectral_contrib = (spectral_weight / total_weight) as f64 * (1u64 << 63) as f64;
|
||||
|
||||
((moment_key as u128 * moment_contrib as u128 +
|
||||
spectral_key as u128 * spectral_contrib as u128) as f64).round() as u64
|
||||
}
|
||||
|
||||
/// Global vector addressing using space-filling curve
|
||||
///
|
||||
/// Directly maps spectral embedding to DHT key without intermediate hashing.
|
||||
/// Most precise but requires careful quantization.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `embedding` - Spectral embedding vector
|
||||
/// * `bits_per_dim` - Bits allocated to each dimension
|
||||
///
|
||||
/// # Returns
|
||||
/// 64-bit DHT key
|
||||
pub fn vector_to_dht_key(embedding: &[f64], bits_per_dim: u32) -> u64 {
|
||||
if embedding.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let mut key = 0u64;
|
||||
let dims = embedding.len().min(4); // Use up to 4 dimensions
|
||||
|
||||
for i in 0..dims {
|
||||
let value = embedding[i];
|
||||
let scale = (1u64 << bits_per_dim) as f64;
|
||||
let quantized = ((value * scale).rem_euclid(scale)) as u64;
|
||||
|
||||
// Distribute bits across the 64-bit key
|
||||
let shift = i as u32 * bits_per_dim;
|
||||
key |= quantized << shift;
|
||||
}
|
||||
|
||||
key
|
||||
}
|
||||
|
||||
/// Generate neighboring keys for multi-probe hashing
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `base_key` - The primary hash key
|
||||
/// * `radius` - How many neighboring buckets to query
|
||||
///
|
||||
/// # Returns
|
||||
/// Vector of neighboring keys including the base key
|
||||
pub fn neighboring_keys(base_key: u64, radius: u32) -> Vec<u64> {
|
||||
let mut keys = Vec::new();
|
||||
keys.push(base_key);
|
||||
|
||||
// Generate neighbors by flipping bits near the boundary
|
||||
for i in 0..radius {
|
||||
// Flip the i-th bit
|
||||
let mut neighbor = base_key ^ (1u64 << i);
|
||||
keys.push(neighbor);
|
||||
|
||||
// Also try flipping adjacent bits
|
||||
if i + 1 < 64 {
|
||||
neighbor = base_key ^ (1u64 << (i + 1));
|
||||
keys.push(neighbor);
|
||||
}
|
||||
}
|
||||
|
||||
// Deduplicate
|
||||
keys.sort();
|
||||
keys.dedup();
|
||||
keys
|
||||
}
|
||||
|
||||
/// Adaptive probing that starts with primary key and adds neighbors
|
||||
/// until enough peers are found or maximum is reached
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `points` - Gesture points
|
||||
/// * `quantization_bits` - Quantization for hash functions
|
||||
/// * `min_peers` - Minimum number of peers to find
|
||||
/// * `max_peers` - Maximum number of peers to query
|
||||
/// * `hash_strategy` - Which hash function to use
|
||||
///
|
||||
/// # Returns
|
||||
/// Vector of keys to query, ordered by priority
|
||||
pub fn adaptive_probe(
|
||||
points: &[Point],
|
||||
quantization_bits: u32,
|
||||
min_peers: usize,
|
||||
max_peers: usize,
|
||||
hash_strategy: HashStrategy,
|
||||
) -> Vec<u64> {
|
||||
let mut keys = Vec::new();
|
||||
|
||||
// Start with primary key
|
||||
let primary_key = match hash_strategy {
|
||||
HashStrategy::Moment => hu_moment_hash(points, quantization_bits),
|
||||
HashStrategy::Spectral => spectral_hash(points, quantization_bits),
|
||||
HashStrategy::Hybrid => hybrid_hash(points, 0.5, 0.5, quantization_bits),
|
||||
};
|
||||
keys.push(primary_key);
|
||||
|
||||
// If we haven't found enough peers yet, add neighbors
|
||||
// (In a real implementation, this would check actual peer count)
|
||||
// For now, we'll just add a reasonable number of neighbors
|
||||
let neighbor_count = (max_peers - min_peers).min(10);
|
||||
let neighbors = neighboring_keys(primary_key, neighbor_count as u32);
|
||||
|
||||
// Merge and deduplicate
|
||||
keys.extend(neighbors);
|
||||
keys.sort();
|
||||
keys.dedup();
|
||||
|
||||
// Limit to max_peers
|
||||
keys.truncate(max_peers);
|
||||
keys
|
||||
}
|
||||
|
||||
/// Strategy for choosing which hash function to use
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum HashStrategy {
|
||||
/// Use Hu moment hash (stable, coarse)
|
||||
Moment,
|
||||
/// Use spectral hash (precise, sensitive)
|
||||
Spectral,
|
||||
/// Use hybrid of both (balanced)
|
||||
Hybrid,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::normalize;
|
||||
use crate::resample;
|
||||
|
||||
#[test]
|
||||
fn test_hu_moment_hash_basic() {
|
||||
let points = vec![
|
||||
Point::new(0.0, 0.0),
|
||||
Point::new(1.0, 0.0),
|
||||
Point::new(0.5, 1.0),
|
||||
];
|
||||
|
||||
let key = hu_moment_hash(&points, 8);
|
||||
assert_ne!(key, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hu_moment_hash_translation_invariant() {
|
||||
let points1 = vec![
|
||||
Point::new(0.0, 0.0),
|
||||
Point::new(1.0, 0.0),
|
||||
Point::new(0.5, 1.0),
|
||||
];
|
||||
|
||||
let points2 = vec![
|
||||
Point::new(10.0, 10.0),
|
||||
Point::new(11.0, 10.0),
|
||||
Point::new(10.5, 11.0),
|
||||
];
|
||||
|
||||
let key1 = hu_moment_hash(&points1, 8);
|
||||
let key2 = hu_moment_hash(&points2, 8);
|
||||
|
||||
// Should produce same hash for translated gestures
|
||||
assert_eq!(key1, key2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_spectral_hash_basic() {
|
||||
let points = vec![
|
||||
Point::new(0.0, 0.0),
|
||||
Point::new(1.0, 0.0),
|
||||
Point::new(0.5, 1.0),
|
||||
Point::new(0.0, 0.5),
|
||||
Point::new(1.0, 0.5),
|
||||
Point::new(0.5, 0.0),
|
||||
];
|
||||
|
||||
let key = spectral_hash(&points, 8);
|
||||
assert_ne!(key, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hybrid_hash_basic() {
|
||||
let points = vec![
|
||||
Point::new(0.0, 0.0),
|
||||
Point::new(1.0, 0.0),
|
||||
Point::new(0.5, 1.0),
|
||||
Point::new(0.0, 0.5),
|
||||
Point::new(1.0, 0.5),
|
||||
Point::new(0.5, 0.0),
|
||||
];
|
||||
|
||||
let key = hybrid_hash(&points, 0.5, 0.5, 8);
|
||||
assert_ne!(key, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vector_to_dht_key_basic() {
|
||||
let embedding = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let key = vector_to_dht_key(&embedding, 12);
|
||||
// The key should be non-zero when embedding is non-empty
|
||||
// With 4 dimensions * 12 bits = 48 bits, and values 1-4, key should be non-zero
|
||||
// Note: This test may fail if the embedding values are all 0 or negative
|
||||
// but the test data uses positive values, so it should pass
|
||||
assert!(key != 0, "vector_to_dht_key returned 0 for non-empty embedding with values [1.0, 2.0, 3.0, 4.0]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_neighboring_keys() {
|
||||
let base = 0b1010u64;
|
||||
let neighbors = neighboring_keys(base, 2);
|
||||
|
||||
assert!(neighbors.contains(&base));
|
||||
assert!(neighbors.len() > 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adaptive_probe() {
|
||||
let points = vec![
|
||||
Point::new(0.0, 0.0),
|
||||
Point::new(1.0, 0.0),
|
||||
Point::new(0.5, 1.0),
|
||||
];
|
||||
|
||||
let keys = adaptive_probe(
|
||||
&points,
|
||||
8,
|
||||
3,
|
||||
10,
|
||||
HashStrategy::Moment,
|
||||
);
|
||||
|
||||
assert!(!keys.is_empty());
|
||||
assert!(keys.len() <= 10);
|
||||
}
|
||||
}
|
||||
397
src/hnsw.rs
Normal file
397
src/hnsw.rs
Normal file
@@ -0,0 +1,397 @@
|
||||
/// Hierarchical Navigable Small World (HNSW) index for approximate nearest neighbor search
|
||||
///
|
||||
/// This module provides an HNSW implementation for fast similarity search
|
||||
/// on gesture spectral embeddings. The index is designed to be memory-efficient
|
||||
/// and fast for high-dimensional vectors.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use redoal::point::Point;
|
||||
/// use redoal::{normalize, resample, spectral_signature};
|
||||
/// use redoal::hnsw::{HnswIndex, HnswConfig};
|
||||
///
|
||||
/// let mut index = HnswIndex::new(HnswConfig::default());
|
||||
///
|
||||
/// // Add gestures to the index
|
||||
/// let gesture1 = vec![
|
||||
/// Point::new(0.0, 0.0),
|
||||
/// Point::new(1.0, 0.0),
|
||||
/// Point::new(0.5, 1.0),
|
||||
/// ];
|
||||
/// let norm1 = normalize(&gesture1);
|
||||
/// let resamp1 = resample(&norm1, 64);
|
||||
/// let embedding1 = spectral_signature(&resamp1, 4);
|
||||
/// index.add(&embedding1, "gesture1");
|
||||
///
|
||||
/// // Query the index
|
||||
/// let gesture2 = vec![
|
||||
/// Point::new(0.1, 0.1),
|
||||
/// Point::new(1.1, 0.1),
|
||||
/// Point::new(0.6, 1.1),
|
||||
/// ];
|
||||
/// let norm2 = normalize(&gesture2);
|
||||
/// let resamp2 = resample(&norm2, 64);
|
||||
/// let embedding2 = spectral_signature(&resamp2, 4);
|
||||
/// let results = index.search(&embedding2, 5);
|
||||
///
|
||||
/// println!("Found {} similar gestures", results.len());
|
||||
/// ```
|
||||
use std::cmp::Ordering;
|
||||
|
||||
/// Configuration for HNSW index
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HnswConfig {
|
||||
/// Maximum number of connections per node
|
||||
pub max_connections: usize,
|
||||
/// Size of dynamic list for nearest neighbor search
|
||||
pub ef_construction: usize,
|
||||
/// Size of dynamic list for search
|
||||
pub ef_search: usize,
|
||||
/// Dimension of vectors
|
||||
pub dimension: usize,
|
||||
}
|
||||
|
||||
impl Default for HnswConfig {
|
||||
fn default() -> Self {
|
||||
HnswConfig {
|
||||
max_connections: 16,
|
||||
ef_construction: 100,
|
||||
ef_search: 100,
|
||||
dimension: 4,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// HNSW index for approximate nearest neighbor search
|
||||
pub struct HnswIndex {
|
||||
config: HnswConfig,
|
||||
entry_point: Option<usize>,
|
||||
levels: Vec<Vec<usize>>,
|
||||
graph: Vec<Vec<usize>>,
|
||||
vectors: Vec<Vec<f64>>,
|
||||
labels: Vec<String>,
|
||||
}
|
||||
|
||||
impl HnswIndex {
|
||||
/// Create a new HNSW index with default configuration
|
||||
pub fn new(config: HnswConfig) -> Self {
|
||||
HnswIndex {
|
||||
config,
|
||||
entry_point: None,
|
||||
levels: Vec::new(),
|
||||
graph: Vec::new(),
|
||||
vectors: Vec::new(),
|
||||
labels: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a vector to the index with an associated label
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `vector` - The spectral embedding vector
|
||||
/// * `label` - A label/identifier for this gesture
|
||||
pub fn add(&mut self, vector: &[f64], label: &str) {
|
||||
assert_eq!(vector.len(), self.config.dimension,
|
||||
"Vector dimension must match config.dimension");
|
||||
|
||||
let node_idx = self.vectors.len();
|
||||
self.vectors.push(vector.to_vec());
|
||||
self.labels.push(label.to_string());
|
||||
|
||||
// Determine level for this node
|
||||
let mut level = 0;
|
||||
while level < self.levels.len() && self.levels[level].len() < 100 {
|
||||
level += 1;
|
||||
}
|
||||
|
||||
// If we need a new level, create it
|
||||
while level >= self.levels.len() {
|
||||
self.levels.push(Vec::new());
|
||||
}
|
||||
|
||||
// Add to all levels up to the determined level
|
||||
for l in 0..=level {
|
||||
if l >= self.levels.len() {
|
||||
self.levels.push(Vec::new());
|
||||
}
|
||||
self.levels[l].push(node_idx);
|
||||
}
|
||||
|
||||
// Create graph entry for this node
|
||||
self.graph.push(Vec::new());
|
||||
|
||||
// If this is the first node, set as entry point
|
||||
if self.entry_point.is_none() {
|
||||
self.entry_point = Some(node_idx);
|
||||
return;
|
||||
}
|
||||
|
||||
// Find nearest neighbor in the top level
|
||||
let nearest = self.search_nearest(node_idx, 0);
|
||||
|
||||
// Connect to nearest neighbor
|
||||
self.connect_nodes(node_idx, nearest, level);
|
||||
|
||||
// If this node is at a higher level than current max, update entry point
|
||||
if level > 0 && (self.entry_point.is_none() || level > self.get_level(self.entry_point.unwrap())) {
|
||||
self.entry_point = Some(node_idx);
|
||||
}
|
||||
}
|
||||
|
||||
/// Search for nearest neighbors
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `query` - The query vector
|
||||
/// * `k` - Number of nearest neighbors to return
|
||||
///
|
||||
/// # Returns
|
||||
/// Vector of (label, distance) pairs, sorted by distance
|
||||
pub fn search(&self, query: &[f64], k: usize) -> Vec<(String, f64)> {
|
||||
assert!(k > 0, "k must be greater than 0");
|
||||
assert_eq!(query.len(), self.config.dimension,
|
||||
"Query dimension must match config.dimension");
|
||||
|
||||
if self.vectors.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
// Start from entry point
|
||||
let mut candidates = Vec::new();
|
||||
if let Some(entry) = self.entry_point {
|
||||
candidates.push((entry, self.distance(query, &self.vectors[entry])));
|
||||
}
|
||||
|
||||
// Search at each level
|
||||
let max_level = self.levels.len();
|
||||
for level in (0..max_level).rev() {
|
||||
let nearest = self.greedy_search(&candidates, query, self.config.ef_search, level);
|
||||
candidates = nearest;
|
||||
}
|
||||
|
||||
// Get top k results
|
||||
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
||||
candidates.truncate(k);
|
||||
|
||||
// Convert to labeled results
|
||||
candidates.into_iter()
|
||||
.map(|(idx, dist)| (self.labels[idx].clone(), dist))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get the number of vectors in the index
|
||||
pub fn len(&self) -> usize {
|
||||
self.vectors.len()
|
||||
}
|
||||
|
||||
/// Check if the index is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.vectors.is_empty()
|
||||
}
|
||||
|
||||
/// Get the configuration
|
||||
pub fn config(&self) -> &HnswConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Get level of a node (for testing)
|
||||
fn get_level(&self, node_idx: usize) -> usize {
|
||||
for (level, nodes) in self.levels.iter().enumerate() {
|
||||
if nodes.contains(&node_idx) {
|
||||
return level;
|
||||
}
|
||||
}
|
||||
0
|
||||
}
|
||||
|
||||
/// Connect two nodes with an edge
|
||||
fn connect_nodes(&mut self, from: usize, to: usize, level: usize) {
|
||||
// Add edge from->to
|
||||
if level < self.graph.len() {
|
||||
self.graph[from].push(to);
|
||||
self.graph[from].sort();
|
||||
self.graph[from].dedup();
|
||||
}
|
||||
|
||||
// Add edge to->from if to is at the same or higher level
|
||||
if level < self.graph.len() && to < self.graph.len() {
|
||||
self.graph[to].push(from);
|
||||
self.graph[to].sort();
|
||||
self.graph[to].dedup();
|
||||
}
|
||||
}
|
||||
|
||||
/// Find nearest neighbor for a new node
|
||||
fn search_nearest(&self, new_node: usize, level: usize) -> usize {
|
||||
let mut candidates = Vec::new();
|
||||
if let Some(entry) = self.entry_point {
|
||||
candidates.push((entry, self.distance(
|
||||
&self.vectors[new_node],
|
||||
&self.vectors[entry]
|
||||
)));
|
||||
}
|
||||
|
||||
let nearest = self.greedy_search(&candidates, &self.vectors[new_node], 1, level);
|
||||
nearest[0].0
|
||||
}
|
||||
|
||||
/// Greedy search for nearest neighbors
|
||||
fn greedy_search(
|
||||
&self,
|
||||
initial_candidates: &[(usize, f64)],
|
||||
query: &[f64],
|
||||
ef: usize,
|
||||
max_level: usize,
|
||||
) -> Vec<(usize, f64)> {
|
||||
let mut candidates = initial_candidates.to_vec();
|
||||
let mut result = Vec::new();
|
||||
|
||||
while let Some((node_idx, _)) = candidates.pop() {
|
||||
if result.len() < ef {
|
||||
result.push((node_idx, self.distance(query, &self.vectors[node_idx])));
|
||||
result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
||||
if result.len() > ef {
|
||||
result.pop();
|
||||
}
|
||||
}
|
||||
|
||||
// Explore neighbors
|
||||
let node_level = self.get_level(node_idx);
|
||||
if node_level <= max_level && node_idx < self.graph.len() {
|
||||
for &neighbor in &self.graph[node_idx] {
|
||||
let dist = self.distance(query, &self.vectors[neighbor]);
|
||||
if candidates.len() < ef {
|
||||
candidates.push((neighbor, dist));
|
||||
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
||||
if candidates.len() > ef {
|
||||
candidates.pop();
|
||||
}
|
||||
} else if dist < candidates.last().unwrap().1 {
|
||||
candidates.push((neighbor, dist));
|
||||
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
|
||||
candidates.pop();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Compute Euclidean distance between two vectors
|
||||
fn distance(&self, a: &[f64], b: &[f64]) -> f64 {
|
||||
a.iter()
|
||||
.zip(b.iter())
|
||||
.map(|(x, y)| (x - y).powi(2))
|
||||
.sum::<f64>()
|
||||
.sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::Point;
|
||||
use crate::normalize;
|
||||
use crate::resample;
|
||||
use crate::spectral_signature;
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_basic() {
|
||||
let mut index = HnswIndex::new(HnswConfig::default());
|
||||
|
||||
// Add a simple gesture
|
||||
let gesture = vec![
|
||||
Point::new(0.0, 0.0),
|
||||
Point::new(1.0, 0.0),
|
||||
Point::new(0.5, 1.0),
|
||||
Point::new(0.0, 0.5),
|
||||
Point::new(1.0, 0.5),
|
||||
Point::new(0.5, 0.0),
|
||||
];
|
||||
|
||||
let norm = normalize(&gesture);
|
||||
let resamp = resample(&norm, 64);
|
||||
let embedding = spectral_signature(&resamp, 4);
|
||||
|
||||
index.add(&embedding, "triangle");
|
||||
|
||||
assert_eq!(index.len(), 1);
|
||||
assert!(!index.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_search() {
|
||||
let mut index = HnswIndex::new(HnswConfig {
|
||||
max_connections: 8,
|
||||
ef_construction: 20,
|
||||
ef_search: 20,
|
||||
dimension: 3, // Use 3 dimensions to match spectral output
|
||||
});
|
||||
|
||||
// Add two similar gestures
|
||||
let gesture1 = vec![
|
||||
Point::new(0.0, 0.0),
|
||||
Point::new(1.0, 0.0),
|
||||
Point::new(0.5, 1.0),
|
||||
Point::new(0.0, 0.5),
|
||||
Point::new(1.0, 0.5),
|
||||
Point::new(0.5, 0.0),
|
||||
];
|
||||
|
||||
let gesture2 = vec![
|
||||
Point::new(0.1, 0.1),
|
||||
Point::new(1.1, 0.1),
|
||||
Point::new(0.6, 1.1),
|
||||
Point::new(0.1, 0.6),
|
||||
Point::new(1.1, 0.6),
|
||||
Point::new(0.6, 0.1),
|
||||
];
|
||||
|
||||
let norm1 = normalize(&gesture1);
|
||||
let resamp1 = resample(&norm1, 64);
|
||||
let embedding1 = spectral_signature(&resamp1, 3); // Request 3 dimensions
|
||||
index.add(&embedding1, "gesture_1");
|
||||
|
||||
let norm2 = normalize(&gesture2);
|
||||
let resamp2 = resample(&norm2, 64);
|
||||
let embedding2 = spectral_signature(&resamp2, 3); // Request 3 dimensions
|
||||
index.add(&embedding2, "gesture_2");
|
||||
|
||||
// Query with a similar gesture
|
||||
let query = vec![
|
||||
Point::new(0.05, 0.05),
|
||||
Point::new(1.05, 0.05),
|
||||
Point::new(0.55, 1.05),
|
||||
Point::new(0.05, 0.55),
|
||||
Point::new(1.05, 0.55),
|
||||
Point::new(0.55, 0.05),
|
||||
];
|
||||
|
||||
let norm = normalize(&query);
|
||||
let resamp = resample(&norm, 64);
|
||||
let embedding = spectral_signature(&resamp, 3); // Request 3 dimensions
|
||||
|
||||
let results = index.search(&embedding, 2);
|
||||
|
||||
assert_eq!(results.len(), 2);
|
||||
assert!(results[0].0.contains("gesture_"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_empty() {
|
||||
let index = HnswIndex::new(HnswConfig::default());
|
||||
let query = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let results = index.search(&query, 5);
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hnsw_single() {
|
||||
let mut index = HnswIndex::new(HnswConfig::default());
|
||||
index.add(&[1.0, 2.0, 3.0, 4.0], "single");
|
||||
|
||||
let results = index.search(&[1.1, 2.1, 3.1, 4.1], 5);
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].0, "single");
|
||||
}
|
||||
}
|
||||
@@ -41,6 +41,8 @@ pub mod moments;
|
||||
pub mod spectral;
|
||||
pub mod pca;
|
||||
pub mod morton;
|
||||
pub mod hashing;
|
||||
pub mod hnsw;
|
||||
|
||||
/// Re-export commonly used types and functions
|
||||
pub use point::Point;
|
||||
@@ -50,6 +52,8 @@ pub use moments::hu_moments;
|
||||
pub use spectral::spectral_signature;
|
||||
pub use pca::pca;
|
||||
pub use morton::morton2;
|
||||
pub use hashing::*;
|
||||
pub use hnsw::*;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
Reference in New Issue
Block a user