diff --git a/src/spaceone/core/model/mongo_model/__init__.py b/src/spaceone/core/model/mongo_model/__init__.py index 4543040..0cba501 100644 --- a/src/spaceone/core/model/mongo_model/__init__.py +++ b/src/spaceone/core/model/mongo_model/__init__.py @@ -590,9 +590,11 @@ def _make_unwind_project_stage(only: list): } @classmethod - def _stat_with_unwind( + def _stat_with_pipeline( cls, - unwind: list, + lookup: list = None, + unwind: dict = None, + add_fields: dict = None, only: list = None, filter: list = None, filter_or: list = None, @@ -600,39 +602,48 @@ def _stat_with_unwind( page: dict = None, target: str = None, ): - if only is None: - raise ERROR_DB_QUERY(reason="unwind option requires only option.") + if unwind: + if only is None: + raise ERROR_DB_QUERY(reason="unwind option requires only option.") - if not isinstance(unwind, dict): - raise ERROR_DB_QUERY(reason="unwind option should be dict type.") + if not isinstance(unwind, dict): + raise ERROR_DB_QUERY(reason="unwind option should be dict type.") - if "path" not in unwind: - raise ERROR_DB_QUERY(reason="unwind option should have path key.") + if "path" not in unwind: + raise ERROR_DB_QUERY(reason="unwind option should have path key.") - unwind_path = unwind["path"] - aggregate = [{"unwind": unwind}] + aggregate = [] - # Add project stage - project_fields = [] - for key in only: - project_fields.append( + if lookup: + for lu in lookup: + aggregate.append({"lookup": lu}) + + if unwind: + aggregate.append({"unwind": unwind}) + + if add_fields: + aggregate.append({"add_fields": add_fields}) + + if only: + project_fields = [] + for key in only: + project_fields.append( + { + "key": key, + "name": key, + } + ) + + aggregate.append( { - "key": key, - "name": key, + "project": { + "exclude_keys": True, + "only_keys": True, + "fields": project_fields, + } } ) - aggregate.append( - { - "project": { - "exclude_keys": True, - "only_keys": True, - "fields": project_fields, - } - } - ) - - # Add sort stage if sort: aggregate.append({"sort": sort}) @@ -641,7 +652,7 @@ def _stat_with_unwind( filter=filter, filter_or=filter_or, page=page, - tageet=target, + target=target, allow_disk_use=True, ) @@ -649,13 +660,15 @@ def _stat_with_unwind( vos = [] total_count = response.get("total_count", 0) for result in response.get("results", []): - unwind_data = utils.get_dict_value(result, unwind_path) - result = utils.change_dict_value(result, unwind_path, [unwind_data]) + if unwind: + unwind_path = unwind["path"] + unwind_data = utils.get_dict_value(result, unwind_path) + result = utils.change_dict_value(result, unwind_path, [unwind_data]) vo = cls(**result) vos.append(vo) except Exception as e: - raise ERROR_DB_QUERY(reason=f"Failed to convert unwind result: {e}") + raise ERROR_DB_QUERY(reason=f"Failed to convert pipeline result: {e}") return vos, total_count @@ -672,7 +685,9 @@ def query( minimal=False, include_count=True, count_only=False, + lookup=None, unwind=None, + add_fields=None, reference_filter=None, target=None, hint=None, @@ -683,9 +698,17 @@ def query( sort = sort or [] page = page or {} - if unwind: - return cls._stat_with_unwind( - unwind, only, filter, filter_or, sort, page, target + if unwind or lookup or add_fields: + return cls._stat_with_pipeline( + lookup=lookup, + unwind=unwind, + add_fields=add_fields, + only=only, + filter=filter, + filter_or=filter_or, + sort=sort, + page=page, + target=target, ) else: @@ -1075,6 +1098,44 @@ def _make_match_rule(cls, options): return {"$match": match_options} + @classmethod + def _make_lookup_rule(cls, options): + return {"$lookup": options} + + @classmethod + def _make_add_fields_rule(cls, options): + add_fields_options = {} + + for field, conditional in options.items(): + add_fields_options.update( + {field: cls._process_conditional_expression(conditional)} + ) + + return {"$addFields": add_fields_options} + + @classmethod + def _process_conditional_expression(cls, expression): + if isinstance(expression, dict): + if_expression = expression["if"] + + if isinstance(if_expression, dict): + replaced = {} + for k, v in if_expression.items(): + new_k = k.replace("__", "$") + replaced[new_k] = v + + if_expression = replaced + + return { + "$cond": { + "if": if_expression, + "then": cls._process_conditional_expression(expression["then"]), + "else": cls._process_conditional_expression(expression["else"]), + } + } + + return expression + @classmethod def _make_aggregate_rules(cls, aggregate): _aggregate_rules = [] @@ -1116,6 +1177,12 @@ def _make_aggregate_rules(cls, aggregate): elif "match" in stage: rule = cls._make_match_rule(stage["match"]) _aggregate_rules.append(rule) + elif "lookup" in stage: + rule = cls._make_lookup_rule(stage["lookup"]) + _aggregate_rules.append(rule) + elif "add_fields" in stage: + rule = cls._make_add_fields_rule(stage["add_fields"]) + _aggregate_rules.append(rule) else: raise ERROR_REQUIRED_PARAMETER( key="aggregate.unwind or aggregate.group or " @@ -1514,7 +1581,9 @@ def analyze( sort=None, start=None, end=None, + lookup=None, unwind=None, + add_fields=None, date_field="date", date_field_format="%Y-%m-%d", reference_filter=None, @@ -1552,9 +1621,16 @@ def analyze( aggregate = [] + if lookup: + for lu in lookup: + aggregate.append({"lookup": lu}) + if unwind: aggregate.append({"unwind": unwind}) + if add_fields: + aggregate.append({"add_fields": add_fields}) + aggregate.append({"group": {"keys": group_keys, "fields": group_fields}}) query = {