@@ -82,18 +82,27 @@ BlockMaxWandIterator::BlockMaxWandIterator(std::vector<std::unique_ptr<DocIterat
82
82
UpdateScoreUpperBoundPrefixSums ();
83
83
}
84
84
85
- // Highly optimized pivot calculation with multiple strategies
85
+ // Highly optimized pivot calculation with multiple strategies and improved safety
86
86
size_t BlockMaxWandIterator::FindPivotOptimized (float threshold) {
87
87
const size_t num_iterators = sorted_iterators_.size ();
88
88
89
+ // Safety check for empty iterator list
90
+ if (num_iterators == 0 ) {
91
+ return 0 ;
92
+ }
93
+
89
94
// For very large sets, try fast estimation first
90
95
if (num_iterators > FAST_PIVOT_THRESHOLD) {
91
96
size_t estimated_pivot;
92
97
if (TryFastPivotEstimation (threshold, estimated_pivot)) {
93
- return estimated_pivot;
98
+ // Double-check the estimated pivot for safety
99
+ if (estimated_pivot < num_iterators) {
100
+ return estimated_pivot;
101
+ }
94
102
}
95
103
}
96
104
105
+ // Ensure prefix sums are valid
97
106
if (!prefix_sums_valid_) {
98
107
UpdateScoreUpperBoundPrefixSums ();
99
108
}
@@ -108,7 +117,9 @@ size_t BlockMaxWandIterator::FindPivotOptimized(float threshold) {
108
117
left = mid + 1 ;
109
118
}
110
119
}
111
- return left;
120
+
121
+ // Safety check - ensure we don't return an out-of-bounds pivot
122
+ return std::min (left, num_iterators);
112
123
}
113
124
114
125
void BlockMaxWandIterator::UpdateScoreUpperBoundPrefixSums () {
@@ -160,25 +171,64 @@ void BlockMaxWandIterator::OptimizedPartialSort(size_t limit) {
160
171
}
161
172
}
162
173
163
- // Fast pivot estimation for very large keyword sets
174
+ // Aggressively optimized pivot estimation with safety guarantees
164
175
bool BlockMaxWandIterator::TryFastPivotEstimation (float threshold, size_t & estimated_pivot) {
165
176
const size_t num_iterators = sorted_iterators_.size ();
166
-
167
- // Sample-based estimation: check every Nth iterator
168
- const size_t sample_step = std::max (1UL , num_iterators / 20 ); // Sample ~5% of iterators
169
- float accumulated_score = 0 .0f ;
170
-
171
- for (size_t i = 0 ; i < num_iterators; i += sample_step) {
172
- accumulated_score += sorted_iterators_[i]->BM25ScoreUpperBound ();
173
- if (accumulated_score > threshold) {
174
- // Estimate the actual pivot position
175
- estimated_pivot = std::min (i + sample_step, num_iterators - 1 );
177
+ if (num_iterators <= FAST_PIVOT_THRESHOLD) {
178
+ return false ;
179
+ }
180
+
181
+ // Level 1: Top 5 high-score terms check (hot path)
182
+ constexpr size_t TOP_K = 5 ;
183
+ float sum = 0 .0f ;
184
+ for (size_t i = 0 ; i < std::min (TOP_K, num_iterators); ++i) {
185
+ sum += sorted_iterators_[i]->BM25ScoreUpperBound ();
186
+ if (sum > threshold) {
187
+ estimated_pivot = i;
176
188
return true ;
177
189
}
178
190
}
179
-
180
- estimated_pivot = num_iterators;
181
- return false ;
191
+
192
+ // Level 2: Strided sampling with early termination
193
+ const size_t stride = std::max<size_t >(1 , num_iterators / 20 ); // 5% sampling
194
+ float prev_sum = sum;
195
+ size_t prev_i = TOP_K;
196
+
197
+ for (size_t i = TOP_K; i < num_iterators; i += stride) {
198
+ prev_sum = sum;
199
+ prev_i = i;
200
+ sum += sorted_iterators_[i]->BM25ScoreUpperBound ();
201
+
202
+ if (sum > threshold) {
203
+ // Linear search backward to find exact pivot
204
+ for (size_t j = prev_i; j < i; ++j) {
205
+ prev_sum += sorted_iterators_[j]->BM25ScoreUpperBound ();
206
+ if (prev_sum > threshold) {
207
+ estimated_pivot = j;
208
+ return true ;
209
+ }
210
+ }
211
+ estimated_pivot = i;
212
+ return true ;
213
+ }
214
+
215
+ // Early termination if remaining terms cannot reach threshold
216
+ float max_remaining = (num_iterators - i) * sorted_iterators_[i]->BM25ScoreUpperBound ();
217
+ if (sum + max_remaining <= threshold) {
218
+ break ;
219
+ }
220
+ }
221
+
222
+ // Fallback: Use prefix sums if available
223
+ if (prefix_sums_valid_) {
224
+ estimated_pivot = std::lower_bound (score_ub_prefix_sums_.begin (),
225
+ score_ub_prefix_sums_.end (),
226
+ threshold) - score_ub_prefix_sums_.begin ();
227
+ estimated_pivot = std::min (estimated_pivot, num_iterators - 1 );
228
+ return true ;
229
+ }
230
+
231
+ return false ; // Fall back to standard method
182
232
}
183
233
184
234
void BlockMaxWandIterator::UpdateScoreThreshold (const float threshold) {
0 commit comments