Skip to content

Commit c5130de

Browse files
author
wangqiang
committed
Support gds var length label pruning
1 parent b89dd36 commit c5130de

21 files changed

+324
-53
lines changed

extension/fts/src/function/query_fts_index.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ void QueryFTSAlgorithm::exec(processor::ExecutionContext* executionContext) {
263263

264264
node_id_map_t<ScoreInfo> scores;
265265
auto edgeCompute = std::make_unique<QFTSEdgeCompute>(scores, dfs);
266-
auto compState = GDSComputeState(std::move(frontierPair), std::move(edgeCompute),
266+
auto compState = GDSComputeState(std::move(frontierPair), std::move(edgeCompute), {},
267267
nullptr /* outputNodeMask */);
268268
GDSUtils::runFrontiersUntilConvergence(executionContext, compState, graph, ExtendDirection::FWD,
269269
1 /* maxIters */, QueryFTSAlgorithm::TERM_FREQUENCY_PROP_NAME);

src/binder/query/query_graph_label_analyzer.cpp

Lines changed: 216 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "binder/query/query_graph_label_analyzer.h"
22

33
#include "catalog/catalog.h"
4+
#include "catalog/catalog_entry/node_table_catalog_entry.h"
45
#include "catalog/catalog_entry/rel_table_catalog_entry.h"
56
#include "common/exception/binder.h"
67
#include "common/string_format.h"
@@ -29,7 +30,7 @@ void QueryGraphLabelAnalyzer::pruneNode(const QueryGraph& graph, NodeExpression&
2930
if (queryRel->isRecursive()) {
3031
continue;
3132
}
32-
common::table_id_set_t candidates;
33+
table_id_set_t candidates;
3334
std::unordered_set<std::string> candidateNamesSet;
3435
auto isSrcConnect = *queryRel->getSrcNode() == node;
3536
auto isDstConnect = *queryRel->getDstNode() == node;
@@ -94,49 +95,231 @@ void QueryGraphLabelAnalyzer::pruneNode(const QueryGraph& graph, NodeExpression&
9495
}
9596
}
9697

97-
void QueryGraphLabelAnalyzer::pruneRel(RelExpression& rel) const {
98-
if (rel.isRecursive()) {
99-
return;
100-
}
98+
std::vector<TableCatalogEntry*> QueryGraphLabelAnalyzer::pruneNonRecursiveRel(
99+
const std::vector<TableCatalogEntry*>& relEntries, const table_id_set_t& srcTableIDSet,
100+
const table_id_set_t& dstTableIDSet, const RelDirectionType directionType) const {
101+
102+
auto forwardPruningFunc = [&](table_id_t srcTableID, table_id_t dstTableID) {
103+
return srcTableIDSet.contains(srcTableID) && dstTableIDSet.contains(dstTableID);
104+
};
101105
std::vector<TableCatalogEntry*> prunedEntries;
102-
if (rel.getDirectionType() == RelDirectionType::BOTH) {
103-
table_id_set_t srcBoundTableIDSet;
104-
table_id_set_t dstBoundTableIDSet;
105-
for (auto entry : rel.getSrcNode()->getEntries()) {
106-
srcBoundTableIDSet.insert(entry->getTableID());
106+
for (auto& entry : relEntries) {
107+
auto& relEntry = entry->constCast<RelTableCatalogEntry>();
108+
auto srcTableID = relEntry.getSrcTableID();
109+
auto dstTableID = relEntry.getDstTableID();
110+
auto satisfyForwardPruning = forwardPruningFunc(srcTableID, dstTableID);
111+
if (directionType == RelDirectionType::BOTH) {
112+
if (satisfyForwardPruning ||
113+
(dstTableIDSet.contains(srcTableID) && srcTableIDSet.contains(dstTableID))) {
114+
prunedEntries.push_back(entry);
115+
}
116+
} else {
117+
if (satisfyForwardPruning) {
118+
prunedEntries.push_back(entry);
119+
}
107120
}
108-
for (auto entry : rel.getDstNode()->getEntries()) {
109-
dstBoundTableIDSet.insert(entry->getTableID());
121+
}
122+
return prunedEntries;
123+
}
124+
125+
table_id_set_t QueryGraphLabelAnalyzer::collectRelNodes(const RelDataDirection direction,
126+
std::vector<TableCatalogEntry*> relEntries) const {
127+
table_id_set_t nodeIDs;
128+
for (const auto& entry : relEntries) {
129+
const auto& relEntry = entry->constCast<RelTableCatalogEntry>();
130+
if (direction == RelDataDirection::FWD) {
131+
nodeIDs.insert(relEntry.getDstTableID());
132+
} else if (direction == RelDataDirection::BWD) {
133+
nodeIDs.insert(relEntry.getSrcTableID());
134+
} else {
135+
KU_UNREACHABLE;
110136
}
111-
for (auto& entry : rel.getEntries()) {
112-
auto& relEntry = entry->constCast<RelTableCatalogEntry>();
113-
auto srcTableID = relEntry.getSrcTableID();
114-
auto dstTableID = relEntry.getDstTableID();
115-
if ((srcBoundTableIDSet.contains(srcTableID) &&
116-
dstBoundTableIDSet.contains(dstTableID)) ||
117-
(dstBoundTableIDSet.contains(srcTableID) &&
118-
srcBoundTableIDSet.contains(dstTableID))) {
119-
prunedEntries.push_back(entry);
137+
}
138+
return nodeIDs;
139+
}
140+
141+
std::pair<std::vector<table_id_set_t>, std::vector<table_id_set_t>>
142+
QueryGraphLabelAnalyzer::pruneRecursiveRel(const std::vector<TableCatalogEntry*>& relEntries,
143+
const table_id_set_t srcTableIDSet, const table_id_set_t dstTableIDSet, size_t lowerBound,
144+
size_t upperBound, RelDirectionType relDirectionType) const {
145+
// src-->[dst,[rels]]
146+
std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>
147+
stepFromLeftGraph;
148+
std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>
149+
stepFromRightGraph;
150+
for (auto entry : relEntries) {
151+
auto& relEntry = entry->constCast<RelTableCatalogEntry>();
152+
auto srcTableID = relEntry.getSrcTableID();
153+
auto dstTableID = relEntry.getDstTableID();
154+
stepFromLeftGraph[srcTableID][dstTableID].push_back(relEntry.getTableID());
155+
stepFromRightGraph[dstTableID][srcTableID].push_back(relEntry.getTableID());
156+
if (relDirectionType == RelDirectionType::BOTH) {
157+
stepFromLeftGraph[dstTableID][srcTableID].push_back(relEntry.getTableID());
158+
stepFromRightGraph[srcTableID][dstTableID].push_back(relEntry.getTableID());
159+
}
160+
}
161+
auto stepFromLeft =
162+
pruneRecursiveRel(stepFromLeftGraph, srcTableIDSet, dstTableIDSet, lowerBound, upperBound);
163+
auto stepFromRight =
164+
pruneRecursiveRel(stepFromRightGraph, dstTableIDSet, srcTableIDSet, lowerBound, upperBound);
165+
return {stepFromLeft, stepFromRight};
166+
}
167+
168+
std::vector<table_id_set_t> QueryGraphLabelAnalyzer::pruneRecursiveRel(
169+
const std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>& graph,
170+
const table_id_set_t& startTableIDSet, const table_id_set_t& endTableIDSet, size_t lowerBound,
171+
size_t upperBound) const {
172+
173+
// There may be multiple rel type between src and dst
174+
using PATH = std::vector<table_id_vector_t>;
175+
std::vector<std::vector<PATH>> paths(upperBound + 1);
176+
177+
std::function<void(table_id_t, size_t, std::vector<table_id_vector_t>&)> dfs;
178+
dfs = [&](table_id_t curTableID, size_t currentLength, std::vector<table_id_vector_t>& path) {
179+
if (currentLength >= lowerBound && currentLength <= upperBound &&
180+
endTableIDSet.contains(curTableID)) {
181+
paths[currentLength].push_back(path);
182+
}
183+
if (currentLength >= upperBound) {
184+
return;
185+
}
186+
if (graph.contains(curTableID)) {
187+
for (const auto& [neighbor, relTableID] : graph.at(curTableID)) {
188+
path.push_back(relTableID);
189+
dfs(neighbor, currentLength + 1, path);
190+
path.pop_back();
120191
}
121192
}
122-
} else {
123-
auto srcTableIDSet = rel.getSrcNode()->getTableIDsSet();
124-
auto dstTableIDSet = rel.getDstNode()->getTableIDsSet();
125-
for (auto& entry : rel.getEntries()) {
126-
auto& relEntry = entry->constCast<RelTableCatalogEntry>();
127-
auto srcTableID = relEntry.getSrcTableID();
128-
auto dstTableID = relEntry.getDstTableID();
129-
if (!srcTableIDSet.contains(srcTableID) || !dstTableIDSet.contains(dstTableID)) {
130-
continue;
193+
};
194+
for (auto startTableID : startTableIDSet) {
195+
std::vector<table_id_vector_t> path;
196+
dfs(startTableID, 0, path);
197+
}
198+
// merge
199+
std::vector<table_id_set_t> stepActiveTableIDs(upperBound);
200+
for (auto i = lowerBound; i <= upperBound; ++i) {
201+
for (auto path : paths[i]) {
202+
for (auto j = 0u; j < i; ++j) {
203+
auto rels = path[j];
204+
stepActiveTableIDs[j].insert(rels.begin(), rels.end());
131205
}
132-
prunedEntries.push_back(entry);
133206
}
134207
}
135-
rel.setEntries(prunedEntries);
208+
return stepActiveTableIDs;
209+
}
210+
211+
std::vector<TableCatalogEntry*> QueryGraphLabelAnalyzer::getTableCatalogEntries(
212+
table_id_set_t tableIDs) const {
213+
std::vector<TableCatalogEntry*> relEntries;
214+
for (const auto& tableID : tableIDs) {
215+
relEntries.push_back(catalog->getTableCatalogEntry(tx, tableID));
216+
}
217+
return relEntries;
218+
}
219+
220+
std::vector<table_id_t> QueryGraphLabelAnalyzer::getNodeTableIDs() const {
221+
std::vector<table_id_t> nodeTableIDs;
222+
for (auto node_table_entry : catalog->getNodeTableEntries(tx)) {
223+
nodeTableIDs.push_back(node_table_entry->getTableID());
224+
}
225+
return nodeTableIDs;
226+
}
227+
228+
std::unordered_set<TableCatalogEntry*> QueryGraphLabelAnalyzer::mergeTableIDs(
229+
const std::vector<table_id_set_t>& v1, const std::vector<table_id_set_t>& v2) const {
230+
std::unordered_set<table_id_t> temp;
231+
for (auto tableIDs : v1) {
232+
temp.insert(tableIDs.begin(), tableIDs.end());
233+
}
234+
for (auto tableIDs : v2) {
235+
temp.insert(tableIDs.begin(), tableIDs.end());
236+
}
237+
std::unordered_set<TableCatalogEntry*> ans;
238+
for (table_id_t tableID : temp) {
239+
ans.emplace(catalog->getTableCatalogEntry(tx, tableID));
240+
}
241+
return ans;
242+
}
243+
244+
static std::vector<catalog::TableCatalogEntry*> intersectEntries(
245+
std::vector<catalog::TableCatalogEntry*> v1, std::vector<catalog::TableCatalogEntry*> v2) {
246+
std::sort(v1.begin(), v1.end());
247+
std::sort(v2.begin(), v2.end());
248+
std::vector<catalog::TableCatalogEntry*> intersection;
249+
std::set_intersection(v1.begin(), v1.end(), v2.begin(), v2.end(),
250+
std::back_inserter(intersection));
251+
return intersection;
252+
}
253+
254+
static bool isSameTableCatalogEntryVector(std::vector<TableCatalogEntry*> v1,
255+
std::vector<TableCatalogEntry*> v2) {
256+
auto compareFunc = [](TableCatalogEntry* a, TableCatalogEntry* b) {
257+
return a->getTableID() < b->getTableID();
258+
};
259+
std::sort(v1.begin(), v1.end(), compareFunc);
260+
std::sort(v2.begin(), v2.end(), compareFunc);
261+
return std::equal(v1.begin(), v1.end(), v2.begin(), v2.end());
262+
}
263+
264+
void QueryGraphLabelAnalyzer::pruneRel(RelExpression& rel) const {
265+
auto srcTableIDSet = rel.getSrcNode()->getTableIDsSet();
266+
auto dstTableIDSet = rel.getDstNode()->getTableIDsSet();
267+
if (rel.isRecursive()) {
268+
auto nodeTableIDs = getNodeTableIDs();
269+
// there is no label on both sides
270+
if (rel.getUpperBound() == 0 || (srcTableIDSet.size() == dstTableIDSet.size() &&
271+
dstTableIDSet.size() == nodeTableIDs.size())) {
272+
return;
273+
}
274+
275+
auto [stepFromLeftTableIDs, stepFromRightTableIDs] =
276+
pruneRecursiveRel(rel.getEntries(), srcTableIDSet, dstTableIDSet, rel.getLowerBound(),
277+
rel.getUpperBound(), rel.getDirectionType());
278+
auto recursiveInfo = rel.getRecursiveInfo();
279+
recursiveInfo->stepFromLeftActivationRelInfos = stepFromLeftTableIDs;
280+
recursiveInfo->stepFromRightActivationRelInfos = stepFromRightTableIDs;
281+
// todo we need reset rel entries?
282+
auto temp = mergeTableIDs(stepFromLeftTableIDs, stepFromRightTableIDs);
283+
std::vector<TableCatalogEntry*> newRelEntries{temp.begin(), temp.end()};
284+
if (!isSameTableCatalogEntryVector(newRelEntries, rel.getEntries())) {
285+
rel.setEntries(newRelEntries);
286+
recursiveInfo->rel->setEntries(newRelEntries);
287+
// update src&dst entries
288+
auto forwardRelNodes = collectRelNodes(RelDataDirection::BWD,
289+
getTableCatalogEntries(stepFromLeftTableIDs.front()));
290+
291+
std::unordered_set<table_id_t> backwardRelNodes;
292+
for (auto i = rel.getLowerBound(); i <= rel.getUpperBound(); ++i) {
293+
if (i == 0) {
294+
continue;
295+
}
296+
const auto relSrcNodes = collectRelNodes(RelDataDirection::FWD,
297+
getTableCatalogEntries(stepFromLeftTableIDs.at(i - 1)));
298+
backwardRelNodes.insert(relSrcNodes.begin(), relSrcNodes.end());
299+
}
300+
301+
if (rel.getDirectionType() == RelDirectionType::BOTH) {
302+
forwardRelNodes.insert(backwardRelNodes.begin(), backwardRelNodes.end());
303+
backwardRelNodes = forwardRelNodes;
304+
}
305+
306+
auto newSrcNodeEntries = intersectEntries(rel.getSrcNode()->getEntries(),
307+
getTableCatalogEntries({forwardRelNodes.begin(), forwardRelNodes.end()}));
308+
rel.getSrcNode()->setEntries(newSrcNodeEntries);
309+
310+
auto newDstNodeEntries = intersectEntries(rel.getDstNode()->getEntries(),
311+
getTableCatalogEntries({backwardRelNodes.begin(), backwardRelNodes.end()}));
312+
rel.getDstNode()->setEntries(newDstNodeEntries);
313+
}
314+
} else {
315+
auto prunedEntries = pruneNonRecursiveRel(rel.getEntries(), srcTableIDSet, dstTableIDSet,
316+
rel.getDirectionType());
317+
rel.setEntries(prunedEntries);
318+
}
136319
// Note the pruning for node should guarantee the following exception won't be triggered.
137320
// For safety (and consistency) reason, we still write the check but skip coverage check.
138321
// LCOV_EXCL_START
139-
if (prunedEntries.empty()) {
322+
if (rel.getEntries().empty()) {
140323
if (throwOnViolate) {
141324
throw BinderException(stringFormat("Cannot find a label for relationship {} that "
142325
"connects to all of its neighbour nodes.",

src/function/gds/all_shortest_paths.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ class AllSPDestinationsAlgorithm final : public SPAlgorithm {
249249
auto frontierPair = std::make_unique<SinglePathLengthsFrontierPair>(frontier);
250250
auto edgeCompute =
251251
std::make_unique<AllSPDestinationsEdgeCompute>(frontierPair.get(), multiplicities);
252-
return RJCompState(std::move(frontierPair), std::move(edgeCompute),
252+
return RJCompState(std::move(frontierPair), std::move(edgeCompute), {},
253253
sharedState->getOutputNodeMaskMap(), std::move(output), std::move(outputWriter));
254254
}
255255
};
@@ -286,7 +286,7 @@ class AllSPPathsAlgorithm final : public SPAlgorithm {
286286
auto frontierPair = std::make_unique<SinglePathLengthsFrontierPair>(frontier);
287287
auto edgeCompute =
288288
std::make_unique<AllSPPathsEdgeCompute>(frontierPair.get(), output->bfsGraph.get());
289-
return RJCompState(std::move(frontierPair), std::move(edgeCompute),
289+
return RJCompState(std::move(frontierPair), std::move(edgeCompute), {},
290290
sharedState->getOutputNodeMaskMap(), std::move(output), std::move(outputWriter));
291291
}
292292
};

src/function/gds/degrees.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ struct DegreesUtils {
7575
frontierPair->setActiveNodesForNextIter();
7676
frontierPair->getNextSparseFrontier().disable();
7777
auto ec = std::make_unique<DegreeEdgeCompute>(degrees);
78-
auto computeState = GDSComputeState(std::move(frontierPair), std::move(ec), nullptr);
78+
auto computeState = GDSComputeState(std::move(frontierPair), std::move(ec), {}, nullptr);
7979
GDSUtils::runFrontiersUntilConvergence(context, computeState, graph, direction,
8080
1 /* maxIters */);
8181
}

src/function/gds/gds_utils.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ namespace function {
1919

2020
GDSComputeState::GDSComputeState(std::unique_ptr<function::FrontierPair> frontierPair,
2121
std::unique_ptr<function::EdgeCompute> edgeCompute,
22-
processor::NodeOffsetMaskMap* outputNodeMask)
22+
std::vector<table_id_set_t> stepActiveRelTableIDs, processor::NodeOffsetMaskMap* outputNodeMask)
2323
: frontierPair{std::move(frontierPair)}, edgeCompute{std::move(edgeCompute)},
24-
outputNodeMask{outputNodeMask} {}
24+
stepActiveRelTableIDs{std::move(stepActiveRelTableIDs)}, outputNodeMask{outputNodeMask} {}
2525

2626
GDSComputeState::~GDSComputeState() = default;
2727

@@ -30,6 +30,20 @@ void GDSComputeState::beginFrontierComputeBetweenTables(common::table_id_t currT
3030
frontierPair->beginFrontierComputeBetweenTables(currTableID, nextTableID);
3131
}
3232

33+
common::table_id_set_t GDSComputeState::getActiveRelTableIDs(size_t index, Graph* graph) {
34+
if (stepActiveRelTableIDs.empty()) {
35+
auto nodeIDs = graph->getRelTableIDs();
36+
common::table_id_set_t set;
37+
set.insert(nodeIDs.begin(), nodeIDs.end());
38+
stepActiveRelTableIDs.push_back(set);
39+
}
40+
if (index < stepActiveRelTableIDs.size()) {
41+
return stepActiveRelTableIDs[index];
42+
} else {
43+
return stepActiveRelTableIDs.back();
44+
}
45+
}
46+
3347
static uint64_t getNumThreads(processor::ExecutionContext& context) {
3448
return context.clientContext->getCurrentSetting(main::ThreadsSetting::name)
3549
.getValue<uint64_t>();
@@ -77,7 +91,11 @@ void GDSUtils::runFrontiersUntilConvergence(processor::ExecutionContext* context
7791
compState.edgeCompute->terminate(*compState.outputNodeMask)) {
7892
break;
7993
}
80-
for (auto& [fromEntry, toEntry, relEntry] : graph->getRelFromToEntryInfos()) {
94+
95+
auto activeRelTableIDs =
96+
compState.getActiveRelTableIDs(frontierPair->getCurrentIter() - 1, graph);
97+
for (auto& [fromEntry, toEntry, relEntry] :
98+
graph->getRelFromToEntryInfos(activeRelTableIDs)) {
8199
switch (extendDirection) {
82100
case ExtendDirection::FWD: {
83101
compState.beginFrontierComputeBetweenTables(fromEntry->getTableID(),

src/function/gds/k_core_decomposition.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ class KCoreDecomposition final : public GDSAlgorithm {
253253
// Compute Core values
254254
auto removeVertexEdgeCompute = std::make_unique<RemoveVertexEdgeCompute>(&degrees);
255255
auto computeState = GDSComputeState(std::move(frontierPair),
256-
std::move(removeVertexEdgeCompute), sharedState->getOutputNodeMaskMap());
256+
std::move(removeVertexEdgeCompute), {}, sharedState->getOutputNodeMaskMap());
257257
auto coreValue = 0u;
258258
auto numNodes = graph->getNumNodes(clientContext->getTransaction());
259259
auto numNodesComputed = 0u;

src/function/gds/page_rank.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,8 @@ class PageRank final : public GDSAlgorithm {
275275
std::make_unique<DoublePathLengthsFrontierPair>(currentFrontier, nextFrontier);
276276
frontierPair->setActiveNodesForNextIter();
277277
frontierPair->getNextSparseFrontier().disable();
278-
auto computeState =
279-
GDSComputeState(std::move(frontierPair), nullptr, sharedState->getOutputNodeMaskMap());
278+
auto computeState = GDSComputeState(std::move(frontierPair), nullptr, {},
279+
sharedState->getOutputNodeMaskMap());
280280
auto pNextUpdateConstant = (1 - pageRankBindData->dampingFactor) * ((double)1 / numNodes);
281281
while (currentIter < pageRankBindData->maxIteration) {
282282
auto ec = std::make_unique<PNextUpdateEdgeCompute>(&degrees, pCurrent, pNext);

0 commit comments

Comments
 (0)