"""Contains the HNSWBlocker class for performing blocking using the HNSW algorithm."""
import logging
import os
import warnings
from typing import Any
import hnswlib
import pandas as pd
from .base import BlockingMethod
from .data_handler import DataHandler
from .helper_functions import rearrange_array
logger = logging.getLogger(__name__)
[docs]
class HNSWBlocker(BlockingMethod):
"""
A class for performing blocking using the Hierarchical Navigable Small World (HNSW) algorithm.
This class implements blocking functionality using the HNSW algorithm for efficient
similarity search and nearest neighbor queries.
Parameters
----------
None
Attributes
----------
index : hnswlib.Index or None
The HNSW index used for nearest neighbor search
x_columns : array-like or None
Column names of the reference dataset
SPACE_MAP : dict
Mapping of distance metric names to their HNSW implementations
See Also
--------
BlockingMethod : Abstract base class defining the blocking interface
Notes
-----
For more details about the HNSW algorithm, see:
https://github.com/nmslib/hnswlib
"""
SPACE_MAP: dict[str, str] = {
"l2": "l2",
"euclidean": "l2",
"cosine": "cosine",
"ip": "ip",
}
[docs]
def __init__(self) -> None:
"""
Initialize the HNSWBlocker instance.
Creates a new HNSWBlocker with empty index.
"""
self.index: hnswlib.Index
self.x_columns: list[str]
[docs]
def block(
self,
x: DataHandler,
y: DataHandler,
k: int,
verbose: bool | None,
controls: dict[str, Any],
) -> pd.DataFrame:
"""
Perform blocking using the HNSW algorithm.
Parameters
----------
x : pandas.DataFrame
Reference dataset containing features for indexing
y : pandas.DataFrame
Query dataset to find nearest neighbors for
k : int
Number of nearest neighbors to find. If k is larger than the number
of reference points, it will be automatically adjusted
verbose : bool, optional
If True, print detailed progress information
controls : dict
Algorithm control parameters with the following structure:
{
'random_seed': int,
'hnsw': {
'k_search': int,
'distance': str,
'n_threads': int,
'path': str,
'ef_c': int,
'ef_s': int,
'M': int,
}
}
Returns
-------
pandas.DataFrame
DataFrame containing the blocking results with columns:
- 'y': indices from query dataset
- 'x': indices of matched items from reference dataset
- 'dist': distances to matched items
Notes
-----
The function builds an HNSW index from the reference dataset and finds
the k-nearest neighbors for each point in the query dataset. The index
parameters ef_c (construction) and ef_s (search) control the trade-off
between search accuracy and speed.
"""
logger.setLevel(logging.INFO if verbose else logging.WARNING)
self.x_columns = list(x.cols)
distance = controls["hnsw"].get("distance")
n_threads = controls["hnsw"].get("n_threads")
path = controls["hnsw"].get("path")
k_search = controls["hnsw"].get("k_search")
space = self.SPACE_MAP[distance]
seed = controls.get("random_seed")
if seed is None:
seed = 100
logger.info("Initializing HNSW index...")
X = x.to_dense()
Y = y.to_dense()
self.index = hnswlib.Index(space=space, dim=X.shape[1])
self.index.init_index(
max_elements=X.shape[0],
ef_construction=controls["hnsw"].get("ef_c"),
M=controls["hnsw"].get("M"),
random_seed=seed,
)
self.index.set_num_threads(n_threads)
logger.info("Adding items to index...")
self.index.add_items(X)
self.index.set_ef(controls["hnsw"].get("ef_s"))
logger.info("Querying index...")
if k_search > X.shape[0]:
original_k_search = k_search
k_search = min(k_search, X.shape[0])
warnings.warn(
f"k_search ({original_k_search}) is larger than the number of reference points "
f"({X.shape[0]}). Adjusted k_search to {k_search}.",
category=UserWarning,
stacklevel=2,
)
l_1nn = self.index.knn_query(Y, k=k_search, num_threads=n_threads)
indices = l_1nn[0]
distances = l_1nn[1]
K_VAL = 2
if k == K_VAL:
indices, distances = rearrange_array(indices, distances)
if path:
self._save_index(path)
result = pd.DataFrame(
{
"y": range(y.shape[0]),
"x": indices[:, k - 1],
"dist": distances[:, k - 1],
}
)
logger.info("Process completed successfully.")
return result
def _save_index(self, path: str) -> None:
"""
Save the HNSW index and column names to files.
Parameters
----------
path : str
Directory path where the files will be saved
Raises
------
ValueError
If the provided path is incorrect
Notes
-----
Creates two files:
- 'index.hnsw': The HNSW index file
- 'index-colnames.txt': A text file with column names
"""
if not os.path.exists(os.path.dirname(path)):
raise ValueError("Provided path is incorrect")
path_ann = os.path.join(path, "index.hnsw")
path_ann_cols = os.path.join(path, "index-colnames.txt")
logger.info(f"Writing an index to {path_ann}")
self.index.save_index(path_ann)
with open(path_ann_cols, "w", encoding="utf-8") as f:
f.write("\n".join(self.x_columns))