diff --git a/src/svsbench/search.py b/src/svsbench/search.py index b797e15..0507a29 100644 --- a/src/svsbench/search.py +++ b/src/svsbench/search.py @@ -127,6 +127,11 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace: choices=STR_TO_CALIBRATE_SEARCH_BUFFER.keys(), default="all", ) + parser.add_argument( + "--shuffle_ground_truth", + help="Shuffle ground truth. See also svsbench.build --shuffle", + action="store_true", + ) return parser.parse_args(argv) @@ -159,6 +164,8 @@ def search( lvq_strategy: svs.LVQStrategy | None = None, train_prefetchers: bool = True, search_buffer_optimization: svs.VamanaSearchBufferOptimization = svs.VamanaSearchBufferOptimization.All, + shuffle_ground_truth: bool = False, + seed: int = 42, ): logger.info({"search_args": locals()}) logger.info(utils.read_system_config()) @@ -209,6 +216,8 @@ def search( query = svs.read_vecs(str(query_path)) ground_truth = svs.read_vecs(str(ground_truth_path)) + if shuffle_ground_truth: + np.random.default_rng(seed).shuffle(ground_truth) for batch_size_idx, batch_size in enumerate(batch_sizes): index.num_threads = min(max_threads, batch_size) @@ -372,6 +381,8 @@ def main(argv: str | None = None) -> None: search_buffer_optimization=STR_TO_CALIBRATE_SEARCH_BUFFER[ args.calibrate_search_buffer ], + shuffle_ground_truth=args.shuffle_ground_truth, + seed=args.seed, )