Skip to content

Commit 61bfed2

Browse files
committed
fix
1 parent 5fd03c2 commit 61bfed2

File tree

2 files changed

+95
-50
lines changed

2 files changed

+95
-50
lines changed

src/storage/invertedindex/search/blockmax_wand_iterator.cppm

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ private:
5050
void UpdateScoreUpperBoundPrefixSums();
5151
bool ShouldSkipSort() const;
5252
void OptimizedPartialSort(size_t limit);
53-
bool TryFastPivotEstimation(float threshold, size_t& estimated_pivot);
53+
bool TryFastPivotEstimation(float threshold, size_t &estimated_pivot);
5454

5555
// block max info
5656
RowID common_block_min_possible_doc_id_{}; // not always exist
@@ -61,17 +61,17 @@ private:
6161
size_t pivot_;
6262

6363
// Enhanced optimization for many keywords
64-
static constexpr u32 SORT_SKIP_THRESHOLD = 15; // Reduced threshold for better balance
65-
static constexpr u32 LAZY_SORT_INTERVAL = 3; // More frequent sorting for accuracy
64+
static constexpr u32 SORT_SKIP_THRESHOLD = 15; // Reduced threshold for better balance
65+
static constexpr u32 LAZY_SORT_INTERVAL = 3; // More frequent sorting for accuracy
6666
static constexpr u32 FAST_PIVOT_THRESHOLD = 50; // Use fast estimation for very large sets
6767
static constexpr u32 PARTIAL_SORT_FACTOR = 3; // Sort only top 1/3 for large sets
68-
69-
std::vector<f32> score_ub_prefix_sums_; // Prefix sums for fast pivot calculation
70-
std::vector<size_t> iterator_indices_; // Cached indices for avoiding pointer chasing
68+
69+
std::vector<f32> score_ub_prefix_sums_; // Prefix sums for fast pivot calculation
70+
std::vector<size_t> iterator_indices_; // Cached indices for avoiding pointer chasing
7171
bool prefix_sums_valid_ = false;
7272
bool indices_valid_ = false;
7373
u32 iterations_since_sort_ = 0;
74-
u32 consecutive_skips_ = 0; // Track consecutive sort skips
74+
u32 consecutive_skips_ = 0; // Track consecutive sort skips
7575

7676
// bm25 score cache
7777
bool bm25_score_cached_ = false;

src/storage/invertedindex/search/blockmax_wand_iterator_impl.cpp

Lines changed: 88 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import :logger;
2626
import :infinity_exception;
2727

2828
import third_party;
29-
3029
import internal_types;
3130

3231
namespace infinity {
@@ -72,28 +71,37 @@ BlockMaxWandIterator::BlockMaxWandIterator(std::vector<std::unique_ptr<DocIterat
7271
// Initialize optimization structures for many keywords
7372
score_ub_prefix_sums_.reserve(num_iterators + 1);
7473
iterator_indices_.reserve(num_iterators);
75-
74+
7675
// Pre-populate iterator indices for cache-friendly access
7776
for (size_t i = 0; i < num_iterators; i++) {
7877
iterator_indices_.push_back(i);
7978
}
8079
indices_valid_ = true;
81-
80+
8281
UpdateScoreUpperBoundPrefixSums();
8382
}
8483

85-
// Highly optimized pivot calculation with multiple strategies
84+
// Highly optimized pivot calculation with multiple strategies and improved safety
8685
size_t BlockMaxWandIterator::FindPivotOptimized(float threshold) {
8786
const size_t num_iterators = sorted_iterators_.size();
88-
87+
88+
// Safety check for empty iterator list
89+
if (num_iterators == 0) {
90+
return 0;
91+
}
92+
8993
// For very large sets, try fast estimation first
9094
if (num_iterators > FAST_PIVOT_THRESHOLD) {
9195
size_t estimated_pivot;
9296
if (TryFastPivotEstimation(threshold, estimated_pivot)) {
93-
return estimated_pivot;
97+
// Double-check the estimated pivot for safety
98+
if (estimated_pivot < num_iterators) {
99+
return estimated_pivot;
100+
}
94101
}
95102
}
96-
103+
104+
// Ensure prefix sums are valid
97105
if (!prefix_sums_valid_) {
98106
UpdateScoreUpperBoundPrefixSums();
99107
}
@@ -108,19 +116,21 @@ size_t BlockMaxWandIterator::FindPivotOptimized(float threshold) {
108116
left = mid + 1;
109117
}
110118
}
111-
return left;
119+
120+
// Safety check - ensure we don't return an out-of-bounds pivot
121+
return std::min(left, num_iterators);
112122
}
113123

114124
void BlockMaxWandIterator::UpdateScoreUpperBoundPrefixSums() {
115125
const size_t num_iterators = sorted_iterators_.size();
116-
126+
117127
// Resize only if necessary to avoid memory allocations
118128
if (score_ub_prefix_sums_.size() != num_iterators + 1) {
119129
score_ub_prefix_sums_.resize(num_iterators + 1);
120130
}
121-
131+
122132
score_ub_prefix_sums_[0] = 0.0f;
123-
133+
124134
// Vectorized accumulation for better performance
125135
for (size_t i = 0; i < num_iterators; i++) {
126136
score_ub_prefix_sums_[i + 1] = score_ub_prefix_sums_[i] + sorted_iterators_[i]->BM25ScoreUpperBound();
@@ -142,43 +152,78 @@ bool BlockMaxWandIterator::ShouldSkipSort() const {
142152
if (consecutive_skips_ > 10) {
143153
adaptive_interval = std::min(adaptive_interval + 2, 8u);
144154
}
145-
155+
146156
return iterations_since_sort_ < adaptive_interval;
147157
}
148158

149159
// Optimized partial sort that only sorts what we need
150160
void BlockMaxWandIterator::OptimizedPartialSort(size_t limit) {
151161
const size_t num_iterators = sorted_iterators_.size();
152162
if (limit >= num_iterators) {
153-
std::sort(sorted_iterators_.begin(), sorted_iterators_.end(),
154-
[](const auto &a, const auto &b) { return a->DocID() < b->DocID(); });
163+
std::sort(sorted_iterators_.begin(), sorted_iterators_.end(), [](const auto &a, const auto &b) { return a->DocID() < b->DocID(); });
155164
} else {
156-
std::partial_sort(sorted_iterators_.begin(),
157-
sorted_iterators_.begin() + limit,
158-
sorted_iterators_.end(),
159-
[](const auto &a, const auto &b) { return a->DocID() < b->DocID(); });
165+
std::partial_sort(sorted_iterators_.begin(), sorted_iterators_.begin() + limit, sorted_iterators_.end(), [](const auto &a, const auto &b) {
166+
return a->DocID() < b->DocID();
167+
});
160168
}
161169
}
162170

163-
// Fast pivot estimation for very large keyword sets
164-
bool BlockMaxWandIterator::TryFastPivotEstimation(float threshold, size_t& estimated_pivot) {
171+
// Aggressively optimized pivot estimation with safety guarantees
172+
bool BlockMaxWandIterator::TryFastPivotEstimation(float threshold, size_t &estimated_pivot) {
165173
const size_t num_iterators = sorted_iterators_.size();
166-
167-
// Sample-based estimation: check every Nth iterator
168-
const size_t sample_step = std::max(1UL, num_iterators / 20); // Sample ~5% of iterators
169-
float accumulated_score = 0.0f;
170-
171-
for (size_t i = 0; i < num_iterators; i += sample_step) {
172-
accumulated_score += sorted_iterators_[i]->BM25ScoreUpperBound();
173-
if (accumulated_score > threshold) {
174-
// Estimate the actual pivot position
175-
estimated_pivot = std::min(i + sample_step, num_iterators - 1);
174+
if (num_iterators <= FAST_PIVOT_THRESHOLD) {
175+
return false;
176+
}
177+
178+
// Level 1: Top 5 high-score terms check (hot path)
179+
constexpr size_t TOP_K = 5;
180+
float sum = 0.0f;
181+
for (size_t i = 0; i < std::min(TOP_K, num_iterators); ++i) {
182+
sum += sorted_iterators_[i]->BM25ScoreUpperBound();
183+
if (sum > threshold) {
184+
estimated_pivot = i;
176185
return true;
177186
}
178187
}
179-
180-
estimated_pivot = num_iterators;
181-
return false;
188+
189+
// Level 2: Strided sampling with early termination
190+
const size_t stride = std::max<size_t>(1, num_iterators / 20); // 5% sampling
191+
float prev_sum = sum;
192+
size_t prev_i = TOP_K;
193+
194+
for (size_t i = TOP_K; i < num_iterators; i += stride) {
195+
prev_sum = sum;
196+
prev_i = i;
197+
sum += sorted_iterators_[i]->BM25ScoreUpperBound();
198+
199+
if (sum > threshold) {
200+
// Linear search backward to find exact pivot
201+
for (size_t j = prev_i; j < i; ++j) {
202+
prev_sum += sorted_iterators_[j]->BM25ScoreUpperBound();
203+
if (prev_sum > threshold) {
204+
estimated_pivot = j;
205+
return true;
206+
}
207+
}
208+
estimated_pivot = i;
209+
return true;
210+
}
211+
212+
// Early termination if remaining terms cannot reach threshold
213+
float max_remaining = (num_iterators - i) * sorted_iterators_[i]->BM25ScoreUpperBound();
214+
if (sum + max_remaining <= threshold) {
215+
break;
216+
}
217+
}
218+
219+
// Fallback: Use prefix sums if available
220+
if (prefix_sums_valid_) {
221+
estimated_pivot = std::lower_bound(score_ub_prefix_sums_.begin(), score_ub_prefix_sums_.end(), threshold) - score_ub_prefix_sums_.begin();
222+
estimated_pivot = std::min(estimated_pivot, num_iterators - 1);
223+
return true;
224+
}
225+
226+
return false; // Fall back to standard method
182227
}
183228

184229
void BlockMaxWandIterator::UpdateScoreThreshold(const float threshold) {
@@ -221,16 +266,15 @@ bool BlockMaxWandIterator::Next(RowID doc_id) {
221266
if (should_sort) {
222267
next_sort_cnt_++;
223268
consecutive_skips_ = 0; // Reset consecutive skip counter
224-
269+
225270
if (num_iterators > SORT_SKIP_THRESHOLD) {
226271
// For large keyword sets, use intelligent partial sort
227272
size_t sort_limit = num_iterators / PARTIAL_SORT_FACTOR + 5;
228273
sort_limit = std::min(sort_limit, num_iterators);
229274
OptimizedPartialSort(sort_limit);
230275
} else {
231276
// For smaller keyword sets, use full sort
232-
std::sort(sorted_iterators_.begin(), sorted_iterators_.end(),
233-
[](const auto &a, const auto &b) { return a->DocID() < b->DocID(); });
277+
std::sort(sorted_iterators_.begin(), sorted_iterators_.end(), [](const auto &a, const auto &b) { return a->DocID() < b->DocID(); });
234278
}
235279
iterations_since_sort_ = 0;
236280
prefix_sums_valid_ = false; // Invalidate prefix sums after sorting
@@ -250,12 +294,12 @@ bool BlockMaxWandIterator::Next(RowID doc_id) {
250294
bm25_score_upper_bound_ -= sorted_iterators_[i]->BM25ScoreUpperBound();
251295
write_pos = i;
252296
}
253-
297+
254298
if (write_pos < num_iterators) {
255299
sorted_iterators_.erase(sorted_iterators_.begin() + write_pos, sorted_iterators_.end());
256300
num_iterators = sorted_iterators_.size();
257301
prefix_sums_valid_ = false; // Invalidate prefix sums when iterators are removed
258-
indices_valid_ = false; // Invalidate indices cache
302+
indices_valid_ = false; // Invalidate indices cache
259303
}
260304
if (bm25_score_upper_bound_ <= threshold_) [[unlikely]] {
261305
doc_id_ = INVALID_ROWID;
@@ -276,13 +320,14 @@ bool BlockMaxWandIterator::Next(RowID doc_id) {
276320
size_t estimated_limit = std::min(num_iterators, num_iterators / 4 + 10);
277321
OptimizedPartialSort(estimated_limit);
278322
} else {
279-
std::sort(sorted_iterators_.begin(), sorted_iterators_.end(),
280-
[](const auto &a, const auto &b) { return a->DocID() < b->DocID(); });
323+
std::sort(sorted_iterators_.begin(), sorted_iterators_.end(), [](const auto &a, const auto &b) {
324+
return a->DocID() < b->DocID();
325+
});
281326
}
282327
iterations_since_sort_ = 0;
283328
prefix_sums_valid_ = false;
284329
}
285-
330+
286331
// Use optimized pivot calculation even for smaller sets
287332
if (prefix_sums_valid_ || num_iterators > 30) {
288333
pivot = FindPivotOptimized(threshold_);
@@ -313,7 +358,7 @@ bool BlockMaxWandIterator::Next(RowID doc_id) {
313358
float sum_score_bm = 0.0f;
314359
bool found_exhausted_it = false;
315360
size_t exhausted_idx = 0;
316-
361+
317362
for (size_t i = 0; i <= pivot; i++) {
318363
bool ok = sorted_iterators_[i]->NextShallow(d);
319364
if (ok) [[likely]] {
@@ -329,7 +374,7 @@ bool BlockMaxWandIterator::Next(RowID doc_id) {
329374
break;
330375
}
331376
}
332-
377+
333378
if (found_exhausted_it) [[unlikely]] {
334379
// Remove exhausted iterator and update bounds
335380
bm25_score_upper_bound_ -= sorted_iterators_[exhausted_idx]->BM25ScoreUpperBound();

0 commit comments

Comments
 (0)