|
1 | 1 | #include "binder/query/query_graph_label_analyzer.h"
|
2 | 2 |
|
3 | 3 | #include "catalog/catalog.h"
|
| 4 | +#include "catalog/catalog_entry/node_table_catalog_entry.h" |
4 | 5 | #include "catalog/catalog_entry/rel_table_catalog_entry.h"
|
5 | 6 | #include "common/exception/binder.h"
|
6 | 7 | #include "common/string_format.h"
|
@@ -29,7 +30,7 @@ void QueryGraphLabelAnalyzer::pruneNode(const QueryGraph& graph, NodeExpression&
|
29 | 30 | if (queryRel->isRecursive()) {
|
30 | 31 | continue;
|
31 | 32 | }
|
32 |
| - common::table_id_set_t candidates; |
| 33 | + table_id_set_t candidates; |
33 | 34 | std::unordered_set<std::string> candidateNamesSet;
|
34 | 35 | auto isSrcConnect = *queryRel->getSrcNode() == node;
|
35 | 36 | auto isDstConnect = *queryRel->getDstNode() == node;
|
@@ -94,49 +95,231 @@ void QueryGraphLabelAnalyzer::pruneNode(const QueryGraph& graph, NodeExpression&
|
94 | 95 | }
|
95 | 96 | }
|
96 | 97 |
|
97 |
| -void QueryGraphLabelAnalyzer::pruneRel(RelExpression& rel) const { |
98 |
| - if (rel.isRecursive()) { |
99 |
| - return; |
100 |
| - } |
| 98 | +std::vector<TableCatalogEntry*> QueryGraphLabelAnalyzer::pruneNonRecursiveRel( |
| 99 | + const std::vector<TableCatalogEntry*>& relEntries, const table_id_set_t& srcTableIDSet, |
| 100 | + const table_id_set_t& dstTableIDSet, const RelDirectionType directionType) const { |
| 101 | + |
| 102 | + auto forwardPruningFunc = [&](table_id_t srcTableID, table_id_t dstTableID) { |
| 103 | + return srcTableIDSet.contains(srcTableID) && dstTableIDSet.contains(dstTableID); |
| 104 | + }; |
101 | 105 | std::vector<TableCatalogEntry*> prunedEntries;
|
102 |
| - if (rel.getDirectionType() == RelDirectionType::BOTH) { |
103 |
| - table_id_set_t srcBoundTableIDSet; |
104 |
| - table_id_set_t dstBoundTableIDSet; |
105 |
| - for (auto entry : rel.getSrcNode()->getEntries()) { |
106 |
| - srcBoundTableIDSet.insert(entry->getTableID()); |
| 106 | + for (auto& entry : relEntries) { |
| 107 | + auto& relEntry = entry->constCast<RelTableCatalogEntry>(); |
| 108 | + auto srcTableID = relEntry.getSrcTableID(); |
| 109 | + auto dstTableID = relEntry.getDstTableID(); |
| 110 | + auto satisfyForwardPruning = forwardPruningFunc(srcTableID, dstTableID); |
| 111 | + if (directionType == RelDirectionType::BOTH) { |
| 112 | + if (satisfyForwardPruning || |
| 113 | + (dstTableIDSet.contains(srcTableID) && srcTableIDSet.contains(dstTableID))) { |
| 114 | + prunedEntries.push_back(entry); |
| 115 | + } |
| 116 | + } else { |
| 117 | + if (satisfyForwardPruning) { |
| 118 | + prunedEntries.push_back(entry); |
| 119 | + } |
107 | 120 | }
|
108 |
| - for (auto entry : rel.getDstNode()->getEntries()) { |
109 |
| - dstBoundTableIDSet.insert(entry->getTableID()); |
| 121 | + } |
| 122 | + return prunedEntries; |
| 123 | +} |
| 124 | + |
| 125 | +table_id_set_t QueryGraphLabelAnalyzer::collectRelNodes(const RelDataDirection direction, |
| 126 | + std::vector<TableCatalogEntry*> relEntries) const { |
| 127 | + table_id_set_t nodeIDs; |
| 128 | + for (const auto& entry : relEntries) { |
| 129 | + const auto& relEntry = entry->constCast<RelTableCatalogEntry>(); |
| 130 | + if (direction == RelDataDirection::FWD) { |
| 131 | + nodeIDs.insert(relEntry.getDstTableID()); |
| 132 | + } else if (direction == RelDataDirection::BWD) { |
| 133 | + nodeIDs.insert(relEntry.getSrcTableID()); |
| 134 | + } else { |
| 135 | + KU_UNREACHABLE; |
110 | 136 | }
|
111 |
| - for (auto& entry : rel.getEntries()) { |
112 |
| - auto& relEntry = entry->constCast<RelTableCatalogEntry>(); |
113 |
| - auto srcTableID = relEntry.getSrcTableID(); |
114 |
| - auto dstTableID = relEntry.getDstTableID(); |
115 |
| - if ((srcBoundTableIDSet.contains(srcTableID) && |
116 |
| - dstBoundTableIDSet.contains(dstTableID)) || |
117 |
| - (dstBoundTableIDSet.contains(srcTableID) && |
118 |
| - srcBoundTableIDSet.contains(dstTableID))) { |
119 |
| - prunedEntries.push_back(entry); |
| 137 | + } |
| 138 | + return nodeIDs; |
| 139 | +} |
| 140 | + |
| 141 | +std::pair<std::vector<table_id_set_t>, std::vector<table_id_set_t>> |
| 142 | +QueryGraphLabelAnalyzer::pruneRecursiveRel(const std::vector<TableCatalogEntry*>& relEntries, |
| 143 | + const table_id_set_t srcTableIDSet, const table_id_set_t dstTableIDSet, size_t lowerBound, |
| 144 | + size_t upperBound, RelDirectionType relDirectionType) const { |
| 145 | + // src-->[dst,[rels]] |
| 146 | + std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>> |
| 147 | + stepFromLeftGraph; |
| 148 | + std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>> |
| 149 | + stepFromRightGraph; |
| 150 | + for (auto entry : relEntries) { |
| 151 | + auto& relEntry = entry->constCast<RelTableCatalogEntry>(); |
| 152 | + auto srcTableID = relEntry.getSrcTableID(); |
| 153 | + auto dstTableID = relEntry.getDstTableID(); |
| 154 | + stepFromLeftGraph[srcTableID][dstTableID].push_back(relEntry.getTableID()); |
| 155 | + stepFromRightGraph[dstTableID][srcTableID].push_back(relEntry.getTableID()); |
| 156 | + if (relDirectionType == RelDirectionType::BOTH) { |
| 157 | + stepFromLeftGraph[dstTableID][srcTableID].push_back(relEntry.getTableID()); |
| 158 | + stepFromRightGraph[srcTableID][dstTableID].push_back(relEntry.getTableID()); |
| 159 | + } |
| 160 | + } |
| 161 | + auto stepFromLeft = |
| 162 | + pruneRecursiveRel(stepFromLeftGraph, srcTableIDSet, dstTableIDSet, lowerBound, upperBound); |
| 163 | + auto stepFromRight = |
| 164 | + pruneRecursiveRel(stepFromRightGraph, dstTableIDSet, srcTableIDSet, lowerBound, upperBound); |
| 165 | + return {stepFromLeft, stepFromRight}; |
| 166 | +} |
| 167 | + |
| 168 | +std::vector<table_id_set_t> QueryGraphLabelAnalyzer::pruneRecursiveRel( |
| 169 | + const std::unordered_map<table_id_t, std::unordered_map<table_id_t, table_id_vector_t>>& graph, |
| 170 | + const table_id_set_t& startTableIDSet, const table_id_set_t& endTableIDSet, size_t lowerBound, |
| 171 | + size_t upperBound) const { |
| 172 | + |
| 173 | + // There may be multiple rel type between src and dst |
| 174 | + using PATH = std::vector<table_id_vector_t>; |
| 175 | + std::vector<std::vector<PATH>> paths(upperBound + 1); |
| 176 | + |
| 177 | + std::function<void(table_id_t, size_t, std::vector<table_id_vector_t>&)> dfs; |
| 178 | + dfs = [&](table_id_t curTableID, size_t currentLength, std::vector<table_id_vector_t>& path) { |
| 179 | + if (currentLength >= lowerBound && currentLength <= upperBound && |
| 180 | + endTableIDSet.contains(curTableID)) { |
| 181 | + paths[currentLength].push_back(path); |
| 182 | + } |
| 183 | + if (currentLength >= upperBound) { |
| 184 | + return; |
| 185 | + } |
| 186 | + if (graph.contains(curTableID)) { |
| 187 | + for (const auto& [neighbor, relTableID] : graph.at(curTableID)) { |
| 188 | + path.push_back(relTableID); |
| 189 | + dfs(neighbor, currentLength + 1, path); |
| 190 | + path.pop_back(); |
120 | 191 | }
|
121 | 192 | }
|
122 |
| - } else { |
123 |
| - auto srcTableIDSet = rel.getSrcNode()->getTableIDsSet(); |
124 |
| - auto dstTableIDSet = rel.getDstNode()->getTableIDsSet(); |
125 |
| - for (auto& entry : rel.getEntries()) { |
126 |
| - auto& relEntry = entry->constCast<RelTableCatalogEntry>(); |
127 |
| - auto srcTableID = relEntry.getSrcTableID(); |
128 |
| - auto dstTableID = relEntry.getDstTableID(); |
129 |
| - if (!srcTableIDSet.contains(srcTableID) || !dstTableIDSet.contains(dstTableID)) { |
130 |
| - continue; |
| 193 | + }; |
| 194 | + for (auto startTableID : startTableIDSet) { |
| 195 | + std::vector<table_id_vector_t> path; |
| 196 | + dfs(startTableID, 0, path); |
| 197 | + } |
| 198 | + // merge |
| 199 | + std::vector<table_id_set_t> stepActiveTableIDs(upperBound); |
| 200 | + for (auto i = lowerBound; i <= upperBound; ++i) { |
| 201 | + for (auto path : paths[i]) { |
| 202 | + for (auto j = 0u; j < i; ++j) { |
| 203 | + auto rels = path[j]; |
| 204 | + stepActiveTableIDs[j].insert(rels.begin(), rels.end()); |
131 | 205 | }
|
132 |
| - prunedEntries.push_back(entry); |
133 | 206 | }
|
134 | 207 | }
|
135 |
| - rel.setEntries(prunedEntries); |
| 208 | + return stepActiveTableIDs; |
| 209 | +} |
| 210 | + |
| 211 | +std::vector<TableCatalogEntry*> QueryGraphLabelAnalyzer::getTableCatalogEntries( |
| 212 | + table_id_set_t tableIDs) const { |
| 213 | + std::vector<TableCatalogEntry*> relEntries; |
| 214 | + for (const auto& tableID : tableIDs) { |
| 215 | + relEntries.push_back(catalog->getTableCatalogEntry(tx, tableID)); |
| 216 | + } |
| 217 | + return relEntries; |
| 218 | +} |
| 219 | + |
| 220 | +std::vector<table_id_t> QueryGraphLabelAnalyzer::getNodeTableIDs() const { |
| 221 | + std::vector<table_id_t> nodeTableIDs; |
| 222 | + for (auto node_table_entry : catalog->getNodeTableEntries(tx)) { |
| 223 | + nodeTableIDs.push_back(node_table_entry->getTableID()); |
| 224 | + } |
| 225 | + return nodeTableIDs; |
| 226 | +} |
| 227 | + |
| 228 | +std::unordered_set<TableCatalogEntry*> QueryGraphLabelAnalyzer::mergeTableIDs( |
| 229 | + const std::vector<table_id_set_t>& v1, const std::vector<table_id_set_t>& v2) const { |
| 230 | + std::unordered_set<table_id_t> temp; |
| 231 | + for (auto tableIDs : v1) { |
| 232 | + temp.insert(tableIDs.begin(), tableIDs.end()); |
| 233 | + } |
| 234 | + for (auto tableIDs : v2) { |
| 235 | + temp.insert(tableIDs.begin(), tableIDs.end()); |
| 236 | + } |
| 237 | + std::unordered_set<TableCatalogEntry*> ans; |
| 238 | + for (table_id_t tableID : temp) { |
| 239 | + ans.emplace(catalog->getTableCatalogEntry(tx, tableID)); |
| 240 | + } |
| 241 | + return ans; |
| 242 | +} |
| 243 | + |
| 244 | +static std::vector<catalog::TableCatalogEntry*> intersectEntries( |
| 245 | + std::vector<catalog::TableCatalogEntry*> v1, std::vector<catalog::TableCatalogEntry*> v2) { |
| 246 | + std::sort(v1.begin(), v1.end()); |
| 247 | + std::sort(v2.begin(), v2.end()); |
| 248 | + std::vector<catalog::TableCatalogEntry*> intersection; |
| 249 | + std::set_intersection(v1.begin(), v1.end(), v2.begin(), v2.end(), |
| 250 | + std::back_inserter(intersection)); |
| 251 | + return intersection; |
| 252 | +} |
| 253 | + |
| 254 | +static bool isSameTableCatalogEntryVector(std::vector<TableCatalogEntry*> v1, |
| 255 | + std::vector<TableCatalogEntry*> v2) { |
| 256 | + auto compareFunc = [](TableCatalogEntry* a, TableCatalogEntry* b) { |
| 257 | + return a->getTableID() < b->getTableID(); |
| 258 | + }; |
| 259 | + std::sort(v1.begin(), v1.end(), compareFunc); |
| 260 | + std::sort(v2.begin(), v2.end(), compareFunc); |
| 261 | + return std::equal(v1.begin(), v1.end(), v2.begin(), v2.end()); |
| 262 | +} |
| 263 | + |
| 264 | +void QueryGraphLabelAnalyzer::pruneRel(RelExpression& rel) const { |
| 265 | + auto srcTableIDSet = rel.getSrcNode()->getTableIDsSet(); |
| 266 | + auto dstTableIDSet = rel.getDstNode()->getTableIDsSet(); |
| 267 | + if (rel.isRecursive()) { |
| 268 | + auto nodeTableIDs = getNodeTableIDs(); |
| 269 | + // there is no label on both sides |
| 270 | + if (rel.getUpperBound() == 0 || (srcTableIDSet.size() == dstTableIDSet.size() && |
| 271 | + dstTableIDSet.size() == nodeTableIDs.size())) { |
| 272 | + return; |
| 273 | + } |
| 274 | + |
| 275 | + auto [stepFromLeftTableIDs, stepFromRightTableIDs] = |
| 276 | + pruneRecursiveRel(rel.getEntries(), srcTableIDSet, dstTableIDSet, rel.getLowerBound(), |
| 277 | + rel.getUpperBound(), rel.getDirectionType()); |
| 278 | + auto recursiveInfo = rel.getRecursiveInfo(); |
| 279 | + recursiveInfo->stepFromLeftActivationRelInfos = stepFromLeftTableIDs; |
| 280 | + recursiveInfo->stepFromRightActivationRelInfos = stepFromRightTableIDs; |
| 281 | + // todo we need reset rel entries? |
| 282 | + auto temp = mergeTableIDs(stepFromLeftTableIDs, stepFromRightTableIDs); |
| 283 | + std::vector<TableCatalogEntry*> newRelEntries{temp.begin(), temp.end()}; |
| 284 | + if (!isSameTableCatalogEntryVector(newRelEntries, rel.getEntries())) { |
| 285 | + rel.setEntries(newRelEntries); |
| 286 | + recursiveInfo->rel->setEntries(newRelEntries); |
| 287 | + // update src&dst entries |
| 288 | + auto forwardRelNodes = collectRelNodes(RelDataDirection::BWD, |
| 289 | + getTableCatalogEntries(stepFromLeftTableIDs.front())); |
| 290 | + |
| 291 | + std::unordered_set<table_id_t> backwardRelNodes; |
| 292 | + for (auto i = rel.getLowerBound(); i <= rel.getUpperBound(); ++i) { |
| 293 | + if (i == 0) { |
| 294 | + continue; |
| 295 | + } |
| 296 | + const auto relSrcNodes = collectRelNodes(RelDataDirection::FWD, |
| 297 | + getTableCatalogEntries(stepFromLeftTableIDs.at(i - 1))); |
| 298 | + backwardRelNodes.insert(relSrcNodes.begin(), relSrcNodes.end()); |
| 299 | + } |
| 300 | + |
| 301 | + if (rel.getDirectionType() == RelDirectionType::BOTH) { |
| 302 | + forwardRelNodes.insert(backwardRelNodes.begin(), backwardRelNodes.end()); |
| 303 | + backwardRelNodes = forwardRelNodes; |
| 304 | + } |
| 305 | + |
| 306 | + auto newSrcNodeEntries = intersectEntries(rel.getSrcNode()->getEntries(), |
| 307 | + getTableCatalogEntries({forwardRelNodes.begin(), forwardRelNodes.end()})); |
| 308 | + rel.getSrcNode()->setEntries(newSrcNodeEntries); |
| 309 | + |
| 310 | + auto newDstNodeEntries = intersectEntries(rel.getDstNode()->getEntries(), |
| 311 | + getTableCatalogEntries({backwardRelNodes.begin(), backwardRelNodes.end()})); |
| 312 | + rel.getDstNode()->setEntries(newDstNodeEntries); |
| 313 | + } |
| 314 | + } else { |
| 315 | + auto prunedEntries = pruneNonRecursiveRel(rel.getEntries(), srcTableIDSet, dstTableIDSet, |
| 316 | + rel.getDirectionType()); |
| 317 | + rel.setEntries(prunedEntries); |
| 318 | + } |
136 | 319 | // Note the pruning for node should guarantee the following exception won't be triggered.
|
137 | 320 | // For safety (and consistency) reason, we still write the check but skip coverage check.
|
138 | 321 | // LCOV_EXCL_START
|
139 |
| - if (prunedEntries.empty()) { |
| 322 | + if (rel.getEntries().empty()) { |
140 | 323 | if (throwOnViolate) {
|
141 | 324 | throw BinderException(stringFormat("Cannot find a label for relationship {} that "
|
142 | 325 | "connects to all of its neighbour nodes.",
|
|
0 commit comments