@@ -26,7 +26,6 @@ import :logger;
26
26
import :infinity_exception;
27
27
28
28
import third_party;
29
-
30
29
import internal_types;
31
30
32
31
namespace infinity {
@@ -72,28 +71,37 @@ BlockMaxWandIterator::BlockMaxWandIterator(std::vector<std::unique_ptr<DocIterat
72
71
// Initialize optimization structures for many keywords
73
72
score_ub_prefix_sums_.reserve (num_iterators + 1 );
74
73
iterator_indices_.reserve (num_iterators);
75
-
74
+
76
75
// Pre-populate iterator indices for cache-friendly access
77
76
for (size_t i = 0 ; i < num_iterators; i++) {
78
77
iterator_indices_.push_back (i);
79
78
}
80
79
indices_valid_ = true ;
81
-
80
+
82
81
UpdateScoreUpperBoundPrefixSums ();
83
82
}
84
83
85
- // Highly optimized pivot calculation with multiple strategies
84
+ // Highly optimized pivot calculation with multiple strategies and improved safety
86
85
size_t BlockMaxWandIterator::FindPivotOptimized (float threshold) {
87
86
const size_t num_iterators = sorted_iterators_.size ();
88
-
87
+
88
+ // Safety check for empty iterator list
89
+ if (num_iterators == 0 ) {
90
+ return 0 ;
91
+ }
92
+
89
93
// For very large sets, try fast estimation first
90
94
if (num_iterators > FAST_PIVOT_THRESHOLD) {
91
95
size_t estimated_pivot;
92
96
if (TryFastPivotEstimation (threshold, estimated_pivot)) {
93
- return estimated_pivot;
97
+ // Double-check the estimated pivot for safety
98
+ if (estimated_pivot < num_iterators) {
99
+ return estimated_pivot;
100
+ }
94
101
}
95
102
}
96
-
103
+
104
+ // Ensure prefix sums are valid
97
105
if (!prefix_sums_valid_) {
98
106
UpdateScoreUpperBoundPrefixSums ();
99
107
}
@@ -108,19 +116,21 @@ size_t BlockMaxWandIterator::FindPivotOptimized(float threshold) {
108
116
left = mid + 1 ;
109
117
}
110
118
}
111
- return left;
119
+
120
+ // Safety check - ensure we don't return an out-of-bounds pivot
121
+ return std::min (left, num_iterators);
112
122
}
113
123
114
124
void BlockMaxWandIterator::UpdateScoreUpperBoundPrefixSums () {
115
125
const size_t num_iterators = sorted_iterators_.size ();
116
-
126
+
117
127
// Resize only if necessary to avoid memory allocations
118
128
if (score_ub_prefix_sums_.size () != num_iterators + 1 ) {
119
129
score_ub_prefix_sums_.resize (num_iterators + 1 );
120
130
}
121
-
131
+
122
132
score_ub_prefix_sums_[0 ] = 0 .0f ;
123
-
133
+
124
134
// Vectorized accumulation for better performance
125
135
for (size_t i = 0 ; i < num_iterators; i++) {
126
136
score_ub_prefix_sums_[i + 1 ] = score_ub_prefix_sums_[i] + sorted_iterators_[i]->BM25ScoreUpperBound ();
@@ -142,43 +152,78 @@ bool BlockMaxWandIterator::ShouldSkipSort() const {
142
152
if (consecutive_skips_ > 10 ) {
143
153
adaptive_interval = std::min (adaptive_interval + 2 , 8u );
144
154
}
145
-
155
+
146
156
return iterations_since_sort_ < adaptive_interval;
147
157
}
148
158
149
159
// Optimized partial sort that only sorts what we need
150
160
void BlockMaxWandIterator::OptimizedPartialSort (size_t limit) {
151
161
const size_t num_iterators = sorted_iterators_.size ();
152
162
if (limit >= num_iterators) {
153
- std::sort (sorted_iterators_.begin (), sorted_iterators_.end (),
154
- [](const auto &a, const auto &b) { return a->DocID () < b->DocID (); });
163
+ std::sort (sorted_iterators_.begin (), sorted_iterators_.end (), [](const auto &a, const auto &b) { return a->DocID () < b->DocID (); });
155
164
} else {
156
- std::partial_sort (sorted_iterators_.begin (),
157
- sorted_iterators_.begin () + limit,
158
- sorted_iterators_.end (),
159
- [](const auto &a, const auto &b) { return a->DocID () < b->DocID (); });
165
+ std::partial_sort (sorted_iterators_.begin (), sorted_iterators_.begin () + limit, sorted_iterators_.end (), [](const auto &a, const auto &b) {
166
+ return a->DocID () < b->DocID ();
167
+ });
160
168
}
161
169
}
162
170
163
- // Fast pivot estimation for very large keyword sets
164
- bool BlockMaxWandIterator::TryFastPivotEstimation (float threshold, size_t & estimated_pivot) {
171
+ // Aggressively optimized pivot estimation with safety guarantees
172
+ bool BlockMaxWandIterator::TryFastPivotEstimation (float threshold, size_t & estimated_pivot) {
165
173
const size_t num_iterators = sorted_iterators_.size ();
166
-
167
- // Sample-based estimation: check every Nth iterator
168
- const size_t sample_step = std::max (1UL , num_iterators / 20 ); // Sample ~5% of iterators
169
- float accumulated_score = 0 .0f ;
170
-
171
- for (size_t i = 0 ; i < num_iterators; i += sample_step) {
172
- accumulated_score += sorted_iterators_[i]->BM25ScoreUpperBound ();
173
- if (accumulated_score > threshold) {
174
- // Estimate the actual pivot position
175
- estimated_pivot = std::min (i + sample_step, num_iterators - 1 );
174
+ if (num_iterators <= FAST_PIVOT_THRESHOLD) {
175
+ return false ;
176
+ }
177
+
178
+ // Level 1: Top 5 high-score terms check (hot path)
179
+ constexpr size_t TOP_K = 5 ;
180
+ float sum = 0 .0f ;
181
+ for (size_t i = 0 ; i < std::min (TOP_K, num_iterators); ++i) {
182
+ sum += sorted_iterators_[i]->BM25ScoreUpperBound ();
183
+ if (sum > threshold) {
184
+ estimated_pivot = i;
176
185
return true ;
177
186
}
178
187
}
179
-
180
- estimated_pivot = num_iterators;
181
- return false ;
188
+
189
+ // Level 2: Strided sampling with early termination
190
+ const size_t stride = std::max<size_t >(1 , num_iterators / 20 ); // 5% sampling
191
+ float prev_sum = sum;
192
+ size_t prev_i = TOP_K;
193
+
194
+ for (size_t i = TOP_K; i < num_iterators; i += stride) {
195
+ prev_sum = sum;
196
+ prev_i = i;
197
+ sum += sorted_iterators_[i]->BM25ScoreUpperBound ();
198
+
199
+ if (sum > threshold) {
200
+ // Linear search backward to find exact pivot
201
+ for (size_t j = prev_i; j < i; ++j) {
202
+ prev_sum += sorted_iterators_[j]->BM25ScoreUpperBound ();
203
+ if (prev_sum > threshold) {
204
+ estimated_pivot = j;
205
+ return true ;
206
+ }
207
+ }
208
+ estimated_pivot = i;
209
+ return true ;
210
+ }
211
+
212
+ // Early termination if remaining terms cannot reach threshold
213
+ float max_remaining = (num_iterators - i) * sorted_iterators_[i]->BM25ScoreUpperBound ();
214
+ if (sum + max_remaining <= threshold) {
215
+ break ;
216
+ }
217
+ }
218
+
219
+ // Fallback: Use prefix sums if available
220
+ if (prefix_sums_valid_) {
221
+ estimated_pivot = std::lower_bound (score_ub_prefix_sums_.begin (), score_ub_prefix_sums_.end (), threshold) - score_ub_prefix_sums_.begin ();
222
+ estimated_pivot = std::min (estimated_pivot, num_iterators - 1 );
223
+ return true ;
224
+ }
225
+
226
+ return false ; // Fall back to standard method
182
227
}
183
228
184
229
void BlockMaxWandIterator::UpdateScoreThreshold (const float threshold) {
@@ -221,16 +266,15 @@ bool BlockMaxWandIterator::Next(RowID doc_id) {
221
266
if (should_sort) {
222
267
next_sort_cnt_++;
223
268
consecutive_skips_ = 0 ; // Reset consecutive skip counter
224
-
269
+
225
270
if (num_iterators > SORT_SKIP_THRESHOLD) {
226
271
// For large keyword sets, use intelligent partial sort
227
272
size_t sort_limit = num_iterators / PARTIAL_SORT_FACTOR + 5 ;
228
273
sort_limit = std::min (sort_limit, num_iterators);
229
274
OptimizedPartialSort (sort_limit);
230
275
} else {
231
276
// For smaller keyword sets, use full sort
232
- std::sort (sorted_iterators_.begin (), sorted_iterators_.end (),
233
- [](const auto &a, const auto &b) { return a->DocID () < b->DocID (); });
277
+ std::sort (sorted_iterators_.begin (), sorted_iterators_.end (), [](const auto &a, const auto &b) { return a->DocID () < b->DocID (); });
234
278
}
235
279
iterations_since_sort_ = 0 ;
236
280
prefix_sums_valid_ = false ; // Invalidate prefix sums after sorting
@@ -250,12 +294,12 @@ bool BlockMaxWandIterator::Next(RowID doc_id) {
250
294
bm25_score_upper_bound_ -= sorted_iterators_[i]->BM25ScoreUpperBound ();
251
295
write_pos = i;
252
296
}
253
-
297
+
254
298
if (write_pos < num_iterators) {
255
299
sorted_iterators_.erase (sorted_iterators_.begin () + write_pos, sorted_iterators_.end ());
256
300
num_iterators = sorted_iterators_.size ();
257
301
prefix_sums_valid_ = false ; // Invalidate prefix sums when iterators are removed
258
- indices_valid_ = false ; // Invalidate indices cache
302
+ indices_valid_ = false ; // Invalidate indices cache
259
303
}
260
304
if (bm25_score_upper_bound_ <= threshold_) [[unlikely]] {
261
305
doc_id_ = INVALID_ROWID;
@@ -276,13 +320,14 @@ bool BlockMaxWandIterator::Next(RowID doc_id) {
276
320
size_t estimated_limit = std::min (num_iterators, num_iterators / 4 + 10 );
277
321
OptimizedPartialSort (estimated_limit);
278
322
} else {
279
- std::sort (sorted_iterators_.begin (), sorted_iterators_.end (),
280
- [](const auto &a, const auto &b) { return a->DocID () < b->DocID (); });
323
+ std::sort (sorted_iterators_.begin (), sorted_iterators_.end (), [](const auto &a, const auto &b) {
324
+ return a->DocID () < b->DocID ();
325
+ });
281
326
}
282
327
iterations_since_sort_ = 0 ;
283
328
prefix_sums_valid_ = false ;
284
329
}
285
-
330
+
286
331
// Use optimized pivot calculation even for smaller sets
287
332
if (prefix_sums_valid_ || num_iterators > 30 ) {
288
333
pivot = FindPivotOptimized (threshold_);
@@ -313,7 +358,7 @@ bool BlockMaxWandIterator::Next(RowID doc_id) {
313
358
float sum_score_bm = 0 .0f ;
314
359
bool found_exhausted_it = false ;
315
360
size_t exhausted_idx = 0 ;
316
-
361
+
317
362
for (size_t i = 0 ; i <= pivot; i++) {
318
363
bool ok = sorted_iterators_[i]->NextShallow (d);
319
364
if (ok) [[likely]] {
@@ -329,7 +374,7 @@ bool BlockMaxWandIterator::Next(RowID doc_id) {
329
374
break ;
330
375
}
331
376
}
332
-
377
+
333
378
if (found_exhausted_it) [[unlikely]] {
334
379
// Remove exhausted iterator and update bounds
335
380
bm25_score_upper_bound_ -= sorted_iterators_[exhausted_idx]->BM25ScoreUpperBound ();
0 commit comments