| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| use std::collections::HashMap; |
| use std::sync::Arc; |
|
|
| use crate::core::{Id, Point}; |
| use crate::core::proximity::Proximity; |
| use crate::ports::{Near, NearError, NearResult, SearchResult}; |
|
|
| |
| pub struct FlatIndex { |
| |
| points: HashMap<Id, Point>, |
|
|
| |
| dimensionality: usize, |
|
|
| |
| proximity: Arc<dyn Proximity>, |
|
|
| |
| |
| higher_is_better: bool, |
| } |
|
|
| impl FlatIndex { |
| |
| |
| |
| |
| |
| pub fn new( |
| dimensionality: usize, |
| proximity: Arc<dyn Proximity>, |
| higher_is_better: bool, |
| ) -> Self { |
| Self { |
| points: HashMap::new(), |
| dimensionality, |
| proximity, |
| higher_is_better, |
| } |
| } |
|
|
| |
| pub fn cosine(dimensionality: usize) -> Self { |
| use crate::core::proximity::Cosine; |
| Self::new(dimensionality, Arc::new(Cosine), true) |
| } |
|
|
| |
| pub fn euclidean(dimensionality: usize) -> Self { |
| use crate::core::proximity::Euclidean; |
| Self::new(dimensionality, Arc::new(Euclidean), false) |
| } |
|
|
| |
| fn sort_results(&self, results: &mut Vec<SearchResult>) { |
| if self.higher_is_better { |
| |
| results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap()); |
| } else { |
| |
| results.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap()); |
| } |
| } |
| } |
|
|
| impl Near for FlatIndex { |
| fn near(&self, query: &Point, k: usize) -> NearResult<Vec<SearchResult>> { |
| |
| if query.dimensionality() != self.dimensionality { |
| return Err(NearError::DimensionalityMismatch { |
| expected: self.dimensionality, |
| got: query.dimensionality(), |
| }); |
| } |
|
|
| |
| let mut results: Vec<SearchResult> = self |
| .points |
| .iter() |
| .map(|(id, point)| { |
| let score = self.proximity.proximity(query, point); |
| SearchResult::new(*id, score) |
| }) |
| .collect(); |
|
|
| |
| self.sort_results(&mut results); |
|
|
| |
| results.truncate(k); |
|
|
| Ok(results) |
| } |
|
|
| fn within(&self, query: &Point, threshold: f32) -> NearResult<Vec<SearchResult>> { |
| |
| if query.dimensionality() != self.dimensionality { |
| return Err(NearError::DimensionalityMismatch { |
| expected: self.dimensionality, |
| got: query.dimensionality(), |
| }); |
| } |
|
|
| |
| let mut results: Vec<SearchResult> = self |
| .points |
| .iter() |
| .filter_map(|(id, point)| { |
| let score = self.proximity.proximity(query, point); |
| let within = if self.higher_is_better { |
| score >= threshold |
| } else { |
| score <= threshold |
| }; |
| if within { |
| Some(SearchResult::new(*id, score)) |
| } else { |
| None |
| } |
| }) |
| .collect(); |
|
|
| |
| self.sort_results(&mut results); |
|
|
| Ok(results) |
| } |
|
|
| fn add(&mut self, id: Id, point: &Point) -> NearResult<()> { |
| if point.dimensionality() != self.dimensionality { |
| return Err(NearError::DimensionalityMismatch { |
| expected: self.dimensionality, |
| got: point.dimensionality(), |
| }); |
| } |
|
|
| self.points.insert(id, point.clone()); |
| Ok(()) |
| } |
|
|
| fn remove(&mut self, id: Id) -> NearResult<()> { |
| self.points.remove(&id); |
| Ok(()) |
| } |
|
|
| fn rebuild(&mut self) -> NearResult<()> { |
| |
| Ok(()) |
| } |
|
|
| fn is_ready(&self) -> bool { |
| true |
| } |
|
|
| fn len(&self) -> usize { |
| self.points.len() |
| } |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
|
|
| fn setup_index() -> FlatIndex { |
| let mut index = FlatIndex::cosine(3); |
|
|
| |
| let points = vec![ |
| (Id::from_bytes([1; 16]), Point::new(vec![1.0, 0.0, 0.0])), |
| (Id::from_bytes([2; 16]), Point::new(vec![0.0, 1.0, 0.0])), |
| (Id::from_bytes([3; 16]), Point::new(vec![0.0, 0.0, 1.0])), |
| (Id::from_bytes([4; 16]), Point::new(vec![0.7, 0.7, 0.0]).normalize()), |
| ]; |
|
|
| for (id, point) in points { |
| index.add(id, &point).unwrap(); |
| } |
|
|
| index |
| } |
|
|
| #[test] |
| fn test_flat_index_near() { |
| let index = setup_index(); |
|
|
| |
| let query = Point::new(vec![1.0, 0.0, 0.0]); |
| let results = index.near(&query, 2).unwrap(); |
|
|
| assert_eq!(results.len(), 2); |
|
|
| |
| assert_eq!(results[0].id, Id::from_bytes([1; 16])); |
| assert!((results[0].score - 1.0).abs() < 0.0001); |
| } |
|
|
| #[test] |
| fn test_flat_index_within_cosine() { |
| let index = setup_index(); |
|
|
| |
| let query = Point::new(vec![1.0, 0.0, 0.0]); |
| let results = index.within(&query, 0.5).unwrap(); |
|
|
| |
| assert_eq!(results.len(), 2); |
| } |
|
|
| #[test] |
| fn test_flat_index_euclidean() { |
| let mut index = FlatIndex::euclidean(2); |
|
|
| index.add(Id::from_bytes([1; 16]), &Point::new(vec![0.0, 0.0])).unwrap(); |
| index.add(Id::from_bytes([2; 16]), &Point::new(vec![1.0, 0.0])).unwrap(); |
| index.add(Id::from_bytes([3; 16]), &Point::new(vec![5.0, 0.0])).unwrap(); |
|
|
| let query = Point::new(vec![0.0, 0.0]); |
| let results = index.near(&query, 2).unwrap(); |
|
|
| |
| assert_eq!(results[0].id, Id::from_bytes([1; 16])); |
| assert!((results[0].score - 0.0).abs() < 0.0001); |
|
|
| |
| assert_eq!(results[1].id, Id::from_bytes([2; 16])); |
| assert!((results[1].score - 1.0).abs() < 0.0001); |
| } |
|
|
| #[test] |
| fn test_flat_index_add_remove() { |
| let mut index = FlatIndex::cosine(3); |
|
|
| let id = Id::from_bytes([1; 16]); |
| let point = Point::new(vec![1.0, 0.0, 0.0]); |
|
|
| index.add(id, &point).unwrap(); |
| assert_eq!(index.len(), 1); |
|
|
| index.remove(id).unwrap(); |
| assert_eq!(index.len(), 0); |
| } |
|
|
| #[test] |
| fn test_flat_index_dimensionality_check() { |
| let mut index = FlatIndex::cosine(3); |
|
|
| let wrong_dims = Point::new(vec![1.0, 0.0]); |
| let result = index.add(Id::now(), &wrong_dims); |
|
|
| match result { |
| Err(NearError::DimensionalityMismatch { expected, got }) => { |
| assert_eq!(expected, 3); |
| assert_eq!(got, 2); |
| } |
| _ => panic!("Expected DimensionalityMismatch error"), |
| } |
| } |
|
|
| #[test] |
| fn test_flat_index_ready() { |
| let index = FlatIndex::cosine(3); |
| assert!(index.is_ready()); |
| } |
| } |
|
|