From dd24b4fd107810b7b17d285ae36a509f0fa408b5 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Thu, 7 Sep 2023 09:23:07 -0700 Subject: [PATCH 1/5] lsh doc --- datasketch/lsh.py | 419 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 294 insertions(+), 125 deletions(-) diff --git a/datasketch/lsh.py b/datasketch/lsh.py index 4b57aad9..f665e27f 100644 --- a/datasketch/lsh.py +++ b/datasketch/lsh.py @@ -1,45 +1,47 @@ +from __future__ import annotations import pickle import struct - -from datasketch.storage import ( - ordered_storage, unordered_storage, _random_name) +from typing import Callable, Dict, Hashable, List, Optional, Tuple, Union +from datasketch.minhash import MinHash +from datasketch.weighted_minhash import WeightedMinHash +from datasketch.storage import ordered_storage, unordered_storage, _random_name from scipy.integrate import quad as integrate def _false_positive_probability(threshold, b, r): - _probability = lambda s : 1 - (1 - s**float(r))**float(b) + _probability = lambda s: 1 - (1 - s ** float(r)) ** float(b) a, err = integrate(_probability, 0.0, threshold) return a def _false_negative_probability(threshold, b, r): - _probability = lambda s : 1 - (1 - (1 - s**float(r))**float(b)) + _probability = lambda s: 1 - (1 - (1 - s ** float(r)) ** float(b)) a, err = integrate(_probability, threshold, 1.0) return a -def _optimal_param(threshold, num_perm, false_positive_weight, - false_negative_weight): - ''' +def _optimal_param(threshold, num_perm, false_positive_weight, false_negative_weight): + """ Compute the optimal `MinHashLSH` parameter that minimizes the weighted sum of probabilities of false positive and false negative. - ''' + """ min_error = float("inf") opt = (0, 0) - for b in range(1, num_perm+1): + for b in range(1, num_perm + 1): max_r = int(num_perm / b) - for r in range(1, max_r+1): + for r in range(1, max_r + 1): fp = _false_positive_probability(threshold, b, r) fn = _false_negative_probability(threshold, b, r) - error = fp*false_positive_weight + fn*false_negative_weight + error = fp * false_positive_weight + fn * false_negative_weight if error < min_error: min_error = error opt = (b, r) return opt + class MinHashLSH(object): - ''' + """ The :ref:`minhash_lsh` index. It supports query with `Jaccard similarity`_ threshold. Reference: `Chapter 3, Mining of Massive Datasets @@ -49,27 +51,27 @@ class MinHashLSH(object): threshold (float): The Jaccard similarity threshold between 0.0 and 1.0. The initialized MinHash LSH will be optimized for the threshold by minizing the false positive and false negative. - num_perm (int, optional): The number of permutation functions used + num_perm (Optional[int]): The number of permutation functions used by the MinHash to be indexed. For weighted MinHash, this is the sample size (`sample_size`). - weights (tuple, optional): Used to adjust the relative importance of + weights (Optional[Tuple[float, float]]): Used to adjust the relative importance of minimizing false positive and false negative when optimizing for the Jaccard similarity threshold. `weights` is a tuple in the format of :code:`(false_positive_weight, false_negative_weight)`. - params (tuple, optional): The LSH parameters (i.e., number of bands and size + params (Tuple[int, int]): The LSH parameters (i.e., number of bands and size of each bands). This is used to bypass the parameter optimization step in the constructor. `threshold` and `weights` will be ignored if this is given. - storage_config (dict, optional): Type of storage service to use for storing + storage_config (Optional[Dict]): Type of storage service to use for storing hashtables and keys. `basename` is an optional property whose value will be used as the prefix to stored keys. If this is not set, a random string will be generated instead. If you set this, you will be responsible for ensuring there are no key collisions. - prepickle (bool, optional): If True, all keys are pickled to bytes before - insertion. If None, a default value is chosen based on the + prepickle (bool): If True, all keys are pickled to bytes before + insertion. If False, a default value is chosen based on the `storage_config`. - hashfunc (function, optional): If a hash function is provided it will be used to + hashfunc (Optional[Callable[[bytes], bytes]]): If a hash function is provided it will be used to compress the index keys to reduce the memory footprint. This could cause a higher false positive rate. @@ -79,11 +81,70 @@ class MinHashLSH(object): For example, if minimizing false negative (or maintaining high recall) is more important, assign more weight toward false negative: weights=(0.4, 0.6). Try to live with a small difference between weights (i.e. < 0.5). - ''' - def __init__(self, threshold=0.9, num_perm=128, weights=(0.5, 0.5), - params=None, storage_config=None, prepickle=None, hashfunc=None): - storage_config = {'type': 'dict'} if not storage_config else storage_config + Examples: + + Create an index with 128 permutation functions optimized for Jaccard + threshold 0.9: + + .. code-block:: python + + lsh = MinHashLSH(threshold=0.9, num_perm=128) + print(lsh.b, lsh.r) + # 5 25 + + The built-in optimizer will try to minimize the weighted sum of + probabilities of false positive and false negative. The algorithm is + a simple grid search over the space of possible parameters. + + Note that it is possible to get :attr:`b` (number of bands) and + :attr:`r` (band size) that do not sum to :attr:`num_perm`, leading to + unused permutation values in the indexed MinHash. + This is because the optimizer only considers bands of + the same size, and the number of bands is not necessarily a divisor of + :attr:`num_perm`. + + Instead of using the built-in optimizer, you can customize the LSH + parameters your self. The snippet below creates an index with 128 + permutation functions and 16 bands each with size 8, skipping the + optimization step: + + .. code-block:: python + + lsh = MinHashLSH(num_perm=128, params=(16, 8)) + print(lsh.b, lsh.r) + # 16 8 + + Create an index backed by Redis storage: + + .. code-block:: python + + lsh = MinHashLSH(threshold=0.9, num_perm=128, storage_config={ + 'type': 'redis', + 'basename': b'mylsh', # optional, defaults to a random string. + 'redis': {'host': 'localhost', 'port': 6379}, + }) + + The `basename` property is optional. It is used to generate key prefixes + in the storage layer to uniquely identify data associated with this LSH. + Thus, if you create a new LSH object with the same `basename`, you will + be using the same underlying data in the storage layer associated with + a previous LSH object. If you do not set this property, a random string + will be generated instead. + + """ + + def __init__( + self, + threshold: float = 0.9, + num_perm: int = 128, + weights: Tuple[float, float] = (0.5, 0.5), + params: Optional[Tuple[int, int]] = None, + storage_config: Optional[Dict] = None, + prepickle: bool = False, + hashfunc: Optional[Callable[[bytes], bytes]] = None, + ) -> None: + storage_config = {"type": "dict"} if not storage_config else storage_config self._buffer_size = 50000 if threshold > 1.0 or threshold < 0.0: raise ValueError("threshold must be in [0.0, 1.0]") @@ -97,18 +158,24 @@ def __init__(self, threshold=0.9, num_perm=128, weights=(0.5, 0.5), if params is not None: self.b, self.r = params if self.b * self.r > num_perm: - raise ValueError("The product of b and r in params is " - "{} * {} = {} -- it must be less than num_perm {}. " - "Did you forget to specify num_perm?".format( - self.b, self.r, self.b*self.r, num_perm)) + raise ValueError( + "The product of b and r in params is " + "{} * {} = {} -- it must be less than num_perm {}. " + "Did you forget to specify num_perm?".format( + self.b, self.r, self.b * self.r, num_perm + ) + ) else: false_positive_weight, false_negative_weight = weights - self.b, self.r = _optimal_param(threshold, num_perm, - false_positive_weight, false_negative_weight) + self.b, self.r = _optimal_param( + threshold, num_perm, false_positive_weight, false_negative_weight + ) if self.b < 2: raise ValueError("The number of bands are too small (b < 2)") - self.prepickle = storage_config['type'] == 'redis' if prepickle is None else prepickle + self.prepickle = ( + storage_config["type"] == "redis" if prepickle is None else prepickle + ) self.hashfunc = hashfunc if hashfunc: @@ -116,84 +183,167 @@ def __init__(self, threshold=0.9, num_perm=128, weights=(0.5, 0.5), else: self._H = self._byteswap - basename = storage_config.get('basename', _random_name(11)) + basename = storage_config.get("basename", _random_name(11)) self.hashtables = [ - unordered_storage(storage_config, name=b''.join([basename, b'_bucket_', struct.pack('>H', i)])) - for i in range(self.b)] - self.hashranges = [(i*self.r, (i+1)*self.r) for i in range(self.b)] - self.keys = ordered_storage(storage_config, name=b''.join([basename, b'_keys'])) + unordered_storage( + storage_config, + name=b"".join([basename, b"_bucket_", struct.pack(">H", i)]), + ) + for i in range(self.b) + ] + self.hashranges = [(i * self.r, (i + 1) * self.r) for i in range(self.b)] + self.keys = ordered_storage(storage_config, name=b"".join([basename, b"_keys"])) @property - def buffer_size(self): + def buffer_size(self) -> int: return self._buffer_size @buffer_size.setter - def buffer_size(self, value): + def buffer_size(self, value: int) -> None: self.keys.buffer_size = value for t in self.hashtables: t.buffer_size = value self._buffer_size = value - def insert(self, key, minhash, check_duplication=True): - ''' - Insert a key to the index, together - with a MinHash (or weighted MinHash) of the set referenced by - the key. + def insert( + self, + key: Hashable, + minhash: Union[MinHash, WeightedMinHash], + check_duplication: bool = True, + ): + """ + Insert a key to the index, together with a MinHash or Weighted MinHash + of the set referenced by the key. - :param str key: The identifier of the set. - :param datasketch.MinHash minhash: The MinHash of the set. - :param bool check_duplication: To avoid duplicate keys in the storage (`default=True`). - It's recommended to not change the default, but - if you want to avoid the overhead during insert - you can set `check_duplication = False`. - ''' + Args: + key (Hashable): The unique identifier of the set. + minhash (Union[MinHash, WeightedMinHash]): The MinHash of the set. + check_duplication (bool): To avoid duplicate keys in the storage + (`default=True`). It's recommended to not change the default, but + if you want to avoid the overhead during insert you can set + `check_duplication = False`. + + """ self._insert(key, minhash, check_duplication=check_duplication, buffer=False) - def insertion_session(self, buffer_size=50000): - ''' + def insertion_session(self, buffer_size: int = 50000) -> MinHashLSHInsertionSession: + """ Create a context manager for fast insertion into this index. - :param int buffer_size: The buffer size for insert_session mode (default=50000). + Args: + buffer_size (int): The buffer size for insert_session mode (default=50000). Returns: - datasketch.lsh.MinHashLSHInsertionSession - ''' + MinHashLSHInsertionSession: The context manager. + + Example: + + Insert 100 MinHashes into an Redis-backed index using a session: + + .. code-block:: python + + from datasketch import MinHash, MinHashLSH + import numpy as np + + minhashes = [] + for i in range(100): + m = MinHash(num_perm=128) + m.update_batch(np.random.randint(low=0, high=30, size=10)) + minhashes.append(m) + + lsh = MinHashLSH(threshold=0.5, num_perm=128, storage_config={ + 'type': 'redis', + 'redis': {'host': 'localhost', 'port': 6379}, + }) + with lsh.insertion_session() as session: + for i, m in enumerate(minhashes): + session.insert(i, m) + + """ return MinHashLSHInsertionSession(self, buffer_size=buffer_size) - def _insert(self, key, minhash, check_duplication=True, buffer=False): + def _insert( + self, + key: Hashable, + minhash: Union[MinHash, WeightedMinHash], + check_duplication: bool = True, + buffer: bool = False, + ): if len(minhash) != self.h: - raise ValueError("Expecting minhash with length %d, got %d" - % (self.h, len(minhash))) + raise ValueError( + "Expecting minhash with length %d, got %d" % (self.h, len(minhash)) + ) if self.prepickle: key = pickle.dumps(key) if check_duplication and key in self.keys: raise ValueError("The given key already exists") - Hs = [self._H(minhash.hashvalues[start:end]) - for start, end in self.hashranges] + Hs = [self._H(minhash.hashvalues[start:end]) for start, end in self.hashranges] self.keys.insert(key, *Hs, buffer=buffer) for H, hashtable in zip(Hs, self.hashtables): hashtable.insert(H, key, buffer=buffer) - def query(self, minhash): - ''' + def query(self, minhash) -> List[Hashable]: + """ Giving the MinHash of the query set, retrieve the keys that reference sets with Jaccard similarities likely greater than the threshold. Results are based on minhash segment collision and are thus approximate. For more accurate results, - filter again with `minhash.jaccard`. For exact results, + filter again with :meth:`MinHash.jaccard`. For exact results, filter by computing Jaccard similarity using original sets. Args: - minhash (datasketch.MinHash): The MinHash of the query set. + minhash (MinHash): The MinHash of the query set. Returns: - `list` of unique keys. - ''' + list: a list of unique keys. + + Example: + + Query and rank results using :meth:`MinHash.jaccard`. + + .. code-block:: python + + from datasketch import MinHash, MinHashLSH + import numpy as np + + # Generate 100 MinHashes. + minhashes = [] + for i in range(100): + m = MinHash(num_perm=128) + m.update_batch(np.random.randint(low=0, high=30, size=10)) + minhashes.append(m) + + # Create LSH index. + lsh = MinHashLSH(threshold=0.5, num_perm=128) + for i, m in enumerate(minhashes): + lsh.insert(i, m) + + # Get the initial results from LSH. + query = minhashes[0] + results = lsh.query(query) + + # Rank results using Jaccard similarity estimated by MinHash. + results = [(query.jaccard(minhashes[key]), key) for key in results] + results.sort(reverse=True) + print(results) + + Output: + + .. code-block:: + + [(1.0, 0), (0.421875, 4), (0.4140625, 19), (0.359375, 58), (0.3359375, 78), (0.265625, 62), (0.2578125, 11), (0.25, 98), (0.171875, 21)] + + Note that although the threshold is set to 0.5, the results are not + guaranteed to be above 0.5 because the LSH index is approximate and + the Jaccard similarity is estimated by MinHash. + + """ if len(minhash) != self.h: - raise ValueError("Expecting minhash with length %d, got %d" - % (self.h, len(minhash))) + raise ValueError( + "Expecting minhash with length %d, got %d" % (self.h, len(minhash)) + ) candidates = set() for (start, end), hashtable in zip(self.hashranges, self.hashtables): H = self._H(minhash.hashvalues[start:end]) @@ -204,39 +354,40 @@ def query(self, minhash): else: return list(candidates) - def add_to_query_buffer(self, minhash): - ''' + def add_to_query_buffer(self, minhash: Union[MinHash, WeightedMinHash]) -> None: + """ Giving the MinHash of the query set, buffer queries to retrieve the keys that references sets with Jaccard similarities greater than the threshold. Buffered queries can be executed using - `collect_query_buffer`. The combination of these + :meth:`collect_query_buffer`. The combination of these functions is way faster if cassandra backend is used with `shared_buffer`. Args: - minhash (datasketch.MinHash): The MinHash of the query set. - ''' + minhash (MinHash): The MinHash of the query set. + """ if len(minhash) != self.h: - raise ValueError("Expecting minhash with length %d, got %d" - % (self.h, len(minhash))) + raise ValueError( + "Expecting minhash with length %d, got %d" % (self.h, len(minhash)) + ) for (start, end), hashtable in zip(self.hashranges, self.hashtables): H = self._H(minhash.hashvalues[start:end]) hashtable.add_to_select_buffer([H]) - def collect_query_buffer(self): - ''' + def collect_query_buffer(self) -> List[Hashable]: + """ Execute and return buffered queries given - by `add_to_query_buffer`. + by :meth:`add_to_query_buffer`. If multiple query MinHash were added to the query buffer, the intersection of the results of all query MinHash will be returned. Returns: - `list` of unique keys. - ''' + list: a list of unique keys. + """ collected_result_sets = [ set(collected_result_lists) for hashtable in self.hashtables @@ -245,29 +396,34 @@ def collect_query_buffer(self): if not collected_result_sets: return [] if self.prepickle: - return [pickle.loads(key) for key in set.intersection(*collected_result_sets)] + return [ + pickle.loads(key) for key in set.intersection(*collected_result_sets) + ] return list(set.intersection(*collected_result_sets)) - def __contains__(self, key): - ''' + def __contains__(self, key: Hashable) -> bool: + """ Args: - key (hashable): The unique identifier of a set. + key (Hashable): The unique identifier of a set. Returns: bool: True only if the key exists in the index. - ''' + """ if self.prepickle: key = pickle.dumps(key) return key in self.keys - def remove(self, key): - ''' + def remove(self, key: Hashable) -> None: + """ Remove the key from the index. Args: - key (hashable): The unique identifier of a set. + key (Hashable): The unique identifier of a set. - ''' + Raises: + ValueError: If the key does not exist. + + """ if self.prepickle: key = pickle.dumps(key) if key not in self.keys: @@ -278,11 +434,11 @@ def remove(self, key): hashtable.remove(H) self.keys.remove(key) - def is_empty(self): - ''' + def is_empty(self) -> bool: + """ Returns: - bool: Check if the index is empty. - ''' + bool: `True` only if the index is empty. + """ return any(t.size() == 0 for t in self.hashtables) def _byteswap(self, hs): @@ -293,8 +449,9 @@ def _hashed_byteswap(self, hs): def _query_b(self, minhash, b): if len(minhash) != self.h: - raise ValueError("Expecting minhash with length %d, got %d" - % (self.h, len(minhash))) + raise ValueError( + "Expecting minhash with length %d, got %d" % (self.h, len(minhash)) + ) if b > len(self.hashtables): raise ValueError("b must be less or equal to the number of hash tables") candidates = set() @@ -308,30 +465,34 @@ def _query_b(self, minhash, b): else: return candidates - def get_counts(self): - ''' - Returns a list of length ``self.b`` with elements representing the - number of keys stored under each bucket for the given permutation. - ''' - counts = [ - hashtable.itemcounts() for hashtable in self.hashtables] - return counts - - def get_subset_counts(self, *keys): - ''' - Returns the bucket allocation counts (see :func:`~datasketch.MinHashLSH.get_counts` above) + def get_counts(self) -> List[Dict[Hashable, int]]: + """ + Returns a list of length :attr:`b` (i.e., number of hash tables) with + each element a dictionary mapping hash table bucket key to the number + of indexed keys stored under each bucket. + + Returns: + list: a list of dictionaries. + """ + return [hashtable.itemcounts() for hashtable in self.hashtables] + + def get_subset_counts(self, *keys: Hashable) -> List[Dict[Hashable, int]]: + """ + Returns the bucket allocation counts (see :meth:`get_counts` above) restricted to the list of keys given. Args: - keys (hashable) : the keys for which to get the bucket allocation - counts - ''' + keys (Hashable) : the keys for which to get the bucket allocation + counts. + + Returns: + list: a list of dictionaries. + """ if self.prepickle: key_set = [pickle.dumps(key) for key in set(keys)] else: key_set = list(set(keys)) - hashtables = [unordered_storage({'type': 'dict'}) for _ in - range(self.b)] + hashtables = [unordered_storage({"type": "dict"}) for _ in range(self.b)] Hss = self.keys.getmany(*key_set) for key, Hs in zip(key_set, Hss): for H, hashtable in zip(Hs, hashtables): @@ -340,33 +501,41 @@ def get_subset_counts(self, *keys): class MinHashLSHInsertionSession: - '''Context manager for batch insertion of documents into a MinHashLSH. - ''' + """Context manager for batch insertion of documents into a MinHashLSH. + + Args: + lsh (MinHashLSH): The MinHashLSH to insert into. + buffer_size (int): The buffer size for insert_session mode. + """ - def __init__(self, lsh, buffer_size): + def __init__(self, lsh: MinHashLSH, buffer_size: int): self.lsh = lsh self.lsh.buffer_size = buffer_size - def __enter__(self): + def __enter__(self) -> MinHashLSHInsertionSession: return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.close() - def close(self): + def close(self) -> None: self.lsh.keys.empty_buffer() for hashtable in self.lsh.hashtables: hashtable.empty_buffer() - def insert(self, key, minhash, check_duplication=True): - ''' + def insert( + self, + key: Hashable, + minhash: Union[MinHash, WeightedMinHash], + check_duplication=True, + ) -> None: + """ Insert a unique key to the index, together with a MinHash (or weighted MinHash) of the set referenced by the key. Args: - key (hashable): The unique identifier of the set. - minhash (datasketch.MinHash): The MinHash of the set. - ''' - self.lsh._insert(key, minhash, check_duplication=check_duplication, - buffer=True) + key (Hashable): The unique identifier of the set. + minhash (Union[MinHash, WeightedMinhash]): The MinHash of the set. + """ + self.lsh._insert(key, minhash, check_duplication=check_duplication, buffer=True) From e43b12f7fe2a4dfa3eddb1c4ce7d4c8e6a124574 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Thu, 7 Sep 2023 10:08:46 -0700 Subject: [PATCH 2/5] update minhash doc --- datasketch/lsh.py | 11 +- datasketch/minhash.py | 286 ++++++++++++++++++++++++++---------------- docs/conf.py | 101 ++++++++------- 3 files changed, 235 insertions(+), 163 deletions(-) diff --git a/datasketch/lsh.py b/datasketch/lsh.py index f665e27f..7a43d496 100644 --- a/datasketch/lsh.py +++ b/datasketch/lsh.py @@ -308,12 +308,11 @@ def query(self, minhash) -> List[Hashable]: from datasketch import MinHash, MinHashLSH import numpy as np - # Generate 100 MinHashes. - minhashes = [] - for i in range(100): - m = MinHash(num_perm=128) - m.update_batch(np.random.randint(low=0, high=30, size=10)) - minhashes.append(m) + # Generate 100 random MinHashes. + minhashes = MinHash.bulk( + np.random.randint(low=0, high=30, size=(100, 10)), + num_perm=128 + ) # Create LSH index. lsh = MinHashLSH(threshold=0.5, num_perm=128) diff --git a/datasketch/minhash.py b/datasketch/minhash.py index 3964f306..bc554bce 100644 --- a/datasketch/minhash.py +++ b/datasketch/minhash.py @@ -1,4 +1,6 @@ -import random, copy, struct +from __future__ import annotations +import copy +from typing import Callable, Generator, Iterable, List, Optional, Tuple import warnings import numpy as np @@ -10,36 +12,41 @@ # http://en.wikipedia.org/wiki/Mersenne_prime _mersenne_prime = np.uint64((1 << 61) - 1) _max_hash = np.uint64((1 << 32) - 1) -_hash_range = (1 << 32) +_hash_range = 1 << 32 + class MinHash(object): - '''MinHash is a probabilistic data structure for computing + """MinHash is a probabilistic data structure for computing `Jaccard similarity`_ between sets. Args: - num_perm (int, optional): Number of random permutation functions. + num_perm (Optional[int]): Number of random permutation functions. It will be ignored if `hashvalues` is not None. - seed (int, optional): The random seed controls the set of random + seed (Optional[int]): The random seed controls the set of random permutation functions generated for this MinHash. - hashfunc (optional): The hash function used by this MinHash. - It takes the input passed to the `update` method and + hashfunc (Optional[Callable]): The hash function used by + this MinHash. + It takes the input passed to the :meth:`update` method and returns an integer that can be encoded with 32 bits. The default hash function is based on SHA1 from hashlib_. + Users can use `farmhash` for better performance. + See the example in :meth:`update`. hashobj (**deprecated**): This argument is deprecated since version 1.4.0. It is a no-op and has been replaced by `hashfunc`. - hashvalues (`numpy.array` or `list`, optional): The hash values is + hashvalues (Optional[Iterable]): The hash values is the internal state of the MinHash. It can be specified for faster - initialization using the existing state from another MinHash. - permutations (optional): The permutation function parameters. This argument + initialization using the existing :attr:`hashvalues` of another MinHash. + permutations (Optional[Tuple[Iterable, Iterable]]): The permutation + function parameters as a tuple of two lists. This argument can be specified for faster initialization using the existing - state from another MinHash. + :attr:`permutations` from another MinHash. Note: To save memory usage, consider using :class:`datasketch.LeanMinHash`. Note: Since version 1.1.1, MinHash will only support serialization using - `pickle`_. ``serialize`` and ``deserialize`` methods are removed, + pickle_. ``serialize`` and ``deserialize`` methods are removed, and are supported in :class:`datasketch.LeanMinHash` instead. MinHash serialized before version 1.1.1 cannot be deserialized properly in newer versions (`need to migrate? `_). @@ -49,25 +56,33 @@ class MinHash(object): instead of Python's built-in random package. This change makes the hash values consistent across different Python versions. The side-effect is that now MinHash created before version 1.1.3 won't - work (i.e., ``jaccard``, ``merge`` and ``union``) + work (i.e., :meth:`jaccard`, :meth:`merge` and :meth:`union`) with those created after. .. _`Jaccard similarity`: https://en.wikipedia.org/wiki/Jaccard_index .. _hashlib: https://docs.python.org/3.5/library/hashlib.html .. _`pickle`: https://docs.python.org/3/library/pickle.html - ''' - - def __init__(self, num_perm=128, seed=1, - hashfunc=sha1_hash32, - hashobj=None, # Deprecated. - hashvalues=None, permutations=None): + """ + + def __init__( + self, + num_perm: int = 128, + seed: int = 1, + hashfunc: Callable = sha1_hash32, + hashobj=None, # Deprecated. + hashvalues: Optional[Iterable] = None, + permutations: Optional[Tuple[Iterable, Iterable]] = None, + ): if hashvalues is not None: num_perm = len(hashvalues) if num_perm > _hash_range: # Because 1) we don't want the size to be too large, and # 2) we are using 4 bytes to store the size value - raise ValueError("Cannot have more than %d number of\ - permutation functions" % _hash_range) + raise ValueError( + "Cannot have more than %d number of\ + permutation functions" + % _hash_range + ) self.seed = seed self.num_perm = num_perm # Check the hash function. @@ -76,8 +91,9 @@ def __init__(self, num_perm=128, seed=1, self.hashfunc = hashfunc # Check for use of hashobj and issue warning. if hashobj is not None: - warnings.warn("hashobj is deprecated, use hashfunc instead.", - DeprecationWarning) + warnings.warn( + "hashobj is deprecated, use hashfunc instead.", DeprecationWarning + ) # Initialize hash values if hashvalues is not None: self.hashvalues = self._parse_hashvalues(hashvalues) @@ -91,23 +107,30 @@ def __init__(self, num_perm=128, seed=1, if len(self) != len(self.permutations[0]): raise ValueError("Numbers of hash values and permutations mismatch") - def _init_hashvalues(self, num_perm): - return np.ones(num_perm, dtype=np.uint64)*_max_hash + def _init_hashvalues(self, num_perm: int) -> np.ndarray: + return np.ones(num_perm, dtype=np.uint64) * _max_hash - def _init_permutations(self, num_perm): + def _init_permutations(self, num_perm: int) -> np.ndarray: # Create parameters for a random bijective permutation function # that maps a 32-bit hash value to another 32-bit hash value. # http://en.wikipedia.org/wiki/Universal_hashing gen = np.random.RandomState(self.seed) - return np.array([ - (gen.randint(1, _mersenne_prime, dtype=np.uint64), gen.randint(0, _mersenne_prime, dtype=np.uint64)) for _ in range(num_perm) - ], dtype=np.uint64).T + return np.array( + [ + ( + gen.randint(1, _mersenne_prime, dtype=np.uint64), + gen.randint(0, _mersenne_prime, dtype=np.uint64), + ) + for _ in range(num_perm) + ], + dtype=np.uint64, + ).T def _parse_hashvalues(self, hashvalues): return np.array(hashvalues, dtype=np.uint64) - def update(self, b): - '''Update this MinHash with a new value. + def update(self, b) -> None: + """Update this MinHash with a new value. The value will be hashed using the hash function specified by the `hashfunc` argument in the constructor. @@ -132,19 +155,19 @@ def _hash_32(b): return farmhash.hash32(b) minhash = MinHash(hashfunc=_hash_32) minhash.update("new value") - ''' + """ hv = self.hashfunc(b) a, b = self.permutations phv = np.bitwise_and((a * hv + b) % _mersenne_prime, _max_hash) self.hashvalues = np.minimum(phv, self.hashvalues) - def update_batch(self, b): - '''Update this MinHash with new values. + def update_batch(self, b: Iterable) -> None: + """Update this MinHash with new values. The values will be hashed using the hash function specified by the `hashfunc` argument in the constructor. Args: - b (list): List of values to be hashed using the hash function specified. + b (Iterable): Values to be hashed using the hash function specified. Example: To update with new string values (using the default SHA1 hash @@ -154,137 +177,189 @@ def update_batch(self, b): minhash = Minhash() minhash.update_batch([s.encode('utf-8') for s in ["token1", "token2"]]) - ''' - hv = np.array([self.hashfunc(_b) for _b in b], dtype=np.uint64,ndmin=2).T + """ + hv = np.array([self.hashfunc(_b) for _b in b], dtype=np.uint64, ndmin=2).T a, b = self.permutations phv = (hv * a + b) % _mersenne_prime & _max_hash self.hashvalues = np.vstack([phv, self.hashvalues]).min(axis=0) - def jaccard(self, other): - '''Estimate the `Jaccard similarity`_ (resemblance) between the sets + def jaccard(self, other: MinHash) -> float: + """Estimate the `Jaccard similarity`_ (resemblance) between the sets represented by this MinHash and the other. Args: - other (datasketch.MinHash): The other MinHash. + other (MinHash): The other MinHash. Returns: float: The Jaccard similarity, which is between 0.0 and 1.0. - ''' + + Raises: + ValueError: If the two MinHashes have different numbers of + permutation functions or different seeds. + """ if other.seed != self.seed: - raise ValueError("Cannot compute Jaccard given MinHash with\ - different seeds") + raise ValueError( + "Cannot compute Jaccard given MinHash with\ + different seeds" + ) if len(self) != len(other): - raise ValueError("Cannot compute Jaccard given MinHash with\ - different numbers of permutation functions") - return float(np.count_nonzero(self.hashvalues==other.hashvalues)) /\ - float(len(self)) - - def count(self): - '''Estimate the cardinality count based on the technique described in + raise ValueError( + "Cannot compute Jaccard given MinHash with\ + different numbers of permutation functions" + ) + return float(np.count_nonzero(self.hashvalues == other.hashvalues)) / float( + len(self) + ) + + def count(self) -> float: + """Estimate the cardinality count based on the technique described in `this paper `_. Returns: int: The estimated cardinality of the set represented by this MinHash. - ''' + """ k = len(self) return float(k) / np.sum(self.hashvalues / float(_max_hash)) - 1.0 - def merge(self, other): - '''Merge the other MinHash with this one, making this one the union + def merge(self, other: MinHash) -> None: + """Merge the other MinHash with this one, making this one the union of both. Args: - other (datasketch.MinHash): The other MinHash. - ''' + other (MinHash): The other MinHash. + + Raises: + ValueError: If the two MinHashes have different numbers of + permutation functions or different seeds. + """ if other.seed != self.seed: - raise ValueError("Cannot merge MinHash with\ - different seeds") + raise ValueError( + "Cannot merge MinHash with\ + different seeds" + ) if len(self) != len(other): - raise ValueError("Cannot merge MinHash with\ - different numbers of permutation functions") + raise ValueError( + "Cannot merge MinHash with\ + different numbers of permutation functions" + ) self.hashvalues = np.minimum(other.hashvalues, self.hashvalues) - def digest(self): - '''Export the hash values, which is the internal state of the + def digest(self) -> np.ndarray: + """Export the hash values, which is the internal state of the MinHash. Returns: - numpy.array: The hash values which is a Numpy array. - ''' + numpy.ndarray: The hash values which is a Numpy array. + """ return copy.copy(self.hashvalues) - def is_empty(self): - ''' + def is_empty(self) -> bool: + """ Returns: bool: If the current MinHash is empty - at the state of just initialized. - ''' + """ if np.any(self.hashvalues != _max_hash): return False return True - def clear(self): - ''' + def clear(self) -> None: + """ Clear the current state of the MinHash. All hash values are reset. - ''' + """ self.hashvalues = self._init_hashvalues(len(self)) - def copy(self): - ''' - :returns: datasketch.MinHash -- A copy of this MinHash by exporting its state. - ''' - return MinHash(seed=self.seed, hashfunc=self.hashfunc, - hashvalues=self.digest(), - permutations=self.permutations) - - def __len__(self): - ''' - :returns: int -- The number of hash values. - ''' + def copy(self) -> MinHash: + """ + Returns: + MinHash: a copy of this MinHash by exporting its state. + """ + return MinHash( + seed=self.seed, + hashfunc=self.hashfunc, + hashvalues=self.digest(), + permutations=self.permutations, + ) + + def __len__(self) -> int: + """ + Returns: + int: The number of hash values. + """ return len(self.hashvalues) - def __eq__(self, other): - ''' - :returns: bool -- If their seeds and hash values are both equal then two are equivalent. - ''' - return type(self) is type(other) and \ - self.seed == other.seed and \ - np.array_equal(self.hashvalues, other.hashvalues) + def __eq__(self, other: MinHash) -> bool: + """ + Returns: + bool: If their seeds and hash values are both equal then two are equivalent. + """ + return ( + type(self) is type(other) + and self.seed == other.seed + and np.array_equal(self.hashvalues, other.hashvalues) + ) @classmethod - def union(cls, *mhs): - '''Create a MinHash which is the union of the MinHash objects passed as arguments. + def union(cls, *mhs: MinHash) -> MinHash: + """Create a MinHash which is the union of the MinHash objects passed as arguments. Args: - *mhs: The MinHash objects to be united. The argument list length is variable, + *mhs (MinHash): The MinHash objects to be united. The argument list length is variable, but must be at least 2. Returns: - datasketch.MinHash: A new union MinHash. - ''' + MinHash: a new union MinHash. + + Raises: + ValueError: If the number of MinHash objects passed as arguments is less than 2, + or if the MinHash objects passed as arguments have different seeds or + different numbers of permutation functions. + + Example: + + .. code-block:: python + + from datasketch import MinHash + import numpy as np + + m1 = MinHash(num_perm=128) + m1.update_batch(np.random.randint(low=0, high=30, size=10)) + + m2 = MinHash(num_perm=128) + m2.update_batch(np.random.randint(low=0, high=30, size=10)) + + # Union m1 and m2. + m = MinHash.union(m1, m2) + """ if len(mhs) < 2: raise ValueError("Cannot union less than 2 MinHash") num_perm = len(mhs[0]) seed = mhs[0].seed if any((seed != m.seed or num_perm != len(m)) for m in mhs): - raise ValueError("The unioning MinHash must have the\ - same seed and number of permutation functions") + raise ValueError( + "The unioning MinHash must have the\ + same seed and number of permutation functions" + ) hashvalues = np.minimum.reduce([m.hashvalues for m in mhs]) permutations = mhs[0].permutations - return cls(num_perm=num_perm, seed=seed, hashvalues=hashvalues, - permutations=permutations) + return cls( + num_perm=num_perm, + seed=seed, + hashvalues=hashvalues, + permutations=permutations, + ) @classmethod - def bulk(cls, b, **minhash_kwargs): - '''Compute MinHashes in bulk. This method avoids unnecessary + def bulk(cls, b: Iterable, **minhash_kwargs) -> List[MinHash]: + """Compute MinHashes in bulk. This method avoids unnecessary overhead when initializing many minhashes by reusing the initialized state. Args: b (Iterable): An Iterable of lists of bytes, each list is hashed in to one MinHash in the output. - minhash_kwargs: Keyword arguments used to initialize MinHash, + **minhash_kwargs: Keyword arguments used to initialize MinHash, will be used for all minhashes. Returns: @@ -299,12 +374,12 @@ def bulk(cls, b, **minhash_kwargs): [b'token4', b'token5', b'token6']] minhashes = MinHash.bulk(data, num_perm=64) - ''' + """ return list(cls.generator(b, **minhash_kwargs)) @classmethod - def generator(cls, b, **minhash_kwargs): - '''Compute MinHashes in a generator. This method avoids unnecessary + def generator(cls, b: Iterable, **minhash_kwargs) -> Generator[MinHash]: + """Compute MinHashes in a generator. This method avoids unnecessary overhead when initializing many minhashes by reusing the initialized state. @@ -315,7 +390,7 @@ def generator(cls, b, **minhash_kwargs): will be used for all minhashes. Returns: - A generator of computed MinHashes. + Generator[MinHash]: a generator of computed MinHashes. Example: @@ -328,10 +403,9 @@ def generator(cls, b, **minhash_kwargs): # do something useful minhash - ''' + """ m = cls(**minhash_kwargs) for _b in b: _m = m.copy() _m.update_batch(_b) yield _m - diff --git a/docs/conf.py b/docs/conf.py index 345f2679..a6b8c280 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -28,41 +28,42 @@ # needs_sphinx = '1.0' from datetime import datetime + year = datetime.now().year # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.napoleon', - 'sphinx.ext.doctest', - 'sphinx.ext.coverage', - 'sphinx.ext.mathjax', - 'sphinx.ext.viewcode', - 'sphinx.ext.githubpages', + "sphinx.ext.autodoc", + "sphinx.ext.napoleon", + "sphinx.ext.doctest", + "sphinx.ext.coverage", + "sphinx.ext.mathjax", + "sphinx.ext.viewcode", + "sphinx.ext.githubpages", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. # # source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'datasketch' -copyright = '%d, Eric Zhu' % year -author = 'ekzhu' +project = "datasketch" +copyright = "%d, Eric Zhu" % year +author = "ekzhu" import datasketch @@ -81,7 +82,7 @@ # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = None +language = "en" # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: @@ -95,7 +96,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # The reST default role (used for this markup: `text`) to use for all # documents. @@ -117,7 +118,7 @@ # show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. # modindex_common_prefix = [] @@ -134,7 +135,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "sphinx_rtd_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -146,8 +147,7 @@ # 'github_repo' : 'datasketch', # 'github_type': 'star', # 'fixed_sidebar' : True, - 'analytics_id': "UA-93507731-1", - + "analytics_id": "UA-93507731-1", # 'analytics_anonymize_ip': False, # 'logo_only': False, # 'display_version': True, @@ -189,7 +189,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied @@ -211,11 +211,11 @@ # Custom sidebar templates, maps document names to template names. # html_sidebars = { - '**': [ - 'about.html', - 'navigation.html', - 'relations.html', - 'searchbox.html', + "**": [ + "about.html", + "navigation.html", + "relations.html", + "searchbox.html", ] } @@ -276,34 +276,30 @@ # html_search_scorer = 'scorer.js' # Output file base name for HTML help builder. -htmlhelp_basename = 'datasketchdoc' +htmlhelp_basename = "datasketchdoc" # -- Options for LaTeX output --------------------------------------------- latex_elements = { - # The paper size ('letterpaper' or 'a4paper'). - # - # 'papersize': 'letterpaper', - - # The font size ('10pt', '11pt' or '12pt'). - # - # 'pointsize': '10pt', - - # Additional stuff for the LaTeX preamble. - # - # 'preamble': '', - - # Latex figure (float) alignment - # - # 'figure_align': 'htbp', + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'datasketch.tex', 'datasketch Documentation', - 'ekzhu', 'manual'), + (master_doc, "datasketch.tex", "datasketch Documentation", "ekzhu", "manual"), ] # The name of an image file (relative to this directory) to place at the top of @@ -343,10 +339,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'datasketch', 'datasketch Documentation', - [author], 1) -] +man_pages = [(master_doc, "datasketch", "datasketch Documentation", [author], 1)] # If true, show URL addresses after external links. # @@ -359,9 +352,15 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'datasketch', 'datasketch Documentation', - author, 'datasketch', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + "datasketch", + "datasketch Documentation", + author, + "datasketch", + "One line description of project.", + "Miscellaneous", + ), ] # Documents to append as an appendix to all manuals. @@ -382,4 +381,4 @@ # -- Additional stuff -autodoc_member_order = 'bysource' +autodoc_member_order = "bysource" From 419cadc09d7fc637b391d1cab73905f5804bce4d Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Thu, 7 Sep 2023 10:16:54 -0700 Subject: [PATCH 3/5] fix bug --- datasketch/lsh.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasketch/lsh.py b/datasketch/lsh.py index 7a43d496..bb461de1 100644 --- a/datasketch/lsh.py +++ b/datasketch/lsh.py @@ -174,7 +174,7 @@ def __init__( raise ValueError("The number of bands are too small (b < 2)") self.prepickle = ( - storage_config["type"] == "redis" if prepickle is None else prepickle + storage_config["type"] == "redis" if not prepickle else prepickle ) self.hashfunc = hashfunc From 22ba154dae25a52d0fa60b0cebfa74ef2c27e428 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Thu, 7 Sep 2023 10:19:00 -0700 Subject: [PATCH 4/5] fix bug --- datasketch/lsh.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datasketch/lsh.py b/datasketch/lsh.py index bb461de1..7edd3f03 100644 --- a/datasketch/lsh.py +++ b/datasketch/lsh.py @@ -68,8 +68,8 @@ class MinHashLSH(object): `basename` is an optional property whose value will be used as the prefix to stored keys. If this is not set, a random string will be generated instead. If you set this, you will be responsible for ensuring there are no key collisions. - prepickle (bool): If True, all keys are pickled to bytes before - insertion. If False, a default value is chosen based on the + prepickle (Optional[bool]): If True, all keys are pickled to bytes before + insertion. If not specified, a default value is chosen based on the `storage_config`. hashfunc (Optional[Callable[[bytes], bytes]]): If a hash function is provided it will be used to compress the index keys to reduce the memory footprint. This could cause a higher @@ -141,7 +141,7 @@ def __init__( weights: Tuple[float, float] = (0.5, 0.5), params: Optional[Tuple[int, int]] = None, storage_config: Optional[Dict] = None, - prepickle: bool = False, + prepickle: Optional[bool] = None, hashfunc: Optional[Callable[[bytes], bytes]] = None, ) -> None: storage_config = {"type": "dict"} if not storage_config else storage_config From a0d267b48f72673432e482bb2f663cce3e770cbb Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Thu, 7 Sep 2023 11:32:14 -0700 Subject: [PATCH 5/5] update doc --- datasketch/hyperloglog.py | 231 +++++++++++++++++++-------------- datasketch/lean_minhash.py | 98 +++++++------- datasketch/lsh.py | 6 +- datasketch/lshensemble.py | 171 ++++++++++++++---------- datasketch/lshforest.py | 90 +++++++------ datasketch/minhash.py | 14 +- datasketch/weighted_minhash.py | 137 ++++++++++--------- 7 files changed, 425 insertions(+), 322 deletions(-) diff --git a/datasketch/hyperloglog.py b/datasketch/hyperloglog.py index cc7fbd64..927b1c11 100644 --- a/datasketch/hyperloglog.py +++ b/datasketch/hyperloglog.py @@ -1,6 +1,9 @@ +from __future__ import annotations import struct, copy +from typing import Callable, Optional import numpy as np import warnings + try: from .hyperloglog_const import _thresholds, _raw_estimate, _bias except ImportError: @@ -10,14 +13,14 @@ from datasketch.hashfunc import sha1_hash32, sha1_hash64 # Get the number of bits starting from the first non-zero bit to the right -_bit_length = lambda bits : bits.bit_length() +_bit_length = lambda bits: bits.bit_length() # For < Python 2.7 -if not hasattr(int, 'bit_length'): - _bit_length = lambda bits : len(bin(bits)) - 2 if bits > 0 else 0 +if not hasattr(int, "bit_length"): + _bit_length = lambda bits: len(bin(bits)) - 2 if bits > 0 else 0 class HyperLogLog(object): - ''' + """ The HyperLogLog data sketch for estimating cardinality of very large dataset in a single pass. The original HyperLogLog is described `here @@ -27,20 +30,20 @@ class HyperLogLog(object): https://github.com/svpcom/hyperloglog Args: - p (int, optional): The precision parameter. It is ignored if + p (int): The precision parameter. It is ignored if the `reg` is given. - reg (numpy.array, optional): The internal state. + reg (Optional[numpy.ndarray]): The internal state. This argument is for initializing the HyperLogLog from an existing one. - hashfunc (optional): The hash function used by this MinHash. + hashfunc (Callable): The hash function used by this MinHash. It takes the input passed to the `update` method and returns an integer that can be encoded with 32 bits. The default hash function is based on SHA1 from hashlib_. hashobj (**deprecated**): This argument is deprecated since version 1.4.0. It is a no-op and has been replaced by `hashfunc`. - ''' + """ - __slots__ = ('p', 'm', 'reg', 'alpha', 'max_rank', 'hashfunc') + __slots__ = ("p", "m", "reg", "alpha", "max_rank", "hashfunc") # The range of the hash values used for HyperLogLog _hash_range_bit = 32 @@ -57,7 +60,13 @@ def _get_alpha(self, p): return 0.709 return 0.7213 / (1.0 + 1.079 / (1 << p)) - def __init__(self, p=8, reg=None, hashfunc=sha1_hash32, hashobj=None): + def __init__( + self, + p: int = 8, + reg: Optional[np.ndarray] = None, + hashfunc: Callable = sha1_hash32, + hashobj: Optional[object] = None, # Deprecated + ): if reg is None: self.p = p self.m = 1 << p @@ -70,8 +79,10 @@ def __init__(self, p=8, reg=None, hashfunc=sha1_hash32, hashobj=None): self.m = reg.size self.p = _bit_length(self.m) - 1 if 1 << self.p != self.m: - raise ValueError("The imported register has \ - incorrect size. Expect a power of 2.") + raise ValueError( + "The imported register has \ + incorrect size. Expect a power of 2." + ) # Generally we trust the user to import register that contains # reasonable counter values, so we don't check for every values. self.reg = reg @@ -80,15 +91,16 @@ def __init__(self, p=8, reg=None, hashfunc=sha1_hash32, hashobj=None): raise ValueError("The hashfunc must be a callable.") # Check for use of hashobj and issue warning. if hashobj is not None: - warnings.warn("hashobj is deprecated, use hashfunc instead.", - DeprecationWarning) + warnings.warn( + "hashobj is deprecated, use hashfunc instead.", DeprecationWarning + ) self.hashfunc = hashfunc # Common settings self.alpha = self._get_alpha(self.p) self.max_rank = self._hash_range_bit - self.p - def update(self, b): - ''' + def update(self, b) -> None: + """ Update the HyperLogLog with a new data value in bytes. The value will be hashed using the hash function specified by the `hashfunc` argument in the constructor. @@ -114,7 +126,7 @@ def _hash_32(b): return farmhash.hash32(b) hll = HyperLogLog(hashfunc=_hash_32) hll.update("new value") - ''' + """ # Digest the hash object to get the hash value hv = self.hashfunc(b) # Get the index of the register using the first p bits of the hash @@ -124,20 +136,24 @@ def _hash_32(b): # Update the register self.reg[reg_index] = max(self.reg[reg_index], self._get_rank(bits)) - def count(self): - ''' + def count(self) -> float: + """ Estimate the cardinality of the data values seen so far. Returns: - int: The estimated cardinality. - ''' + float: The estimated cardinality. + """ # Use HyperLogLog estimation function - e = self.alpha * float(self.m ** 2) / np.sum(2.0**(-self.reg)) + e = self.alpha * float(self.m**2) / np.sum(2.0 ** (-self.reg)) # Small range correction small_range_threshold = (5.0 / 2.0) * self.m - if abs(e-small_range_threshold)/small_range_threshold < 0.15: - warnings.warn(("Warning: estimate is close to error correction threshold. " - +"Output may not satisfy HyperLogLog accuracy guarantee.")) + if abs(e - small_range_threshold) / small_range_threshold < 0.15: + warnings.warn( + ( + "Warning: estimate is close to error correction threshold. " + + "Output may not satisfy HyperLogLog accuracy guarantee." + ) + ) if e <= small_range_threshold: num_zero = self.m - np.count_nonzero(self.reg) return self._linearcounting(num_zero) @@ -147,131 +163,147 @@ def count(self): # Large range correction return self._largerange_correction(e) - def merge(self, other): - ''' + def merge(self, other: HyperLogLog) -> None: + """ Merge the other HyperLogLog with this one, making this the union of the two. Args: - other (datasketch.HyperLogLog): - ''' + other (HyperLogLog): The other HyperLogLog to be merged. + """ if self.m != other.m or self.p != other.p: - raise ValueError("Cannot merge HyperLogLog with different\ - precisions.") + raise ValueError( + "Cannot merge HyperLogLog with different\ + precisions." + ) self.reg = np.maximum(self.reg, other.reg) - def digest(self): - ''' + def digest(self) -> np.ndarray: + """ Returns: numpy.array: The current internal state. - ''' + """ return copy.copy(self.reg) - def copy(self): - ''' + def copy(self) -> HyperLogLog: + """ Create a copy of the current HyperLogLog by exporting its state. Returns: - datasketch.HyperLogLog: - ''' + HyperLogLog: A copy of the current HyperLogLog. + """ return self.__class__(reg=self.digest(), hashfunc=self.hashfunc) - def is_empty(self): - ''' + def is_empty(self) -> bool: + """ Returns: bool: True if the current HyperLogLog is empty - at the state of just initialized. - ''' + """ if np.any(self.reg): return False return True - def clear(self): - ''' + def clear(self) -> None: + """ Reset the current HyperLogLog to empty. - ''' + """ self.reg = np.zeros((self.m,), dtype=np.int8) - def __len__(self): - ''' + def __len__(self) -> int: + """ Returns: int: Get the size of the HyperLogLog as the size of `reg`. - ''' + """ return len(self.reg) - def __eq__(self, other): - ''' + def __eq__(self, other: HyperLogLog) -> bool: + """ Check equivalence between two HyperLogLogs Args: - other (datasketch.HyperLogLog): + other (HyperLogLog): Returns: bool: True if both have the same internal state. - ''' - return type(self) is type(other) and \ - self.p == other.p and \ - self.m == other.m and \ - np.array_equal(self.reg, other.reg) + """ + return ( + type(self) is type(other) + and self.p == other.p + and self.m == other.m + and np.array_equal(self.reg, other.reg) + ) def _get_rank(self, bits): rank = self.max_rank - _bit_length(bits) + 1 if rank <= 0: - raise ValueError("Hash value overflow, maximum size is %d\ - bits" % self.max_rank) + raise ValueError( + "Hash value overflow, maximum size is %d\ + bits" + % self.max_rank + ) return rank def _linearcounting(self, num_zero): return self.m * np.log(self.m / float(num_zero)) def _largerange_correction(self, e): - return - (1 << 32) * np.log(1.0 - e / (1 << 32)) + return -(1 << 32) * np.log(1.0 - e / (1 << 32)) @classmethod - def union(cls, *hyperloglogs): + def union(cls, *hyperloglogs: HyperLogLog) -> HyperLogLog: if len(hyperloglogs) < 2: - raise ValueError("Cannot union less than 2 HyperLogLog\ - sketches") + raise ValueError( + "Cannot union less than 2 HyperLogLog\ + sketches" + ) m = hyperloglogs[0].m if not all(h.m == m for h in hyperloglogs): - raise ValueError("Cannot union HyperLogLog sketches with\ - different precisions") + raise ValueError( + "Cannot union HyperLogLog sketches with\ + different precisions" + ) reg = np.maximum.reduce([h.reg for h in hyperloglogs]) h = cls(reg=reg) return h - def bytesize(self): + def bytesize(self) -> int: + """Get the size of the HyperLogLog in bytes.""" # Since p is no larger than 64, use 8 bits - p_size = struct.calcsize('B') + p_size = struct.calcsize("B") # Each register value is no larger than 64, use 8 bits # TODO: is there a way to use 5 bits instead of 8 bits # to store integer in Python? - reg_val_size = struct.calcsize('B') + reg_val_size = struct.calcsize("B") return p_size + reg_val_size * self.m def serialize(self, buf): if len(buf) < self.bytesize(): - raise ValueError("The buffer does not have enough space\ - for holding this HyperLogLog.") - fmt = 'B%dB' % self.m + raise ValueError( + "The buffer does not have enough space\ + for holding this HyperLogLog." + ) + fmt = "B%dB" % self.m struct.pack_into(fmt, buf, 0, self.p, *self.reg) @classmethod def deserialize(cls, buf): - size = struct.calcsize('B') + size = struct.calcsize("B") try: - p = struct.unpack_from('B', buf, 0)[0] + p = struct.unpack_from("B", buf, 0)[0] except TypeError: - p = struct.unpack_from('B', buffer(buf), 0)[0] + p = struct.unpack_from("B", buffer(buf), 0)[0] h = cls(p) offset = size try: - h.reg = np.array(struct.unpack_from('%dB' % h.m, - buf, offset), dtype=np.int8) + h.reg = np.array( + struct.unpack_from("%dB" % h.m, buf, offset), dtype=np.int8 + ) except TypeError: - h.reg = np.array(struct.unpack_from('%dB' % h.m, - buffer(buf), offset), dtype=np.int8) + h.reg = np.array( + struct.unpack_from("%dB" % h.m, buffer(buf), offset), dtype=np.int8 + ) return h def __getstate__(self): @@ -280,23 +312,25 @@ def __getstate__(self): return buf def __setstate__(self, buf): - size = struct.calcsize('B') + size = struct.calcsize("B") try: - p = struct.unpack_from('B', buf, 0)[0] + p = struct.unpack_from("B", buf, 0)[0] except TypeError: - p = struct.unpack_from('B', buffer(buf), 0)[0] + p = struct.unpack_from("B", buffer(buf), 0)[0] self.__init__(p=p) offset = size try: - self.reg = np.array(struct.unpack_from('%dB' % self.m, - buf, offset), dtype=np.int8) + self.reg = np.array( + struct.unpack_from("%dB" % self.m, buf, offset), dtype=np.int8 + ) except TypeError: - self.reg = np.array(struct.unpack_from('%dB' % self.m, - buffer(buf), offset), dtype=np.int8) + self.reg = np.array( + struct.unpack_from("%dB" % self.m, buffer(buf), offset), dtype=np.int8 + ) class HyperLogLogPlusPlus(HyperLogLog): - ''' + """ HyperLogLog++ is an enhanced HyperLogLog `from Google `_. Main changes from the original HyperLogLog: @@ -306,26 +340,32 @@ class HyperLogLogPlusPlus(HyperLogLog): 3. Sparse representation (not implemented here) Args: - p (int, optional): The precision parameter. It is ignored if + p (int): The precision parameter. It is ignored if the `reg` is given. - reg (numpy.array, optional): The internal state. + reg (Optional[numpy.array]): The internal state. This argument is for initializing the HyperLogLog from an existing one. - hashfunc (optional): The hash function used by this MinHash. + hashfunc (Callable): The hash function used by this MinHash. It takes the input passed to the `update` method and returns an integer that can be encoded with 64 bits. The default hash function is based on SHA1 from hashlib_. hashobj (**deprecated**): This argument is deprecated since version 1.4.0. It is a no-op and has been replaced by `hashfunc`. - ''' + """ _hash_range_bit = 64 _hash_range_byte = 8 - def __init__(self, p=8, reg=None, hashfunc=sha1_hash64, - hashobj=None): - super(HyperLogLogPlusPlus, self).__init__(p=p, reg=reg, - hashfunc=hashfunc, hashobj=hashobj) + def __init__( + self, + p: int = 8, + reg: Optional[np.ndarray] = None, + hashfunc: Callable = sha1_hash64, + hashobj: Optional[object] = None, + ): + super(HyperLogLogPlusPlus, self).__init__( + p=p, reg=reg, hashfunc=hashfunc, hashobj=hashobj + ) def _get_threshold(self, p): return _thresholds[p - 4] @@ -333,10 +373,11 @@ def _get_threshold(self, p): def _estimate_bias(self, e, p): bias_vector = _bias[p - 4] estimate_vector = _raw_estimate[p - 4] - nearest_neighbors = np.argsort((e - estimate_vector)**2)[:6] + nearest_neighbors = np.argsort((e - estimate_vector) ** 2)[:6] return np.mean(bias_vector[nearest_neighbors]) - def count(self): + def count(self) -> float: + """Estimate the cardinality of the data values seen so far.""" num_zero = self.m - np.count_nonzero(self.reg) if num_zero > 0: # linear counting @@ -344,7 +385,7 @@ def count(self): if lc <= self._get_threshold(self.p): return lc # Use HyperLogLog estimation function - e = self.alpha * float(self.m ** 2) / np.sum(2.0**(-self.reg)) + e = self.alpha * float(self.m**2) / np.sum(2.0 ** (-self.reg)) if e <= 5 * self.m: return e - self._estimate_bias(e, self.p) else: diff --git a/datasketch/lean_minhash.py b/datasketch/lean_minhash.py index e205e47d..62568dad 100644 --- a/datasketch/lean_minhash.py +++ b/datasketch/lean_minhash.py @@ -1,10 +1,13 @@ +from __future__ import annotations import struct +from typing import Iterable import numpy as np from datasketch import MinHash + class LeanMinHash(MinHash): - '''Lean MinHash is MinHash with a smaller memory footprint + """Lean MinHash is MinHash with a smaller memory footprint and faster deserialization, but with its internal state frozen -- no `update()`. @@ -25,7 +28,7 @@ class LeanMinHash(MinHash): # Or between a lean MinHash and a MinHash lean_minhash.jaccard(minhash2) - + To create a lean MinHash from the hash values and seed of an existing MinHash: @@ -51,52 +54,56 @@ class LeanMinHash(MinHash): :class:`datasketch.MinHashLSHForest`, and :class:`datasketch.MinHashLSHEnsemble`. Args: - minhash (optional): The :class:`datasketch.MinHash` object used to + minhash (optional): The :class:`datasketch.MinHash` object used to initialize the LeanMinHash. If this is not set, then `seed` - and `hashvalues` must be set. - seed (optional): The random seed that controls the set of random + and `hashvalues` must be set. + seed (optional): The random seed that controls the set of random permutation functions generated for this LeanMinHash. This parameter must be used together with `hashvalues`. hashvalues (optional): The hash values used to inititialize the state of the LeanMinHash. This parameter must be used together with `seed`. - ''' + """ - __slots__ = ('seed', 'hashvalues') + __slots__ = ("seed", "hashvalues") def _initialize_slots(self, seed, hashvalues): - '''Initialize the slots of the LeanMinHash. + """Initialize the slots of the LeanMinHash. Args: seed (int): The random seed controls the set of random permutation functions generated for this LeanMinHash. - hashvalues: The hash values is the internal state of the LeanMinHash. - ''' + hashvalues (Iterable): The hash values is the internal state of the LeanMinHash. + """ self.seed = seed self.hashvalues = self._parse_hashvalues(hashvalues) - def __init__(self, minhash=None, seed=None, hashvalues=None): + def __init__( + self, minhash: MinHash = None, seed: int = None, hashvalues: Iterable = None + ): if minhash is not None: self._initialize_slots(minhash.seed, minhash.hashvalues) elif hashvalues is not None and seed is not None: self._initialize_slots(seed, hashvalues) else: - raise ValueError("Init parameters cannot be None: make sure " - "to set either minhash or both of hash values and seed") + raise ValueError( + "Init parameters cannot be None: make sure " + "to set either minhash or both of hash values and seed" + ) - def update(self, b): - '''This method is not available on a LeanMinHash. + def update(self, b) -> None: + """This method is not available on a LeanMinHash. Calling it raises a TypeError. - ''' + """ raise TypeError("Cannot update a LeanMinHash") - def copy(self): + def copy(self) -> LeanMinHash: lmh = object.__new__(LeanMinHash) lmh._initialize_slots(*self.__slots__) return lmh - def bytesize(self, byteorder='@'): - '''Compute the byte size after serialization. + def bytesize(self, byteorder="@") -> int: + """Compute the byte size after serialization. Args: byteorder (str, optional): This is byte order of the serialized data. Use one @@ -107,17 +114,17 @@ def bytesize(self, byteorder='@'): Returns: int: Size in number of bytes after serialization. - ''' + """ # Use 8 bytes to store the seed integer - seed_size = struct.calcsize(byteorder+'q') + seed_size = struct.calcsize(byteorder + "q") # Use 4 bytes to store the number of hash values - length_size = struct.calcsize(byteorder+'i') + length_size = struct.calcsize(byteorder + "i") # Use 4 bytes to store each hash value as we are using the lower 32 bit - hashvalue_size = struct.calcsize(byteorder+'I') + hashvalue_size = struct.calcsize(byteorder + "I") return seed_size + length_size + len(self) * hashvalue_size - def serialize(self, buf, byteorder='@'): - ''' + def serialize(self, buf, byteorder="@") -> None: + """ Serialize this lean MinHash and store the result in an allocated buffer. Args: @@ -158,17 +165,18 @@ def serialize(self, buf, byteorder='@'): .. _`buffer`: https://docs.python.org/3/c-api/buffer.html .. _`bytearray`: https://docs.python.org/3.6/library/functions.html#bytearray .. _`byteorder`: https://docs.python.org/3/library/struct.html - ''' + """ if len(buf) < self.bytesize(): - raise ValueError("The buffer does not have enough space\ - for holding this MinHash.") + raise ValueError( + "The buffer does not have enough space\ + for holding this MinHash." + ) fmt = "%sqi%dI" % (byteorder, len(self)) - struct.pack_into(fmt, buf, 0, - self.seed, len(self), *self.hashvalues) + struct.pack_into(fmt, buf, 0, self.seed, len(self), *self.hashvalues) @classmethod - def deserialize(cls, buf, byteorder='@'): - ''' + def deserialize(cls, buf, byteorder="@") -> LeanMinHash: + """ Deserialize a lean MinHash from a buffer. Args: @@ -189,7 +197,7 @@ def deserialize(cls, buf, byteorder='@'): .. code-block:: python lean_minhash = LeanMinHash.deserialize(buf) - ''' + """ fmt_seed_size = "%sqi" % byteorder fmt_hash = byteorder + "%dI" try: @@ -208,34 +216,36 @@ def deserialize(cls, buf, byteorder='@'): def __getstate__(self): buf = bytearray(self.bytesize()) fmt = "qi%dI" % len(self) - struct.pack_into(fmt, buf, 0, - self.seed, len(self), *self.hashvalues) + struct.pack_into(fmt, buf, 0, self.seed, len(self), *self.hashvalues) return buf def __setstate__(self, buf): try: - seed, num_perm = struct.unpack_from('qi', buf, 0) + seed, num_perm = struct.unpack_from("qi", buf, 0) except TypeError: - seed, num_perm = struct.unpack_from('qi', buffer(buf), 0) - offset = struct.calcsize('qi') + seed, num_perm = struct.unpack_from("qi", buffer(buf), 0) + offset = struct.calcsize("qi") try: - hashvalues = struct.unpack_from('%dI' % num_perm, buf, offset) + hashvalues = struct.unpack_from("%dI" % num_perm, buf, offset) except TypeError: - hashvalues = struct.unpack_from('%dI' % num_perm, buffer(buf), offset) + hashvalues = struct.unpack_from("%dI" % num_perm, buffer(buf), offset) self._initialize_slots(seed, hashvalues) - def __hash__(self): + def __hash__(self) -> int: return hash((self.seed, tuple(self.hashvalues))) @classmethod - def union(cls, *lmhs): + def union(cls, *lmhs: LeanMinHash) -> LeanMinHash: + """Create a new lean MinHash by unioning multiple lean MinHash.""" if len(lmhs) < 2: raise ValueError("Cannot union less than 2 MinHash") num_perm = len(lmhs[0]) seed = lmhs[0].seed if any((seed != m.seed or num_perm != len(m)) for m in lmhs): - raise ValueError("The unioning MinHash must have the\ - same seed, number of permutation functions.") + raise ValueError( + "The unioning MinHash must have the\ + same seed, number of permutation functions." + ) hashvalues = np.minimum.reduce([m.hashvalues for m in lmhs]) lmh = object.__new__(LeanMinHash) diff --git a/datasketch/lsh.py b/datasketch/lsh.py index 7edd3f03..f77e36e3 100644 --- a/datasketch/lsh.py +++ b/datasketch/lsh.py @@ -51,15 +51,15 @@ class MinHashLSH(object): threshold (float): The Jaccard similarity threshold between 0.0 and 1.0. The initialized MinHash LSH will be optimized for the threshold by minizing the false positive and false negative. - num_perm (Optional[int]): The number of permutation functions used + num_perm (int): The number of permutation functions used by the MinHash to be indexed. For weighted MinHash, this is the sample size (`sample_size`). - weights (Optional[Tuple[float, float]]): Used to adjust the relative importance of + weights (Tuple[float, float]): Used to adjust the relative importance of minimizing false positive and false negative when optimizing for the Jaccard similarity threshold. `weights` is a tuple in the format of :code:`(false_positive_weight, false_negative_weight)`. - params (Tuple[int, int]): The LSH parameters (i.e., number of bands and size + params (Optiona[Tuple[int, int]]): The LSH parameters (i.e., number of bands and size of each bands). This is used to bypass the parameter optimization step in the constructor. `threshold` and `weights` will be ignored if this is given. diff --git a/datasketch/lshensemble.py b/datasketch/lshensemble.py index 698c0b89..3a3fcaf4 100644 --- a/datasketch/lshensemble.py +++ b/datasketch/lshensemble.py @@ -1,18 +1,20 @@ from collections import deque, Counter import struct +from typing import Dict, Generator, Hashable, Iterable, Optional, Tuple import numpy as np +from datasketch.minhash import MinHash from datasketch.storage import _random_name from datasketch.lsh import integrate, MinHashLSH from datasketch.lshensemble_partition import optimal_partitions def _false_positive_probability(threshold, b, r, xq): - ''' + """ Compute the false positive probability given the containment threshold. xq is the ratio of x/q. - ''' - _probability = lambda t : 1 - (1 - (t/(1 + xq - t))**float(r))**float(b) + """ + _probability = lambda t: 1 - (1 - (t / (1 + xq - t)) ** float(r)) ** float(b) if xq >= threshold: a, err = integrate(_probability, 0.0, threshold) return a @@ -21,10 +23,10 @@ def _false_positive_probability(threshold, b, r, xq): def _false_negative_probability(threshold, b, r, xq): - ''' + """ Compute the false negative probability given the containment threshold - ''' - _probability = lambda t : 1 - (1 - (1 - (t/(1 + xq - t))**float(r))**float(b)) + """ + _probability = lambda t: 1 - (1 - (1 - (t / (1 + xq - t)) ** float(r)) ** float(b)) if xq >= 1.0: a, err = integrate(_probability, threshold, 1.0) return a @@ -34,22 +36,23 @@ def _false_negative_probability(threshold, b, r, xq): return 0.0 -def _optimal_param(threshold, num_perm, max_r, xq, false_positive_weight, - false_negative_weight): - ''' +def _optimal_param( + threshold, num_perm, max_r, xq, false_positive_weight, false_negative_weight +): + """ Compute the optimal parameters that minimizes the weighted sum of probabilities of false positive and false negative. xq is the ratio of x/q. - ''' + """ min_error = float("inf") opt = (0, 0) - for b in range(1, num_perm+1): - for r in range(1, max_r+1): - if b*r > num_perm: + for b in range(1, num_perm + 1): + for r in range(1, max_r + 1): + if b * r > num_perm: continue fp = _false_positive_probability(threshold, b, r, xq) fn = _false_negative_probability(threshold, b, r, xq) - error = fp*false_positive_weight + fn*false_negative_weight + error = fp * false_positive_weight + fn * false_negative_weight if error < min_error: min_error = error opt = (b, r) @@ -57,7 +60,7 @@ def _optimal_param(threshold, num_perm, max_r, xq, false_positive_weight, class MinHashLSHEnsemble(object): - ''' + """ The :ref:`minhash_lsh_ensemble` index. It supports :ref:`containment` queries. The implementation is based on @@ -67,23 +70,23 @@ class MinHashLSHEnsemble(object): threshold (float): The Containment threshold between 0.0 and 1.0. The initialized LSH Ensemble will be optimized for the threshold by minizing the false positive and false negative. - num_perm (int, optional): The number of permutation functions used + num_perm (int): The number of permutation functions used by the MinHash to be indexed. For weighted MinHash, this is the sample size (`sample_size`). - num_part (int, optional): The number of partitions in LSH Ensemble. - m (int, optional): The memory usage factor: an LSH Ensemble uses approximately + num_part (int): The number of partitions in LSH Ensemble. + m (int): The memory usage factor: an LSH Ensemble uses approximately `m` times more memory space than a MinHash LSH with the same number of sets indexed. The higher the `m` the better the accuracy. - weights (tuple, optional): Used to adjust the relative importance of + weights (Tuple[float, float]): Used to adjust the relative importance of minizing false positive and false negative when optimizing for the Containment threshold. Similar to the `weights` parameter in :class:`datasketch.MinHashLSH`. - storage_config (dict, optional): Type of storage service to use for storing + storage_config (Optional[Dict]): Type of storage service to use for storing hashtables and keys. `basename` is an optional property whose value will be used as the prefix to stored keys. If this is not set, a random string will be generated instead. If you set this, you will be responsible for ensuring there are no key collisions. - prepickle (bool, optional): If True, all keys are pickled to bytes before + prepickle (Optional[bool]): If True, all keys are pickled to bytes before insertion. If None, a default value is chosen based on the `storage_config`. @@ -101,10 +104,18 @@ class MinHashLSHEnsemble(object): .. _`Go implementation`: https://github.com/ekzhu/lshensemble .. _`the paper`: http://www.vldb.org/pvldb/vol9/p1185-zhu.pdf - ''' + """ - def __init__(self, threshold=0.9, num_perm=128, num_part=16, m=8, - weights=(0.5,0.5), storage_config=None, prepickle=None): + def __init__( + self, + threshold: float = 0.9, + num_perm: int = 128, + num_part: int = 16, + m: int = 8, + weights: Tuple[float, float] = (0.5, 0.5), + storage_config: Optional[Dict] = None, + prepickle: Optional[bool] = None, + ) -> None: if threshold > 1.0 or threshold < 0.0: raise ValueError("threshold must be in [0.0, 1.0]") if num_perm < 2: @@ -122,27 +133,45 @@ def __init__(self, threshold=0.9, num_perm=128, num_part=16, m=8, self.m = m rs = self._init_optimal_params(weights) # Initialize multiple LSH indexes for each partition - storage_config = {'type': 'dict'} if not storage_config else storage_config - basename = storage_config.get('basename', _random_name(11)) + storage_config = {"type": "dict"} if not storage_config else storage_config + basename = storage_config.get("basename", _random_name(11)) self.indexes = [ - dict((r, MinHashLSH( - num_perm=self.h, - params=(int(self.h/r), r), - storage_config=self._get_storage_config( - basename, storage_config, partition, r), - prepickle=prepickle)) for r in rs) - for partition in range(0, num_part)] + dict( + ( + r, + MinHashLSH( + num_perm=self.h, + params=(int(self.h / r), r), + storage_config=self._get_storage_config( + basename, storage_config, partition, r + ), + prepickle=prepickle, + ), + ) + for r in rs + ) + for partition in range(0, num_part) + ] self.lowers = [None for _ in self.indexes] self.uppers = [None for _ in self.indexes] def _init_optimal_params(self, weights): false_positive_weight, false_negative_weight = weights self.xqs = np.exp(np.linspace(-5, 5, 10)) - self.params = np.array([_optimal_param(self.threshold, self.h, self.m, - xq, - false_positive_weight, - false_negative_weight) - for xq in self.xqs], dtype=int) + self.params = np.array( + [ + _optimal_param( + self.threshold, + self.h, + self.m, + xq, + false_positive_weight, + false_negative_weight, + ) + for xq in self.xqs + ], + dtype=int, + ) # Find all unique r rs = set() for _, r in self.params: @@ -150,7 +179,7 @@ def _init_optimal_params(self, weights): return rs def _get_optimal_param(self, x, q): - i = np.searchsorted(self.xqs, float(x)/float(q), side='left') + i = np.searchsorted(self.xqs, float(x) / float(q), side="left") if i == len(self.params): i = i - 1 return self.params[i] @@ -158,23 +187,26 @@ def _get_optimal_param(self, x, q): def _get_storage_config(self, basename, base_config, partition, r): config = dict(base_config) config["basename"] = b"-".join( - [basename, struct.pack('>H', partition), struct.pack('>H', r)]) + [basename, struct.pack(">H", partition), struct.pack(">H", r)] + ) return config - def index(self, entries): - ''' + def index(self, entries: Iterable[Tuple[Hashable, MinHash, int]]) -> None: + """ Index all sets given their keys, MinHashes, and sizes. It can be called only once after the index is created. Args: - entries (`iterable` of `tuple`): An iterable of tuples, each must be - in the form of `(key, minhash, size)`, where `key` is the unique - identifier of a set, `minhash` is the MinHash of the set, - and `size` is the size or number of unique items in the set. - - Note: - `size` must be positive. - ''' + entries (Iterable[Tuple[Hashable, MinHash, int]]): An iterable of + tuples, each must be in the form of ``(key, minhash, size)``, + where ``key`` is the unique + identifier of a set, ``minhash`` is the MinHash of the set, + and ``size`` is the size or number of unique items in the set. + + Raises: + ValueError: If the index is not empty or ``entries`` is empty. + + """ if not self.is_empty(): raise ValueError("Cannot call index again on a non-empty index") if not isinstance(entries, list): @@ -187,13 +219,12 @@ def index(self, entries): if len(entries) == 0: raise ValueError("entries is empty") # Create optimal partitions. - sizes, counts = np.array(sorted( - Counter(e[2] for e in entries).most_common())).T + sizes, counts = np.array(sorted(Counter(e[2] for e in entries).most_common())).T partitions = optimal_partitions(sizes, counts, len(self.indexes)) for i, (lower, upper) in enumerate(partitions): self.lowers[i], self.uppers[i] = lower, upper # Insert into partitions. - entries.sort(key=lambda e : e[2]) + entries.sort(key=lambda e: e[2]) curr_part = 0 for key, minhash, size in entries: if size > self.uppers[curr_part]: @@ -201,19 +232,19 @@ def index(self, entries): for r in self.indexes[curr_part]: self.indexes[curr_part][r].insert(key, minhash) - def query(self, minhash, size): - ''' + def query(self, minhash: MinHash, size: int) -> Generator[Hashable, None, None]: + """ Giving the MinHash and size of the query set, retrieve keys that references sets with containment with respect to the query set greater than the threshold. Args: - minhash (datasketch.MinHash): The MinHash of the query set. + minhash (MinHash): The MinHash of the query set. size (int): The size (number of unique items) of the query set. Returns: - `iterator` of keys. - ''' + Generator[Hashable, None, None]: an iterator of keys. + """ for i, index in enumerate(self.indexes): u = self.uppers[i] if u is None: @@ -222,34 +253,34 @@ def query(self, minhash, size): for key in index[r]._query_b(minhash, b): yield key - def __contains__(self, key): - ''' + def __contains__(self, key: Hashable) -> bool: + """ Args: key (hashable): The unique identifier of a set. Returns: bool: True only if the key exists in the index. - ''' - return any(any(key in index[r] for r in index) - for index in self.indexes) + """ + return any(any(key in index[r] for r in index) for index in self.indexes) - def is_empty(self): - ''' + def is_empty(self) -> bool: + """ Returns: bool: Check if the index is empty. - ''' - return all(all(index[r].is_empty() for r in index) - for index in self.indexes) + """ + return all(all(index[r].is_empty() for r in index) for index in self.indexes) if __name__ == "__main__": import numpy as np + xqs = np.exp(np.linspace(-5, 5, 10)) threshold = 0.5 max_r = 8 num_perm = 256 false_negative_weight, false_positive_weight = 0.5, 0.5 for xq in xqs: - b, r = _optimal_param(threshold, num_perm, max_r, xq, - false_positive_weight, false_negative_weight) + b, r = _optimal_param( + threshold, num_perm, max_r, xq, false_positive_weight, false_negative_weight + ) print("threshold: %.2f, xq: %.3f, b: %d, r: %d" % (threshold, xq, b, r)) diff --git a/datasketch/lshforest.py b/datasketch/lshforest.py index bee6f29c..9f3455ba 100644 --- a/datasketch/lshforest.py +++ b/datasketch/lshforest.py @@ -1,8 +1,11 @@ from collections import defaultdict +from typing import Hashable, List + +from datasketch.minhash import MinHash class MinHashLSHForest(object): - ''' + """ The LSH Forest for MinHash. It supports top-k query in Jaccard similarity. Instead of using prefix trees as the `original paper @@ -11,18 +14,18 @@ class MinHashLSHForest(object): hash table. Args: - num_perm (int, optional): The number of permutation functions used + num_perm (int): The number of permutation functions used by the MinHash to be indexed. For weighted MinHash, this is the sample size (`sample_size`). - l (int, optional): The number of prefix trees as described in the + l (int): The number of prefix trees as described in the paper. Note: The MinHash LSH Forest also works with weighted Jaccard similarity and weighted MinHash without modification. - ''' + """ - def __init__(self, num_perm=128, l=8): + def __init__(self, num_perm: int = 128, l: int = 8) -> None: if l <= 0 or num_perm <= 0: raise ValueError("num_perm and l must be positive") if l > num_perm: @@ -32,51 +35,54 @@ def __init__(self, num_perm=128, l=8): # Maximum depth of the prefix tree self.k = int(num_perm / l) self.hashtables = [defaultdict(list) for _ in range(self.l)] - self.hashranges = [(i*self.k, (i+1)*self.k) for i in range(self.l)] + self.hashranges = [(i * self.k, (i + 1) * self.k) for i in range(self.l)] self.keys = dict() # This is the sorted array implementation for the prefix trees self.sorted_hashtables = [[] for _ in range(self.l)] - def add(self, key, minhash): - ''' + def add(self, key: Hashable, minhash: MinHash) -> None: + """ Add a unique key, together with a MinHash (or weighted MinHash) of the set referenced by the key. Note: The key won't be searchbale until the - :func:`datasketch.MinHashLSHForest.index` method is called. + :meth:`index` method is called. Args: - key (hashable): The unique identifier of the set. - minhash (datasketch.MinHash): The MinHash of the set. - ''' - if len(minhash) < self.k*self.l: + key (Hashable): The unique identifier of the set. + minhash (MinHash): The MinHash of the set. + """ + if len(minhash) < self.k * self.l: raise ValueError("The num_perm of MinHash out of range") if key in self.keys: raise ValueError("The given key has already been added") - self.keys[key] = [self._H(minhash.hashvalues[start:end]) - for start, end in self.hashranges] + self.keys[key] = [ + self._H(minhash.hashvalues[start:end]) for start, end in self.hashranges + ] for H, hashtable in zip(self.keys[key], self.hashtables): hashtable[H].append(key) - def index(self): - ''' + def index(self) -> None: + """ Index all the keys added so far and make them searchable. - ''' + """ for i, hashtable in enumerate(self.hashtables): self.sorted_hashtables[i] = [H for H in hashtable.keys()] self.sorted_hashtables[i].sort() def _query(self, minhash, r, b): - if r > self.k or r <=0 or b > self.l or b <= 0: + if r > self.k or r <= 0 or b > self.l or b <= 0: raise ValueError("parameter outside range") # Generate prefixes of concatenated hash values - hps = [self._H(minhash.hashvalues[start:start+r]) - for start, _ in self.hashranges] + hps = [ + self._H(minhash.hashvalues[start : start + r]) + for start, _ in self.hashranges + ] # Set the prefix length for look-ups in the sorted hash values list prefix_size = len(hps[0]) for ht, hp, hashtable in zip(self.sorted_hashtables, hps, self.hashtables): - i = self._binary_search(len(ht), lambda x : ht[x][:prefix_size] >= hp) + i = self._binary_search(len(ht), lambda x: ht[x][:prefix_size] >= hp) if i < len(ht) and ht[i][:prefix_size] == hp: j = i while j < len(ht) and ht[j][:prefix_size] == hp: @@ -84,33 +90,33 @@ def _query(self, minhash, r, b): yield key j += 1 - def query(self, minhash, k): - ''' + def query(self, minhash: MinHash, k: int) -> List[Hashable]: + """ Return the approximate top-k keys that have the (approximately) highest Jaccard similarities to the query set. Args: - minhash (datasketch.MinHash): The MinHash of the query set. + minhash (MinHash): The MinHash of the query set. k (int): The maximum number of keys to return. Returns: - `list` of at most k keys. - + List[Hashable]: list of at most k keys. + Note: - Tip for improving accuracy: + Tip for improving accuracy: you can use a multiple of `k` (e.g., `2*k`) in the argument, - compute the exact (or approximate using MinHash) Jaccard + compute the exact (or approximate using MinHash) Jaccard similarities of the sets referenced by the returned keys, - from which you then take the final top-k. This is often called - "post-processing". Because the total number of similarity + from which you then take the final top-k. This is often called + "post-processing". Because the total number of similarity computations is still bounded by a constant multiple of `k`, the performance won't degrade too much -- however you do have to keep - the original sets (or MinHashes) around some where so that you + the original sets (or MinHashes) around some where so that you can make references to them. - ''' + """ if k <= 0: raise ValueError("k must be positive") - if len(minhash) < self.k*self.l: + if len(minhash) < self.k * self.l: raise ValueError("The num_perm of MinHash out of range") results = set() r = self.k @@ -123,9 +129,9 @@ def query(self, minhash, k): return list(results) def _binary_search(self, n, func): - ''' + """ https://golang.org/src/sort/search.go?s=2247:2287#L49 - ''' + """ i, j = 0, n while i < j: h = int(i + (j - i) / 2) @@ -135,22 +141,22 @@ def _binary_search(self, n, func): j = h return i - def is_empty(self): - ''' + def is_empty(self) -> bool: + """ Check whether there is any searchable keys in the index. Note that keys won't be searchable until `index` is called. Returns: bool: True if there is no searchable key in the index. - ''' + """ return any(len(t) == 0 for t in self.sorted_hashtables) def _H(self, hs): return bytes(hs.byteswap().data) - def __contains__(self, key): - ''' + def __contains__(self, key: Hashable) -> bool: + """ Returns: bool: True only if the key has been added to the index. - ''' + """ return key in self.keys diff --git a/datasketch/minhash.py b/datasketch/minhash.py index bc554bce..2b8db892 100644 --- a/datasketch/minhash.py +++ b/datasketch/minhash.py @@ -20,11 +20,11 @@ class MinHash(object): `Jaccard similarity`_ between sets. Args: - num_perm (Optional[int]): Number of random permutation functions. + num_perm (int): Number of random permutation functions. It will be ignored if `hashvalues` is not None. - seed (Optional[int]): The random seed controls the set of random + seed (int): The random seed controls the set of random permutation functions generated for this MinHash. - hashfunc (Optional[Callable]): The hash function used by + hashfunc (Callable): The hash function used by this MinHash. It takes the input passed to the :meth:`update` method and returns an integer that can be encoded with 32 bits. @@ -69,10 +69,10 @@ def __init__( num_perm: int = 128, seed: int = 1, hashfunc: Callable = sha1_hash32, - hashobj=None, # Deprecated. + hashobj: Optional[object] = None, # Deprecated. hashvalues: Optional[Iterable] = None, permutations: Optional[Tuple[Iterable, Iterable]] = None, - ): + ) -> None: if hashvalues is not None: num_perm = len(hashvalues) if num_perm > _hash_range: @@ -378,7 +378,7 @@ def bulk(cls, b: Iterable, **minhash_kwargs) -> List[MinHash]: return list(cls.generator(b, **minhash_kwargs)) @classmethod - def generator(cls, b: Iterable, **minhash_kwargs) -> Generator[MinHash]: + def generator(cls, b: Iterable, **minhash_kwargs) -> Generator[MinHash, None, None]: """Compute MinHashes in a generator. This method avoids unnecessary overhead when initializing many minhashes by reusing the initialized state. @@ -390,7 +390,7 @@ def generator(cls, b: Iterable, **minhash_kwargs) -> Generator[MinHash]: will be used for all minhashes. Returns: - Generator[MinHash]: a generator of computed MinHashes. + Generator[MinHash, None, None]: a generator of computed MinHashes. Example: diff --git a/datasketch/weighted_minhash.py b/datasketch/weighted_minhash.py index ad6dacc4..f7b24555 100644 --- a/datasketch/weighted_minhash.py +++ b/datasketch/weighted_minhash.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Union, List import collections.abc @@ -5,43 +6,50 @@ import numpy as np import scipy as sp -import scipy.sparse class WeightedMinHash(object): - '''New weighted MinHash is generated by - :class:`datasketch.WeightedMinHashGenerator`. - You can also initialize weighted MinHash by using the state + """New weighted MinHash is generated by + :class:`WeightedMinHashGenerator`. + You can also initialize a weighted MinHash by using the state from an existing one. Args: seed (int): The random seed used to generate this weighted MinHash. - hashvalues: The internal state of this weighted MinHash. - ''' + hashvalues (numpy.ndarray): The internal state of this weighted MinHash. + """ - def __init__(self, seed, hashvalues): + def __init__(self, seed: int, hashvalues: np.ndarray) -> None: self.seed = seed self.hashvalues = hashvalues - def jaccard(self, other): - '''Estimate the `weighted Jaccard similarity`_ between the + def jaccard(self, other: WeightedMinHash) -> float: + """Estimate the `weighted Jaccard similarity`_ between the multi-sets represented by this weighted MinHash and the other. Args: - other (datasketch.WeightedMinHash): The other weighted MinHash. + other (WeightedMinHash): The other weighted MinHash. Returns: float: The weighted Jaccard similarity between 0.0 and 1.0. + Raises: + ValueError: If the two weighted MinHash objects have different + seeds or different numbers of hash values. + .. _`weighted Jaccard similarity`: http://mathoverflow.net/questions/123339/weighted-jaccard-similarity - ''' + """ if other.seed != self.seed: - raise ValueError("Cannot compute Jaccard given WeightedMinHash objects with\ - different seeds") + raise ValueError( + "Cannot compute Jaccard given WeightedMinHash objects with\ + different seeds" + ) if len(self) != len(other): - raise ValueError("Cannot compute Jaccard given WeightedMinHash objects with\ - different numbers of hash values") + raise ValueError( + "Cannot compute Jaccard given WeightedMinHash objects with\ + different numbers of hash values" + ) # Check how many pairs of (k, t) hashvalues are equal intersection = 0 for this, that in zip(self.hashvalues, other.hashvalues): @@ -49,44 +57,46 @@ def jaccard(self, other): intersection += 1 return float(intersection) / float(len(self)) - def digest(self): - '''Export the hash values, which is the internal state of the + def digest(self) -> np.ndarray: + """Export the hash values, which is the internal state of the weighted MinHash. Returns: - numpy.array: The hash values which is a Numpy array. - ''' + numpy.ndarray: The hash values which is a Numpy array. + """ return copy.copy(self.hashvalues) - def copy(self): - ''' + def copy(self) -> WeightedMinHash: + """ Returns: - datasketch.WeightedMinHash: A copy of this weighted MinHash by exporting + WeightedMinHash: A copy of this weighted MinHash by exporting its state. - ''' + """ return WeightedMinHash(self.seed, self.digest()) - def __len__(self): - ''' + def __len__(self) -> int: + """ Returns: int: The number of hash values. - ''' + """ return len(self.hashvalues) - def __eq__(self, other): - ''' + def __eq__(self, other) -> bool: + """ Returns: bool: If their seeds and hash values are both equal then two are equivalent. - ''' - return type(self) is type(other) and \ - self.seed == other.seed and \ - np.array_equal(self.hashvalues, other.hashvalues) + """ + return ( + type(self) is type(other) + and self.seed == other.seed + and np.array_equal(self.hashvalues, other.hashvalues) + ) class WeightedMinHashGenerator(object): - '''The weighted MinHash generator is used for creating - new :class:`datasketch.WeightedMinHash` objects. + """The weighted MinHash generator is used for creating + new :class:`WeightedMinHash` objects. This weighted MinHash implementation is based on Sergey Ioffe's paper, `Improved Consistent Sampling, Weighted Minhash and L1 Sketching @@ -94,33 +104,35 @@ class WeightedMinHashGenerator(object): Args: dim (int): The number of dimensions of the input Jaccard vectors. - sample_size (int, optional): The number of samples to use for creating + sample_size (int): The number of samples to use for creating weighted MinHash. seed (int): The random seed to use for generating permutation functions. - ''' + """ - def __init__(self, dim, sample_size=128, seed=1): + def __init__(self, dim: int, sample_size: int = 128, seed: int = 1) -> None: self.dim = dim self.sample_size = sample_size self.seed = seed generator = np.random.RandomState(seed=seed) self.rs = generator.gamma(2, 1, (sample_size, dim)).astype(np.float32) - self.ln_cs = np.log(generator.gamma(2, 1, (sample_size, dim))).astype(np.float32) + self.ln_cs = np.log(generator.gamma(2, 1, (sample_size, dim))).astype( + np.float32 + ) self.betas = generator.uniform(0, 1, (sample_size, dim)).astype(np.float32) - def minhash(self, v): - '''Create a new weighted MinHash given a weighted Jaccard vector. + def minhash(self, v: np.ndarray) -> WeightedMinHash: + """Create a new weighted MinHash given a weighted Jaccard vector. Each dimension is an integer frequency of the corresponding element in the multi-set represented by the vector. Args: v (numpy.ndarray): The Jaccard vector. - + Returns: - datasketch.WeightedMinHash: The weighted MinHash. - ''' + WeightedMinHash: The weighted MinHash. + """ if not isinstance(v, collections.abc.Iterable): raise TypeError("Input vector must be an iterable") if not len(v) == self.dim: @@ -130,7 +142,7 @@ def minhash(self, v): elif v.dtype != np.float32: v = v.astype(np.float32) hashvalues = np.zeros((self.sample_size, 2), dtype=int) - vzeros = (v == 0) + vzeros = v == 0 if vzeros.all(): raise ValueError("Input is all zeros") v[vzeros] = np.nan @@ -143,26 +155,27 @@ def minhash(self, v): hashvalues[i][0], hashvalues[i][1] = k, int(t[k]) return WeightedMinHash(self.seed, hashvalues) - def minhash_many(self, X : Union[sp.sparse.spmatrix, np.ndarray]) \ - -> List[Union[WeightedMinHash, None]]: - '''Create new WeightedMinHash instances given a matrix of weighted - Jaccard vectors. In the input matrix X, each row corresponds to a + def minhash_many( + self, X: Union[sp.sparse.spmatrix, np.ndarray] + ) -> List[Union[WeightedMinHash, None]]: + """Create new WeightedMinHash instances given a matrix of weighted + Jaccard vectors. In the input matrix X, each row corresponds to a multi-set, and each column stores the integer frequency of the element of a dimension. - Note: - This method is experimental and does not yield the same MinHash - hash values as :func:`~datasketch.WeightedMinHashGenerator.minhash`. + Note: + This method is experimental and does not yield the same MinHash + hash values as :meth:`minhash`. Args: X (Union[scipy.sparse.spmatrix, numpy.ndarray]): A matrix of Jaccard vectors (rows). Returns: - List[Union[datasketch.WeightedMinHash, None]] - A list of length X.shape[0]. - Each element is either a WeightedMinHash instance or None + List[Union[WeightedMinHash, None]] - A list of length X.shape[0]. + Each element is either a :class:`WeightedMinHash` instance or None (if the original row in X is empty). - ''' + """ # Input validation if not isinstance(X, (sp.sparse.spmatrix, np.ndarray)): @@ -191,15 +204,15 @@ def minhash_many(self, X : Union[sp.sparse.spmatrix, np.ndarray]) \ it_doc, doc_begin, doc_end = None, 0, rowends[0] # Generate temporary data - rs_cidx = np.array(self.rs, copy=True)[:, cidx] #sample_size x dims - betas_cidx = np.array(self.betas, copy=True)[:, cidx] #sample_size x dims - ln_cs_cidx = np.array(self.ln_cs, copy=True)[:, cidx] #sample_size x dims + rs_cidx = np.array(self.rs, copy=True)[:, cidx] # sample_size x dims + betas_cidx = np.array(self.betas, copy=True)[:, cidx] # sample_size x dims + ln_cs_cidx = np.array(self.ln_cs, copy=True)[:, cidx] # sample_size x dims log_data = np.log(X[ridx, cidx].getA1()) - log_data = np.vstack([log_data] * self.sample_size) #sample_size x dims + log_data = np.vstack([log_data] * self.sample_size) # sample_size x dims # Unary transformations - t = np.floor(log_data / rs_cidx + betas_cidx) #sample_size x dims + t = np.floor(log_data / rs_cidx + betas_cidx) # sample_size x dims ln_y = (t - betas_cidx + 1) * rs_cidx ln_a = ln_cs_cidx - ln_y @@ -218,8 +231,10 @@ def minhash_many(self, X : Union[sp.sparse.spmatrix, np.ndarray]) \ all_hashvalues[it_doc] = np.zeros((self.sample_size, 2), dtype=int) hashvalues = all_hashvalues[it_doc] - hashvalues[:, 0], hashvalues[:, 1] = \ - doc_k, t[np.arange(self.sample_size), doc_begin + doc_argmin] + hashvalues[:, 0], hashvalues[:, 1] = ( + doc_k, + t[np.arange(self.sample_size), doc_begin + doc_argmin], + ) doc_begin = doc_end