Skip to content

Commit a2cbcbe

Browse files
committed
Allow reordering of joins for regular match subquery
Temp disable reorder for hash join Re-enable hash-join reorder Undo test change Correctly populate schema cardinality multiplier in cardinality multiplier Temp raise buffer pool size for lsqb test
1 parent aa27c83 commit a2cbcbe

File tree

7 files changed

+117
-4
lines changed

7 files changed

+117
-4
lines changed

src/include/optimizer/cardinality_updater.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class CardinalityUpdater : public LogicalOperatorVisitor {
2727
void visitOperatorDefault(planner::LogicalOperator* op);
2828
void visitScanNodeTable(planner::LogicalOperator* op) override;
2929
void visitExtend(planner::LogicalOperator* op) override;
30+
void visitRecursiveExtend(planner::LogicalOperator* op) override;
3031
void visitHashJoin(planner::LogicalOperator* op) override;
3132
void visitCrossProduct(planner::LogicalOperator* op) override;
3233
void visitIntersect(planner::LogicalOperator* op) override;

src/optimizer/cardinality_updater.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "planner/join_order/cardinality_estimator.h"
44
#include "planner/operator/extend/logical_extend.h"
5+
#include "planner/operator/extend/logical_recursive_extend.h"
56
#include "planner/operator/logical_aggregate.h"
67
#include "planner/operator/logical_filter.h"
78
#include "planner/operator/logical_flatten.h"
@@ -19,6 +20,8 @@ void CardinalityUpdater::visitOperator(planner::LogicalOperator* op) {
1920
for (auto i = 0u; i < op->getNumChildren(); ++i) {
2021
visitOperator(op->getChild(i).get());
2122
}
23+
// we need to recompute the cardinality multipliers for each factorized group
24+
op->computeFactorizedSchema();
2225
visitOperatorSwitchWithDefault(op);
2326
}
2427

@@ -83,6 +86,18 @@ void CardinalityUpdater::visitExtend(planner::LogicalOperator* op) {
8386
const auto extensionRate = cardinalityEstimator.getExtensionRate(*extend.getRel(),
8487
*extend.getBoundNode(), transaction);
8588
extend.setCardinality(cardinalityEstimator.estimateExtend(extensionRate, *op->getChild(0)));
89+
auto group = extend.getSchema()->getGroup(extend.getNbrNode()->getInternalID());
90+
group->setMultiplier(extensionRate);
91+
}
92+
93+
void CardinalityUpdater::visitRecursiveExtend(planner::LogicalOperator* op) {
94+
KU_ASSERT(transaction);
95+
auto& extend = op->cast<planner::LogicalRecursiveExtend&>();
96+
const auto extensionRate = cardinalityEstimator.getExtensionRate(*extend.getRel(),
97+
*extend.getBoundNode(), transaction);
98+
extend.setCardinality(cardinalityEstimator.estimateExtend(extensionRate, *op->getChild(0)));
99+
auto group = extend.getSchema()->getGroup(extend.getNbrNode()->getInternalID());
100+
group->setMultiplier(extensionRate);
86101
}
87102

88103
void CardinalityUpdater::visitHashJoin(planner::LogicalOperator* op) {

src/optimizer/factorization_rewriter.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "optimizer/factorization_rewriter.h"
22

33
#include "binder/expression_visitor.h"
4+
#include "planner/operator/extend/logical_extend.h"
45
#include "planner/operator/extend/logical_recursive_extend.h"
56
#include "planner/operator/factorization/flatten_resolver.h"
67
#include "planner/operator/logical_accumulate.h"

src/planner/plan/plan_subquery.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "binder/expression/property_expression.h"
33
#include "binder/expression/subquery_expression.h"
44
#include "binder/expression_visitor.h"
5+
#include "planner/join_order/cost_model.h"
56
#include "planner/operator/factorization/flatten_resolver.h"
67
#include "planner/planner.h"
78

@@ -100,6 +101,17 @@ void Planner::planOptionalMatch(const QueryGraphCollection& queryGraphCollection
100101
}
101102
}
102103

104+
template<std::invocable<LogicalPlan&, LogicalPlan&, LogicalPlan&> AppendJoinFunc,
105+
std::invocable<LogicalPlan&, LogicalPlan&> EstimateJoinCostFunc>
106+
static void planRegularMatchJoinOrder(LogicalPlan& leftPlan, LogicalPlan& rightPlan,
107+
const AppendJoinFunc& appendJoinFunc, const EstimateJoinCostFunc& estimateJoinCostFunc) {
108+
if (estimateJoinCostFunc(leftPlan, rightPlan) <= estimateJoinCostFunc(rightPlan, leftPlan)) {
109+
appendJoinFunc(leftPlan, rightPlan, leftPlan);
110+
} else {
111+
appendJoinFunc(rightPlan, leftPlan, leftPlan);
112+
}
113+
}
114+
103115
void Planner::planRegularMatch(const QueryGraphCollection& queryGraphCollection,
104116
const expression_vector& predicates, LogicalPlan& leftPlan) {
105117
expression_vector predicatesToPushDown, predicatesToPullUp;
@@ -124,7 +136,15 @@ void Planner::planRegularMatch(const QueryGraphCollection& queryGraphCollection,
124136
if (leftPlan.hasUpdate()) {
125137
appendCrossProduct(*rightPlan, leftPlan, leftPlan);
126138
} else {
127-
appendCrossProduct(leftPlan, *rightPlan, leftPlan);
139+
planRegularMatchJoinOrder(
140+
leftPlan, *rightPlan,
141+
[this](LogicalPlan& leftPlan, LogicalPlan& rightPlan, LogicalPlan& resultPlan) {
142+
appendCrossProduct(leftPlan, rightPlan, resultPlan);
143+
},
144+
[](LogicalPlan&, LogicalPlan& rightPlan) {
145+
// we want to minimize the cardinality of the build plan
146+
return rightPlan.getCardinality();
147+
});
128148
}
129149
} else {
130150
// TODO(Xiyang): there is a question regarding if we want to plan as a correlated subquery
@@ -137,7 +157,16 @@ void Planner::planRegularMatch(const QueryGraphCollection& queryGraphCollection,
137157
if (leftPlan.hasUpdate()) {
138158
appendHashJoin(joinNodeIDs, JoinType::INNER, *rightPlan, leftPlan, leftPlan);
139159
} else {
140-
appendHashJoin(joinNodeIDs, JoinType::INNER, leftPlan, *rightPlan, leftPlan);
160+
planRegularMatchJoinOrder(
161+
leftPlan, *rightPlan,
162+
[this, &joinNodeIDs](LogicalPlan& leftPlan, LogicalPlan& rightPlan,
163+
LogicalPlan& resultPlan) {
164+
appendHashJoin(joinNodeIDs, JoinType::INNER, leftPlan, rightPlan, resultPlan);
165+
},
166+
[&joinNodeIDs](LogicalPlan& leftPlan, LogicalPlan& rightPlan) {
167+
return CostModel::computeHashJoinCost(joinNodeIDs, leftPlan, rightPlan);
168+
});
169+
// appendHashJoin(joinNodeIDs, JoinType::INNER, leftPlan, *rightPlan, leftPlan);
141170
}
142171
}
143172
for (auto& predicate : predicatesToPullUp) {

test/planner/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
add_kuzu_test(planner_tests cardinality_test.cpp)
1+
add_kuzu_test(planner_tests
2+
cardinality_test.cpp
3+
planner_test.cpp)

test/planner/planner_test.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#include "graph_test/graph_test.h"
2+
#include "planner/operator/logical_plan_util.h"
3+
#include "test_runner/test_runner.h"
4+
5+
namespace kuzu {
6+
namespace testing {
7+
8+
class PlannerTest : public DBTest {
9+
public:
10+
std::string getInputDir() override {
11+
return TestHelper::appendKuzuRootPath("dataset/tinysnb/");
12+
}
13+
14+
std::string getEncodedPlan(const std::string& query) {
15+
return planner::LogicalPlanUtil::encodeJoin(*getRoot(query));
16+
}
17+
std::unique_ptr<planner::LogicalPlan> getRoot(const std::string& query) {
18+
return TestRunner::getLogicalPlan(query, *conn);
19+
}
20+
std::pair<planner::LogicalOperator*, planner::LogicalOperator*> getSource(
21+
planner::LogicalOperator* op, planner::LogicalOperator* parent = nullptr) {
22+
if (op->getNumChildren() == 0) {
23+
return {parent, op};
24+
}
25+
return getSource(op->getChild(0).get(), op);
26+
}
27+
planner::LogicalOperator* getOpWithType(planner::LogicalOperator* op,
28+
planner::LogicalOperatorType type) {
29+
if (op->getOperatorType() == type) {
30+
return op;
31+
}
32+
if (op->getNumChildren() == 0) {
33+
return nullptr;
34+
}
35+
return getOpWithType(op->getChild(0).get(), type);
36+
}
37+
};
38+
39+
TEST_F(PlannerTest, TestSubqueryJoinOrder) {
40+
// Cross Product
41+
{
42+
// We should pick the smaller table to be on the build side
43+
auto query = "MATCH (a:person) WITH a MATCH (b:organisation) RETURN *";
44+
EXPECT_STREQ("CP(){S(a)}{S(b)}", getEncodedPlan(query).c_str());
45+
auto queryFlipped = "MATCH (b:organisation) WITH b MATCH (a:person) RETURN *";
46+
EXPECT_STREQ("CP(){S(a)}{S(b)}", getEncodedPlan(queryFlipped).c_str());
47+
}
48+
49+
// Hash Join
50+
{
51+
// cardinality(person) > cardinality(studyAt)
52+
// scan(person) should go on probe side
53+
auto query = "MATCH (a:person) WITH a MATCH (a)-[s:studyAt]->(b:organisation) RETURN *";
54+
EXPECT_STREQ("HJ(a._ID){S(a)}{HJ(b._ID){S(b)}{E(b)S(a)}}", getEncodedPlan(query).c_str());
55+
56+
// cardinality(organisation) < cardinality(studyAt) + cardinality(workAt)
57+
// scan(organisation) should go on build side
58+
auto queryFlipped =
59+
"MATCH (b:organisation) WITH b MATCH (a:person)-[s:studyAt|:workAt]->(b) RETURN *";
60+
EXPECT_STREQ("HJ(b._ID){E(b)S(a)}{S(b)}", getEncodedPlan(queryFlipped).c_str());
61+
}
62+
}
63+
64+
} // namespace testing
65+
} // namespace kuzu

test/test_files/lsqb/lsqb_queries.test

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-DATASET CSV lsqb-sf01
2-
-BUFFER_POOL_SIZE 1073741824
2+
-BUFFER_POOL_SIZE 4294967296
33

44
--
55

0 commit comments

Comments
 (0)