diff --git a/metaflow/decorators.py b/metaflow/decorators.py index 8723367b79b..80024ffa6eb 100644 --- a/metaflow/decorators.py +++ b/metaflow/decorators.py @@ -108,6 +108,8 @@ class Decorator(object): name = "NONAME" defaults = {} + # `allow_multiple` allows setting many decorators of the same type to a step/flow. + allow_multiple = False def __init__(self, attributes=None, statically_defined=False): self.attributes = self.defaults.copy() @@ -255,9 +257,6 @@ class MyDecorator(StepDecorator): pass them around with every lifecycle call. """ - # `allow_multiple` allows setting many decorators of the same type to a step. - allow_multiple = False - def step_init( self, flow, graph, step_name, decorators, environment, flow_datastore, logger ): @@ -403,12 +402,17 @@ def _base_flow_decorator(decofunc, *args, **kwargs): if isinstance(cls, type) and issubclass(cls, FlowSpec): # flow decorators add attributes in the class dictionary, # _flow_decorators. - if decofunc.name in cls._flow_decorators: + if decofunc.name in cls._flow_decorators and not decofunc.allow_multiple: raise DuplicateFlowDecoratorException(decofunc.name) else: - cls._flow_decorators[decofunc.name] = decofunc( - attributes=kwargs, statically_defined=True - ) + deco_instance = decofunc(attributes=kwargs, statically_defined=True) + if decofunc.allow_multiple: + if decofunc.name not in cls._flow_decorators: + cls._flow_decorators[decofunc.name] = [deco_instance] + else: + cls._flow_decorators[decofunc.name].append(deco_instance) + else: + cls._flow_decorators[decofunc.name] = deco_instance else: raise BadFlowDecoratorException(decofunc.name) return cls @@ -503,11 +507,26 @@ def _attach_decorators_to_step(step, decospecs): def _init_flow_decorators( flow, graph, environment, flow_datastore, metadata, logger, echo, deco_options ): + # Certain decorators can be specified multiple times and exist as lists in the _flow_decorators dictionary for deco in flow._flow_decorators.values(): - opts = {option: deco_options[option] for option in deco.options} - deco.flow_init( - flow, graph, environment, flow_datastore, metadata, logger, echo, opts - ) + if type(deco) == list: + for rd in deco: + opts = {option: deco_options[option] for option in rd.options} + rd.flow_init( + flow, + graph, + environment, + flow_datastore, + metadata, + logger, + echo, + opts, + ) + else: + opts = {option: deco_options[option] for option in deco.options} + deco.flow_init( + flow, graph, environment, flow_datastore, metadata, logger, echo, opts + ) def _init_step_decorators(flow, graph, environment, flow_datastore, logger): diff --git a/metaflow/plugins/__init__.py b/metaflow/plugins/__init__.py index bb5c52ea746..460298a4fe7 100644 --- a/metaflow/plugins/__init__.py +++ b/metaflow/plugins/__init__.py @@ -174,11 +174,13 @@ def get_plugin_cli(): from .project_decorator import ProjectDecorator +from .airflow.sensors import SUPPORTED_SENSORS + FLOW_DECORATORS = [ CondaFlowDecorator, ScheduleDecorator, ProjectDecorator, -] +] + SUPPORTED_SENSORS _merge_lists(FLOW_DECORATORS, _ext_plugins["FLOW_DECORATORS"], "name") # Cards diff --git a/metaflow/plugins/airflow/airflow.py b/metaflow/plugins/airflow/airflow.py index 5480a79c59e..88b8b3a338c 100644 --- a/metaflow/plugins/airflow/airflow.py +++ b/metaflow/plugins/airflow/airflow.py @@ -35,6 +35,7 @@ from . import airflow_utils from .exception import AirflowException +from .sensors import SUPPORTED_SENSORS from .airflow_utils import ( TASK_ID_XCOM_KEY, AirflowTask, @@ -88,6 +89,7 @@ def __init__( self.username = username self.max_workers = max_workers self.description = description + self._depends_on_upstream_sensors = False self._file_path = file_path _, self.graph_structure = self.graph.output_steps() self.worker_pool = worker_pool @@ -584,6 +586,17 @@ def _step_cli(self, node, paths, code_package_url, user_code_retries): cmds.append(" ".join(entrypoint + top_level + step)) return cmds + def _collect_flow_sensors(self): + decos_lists = [ + self.flow._flow_decorators.get(s.name) + for s in SUPPORTED_SENSORS + if self.flow._flow_decorators.get(s.name) is not None + ] + af_tasks = [deco.create_task() for decos in decos_lists for deco in decos] + if len(af_tasks) > 0: + self._depends_on_upstream_sensors = True + return af_tasks + def _contains_foreach(self): for node in self.graph: if node.type == "foreach": @@ -638,6 +651,7 @@ def _visit(node, workflow, exit_node=None): if self.workflow_timeout is not None and self.schedule is not None: airflow_dag_args["dagrun_timeout"] = dict(seconds=self.workflow_timeout) + appending_sensors = self._collect_flow_sensors() workflow = Workflow( dag_id=self.name, default_args=self._create_defaults(), @@ -658,6 +672,10 @@ def _visit(node, workflow, exit_node=None): workflow = _visit(self.graph["start"], workflow) workflow.set_parameters(self.parameters) + if len(appending_sensors) > 0: + for s in appending_sensors: + workflow.add_state(s) + workflow.graph_structure.insert(0, [[s.name] for s in appending_sensors]) return self._to_airflow_dag_file(workflow.to_dict()) def _to_airflow_dag_file(self, json_dag): diff --git a/metaflow/plugins/airflow/airflow_cli.py b/metaflow/plugins/airflow/airflow_cli.py index 5ac676978c2..1f48a1fa481 100644 --- a/metaflow/plugins/airflow/airflow_cli.py +++ b/metaflow/plugins/airflow/airflow_cli.py @@ -322,7 +322,6 @@ def make_flow( def _validate_foreach_constraints(graph): - # Todo :Invoke this function when we integrate `foreach`s def traverse_graph(node, state): if node.type == "foreach" and node.is_inside_foreach: raise NotSupportedException( @@ -338,7 +337,7 @@ def traverse_graph(node, state): if node.type == "linear" and node.is_inside_foreach: state["foreach_stack"].append(node.name) - if len(state["foreach_stack"]) > 2: + if "foreach_stack" in state and len(state["foreach_stack"]) > 2: raise NotSupportedException( "The foreach step *%s* created by step *%s* needs to have an immediate join step. " "Step *%s* is invalid since it is a linear step with a foreach. " @@ -378,18 +377,13 @@ def _validate_workflow(flow, graph, flow_datastore, metadata, workflow_timeout): "A default value is required for parameters when deploying flows on Airflow." ) # check for other compute related decorators. + _validate_foreach_constraints(graph) for node in graph: if node.parallel_foreach: raise AirflowException( "Deploying flows with @parallel decorator(s) " "to Airflow is not supported currently." ) - - if node.type == "foreach": - raise NotSupportedException( - "Step *%s* is a foreach step and Foreach steps are not currently supported with Airflow." - % node.name - ) if any([d.name == "batch" for d in node.decorators]): raise NotSupportedException( "Step *%s* is marked for execution on AWS Batch with Airflow which isn't currently supported." diff --git a/metaflow/plugins/airflow/airflow_utils.py b/metaflow/plugins/airflow/airflow_utils.py index 26e544e8d61..c94b6734818 100644 --- a/metaflow/plugins/airflow/airflow_utils.py +++ b/metaflow/plugins/airflow/airflow_utils.py @@ -44,6 +44,10 @@ class IncompatibleKubernetesProviderVersionException(Exception): ) % (sys.executable, KUBERNETES_PROVIDER_FOREACH_VERSION) +class AirflowSensorNotFound(Exception): + headline = "Sensor package not found" + + def create_absolute_version_number(version): abs_version = None # For all digits @@ -189,6 +193,16 @@ def pathspec(cls, flowname, is_foreach=False): ) +class SensorNames: + EXTERNAL_TASK_SENSOR = "ExternalTaskSensor" + S3_SENSOR = "S3KeySensor" + SQL_SENSOR = "SQLSensor" + + @classmethod + def get_supported_sensors(cls): + return list(cls.__dict__.values()) + + def run_id_creator(val): # join `[dag-id,run-id]` of airflow dag. return hashlib.md5("-".join([str(x) for x in val]).encode("utf-8")).hexdigest()[ @@ -375,6 +389,46 @@ def _kubernetes_pod_operator_args(operator_args): return args +def _parse_sensor_args(name, kwargs): + if name == SensorNames.EXTERNAL_TASK_SENSOR: + if "execution_delta" in kwargs: + if type(kwargs["execution_delta"]) == dict: + kwargs["execution_delta"] = timedelta(**kwargs["execution_delta"]) + else: + del kwargs["execution_delta"] + return kwargs + + +def _get_sensor(name): + # from airflow import XComArg + # XComArg() + if name == SensorNames.EXTERNAL_TASK_SENSOR: + # ExternalTaskSensors uses an execution_date of a dag to + # determine the appropriate DAG. + # This is set to the exact date the current dag gets executed on. + # For example if "DagA" (Upstream DAG) got scheduled at + # 12 Jan 4:00 PM PDT then "DagB"(current DAG)'s task sensor will try to + # look for a "DagA" that got executed at 12 Jan 4:00 PM PDT **exactly**. + # They also support a `execution_timeout` argument to + from airflow.sensors.external_task_sensor import ExternalTaskSensor + + return ExternalTaskSensor + elif name == SensorNames.S3_SENSOR: + try: + from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor + except ImportError: + raise AirflowSensorNotFound( + "This DAG requires a `S3KeySensor`. " + "Install the Airflow AWS provider using : " + "`pip install apache-airflow-providers-amazon`" + ) + return S3KeySensor + elif name == SensorNames.SQL_SENSOR: + from airflow.sensors.sql import SqlSensor + + return SqlSensor + + def get_metaflow_kubernetes_operator(): try: from airflow.contrib.operators.kubernetes_pod_operator import ( @@ -493,6 +547,13 @@ def set_operator_args(self, **kwargs): self._operator_args = kwargs return self + def _make_sensor(self): + TaskSensor = _get_sensor(self._operator_type) + return TaskSensor( + task_id=self.name, + **_parse_sensor_args(self._operator_type, self._operator_args) + ) + def to_dict(self): return { "name": self.name, @@ -541,6 +602,8 @@ def to_task(self): return self._kubernetes_task() else: return self._kubernetes_mapper_task() + elif self._operator_type in SensorNames.get_supported_sensors(): + return self._make_sensor() class Workflow(object): diff --git a/metaflow/plugins/airflow/sensors/__init__.py b/metaflow/plugins/airflow/sensors/__init__.py new file mode 100644 index 00000000000..02952d0c9a4 --- /dev/null +++ b/metaflow/plugins/airflow/sensors/__init__.py @@ -0,0 +1,9 @@ +from .external_task_sensor import ExternalTaskSensorDecorator +from .s3_sensor import S3KeySensorDecorator +from .sql_sensor import SQLSensorDecorator + +SUPPORTED_SENSORS = [ + ExternalTaskSensorDecorator, + S3KeySensorDecorator, + SQLSensorDecorator, +] diff --git a/metaflow/plugins/airflow/sensors/base_sensor.py b/metaflow/plugins/airflow/sensors/base_sensor.py new file mode 100644 index 00000000000..9412072cd23 --- /dev/null +++ b/metaflow/plugins/airflow/sensors/base_sensor.py @@ -0,0 +1,74 @@ +import uuid +from metaflow.decorators import FlowDecorator +from ..exception import AirflowException +from ..airflow_utils import AirflowTask, id_creator, TASK_ID_HASH_LEN + + +class AirflowSensorDecorator(FlowDecorator): + """ + Base class for all Airflow sensor decorators. + """ + + allow_multiple = True + + defaults = dict( + timeout=3600, + poke_interval=60, + mode="reschedule", + exponential_backoff=True, + pool=None, + soft_fail=False, + name=None, + description=None, + ) + + operator_type = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._airflow_task_name = None + self._id = str(uuid.uuid4()) + + def serialize_operator_args(self): + """ + Subclasses will parse the decorator arguments to + Airflow task serializable arguments. + """ + task_args = dict(**self.attributes) + del task_args["name"] + if task_args["description"] is not None: + task_args["doc"] = task_args["description"] + del task_args["description"] + task_args["do_xcom_push"] = True + return task_args + + def create_task(self): + task_args = self.serialize_operator_args() + return AirflowTask( + self._airflow_task_name, + operator_type=self.operator_type, + ).set_operator_args(**{k: v for k, v in task_args.items() if v is not None}) + + def validate(self): + """ + Validate if the arguments for the sensor are correct. + """ + # If there is no name set then auto-generate the name. This is done because there can be more than + # one `AirflowSensorDecorator` of the same type. + if self.attributes["name"] is None: + deco_index = [ + d._id + for d in self._flow_decorators + if issubclass(d.__class__, AirflowSensorDecorator) + ].index(self._id) + self._airflow_task_name = "%s-%s" % ( + self.operator_type, + id_creator([self.operator_type, str(deco_index)], TASK_ID_HASH_LEN), + ) + else: + self._airflow_task_name = self.attributes["name"] + + def flow_init( + self, flow, graph, environment, flow_datastore, metadata, logger, echo, options + ): + self.validate() diff --git a/metaflow/plugins/airflow/sensors/external_task_sensor.py b/metaflow/plugins/airflow/sensors/external_task_sensor.py new file mode 100644 index 00000000000..649edba706c --- /dev/null +++ b/metaflow/plugins/airflow/sensors/external_task_sensor.py @@ -0,0 +1,94 @@ +from .base_sensor import AirflowSensorDecorator +from ..airflow_utils import SensorNames +from ..exception import AirflowException +from datetime import timedelta + + +AIRFLOW_STATES = dict( + QUEUED="queued", + RUNNING="running", + SUCCESS="success", + SHUTDOWN="shutdown", # External request to shut down, + FAILED="failed", + UP_FOR_RETRY="up_for_retry", + UP_FOR_RESCHEDULE="up_for_reschedule", + UPSTREAM_FAILED="upstream_failed", + SKIPPED="skipped", +) + + +class ExternalTaskSensorDecorator(AirflowSensorDecorator): + operator_type = SensorNames.EXTERNAL_TASK_SENSOR + # Docs: + # https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/sensors/external_task/index.html#airflow.sensors.external_task.ExternalTaskSensor + name = "airflow_external_task_sensor" + defaults = dict( + **AirflowSensorDecorator.defaults, + external_dag_id=None, + external_task_ids=None, + allowed_states=[AIRFLOW_STATES["SUCCESS"]], + failed_states=None, + execution_delta=None, + check_existence=True, + # We cannot add `execution_date_fn` as it requires a python callable. + # Passing around a python callable is non-trivial since we are passing a + # callable from metaflow-code to airflow python script. Since we cannot + # transfer dependencies of the callable, we cannot gaurentee that the callable + # behave exactly as the user expects + ) + + def serialize_operator_args(self): + task_args = super().serialize_operator_args() + if task_args["execution_delta"] is not None: + task_args["execution_delta"] = dict( + seconds=task_args["execution_delta"].total_seconds() + ) + return task_args + + def validate(self): + if self.attributes["external_dag_id"] is None: + raise AirflowException( + "`%s` argument of `@%s`cannot be `None`." + % ("external_dag_id", self.name) + ) + + if type(self.attributes["allowed_states"]) == str: + if self.attributes["allowed_states"] not in list(AIRFLOW_STATES.values()): + raise AirflowException( + "`%s` is an invalid input of `%s` for `@%s`. Accepted values are %s" + % ( + str(self.attributes["allowed_states"]), + "allowed_states", + self.name, + ", ".join(list(AIRFLOW_STATES.values())), + ) + ) + elif type(self.attributes["allowed_states"]) == list: + enum_not_matched = [ + x + for x in self.attributes["allowed_states"] + if x not in list(AIRFLOW_STATES.values()) + ] + if len(enum_not_matched) > 0: + raise AirflowException( + "`%s` is an invalid input of `%s` for `@%s`. Accepted values are %s" + % ( + str(" OR ".join(["'%s'" % i for i in enum_not_matched])), + "allowed_states", + self.name, + ", ".join(list(AIRFLOW_STATES.values())), + ) + ) + else: + self.attributes["allowed_states"] = [AIRFLOW_STATES["SUCCESS"]] + + if self.attributes["execution_delta"] is not None: + if not isinstance(self.attributes["execution_delta"], timedelta): + raise AirflowException( + "`%s` is an invalid input type of `execution_delta` for `@%s`. Accepted type is `datetime.timedelta`" + % ( + str(type(self.attributes["execution_delta"])), + self.name, + ) + ) + super().validate() diff --git a/metaflow/plugins/airflow/sensors/s3_sensor.py b/metaflow/plugins/airflow/sensors/s3_sensor.py new file mode 100644 index 00000000000..b4f7ae5b6de --- /dev/null +++ b/metaflow/plugins/airflow/sensors/s3_sensor.py @@ -0,0 +1,26 @@ +from .base_sensor import AirflowSensorDecorator +from ..airflow_utils import SensorNames +from ..exception import AirflowException + + +class S3KeySensorDecorator(AirflowSensorDecorator): + name = "airflow_s3_key_sensor" + operator_type = SensorNames.S3_SENSOR + # Arg specification can be found here : + # https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/_api/airflow/providers/amazon/aws/sensors/s3/index.html#airflow.providers.amazon.aws.sensors.s3.S3KeySensor + defaults = dict( + **AirflowSensorDecorator.defaults, + bucket_key=None, # Required + bucket_name=None, + wildcard_match=False, + aws_conn_id=None, + verify=None, # `verify (Optional[Union[str, bool]])` Whether or not to verify SSL certificates for S3 connection. + # `verify` is a airflow variable. + ) + + def validate(self): + if self.attributes["bucket_key"] is None: + raise AirflowException( + "`bucket_key` for `@%s`cannot be empty." % (self.name) + ) + super().validate() diff --git a/metaflow/plugins/airflow/sensors/sql_sensor.py b/metaflow/plugins/airflow/sensors/sql_sensor.py new file mode 100644 index 00000000000..c97c41b283e --- /dev/null +++ b/metaflow/plugins/airflow/sensors/sql_sensor.py @@ -0,0 +1,31 @@ +from .base_sensor import AirflowSensorDecorator +from ..airflow_utils import SensorNames +from ..exception import AirflowException + + +class SQLSensorDecorator(AirflowSensorDecorator): + name = "airflow_sql_sensor" + operator_type = SensorNames.SQL_SENSOR + # Arg specification can be found here : + # https://airflow.apache.org/docs/apache-airflow/stable/_api/airflow/sensors/sql/index.html#airflow.sensors.sql.SqlSensor + defaults = dict( + **AirflowSensorDecorator.defaults, + conn_id=None, + sql=None, + # success = None, # sucess/failure require callables. Wont be supported at start since not serialization friendly. + # failure = None, + parameters=None, + fail_on_empty=True, + ) + + def validate(self): + if self.attributes["conn_id"] is None: + raise AirflowException( + "`%s` argument of `@%s`cannot be `None`." % ("conn_id", self.name) + ) + raise _arg_exception("conn_id", self.name, None) + if self.attributes["sql"] is None: + raise AirflowException( + "`%s` argument of `@%s`cannot be `None`." % ("sql", self.name) + ) + super().validate()