"""Contains the FaissBlocker class for performing blocking using one of the FAISS algorithms."""
import logging
import os
import warnings
from typing import Any
import faiss
import numpy as np
import pandas as pd
from .base import BlockingMethod
from .data_handler import DataHandler
from .helper_functions import rearrange_array
logger = logging.getLogger(__name__)
[docs]
class FaissBlocker(BlockingMethod):
"""
A class for performing blocking using the FAISS (Facebook AI Similarity Search) algorithm.
This class implements blocking functionality using Facebook's FAISS library for
efficient similarity search and nearest neighbor queries. It supports multiple
distance metrics and is optimized for high-performance computing.
Parameters
----------
None
Attributes
----------
index : faiss.Index
The FAISS index used for nearest neighbor search
x_columns : array-like or None
Column names of the reference dataset
METRIC_MAP : dict
Mapping of distance metric names to FAISS metric types
See Also
--------
BlockingMethod : Abstract base class defining the blocking interface
faiss.Index : The underlying FAISS index implementation
Notes
-----
The available Index types from FAISS are: 'flat', 'hnsw', and 'lsh'.
- 'flat' is a brute-force exact search (most accurate but slowest)
- 'hnsw' is a Hierarchical Navigable Small World graph algorithm
(good balance of speed and accuracy)
- 'lsh' is a Locality Sensitive Hashing algorithm
(fastest but approximate results)
For more details about the FAISS library and implementation, see:
https://github.com/facebookresearch/faiss
Some distance metrics require special handling:
- Cosine similarity is implemented through L2 normalization
- Jensen-Shannon and Canberra metrics require smoothing to handle zero values
- Selected distance metrics does not affect the algorithm if 'lsh' was selected
Faiss does not support `random_seed` parameter. Instead, it handles reproducibility
inside the algorithm. For more details, see:
https://gist.github.com/mdouze/1892178b5663b80e85ab076966c59c28
"""
METRIC_MAP: dict[str, Any] = {
"euclidean": faiss.METRIC_L2,
"l2": faiss.METRIC_L2,
"inner_product": faiss.METRIC_INNER_PRODUCT,
"cosine": faiss.METRIC_INNER_PRODUCT,
# note: later handled by vector normalisation
# see:(https://github.com/facebookresearch/faiss/wiki/MetricType-and-distances)
"l1": faiss.METRIC_L1,
"manhattan": faiss.METRIC_L1,
"linf": faiss.METRIC_Linf,
# "lp" : faiss.METRIC_Lp,
"canberra": faiss.METRIC_Canberra,
# note: requires smoothing since 0/0 is undefined
"bray_curtis": faiss.METRIC_BrayCurtis,
"jensen_shannon": faiss.METRIC_JensenShannon,
# note: requires smoothing since log(0) is undefined
}
[docs]
def __init__(self) -> None:
"""
Initialize the FaissBlocker instance.
Creates a new FaissBlocker with empty index.
"""
self.index: faiss.Index
self.x_columns: list[str]
[docs]
def block( # noqa: PLR0915, PLR0912
self,
x: DataHandler,
y: DataHandler,
k: int,
verbose: bool | None,
controls: dict[str, Any],
) -> pd.DataFrame:
"""
Perform blocking using the FAISS algorithm.
Parameters
----------
x : DataHandler
Reference dataset containing features for indexing
y : DataHandler
Query dataset to find nearest neighbors for
k : int
Number of nearest neighbors to find
verbose : bool, optional
If True, print detailed progress information
controls : dict
Algorithm control parameters with the following structure:
{
'faiss': {
'index_type': ['flat', 'hnsw', 'lsh'],
'distance': str,
'k_search': int,
'path': str,
'hnsw_M': int,
'hnsw_ef_construction': int,
'hnsw_ef_search': int,
'lsh_nbits': int, (gets multiplied by dimensions)
'lsh_rotate_data': bool,
}
}
Returns
-------
pandas.DataFrame
DataFrame containing the blocking results with columns:
- 'y': indices from query dataset
- 'x': indices of matched items from reference dataset
- 'dist': distances to matched items
Notes
-----
Special preprocessing is applied for certain metrics:
- For cosine similarity, vectors are L2-normalized
- For Jensen-Shannon and Canberra metrics, small constant is added
to prevent undefined values
- For LSH index, the distance calculation is determined by the hash function,
not directly by the selected distance metric
"""
logger.setLevel(logging.INFO if verbose else logging.WARNING)
self.x_columns = list(x.cols)
distance = controls["faiss"].get("distance")
k_search = controls["faiss"].get("k_search")
path = controls["faiss"].get("path")
index_type = controls["faiss"].get("index_type", "hnsw")
seed = controls.get("random_seed")
if index_type not in {"flat", "hnsw", "lsh"}:
raise ValueError(
f"Invalid index_type '{index_type}'. Must be one of 'flat', 'hnsw', or 'lsh'."
)
X = x.to_dense()
Y = y.to_dense()
if distance == "cosine":
faiss.normalize_L2(X)
faiss.normalize_L2(Y)
elif distance in {"jensen_shannon", "canberra"}:
smooth = 1e-12
X += smooth
Y += smooth
metric = self.METRIC_MAP[distance]
if index_type == "flat":
self.index = faiss.IndexFlat(X.shape[1], metric)
elif index_type == "hnsw":
M = controls["faiss"].get("hnsw_M")
ef_construction = controls["faiss"].get("hnsw_ef_construction")
ef_search = controls["faiss"].get("hnsw_ef_search")
self.index = faiss.IndexHNSWFlat(X.shape[1], M, metric)
self.index.hnsw.efConstruction = ef_construction
self.index.hnsw.efSearch = ef_search
if seed is not None:
self.index.hnsw.rng = faiss.RandomGenerator(int(seed))
elif index_type == "lsh":
nbits = controls["faiss"].get("lsh_nbits") * X.shape[1]
if not isinstance(nbits, int):
nbits = round(nbits)
rotate_data = controls["faiss"].get("lsh_rotate_data")
if seed is None:
self.index = faiss.IndexLSH(X.shape[1], nbits, rotate_data)
else:
rot = faiss.RandomRotationMatrix(X.shape[1], X.shape[1])
rot.init(int(seed))
base = faiss.IndexLSH(X.shape[1], nbits, False, False)
self.index = faiss.IndexPreTransform(rot, base)
logger.info("Building index...")
if distance == "cosine":
self.index.add(x=X)
else:
self.index.add(x=X)
logger.info("Querying index...")
if k_search > X.shape[0]:
original_k_search = k_search
k_search = min(k_search, X.shape[0])
warnings.warn(
f"k_search ({original_k_search}) is larger than the number of reference points "
f"({X.shape[0]}). Adjusted k_search to {k_search}.",
category=UserWarning,
stacklevel=2,
)
if distance == "cosine":
distances, indices = self.index.search(x=Y, k=k_search)
else:
distances, indices = self.index.search(x=Y, k=k_search)
if distance == "cosine" and index_type != "lsh":
distances = (1 - distances) / 2
K_VAL = 2
if k == K_VAL:
indices, distances = rearrange_array(indices, distances)
if path:
self._save_index(path)
result = pd.DataFrame(
{
"y": np.arange(Y.shape[0]),
"x": indices[:, k - 1],
"dist": distances[:, k - 1],
}
)
logger.info("Process completed successfully.")
return result
def _save_index(self, path: str) -> None:
"""
Save the FAISS index and column names to files.
Parameters
----------
path : str
Directory path where the files will be saved
Raises
------
ValueError
If the provided path is incorrect
Notes
-----
Creates two files:
- 'index.faiss': The FAISS index file
- 'index-colnames.txt': A text file with column names
"""
if not os.path.exists(os.path.dirname(path)):
raise ValueError("Provided path is incorrect")
path_faiss = os.path.join(path, "index.faiss")
path_faiss_cols = os.path.join(path, "index-colnames.txt")
logger.info(f"Writing index to {path_faiss}")
faiss.write_index(self.index, path_faiss)
with open(path_faiss_cols, "w", encoding="utf-8") as f:
f.write("\n".join(self.x_columns))