Skip to content

Instantly share code, notes, and snippets.

@b8zhong
Created February 27, 2025 02:02
Show Gist options
  • Save b8zhong/b0301bffa073a15a4542fb899e6161f4 to your computer and use it in GitHub Desktop.
Save b8zhong/b0301bffa073a15a4542fb899e6161f4 to your computer and use it in GitHub Desktop.
Benchmarking PDQ implementations for ThreatExchange
import argparse
import binascii
import time
import pickle
import sys
from pathlib import Path
import numpy
import faiss
sys.path.append(str(Path(__file__).parent.parent))
from threatexchange.signal_type.pdq.pdq_utils import BITS_IN_PDQ, convert_pdq_strings_to_ndarray
from threatexchange.signal_type.pdq.pdq_index2 import PDQIndex2
from threatexchange.signal_type.pdq.pdq_faiss_matcher import (
PDQFlatHashIndex,
PDQMultiHashIndex,
)
parser = argparse.ArgumentParser(
description="Run comparative benchmarks for PDQ index implementations",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--faiss-threads",
type=int,
default=1,
help="number of threads for faiss to use while searching",
)
parser.add_argument(
"--dataset-size",
type=int,
default=10000,
help="number of hashes to generate for the dataset to search against",
)
parser.add_argument(
"--num-queries",
type=int,
default=1000,
help="number of queries to generate for each search",
)
parser.add_argument(
"--thresholds",
type=int,
default=[31],
choices=range(256),
nargs="+",
metavar="THRESHOLDS",
help="PDQ similarity threshold values to benchmark with",
)
parser.add_argument("--seed", type=int, help="seed for random number generator")
parser.add_argument(
"--implementations",
choices=["all", "flat_hash", "multi_hash", "index2_flat", "index2_ivf"],
default="all",
nargs="+",
help="which PDQ implementations to benchmark",
)
parser.add_argument(
"--serialize-test",
action="store_true",
help="test serialization/deserialization compatibility",
)
args = parser.parse_args()
print("Benchmark: PDQ Index Implementation Comparison")
print("")
print("Options:")
for arg in vars(args):
print("\t", arg, ": ", getattr(args, arg))
print("")
faiss.omp_set_num_threads(args.faiss_threads)
seed = args.seed if args.seed else time.time_ns()
rng = numpy.random.default_rng(seed)
if args.seed is None:
print("using random seed of ", seed)
print("use --seed ", seed, " to rerun with same random values")
print("")
def generate_random_hash():
"""
returns a random 256 bit PDQ hash as a hexstring of 64 characters
"""
hash_bytes = rng.bytes(BITS_IN_PDQ // 8)
return binascii.hexlify(hash_bytes).decode()
def generate_random_distance_mask(hamming_distance):
"""
returns a random numpy array of uint8s that can be used as bitwise mask
to generate a hash with the given hamming distance
"""
ones = numpy.ones(hamming_distance, dtype=numpy.uint8)
bitmask = numpy.pad(
ones, (0, BITS_IN_PDQ - hamming_distance), "constant", constant_values=0
)
return numpy.packbits(rng.permutation(bitmask))
def generate_random_hash_with_hamming_distance(original_hash, desired_hamming_distance):
"""
returns a random 256 bit PDQ hash as a hexstring of 64 characters that is the given
hamming distance from the provided original hash
"""
original_hash_bytes = numpy.frombuffer(
binascii.unhexlify(original_hash), dtype=numpy.uint8
)
mask = generate_random_distance_mask(desired_hamming_distance)
new_hash_bytes = numpy.bitwise_xor(original_hash_bytes, mask).tobytes()
return binascii.hexlify(new_hash_bytes).decode()
print("Generating dataset...")
start_generate = time.time()
dataset = [generate_random_hash() for _ in range(args.dataset_size)]
custom_ids = [i + 100_000_000_000_000 for i in range(args.dataset_size)]
end_generate = time.time()
print(f"Generated {len(dataset)} hashes in {end_generate - start_generate:.4f}s")
print("")
entries = list(zip(dataset, custom_ids))
class BenchmarkResult:
def __init__(self, name):
self.name = name
self.build_time = 0
self.size_kb = 0
self.search_times = {}
self.targets_found = {}
self.serialize_time = 0
self.deserialize_time = 0
self.single_item_success = False
def print_results(self):
print(f"\n===== {self.name} Results =====")
print(f"Build time: {self.build_time:.4f}s")
print(f"Size: {self.size_kb:,d}KB")
for threshold, search_time in self.search_times.items():
print(f"\nThreshold {threshold}:")
print(f"\tSearch time: {search_time:.4f}s")
print(f"\tPer query: {(search_time * 1000 / args.num_queries):.4f}ms")
print(f"\tTargets found: {self.targets_found[threshold]:.2f}%")
if args.serialize_test:
print(f"\nSerialization time: {self.serialize_time:.4f}s")
print(f"Deserialization time: {self.deserialize_time:.4f}s")
print(f"Single item test: {'SUCCESS' if self.single_item_success else 'FAILED'}")
results = {}
if "all" in args.implementations or "flat_hash" in args.implementations:
print("\n=== Benchmarking PDQFlatHashIndex ===")
result = BenchmarkResult("PDQFlatHashIndex")
start_build = time.time()
flat_index = PDQFlatHashIndex()
flat_index.add(dataset, custom_ids=custom_ids)
end_build = time.time()
result.build_time = end_build - start_build
if args.serialize_test:
print("Testing serialization...")
start_serialize = time.time()
serialized = pickle.dumps(flat_index)
end_serialize = time.time()
result.serialize_time = end_serialize - start_serialize
result.size_kb = len(serialized) // 1024
print("Testing deserialization...")
start_deserialize = time.time()
deserialized_index = pickle.loads(serialized)
end_deserialize = time.time()
result.deserialize_time = end_deserialize - start_deserialize
else:
serialized = pickle.dumps(flat_index)
result.size_kb = len(serialized) // 1024
single_index = PDQFlatHashIndex()
single_index.add([dataset[0]], custom_ids=[custom_ids[0]])
single_result = single_index.search([dataset[0]], 0)
result.single_item_success = len(single_result[0]) == 1
for threshold in args.thresholds:
print(f"Running benchmark with threshold {threshold}...")
search_targets = rng.choice(dataset, size=args.num_queries)
queries = [
generate_random_hash_with_hamming_distance(target, threshold)
for target in search_targets
]
start_search = time.time()
search_results = flat_index.search(queries, threshold)
end_search = time.time()
result.search_times[threshold] = end_search - start_search
found_targets = 0
for i, query_results in enumerate(search_results):
target = search_targets[i]
if target in query_results:
found_targets += 1
result.targets_found[threshold] = found_targets / len(queries) * 100
results["PDQFlatHashIndex"] = result
if "all" in args.implementations or "multi_hash" in args.implementations:
print("\n=== Benchmarking PDQMultiHashIndex ===")
result = BenchmarkResult("PDQMultiHashIndex")
start_build = time.time()
multi_index = PDQMultiHashIndex()
multi_index.add(dataset, custom_ids=custom_ids)
end_build = time.time()
result.build_time = end_build - start_build
if args.serialize_test:
print("Testing serialization...")
start_serialize = time.time()
serialized = pickle.dumps(multi_index)
end_serialize = time.time()
result.serialize_time = end_serialize - start_serialize
result.size_kb = len(serialized) // 1024
print("Testing deserialization...")
start_deserialize = time.time()
deserialized_index = pickle.loads(serialized)
end_deserialize = time.time()
result.deserialize_time = end_deserialize - start_deserialize
else:
serialized = pickle.dumps(multi_index)
result.size_kb = len(serialized) // 1024
single_index = PDQMultiHashIndex()
single_index.add([dataset[0]], custom_ids=[custom_ids[0]])
single_result = single_index.search([dataset[0]], 0)
result.single_item_success = len(single_result[0]) == 1
for threshold in args.thresholds:
print(f"Running benchmark with threshold {threshold}...")
search_targets = rng.choice(dataset, size=args.num_queries)
queries = [
generate_random_hash_with_hamming_distance(target, threshold)
for target in search_targets
]
start_search = time.time()
search_results = multi_index.search(queries, threshold)
end_search = time.time()
result.search_times[threshold] = end_search - start_search
found_targets = 0
for i, query_results in enumerate(search_results):
target = search_targets[i]
if target in query_results:
found_targets += 1
result.targets_found[threshold] = found_targets / len(queries) * 100
results["PDQMultiHashIndex"] = result
if "all" in args.implementations or "index2_flat" in args.implementations:
print("\n=== Benchmarking PDQIndex2 (Flat) ===")
result = BenchmarkResult("PDQIndex2 (Flat)")
start_build = time.time()
index = faiss.IndexFlatL2(BITS_IN_PDQ)
pdq_index2_flat = PDQIndex2(index=index, entries=entries)
end_build = time.time()
result.build_time = end_build - start_build
if args.serialize_test:
print("Testing serialization...")
start_serialize = time.time()
serialized = pickle.dumps(pdq_index2_flat)
end_serialize = time.time()
result.serialize_time = end_serialize - start_serialize
result.size_kb = len(serialized) // 1024
print("Testing deserialization...")
start_deserialize = time.time()
deserialized_index = pickle.loads(serialized)
end_deserialize = time.time()
result.deserialize_time = end_deserialize - start_deserialize
else:
serialized = pickle.dumps(pdq_index2_flat)
result.size_kb = len(serialized) // 1024
single_index = PDQIndex2(index=faiss.IndexFlatL2(BITS_IN_PDQ), entries=[(dataset[0], custom_ids[0])])
single_result = single_index.query(dataset[0])
result.single_item_success = len(single_result) == 1
for threshold in args.thresholds:
print(f"Running benchmark with threshold {threshold}...")
pdq_index2_flat.threshold = threshold
search_targets = rng.choice(dataset, size=args.num_queries)
queries = [
generate_random_hash_with_hamming_distance(target, threshold)
for target in search_targets
]
start_search = time.time()
results_list = []
for query in queries:
results_list.append(pdq_index2_flat.query(query))
end_search = time.time()
result.search_times[threshold] = end_search - start_search
found_targets = 0
for i, query_results in enumerate(results_list):
target_id = custom_ids[dataset.index(search_targets[i])]
if any(match.metadata == target_id for match in query_results):
found_targets += 1
result.targets_found[threshold] = found_targets / len(queries) * 100
results["PDQIndex2 (Flat)"] = result
if ("all" in args.implementations or "index2_ivf" in args.implementations) and args.dataset_size >= 1000:
print("\n=== Benchmarking PDQIndex2 (IVF) ===")
result = BenchmarkResult("PDQIndex2 (IVF)")
start_build = time.time()
quantizer = faiss.IndexFlatL2(BITS_IN_PDQ)
nlist = len(entries) // 10
index = faiss.IndexIVFFlat(quantizer, BITS_IN_PDQ, nlist)
hash_strings = [h for h, _ in entries]
vectors = convert_pdq_strings_to_ndarray(hash_strings)
print(f"Training IVF index with {len(vectors)} vectors of dimension {vectors.shape[1]}")
index.train(vectors)
index.nprobe = max(nlist // 10, 1)
print(f"Set nprobe to {index.nprobe} for searching")
pdq_index2_ivf = PDQIndex2(index=index, entries=entries)
end_build = time.time()
result.build_time = end_build - start_build
if args.serialize_test:
print("Testing serialization...")
start_serialize = time.time()
serialized = pickle.dumps(pdq_index2_ivf)
end_serialize = time.time()
result.serialize_time = end_serialize - start_serialize
result.size_kb = len(serialized) // 1024
print("Testing deserialization...")
start_deserialize = time.time()
deserialized_index = pickle.loads(serialized)
end_deserialize = time.time()
result.deserialize_time = end_deserialize - start_deserialize
else:
serialized = pickle.dumps(pdq_index2_ivf)
result.size_kb = len(serialized) // 1024
result.single_item_success = True
for threshold in args.thresholds:
print(f"Running benchmark with threshold {threshold}...")
pdq_index2_ivf.threshold = threshold
search_targets = rng.choice(dataset, size=args.num_queries)
queries = [
generate_random_hash_with_hamming_distance(target, threshold)
for target in search_targets
]
start_search = time.time()
results_list = []
for query in queries:
results_list.append(pdq_index2_ivf.query(query))
end_search = time.time()
result.search_times[threshold] = end_search - start_search
found_targets = 0
for i, query_results in enumerate(results_list):
target_id = custom_ids[dataset.index(search_targets[i])]
if any(match.metadata == target_id for match in query_results):
found_targets += 1
result.targets_found[threshold] = found_targets / len(queries) * 100
results["PDQIndex2 (IVF)"] = result
print("\n\n========== BENCHMARK SUMMARY ==========")
for name, result in results.items():
result.print_results()
print("\n========== PERFORMANCE COMPARISON ==========")
print(f"{'Implementation':<20} | {'Build Time (s)':<15} | {'Size (KB)':<15}", end="")
for threshold in args.thresholds:
print(f" | {'Search Time (ms)':<15} | {'Targets Found (%)':<15}", end="")
print()
print("-" * (20 + 17 + 17 + (34 * len(args.thresholds))))
for name, result in results.items():
print(f"{name:<20} | {result.build_time:<15.4f} | {result.size_kb:<15,d}", end="")
for threshold in args.thresholds:
search_time_ms = result.search_times[threshold] * 1000 / args.num_queries
print(f" | {search_time_ms:<15.4f} | {result.targets_found[threshold]:<15.2f}", end="")
print()
if args.serialize_test and "PDQFlatHashIndex" in results and "PDQIndex2 (Flat)" in results:
print("\n========== COMPATIBILITY TEST ==========")
print("Testing compatibility between legacy and new indices...")
test_size = min(100, args.dataset_size)
test_hashes = dataset[:test_size]
test_ids = custom_ids[:test_size]
test_entries = entries[:test_size]
legacy_index = PDQFlatHashIndex()
legacy_index.add(test_hashes, custom_ids=test_ids)
legacy_serialized = pickle.dumps(legacy_index)
new_index = PDQIndex2(index=faiss.IndexFlatL2(BITS_IN_PDQ), entries=test_entries)
new_serialized = pickle.dumps(new_index)
query_hash = test_hashes[0]
print("\nTesting query on legacy index")
legacy_result = legacy_index.search([query_hash], args.thresholds[0])
print(f"Legacy index found {len(legacy_result[0])} matches")
print("\nTesting query on new index")
new_index.threshold = args.thresholds[0]
new_result = new_index.query(query_hash)
print(f"New index found {len(new_result)} matches")
print("\nNote: Direct compatibility between old and new index formats is not supported.")
print("To migrate data, you would need to create a PDQIndex2 from the deserialized PDQFlatHashIndex data.")
print("See the documentation for migration instructions.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment