Created
February 27, 2025 02:02
-
-
Save b8zhong/b0301bffa073a15a4542fb899e6161f4 to your computer and use it in GitHub Desktop.
Benchmarking PDQ implementations for ThreatExchange
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import binascii | |
import time | |
import pickle | |
import sys | |
from pathlib import Path | |
import numpy | |
import faiss | |
sys.path.append(str(Path(__file__).parent.parent)) | |
from threatexchange.signal_type.pdq.pdq_utils import BITS_IN_PDQ, convert_pdq_strings_to_ndarray | |
from threatexchange.signal_type.pdq.pdq_index2 import PDQIndex2 | |
from threatexchange.signal_type.pdq.pdq_faiss_matcher import ( | |
PDQFlatHashIndex, | |
PDQMultiHashIndex, | |
) | |
parser = argparse.ArgumentParser( | |
description="Run comparative benchmarks for PDQ index implementations", | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | |
) | |
parser.add_argument( | |
"--faiss-threads", | |
type=int, | |
default=1, | |
help="number of threads for faiss to use while searching", | |
) | |
parser.add_argument( | |
"--dataset-size", | |
type=int, | |
default=10000, | |
help="number of hashes to generate for the dataset to search against", | |
) | |
parser.add_argument( | |
"--num-queries", | |
type=int, | |
default=1000, | |
help="number of queries to generate for each search", | |
) | |
parser.add_argument( | |
"--thresholds", | |
type=int, | |
default=[31], | |
choices=range(256), | |
nargs="+", | |
metavar="THRESHOLDS", | |
help="PDQ similarity threshold values to benchmark with", | |
) | |
parser.add_argument("--seed", type=int, help="seed for random number generator") | |
parser.add_argument( | |
"--implementations", | |
choices=["all", "flat_hash", "multi_hash", "index2_flat", "index2_ivf"], | |
default="all", | |
nargs="+", | |
help="which PDQ implementations to benchmark", | |
) | |
parser.add_argument( | |
"--serialize-test", | |
action="store_true", | |
help="test serialization/deserialization compatibility", | |
) | |
args = parser.parse_args() | |
print("Benchmark: PDQ Index Implementation Comparison") | |
print("") | |
print("Options:") | |
for arg in vars(args): | |
print("\t", arg, ": ", getattr(args, arg)) | |
print("") | |
faiss.omp_set_num_threads(args.faiss_threads) | |
seed = args.seed if args.seed else time.time_ns() | |
rng = numpy.random.default_rng(seed) | |
if args.seed is None: | |
print("using random seed of ", seed) | |
print("use --seed ", seed, " to rerun with same random values") | |
print("") | |
def generate_random_hash(): | |
""" | |
returns a random 256 bit PDQ hash as a hexstring of 64 characters | |
""" | |
hash_bytes = rng.bytes(BITS_IN_PDQ // 8) | |
return binascii.hexlify(hash_bytes).decode() | |
def generate_random_distance_mask(hamming_distance): | |
""" | |
returns a random numpy array of uint8s that can be used as bitwise mask | |
to generate a hash with the given hamming distance | |
""" | |
ones = numpy.ones(hamming_distance, dtype=numpy.uint8) | |
bitmask = numpy.pad( | |
ones, (0, BITS_IN_PDQ - hamming_distance), "constant", constant_values=0 | |
) | |
return numpy.packbits(rng.permutation(bitmask)) | |
def generate_random_hash_with_hamming_distance(original_hash, desired_hamming_distance): | |
""" | |
returns a random 256 bit PDQ hash as a hexstring of 64 characters that is the given | |
hamming distance from the provided original hash | |
""" | |
original_hash_bytes = numpy.frombuffer( | |
binascii.unhexlify(original_hash), dtype=numpy.uint8 | |
) | |
mask = generate_random_distance_mask(desired_hamming_distance) | |
new_hash_bytes = numpy.bitwise_xor(original_hash_bytes, mask).tobytes() | |
return binascii.hexlify(new_hash_bytes).decode() | |
print("Generating dataset...") | |
start_generate = time.time() | |
dataset = [generate_random_hash() for _ in range(args.dataset_size)] | |
custom_ids = [i + 100_000_000_000_000 for i in range(args.dataset_size)] | |
end_generate = time.time() | |
print(f"Generated {len(dataset)} hashes in {end_generate - start_generate:.4f}s") | |
print("") | |
entries = list(zip(dataset, custom_ids)) | |
class BenchmarkResult: | |
def __init__(self, name): | |
self.name = name | |
self.build_time = 0 | |
self.size_kb = 0 | |
self.search_times = {} | |
self.targets_found = {} | |
self.serialize_time = 0 | |
self.deserialize_time = 0 | |
self.single_item_success = False | |
def print_results(self): | |
print(f"\n===== {self.name} Results =====") | |
print(f"Build time: {self.build_time:.4f}s") | |
print(f"Size: {self.size_kb:,d}KB") | |
for threshold, search_time in self.search_times.items(): | |
print(f"\nThreshold {threshold}:") | |
print(f"\tSearch time: {search_time:.4f}s") | |
print(f"\tPer query: {(search_time * 1000 / args.num_queries):.4f}ms") | |
print(f"\tTargets found: {self.targets_found[threshold]:.2f}%") | |
if args.serialize_test: | |
print(f"\nSerialization time: {self.serialize_time:.4f}s") | |
print(f"Deserialization time: {self.deserialize_time:.4f}s") | |
print(f"Single item test: {'SUCCESS' if self.single_item_success else 'FAILED'}") | |
results = {} | |
if "all" in args.implementations or "flat_hash" in args.implementations: | |
print("\n=== Benchmarking PDQFlatHashIndex ===") | |
result = BenchmarkResult("PDQFlatHashIndex") | |
start_build = time.time() | |
flat_index = PDQFlatHashIndex() | |
flat_index.add(dataset, custom_ids=custom_ids) | |
end_build = time.time() | |
result.build_time = end_build - start_build | |
if args.serialize_test: | |
print("Testing serialization...") | |
start_serialize = time.time() | |
serialized = pickle.dumps(flat_index) | |
end_serialize = time.time() | |
result.serialize_time = end_serialize - start_serialize | |
result.size_kb = len(serialized) // 1024 | |
print("Testing deserialization...") | |
start_deserialize = time.time() | |
deserialized_index = pickle.loads(serialized) | |
end_deserialize = time.time() | |
result.deserialize_time = end_deserialize - start_deserialize | |
else: | |
serialized = pickle.dumps(flat_index) | |
result.size_kb = len(serialized) // 1024 | |
single_index = PDQFlatHashIndex() | |
single_index.add([dataset[0]], custom_ids=[custom_ids[0]]) | |
single_result = single_index.search([dataset[0]], 0) | |
result.single_item_success = len(single_result[0]) == 1 | |
for threshold in args.thresholds: | |
print(f"Running benchmark with threshold {threshold}...") | |
search_targets = rng.choice(dataset, size=args.num_queries) | |
queries = [ | |
generate_random_hash_with_hamming_distance(target, threshold) | |
for target in search_targets | |
] | |
start_search = time.time() | |
search_results = flat_index.search(queries, threshold) | |
end_search = time.time() | |
result.search_times[threshold] = end_search - start_search | |
found_targets = 0 | |
for i, query_results in enumerate(search_results): | |
target = search_targets[i] | |
if target in query_results: | |
found_targets += 1 | |
result.targets_found[threshold] = found_targets / len(queries) * 100 | |
results["PDQFlatHashIndex"] = result | |
if "all" in args.implementations or "multi_hash" in args.implementations: | |
print("\n=== Benchmarking PDQMultiHashIndex ===") | |
result = BenchmarkResult("PDQMultiHashIndex") | |
start_build = time.time() | |
multi_index = PDQMultiHashIndex() | |
multi_index.add(dataset, custom_ids=custom_ids) | |
end_build = time.time() | |
result.build_time = end_build - start_build | |
if args.serialize_test: | |
print("Testing serialization...") | |
start_serialize = time.time() | |
serialized = pickle.dumps(multi_index) | |
end_serialize = time.time() | |
result.serialize_time = end_serialize - start_serialize | |
result.size_kb = len(serialized) // 1024 | |
print("Testing deserialization...") | |
start_deserialize = time.time() | |
deserialized_index = pickle.loads(serialized) | |
end_deserialize = time.time() | |
result.deserialize_time = end_deserialize - start_deserialize | |
else: | |
serialized = pickle.dumps(multi_index) | |
result.size_kb = len(serialized) // 1024 | |
single_index = PDQMultiHashIndex() | |
single_index.add([dataset[0]], custom_ids=[custom_ids[0]]) | |
single_result = single_index.search([dataset[0]], 0) | |
result.single_item_success = len(single_result[0]) == 1 | |
for threshold in args.thresholds: | |
print(f"Running benchmark with threshold {threshold}...") | |
search_targets = rng.choice(dataset, size=args.num_queries) | |
queries = [ | |
generate_random_hash_with_hamming_distance(target, threshold) | |
for target in search_targets | |
] | |
start_search = time.time() | |
search_results = multi_index.search(queries, threshold) | |
end_search = time.time() | |
result.search_times[threshold] = end_search - start_search | |
found_targets = 0 | |
for i, query_results in enumerate(search_results): | |
target = search_targets[i] | |
if target in query_results: | |
found_targets += 1 | |
result.targets_found[threshold] = found_targets / len(queries) * 100 | |
results["PDQMultiHashIndex"] = result | |
if "all" in args.implementations or "index2_flat" in args.implementations: | |
print("\n=== Benchmarking PDQIndex2 (Flat) ===") | |
result = BenchmarkResult("PDQIndex2 (Flat)") | |
start_build = time.time() | |
index = faiss.IndexFlatL2(BITS_IN_PDQ) | |
pdq_index2_flat = PDQIndex2(index=index, entries=entries) | |
end_build = time.time() | |
result.build_time = end_build - start_build | |
if args.serialize_test: | |
print("Testing serialization...") | |
start_serialize = time.time() | |
serialized = pickle.dumps(pdq_index2_flat) | |
end_serialize = time.time() | |
result.serialize_time = end_serialize - start_serialize | |
result.size_kb = len(serialized) // 1024 | |
print("Testing deserialization...") | |
start_deserialize = time.time() | |
deserialized_index = pickle.loads(serialized) | |
end_deserialize = time.time() | |
result.deserialize_time = end_deserialize - start_deserialize | |
else: | |
serialized = pickle.dumps(pdq_index2_flat) | |
result.size_kb = len(serialized) // 1024 | |
single_index = PDQIndex2(index=faiss.IndexFlatL2(BITS_IN_PDQ), entries=[(dataset[0], custom_ids[0])]) | |
single_result = single_index.query(dataset[0]) | |
result.single_item_success = len(single_result) == 1 | |
for threshold in args.thresholds: | |
print(f"Running benchmark with threshold {threshold}...") | |
pdq_index2_flat.threshold = threshold | |
search_targets = rng.choice(dataset, size=args.num_queries) | |
queries = [ | |
generate_random_hash_with_hamming_distance(target, threshold) | |
for target in search_targets | |
] | |
start_search = time.time() | |
results_list = [] | |
for query in queries: | |
results_list.append(pdq_index2_flat.query(query)) | |
end_search = time.time() | |
result.search_times[threshold] = end_search - start_search | |
found_targets = 0 | |
for i, query_results in enumerate(results_list): | |
target_id = custom_ids[dataset.index(search_targets[i])] | |
if any(match.metadata == target_id for match in query_results): | |
found_targets += 1 | |
result.targets_found[threshold] = found_targets / len(queries) * 100 | |
results["PDQIndex2 (Flat)"] = result | |
if ("all" in args.implementations or "index2_ivf" in args.implementations) and args.dataset_size >= 1000: | |
print("\n=== Benchmarking PDQIndex2 (IVF) ===") | |
result = BenchmarkResult("PDQIndex2 (IVF)") | |
start_build = time.time() | |
quantizer = faiss.IndexFlatL2(BITS_IN_PDQ) | |
nlist = len(entries) // 10 | |
index = faiss.IndexIVFFlat(quantizer, BITS_IN_PDQ, nlist) | |
hash_strings = [h for h, _ in entries] | |
vectors = convert_pdq_strings_to_ndarray(hash_strings) | |
print(f"Training IVF index with {len(vectors)} vectors of dimension {vectors.shape[1]}") | |
index.train(vectors) | |
index.nprobe = max(nlist // 10, 1) | |
print(f"Set nprobe to {index.nprobe} for searching") | |
pdq_index2_ivf = PDQIndex2(index=index, entries=entries) | |
end_build = time.time() | |
result.build_time = end_build - start_build | |
if args.serialize_test: | |
print("Testing serialization...") | |
start_serialize = time.time() | |
serialized = pickle.dumps(pdq_index2_ivf) | |
end_serialize = time.time() | |
result.serialize_time = end_serialize - start_serialize | |
result.size_kb = len(serialized) // 1024 | |
print("Testing deserialization...") | |
start_deserialize = time.time() | |
deserialized_index = pickle.loads(serialized) | |
end_deserialize = time.time() | |
result.deserialize_time = end_deserialize - start_deserialize | |
else: | |
serialized = pickle.dumps(pdq_index2_ivf) | |
result.size_kb = len(serialized) // 1024 | |
result.single_item_success = True | |
for threshold in args.thresholds: | |
print(f"Running benchmark with threshold {threshold}...") | |
pdq_index2_ivf.threshold = threshold | |
search_targets = rng.choice(dataset, size=args.num_queries) | |
queries = [ | |
generate_random_hash_with_hamming_distance(target, threshold) | |
for target in search_targets | |
] | |
start_search = time.time() | |
results_list = [] | |
for query in queries: | |
results_list.append(pdq_index2_ivf.query(query)) | |
end_search = time.time() | |
result.search_times[threshold] = end_search - start_search | |
found_targets = 0 | |
for i, query_results in enumerate(results_list): | |
target_id = custom_ids[dataset.index(search_targets[i])] | |
if any(match.metadata == target_id for match in query_results): | |
found_targets += 1 | |
result.targets_found[threshold] = found_targets / len(queries) * 100 | |
results["PDQIndex2 (IVF)"] = result | |
print("\n\n========== BENCHMARK SUMMARY ==========") | |
for name, result in results.items(): | |
result.print_results() | |
print("\n========== PERFORMANCE COMPARISON ==========") | |
print(f"{'Implementation':<20} | {'Build Time (s)':<15} | {'Size (KB)':<15}", end="") | |
for threshold in args.thresholds: | |
print(f" | {'Search Time (ms)':<15} | {'Targets Found (%)':<15}", end="") | |
print() | |
print("-" * (20 + 17 + 17 + (34 * len(args.thresholds)))) | |
for name, result in results.items(): | |
print(f"{name:<20} | {result.build_time:<15.4f} | {result.size_kb:<15,d}", end="") | |
for threshold in args.thresholds: | |
search_time_ms = result.search_times[threshold] * 1000 / args.num_queries | |
print(f" | {search_time_ms:<15.4f} | {result.targets_found[threshold]:<15.2f}", end="") | |
print() | |
if args.serialize_test and "PDQFlatHashIndex" in results and "PDQIndex2 (Flat)" in results: | |
print("\n========== COMPATIBILITY TEST ==========") | |
print("Testing compatibility between legacy and new indices...") | |
test_size = min(100, args.dataset_size) | |
test_hashes = dataset[:test_size] | |
test_ids = custom_ids[:test_size] | |
test_entries = entries[:test_size] | |
legacy_index = PDQFlatHashIndex() | |
legacy_index.add(test_hashes, custom_ids=test_ids) | |
legacy_serialized = pickle.dumps(legacy_index) | |
new_index = PDQIndex2(index=faiss.IndexFlatL2(BITS_IN_PDQ), entries=test_entries) | |
new_serialized = pickle.dumps(new_index) | |
query_hash = test_hashes[0] | |
print("\nTesting query on legacy index") | |
legacy_result = legacy_index.search([query_hash], args.thresholds[0]) | |
print(f"Legacy index found {len(legacy_result[0])} matches") | |
print("\nTesting query on new index") | |
new_index.threshold = args.thresholds[0] | |
new_result = new_index.query(query_hash) | |
print(f"New index found {len(new_result)} matches") | |
print("\nNote: Direct compatibility between old and new index formats is not supported.") | |
print("To migrate data, you would need to create a PDQIndex2 from the deserialized PDQFlatHashIndex data.") | |
print("See the documentation for migration instructions.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment