Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions metaflow/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ class Decorator(object):

name = "NONAME"
defaults = {}
# `allow_multiple` allows setting many decorators of the same type to a step/flow.
allow_multiple = False

def __init__(self, attributes=None, statically_defined=False):
self.attributes = self.defaults.copy()
Expand Down Expand Up @@ -255,9 +257,6 @@ class MyDecorator(StepDecorator):
pass them around with every lifecycle call.
"""

# `allow_multiple` allows setting many decorators of the same type to a step.
allow_multiple = False

def step_init(
self, flow, graph, step_name, decorators, environment, flow_datastore, logger
):
Expand Down Expand Up @@ -403,12 +402,17 @@ def _base_flow_decorator(decofunc, *args, **kwargs):
if isinstance(cls, type) and issubclass(cls, FlowSpec):
# flow decorators add attributes in the class dictionary,
# _flow_decorators.
if decofunc.name in cls._flow_decorators:
if decofunc.name in cls._flow_decorators and not decofunc.allow_multiple:
raise DuplicateFlowDecoratorException(decofunc.name)
else:
cls._flow_decorators[decofunc.name] = decofunc(
attributes=kwargs, statically_defined=True
)
deco_instance = decofunc(attributes=kwargs, statically_defined=True)
if decofunc.allow_multiple:
if decofunc.name not in cls._flow_decorators:
cls._flow_decorators[decofunc.name] = [deco_instance]
else:
cls._flow_decorators[decofunc.name].append(deco_instance)
else:
cls._flow_decorators[decofunc.name] = deco_instance
else:
raise BadFlowDecoratorException(decofunc.name)
return cls
Expand Down Expand Up @@ -503,11 +507,26 @@ def _attach_decorators_to_step(step, decospecs):
def _init_flow_decorators(
flow, graph, environment, flow_datastore, metadata, logger, echo, deco_options
):
# Certain decorators can be specified multiple times and exist as lists in the _flow_decorators dictionary
for deco in flow._flow_decorators.values():
opts = {option: deco_options[option] for option in deco.options}
deco.flow_init(
flow, graph, environment, flow_datastore, metadata, logger, echo, opts
)
if type(deco) == list:
for rd in deco:
opts = {option: deco_options[option] for option in rd.options}
rd.flow_init(
flow,
graph,
environment,
flow_datastore,
metadata,
logger,
echo,
opts,
)
else:
opts = {option: deco_options[option] for option in deco.options}
deco.flow_init(
flow, graph, environment, flow_datastore, metadata, logger, echo, opts
)


def _init_step_decorators(flow, graph, environment, flow_datastore, logger):
Expand Down
4 changes: 3 additions & 1 deletion metaflow/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,13 @@ def get_plugin_cli():
from .project_decorator import ProjectDecorator


from .airflow.sensors import SUPPORTED_SENSORS

FLOW_DECORATORS = [
CondaFlowDecorator,
ScheduleDecorator,
ProjectDecorator,
]
] + SUPPORTED_SENSORS
_merge_lists(FLOW_DECORATORS, _ext_plugins["FLOW_DECORATORS"], "name")

# Cards
Expand Down
18 changes: 18 additions & 0 deletions metaflow/plugins/airflow/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from . import airflow_utils
from .exception import AirflowException
from .sensors import SUPPORTED_SENSORS
from .airflow_utils import (
TASK_ID_XCOM_KEY,
AirflowTask,
Expand Down Expand Up @@ -88,6 +89,7 @@ def __init__(
self.username = username
self.max_workers = max_workers
self.description = description
self._depends_on_upstream_sensors = False
self._file_path = file_path
_, self.graph_structure = self.graph.output_steps()
self.worker_pool = worker_pool
Expand Down Expand Up @@ -584,6 +586,17 @@ def _step_cli(self, node, paths, code_package_url, user_code_retries):
cmds.append(" ".join(entrypoint + top_level + step))
return cmds

def _collect_flow_sensors(self):
decos_lists = [
self.flow._flow_decorators.get(s.name)
for s in SUPPORTED_SENSORS
if self.flow._flow_decorators.get(s.name) is not None
]
af_tasks = [deco.create_task() for decos in decos_lists for deco in decos]
if len(af_tasks) > 0:
self._depends_on_upstream_sensors = True
return af_tasks

def _contains_foreach(self):
for node in self.graph:
if node.type == "foreach":
Expand Down Expand Up @@ -638,6 +651,7 @@ def _visit(node, workflow, exit_node=None):
if self.workflow_timeout is not None and self.schedule is not None:
airflow_dag_args["dagrun_timeout"] = dict(seconds=self.workflow_timeout)

appending_sensors = self._collect_flow_sensors()
workflow = Workflow(
dag_id=self.name,
default_args=self._create_defaults(),
Expand All @@ -658,6 +672,10 @@ def _visit(node, workflow, exit_node=None):
workflow = _visit(self.graph["start"], workflow)

workflow.set_parameters(self.parameters)
if len(appending_sensors) > 0:
for s in appending_sensors:
workflow.add_state(s)
workflow.graph_structure.insert(0, [[s.name] for s in appending_sensors])
return self._to_airflow_dag_file(workflow.to_dict())

def _to_airflow_dag_file(self, json_dag):
Expand Down
10 changes: 2 additions & 8 deletions metaflow/plugins/airflow/airflow_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ def make_flow(


def _validate_foreach_constraints(graph):
# Todo :Invoke this function when we integrate `foreach`s
def traverse_graph(node, state):
if node.type == "foreach" and node.is_inside_foreach:
raise NotSupportedException(
Expand All @@ -338,7 +337,7 @@ def traverse_graph(node, state):
if node.type == "linear" and node.is_inside_foreach:
state["foreach_stack"].append(node.name)

if len(state["foreach_stack"]) > 2:
if "foreach_stack" in state and len(state["foreach_stack"]) > 2:
raise NotSupportedException(
"The foreach step *%s* created by step *%s* needs to have an immediate join step. "
"Step *%s* is invalid since it is a linear step with a foreach. "
Expand Down Expand Up @@ -378,18 +377,13 @@ def _validate_workflow(flow, graph, flow_datastore, metadata, workflow_timeout):
"A default value is required for parameters when deploying flows on Airflow."
)
# check for other compute related decorators.
_validate_foreach_constraints(graph)
for node in graph:
if node.parallel_foreach:
raise AirflowException(
"Deploying flows with @parallel decorator(s) "
"to Airflow is not supported currently."
)

if node.type == "foreach":
raise NotSupportedException(
"Step *%s* is a foreach step and Foreach steps are not currently supported with Airflow."
% node.name
)
if any([d.name == "batch" for d in node.decorators]):
raise NotSupportedException(
"Step *%s* is marked for execution on AWS Batch with Airflow which isn't currently supported."
Expand Down
63 changes: 63 additions & 0 deletions metaflow/plugins/airflow/airflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class IncompatibleKubernetesProviderVersionException(Exception):
) % (sys.executable, KUBERNETES_PROVIDER_FOREACH_VERSION)


class AirflowSensorNotFound(Exception):
headline = "Sensor package not found"


def create_absolute_version_number(version):
abs_version = None
# For all digits
Expand Down Expand Up @@ -189,6 +193,16 @@ def pathspec(cls, flowname, is_foreach=False):
)


class SensorNames:
EXTERNAL_TASK_SENSOR = "ExternalTaskSensor"
S3_SENSOR = "S3KeySensor"
SQL_SENSOR = "SQLSensor"

@classmethod
def get_supported_sensors(cls):
return list(cls.__dict__.values())


def run_id_creator(val):
# join `[dag-id,run-id]` of airflow dag.
return hashlib.md5("-".join([str(x) for x in val]).encode("utf-8")).hexdigest()[
Expand Down Expand Up @@ -375,6 +389,46 @@ def _kubernetes_pod_operator_args(operator_args):
return args


def _parse_sensor_args(name, kwargs):
if name == SensorNames.EXTERNAL_TASK_SENSOR:
if "execution_delta" in kwargs:
if type(kwargs["execution_delta"]) == dict:
kwargs["execution_delta"] = timedelta(**kwargs["execution_delta"])
else:
del kwargs["execution_delta"]
return kwargs


def _get_sensor(name):
# from airflow import XComArg
# XComArg()
if name == SensorNames.EXTERNAL_TASK_SENSOR:
# ExternalTaskSensors uses an execution_date of a dag to
# determine the appropriate DAG.
# This is set to the exact date the current dag gets executed on.
# For example if "DagA" (Upstream DAG) got scheduled at
# 12 Jan 4:00 PM PDT then "DagB"(current DAG)'s task sensor will try to
# look for a "DagA" that got executed at 12 Jan 4:00 PM PDT **exactly**.
# They also support a `execution_timeout` argument to
from airflow.sensors.external_task_sensor import ExternalTaskSensor

return ExternalTaskSensor
elif name == SensorNames.S3_SENSOR:
try:
from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor
except ImportError:
raise AirflowSensorNotFound(
"This DAG requires a `S3KeySensor`. "
"Install the Airflow AWS provider using : "
"`pip install apache-airflow-providers-amazon`"
)
return S3KeySensor
elif name == SensorNames.SQL_SENSOR:
from airflow.sensors.sql import SqlSensor

return SqlSensor


def get_metaflow_kubernetes_operator():
try:
from airflow.contrib.operators.kubernetes_pod_operator import (
Expand Down Expand Up @@ -493,6 +547,13 @@ def set_operator_args(self, **kwargs):
self._operator_args = kwargs
return self

def _make_sensor(self):
TaskSensor = _get_sensor(self._operator_type)
return TaskSensor(
task_id=self.name,
**_parse_sensor_args(self._operator_type, self._operator_args)
)

def to_dict(self):
return {
"name": self.name,
Expand Down Expand Up @@ -541,6 +602,8 @@ def to_task(self):
return self._kubernetes_task()
else:
return self._kubernetes_mapper_task()
elif self._operator_type in SensorNames.get_supported_sensors():
return self._make_sensor()


class Workflow(object):
Expand Down
9 changes: 9 additions & 0 deletions metaflow/plugins/airflow/sensors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .external_task_sensor import ExternalTaskSensorDecorator
from .s3_sensor import S3KeySensorDecorator
from .sql_sensor import SQLSensorDecorator

SUPPORTED_SENSORS = [
ExternalTaskSensorDecorator,
S3KeySensorDecorator,
SQLSensorDecorator,
]
74 changes: 74 additions & 0 deletions metaflow/plugins/airflow/sensors/base_sensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import uuid
from metaflow.decorators import FlowDecorator
from ..exception import AirflowException
from ..airflow_utils import AirflowTask, id_creator, TASK_ID_HASH_LEN


class AirflowSensorDecorator(FlowDecorator):
"""
Base class for all Airflow sensor decorators.
"""

allow_multiple = True

defaults = dict(
timeout=3600,
poke_interval=60,
mode="reschedule",
exponential_backoff=True,
pool=None,
soft_fail=False,
name=None,
description=None,
)

operator_type = None

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._airflow_task_name = None
self._id = str(uuid.uuid4())

def serialize_operator_args(self):
"""
Subclasses will parse the decorator arguments to
Airflow task serializable arguments.
"""
task_args = dict(**self.attributes)
del task_args["name"]
if task_args["description"] is not None:
task_args["doc"] = task_args["description"]
del task_args["description"]
task_args["do_xcom_push"] = True
return task_args

def create_task(self):
task_args = self.serialize_operator_args()
return AirflowTask(
self._airflow_task_name,
operator_type=self.operator_type,
).set_operator_args(**{k: v for k, v in task_args.items() if v is not None})

def validate(self):
"""
Validate if the arguments for the sensor are correct.
"""
# If there is no name set then auto-generate the name. This is done because there can be more than
# one `AirflowSensorDecorator` of the same type.
if self.attributes["name"] is None:
deco_index = [
d._id
for d in self._flow_decorators
if issubclass(d.__class__, AirflowSensorDecorator)
].index(self._id)
self._airflow_task_name = "%s-%s" % (
self.operator_type,
id_creator([self.operator_type, str(deco_index)], TASK_ID_HASH_LEN),
)
else:
self._airflow_task_name = self.attributes["name"]

def flow_init(
self, flow, graph, environment, flow_datastore, metadata, logger, echo, options
):
self.validate()
Loading