Last active
September 28, 2019 15:42
-
-
Save nicor88/ebe296b87ed47f2628fe1b4fb9ecf186 to your computer and use it in GitHub Desktop.
Airflow ECS dbt operator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import pprint | |
import sys | |
from airflow.exceptions import AirflowException | |
from airflow.models import BaseOperator | |
from airflow.plugins_manager import AirflowPlugin | |
from airflow.utils import apply_defaults | |
from airflow.contrib.hooks.aws_hook import AwsHook | |
logger = logging.getLogger('airflow.dbt_operator') | |
class DbtOperator(BaseOperator): | |
ui_color = '#D6EAF8' | |
client = None | |
arn = None | |
@apply_defaults | |
def __init__(self, | |
command, | |
target='dev', | |
dbt_models=None, | |
dbt_exclude=None, | |
subnets=[], | |
security_groups=[], | |
aws_connection_id='aws_default', | |
region_name='eu-west-1', | |
cluster='dbt-serverless', | |
task_definition='dbt-serverless-task', | |
log_group_name='/aws/ecs/dbt-serverless', | |
log_stream_name='dbt', # TODO to check | |
** kwargs): | |
super(DbtOperator, self).__init__(**kwargs) | |
self.command = command | |
self.dbt_models = dbt_models | |
self.target = target | |
self.dbt_exclude = dbt_exclude | |
self.aws_conn_id = aws_connection_id | |
self.region_name = region_name | |
self.cluster = cluster | |
self.task_definition = task_definition | |
self.log_group_name = log_group_name | |
self.log_stream_name = log_stream_name | |
self.subnets = subnets | |
self.security_groups = security_groups | |
self.hook = self.get_hook() | |
def execute(self, context): | |
container_command = ['dbt', f'{self.command}', '--target', f'{self.target}'] | |
if self.dbt_models is not None: | |
container_command.extend(['--models', f'{self.dbt_models}']) | |
if self.dbt_exclude is not None: | |
container_command.extend(['--exclude', f'{self.dbt_exclude}']) | |
overrides = { | |
'containerOverrides': [ | |
{ | |
'name': 'dbt', | |
'command': container_command | |
} | |
] | |
} | |
logger.info(f'Running ECS Task - Task definition: {self.task_definition} - on cluster {self.cluster}') | |
logger.debug('ECSOperator overrides: %s', overrides) | |
self.client = self.hook.get_client_type( | |
'ecs', | |
region_name=self.region_name | |
) | |
response = self.client.run_task( | |
cluster=self.cluster, | |
taskDefinition=self.task_definition, | |
launchType='FARGATE', | |
overrides=overrides, | |
startedBy=f'{self.target}_{self.command}', | |
networkConfiguration={'awsvpcConfiguration': { | |
'subnets': self.subnets, | |
'assignPublicIp': 'ENABLED', # keep it enabled otherwise will fail to pull the image | |
'securityGroups': self.security_groups | |
}} | |
) | |
failures = response['failures'] | |
if len(failures) > 0: | |
raise AirflowException(response) | |
logger.info(f'ECS Task {self.task_definition} started') | |
logger.debug('ECS Task started: %s', pprint.pformat(response)) | |
self.arn = response['tasks'][0]['taskArn'] | |
self.task_id = response['tasks'][0]['taskArn'].split('/')[1] | |
self._wait_for_task_ended() | |
self._check_success_task() | |
logger.debug('ECS Task has been successfully executed: %s', pprint.pformat(response)) | |
logger.info('Retrieving logs from Cloudwatch') | |
self._get_cloudwatch_logs() | |
logger.info(f'{self.task_id} task has been successfully executed in ECS cluster {self.cluster}') | |
def _wait_for_task_ended(self): | |
waiter = self.client.get_waiter('tasks_stopped') | |
waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow | |
waiter.wait( | |
cluster=self.cluster, | |
tasks=[self.arn] | |
) | |
def _check_success_task(self): | |
response = self.client.describe_tasks( | |
cluster=self.cluster, | |
tasks=[self.arn] | |
) | |
logger.info(f'ECS Task {self.task_id} stopped') | |
logger.debug('ECS Task stopped, check status: %s', pprint.pformat(response)) | |
if len(response.get('failures', [])) > 0: | |
raise AirflowException(response) | |
for task in response['tasks']: | |
containers = task['containers'] | |
for container in containers: | |
if container.get('lastStatus') == 'STOPPED' and container['exitCode'] != 0: | |
self._get_cloudwatch_logs() | |
raise AirflowException('This task is not in success state {}'.format(task)) | |
elif container.get('lastStatus') == 'PENDING': | |
self._get_cloudwatch_logs() | |
raise AirflowException('This task is still pending {}'.format(task)) | |
elif 'error' in container.get('reason', '').lower(): | |
self._get_cloudwatch_logs() | |
raise AirflowException('This containers encounter an error during launching : {}'. | |
format(container.get('reason', '').lower())) | |
def _get_cloudwatch_logs(self): | |
try: | |
cloudwatch_client = self.hook.get_client_type( | |
'logs', | |
region_name=self.region_name | |
) | |
raw_logs = cloudwatch_client.get_log_events( | |
logGroupName=self.log_group_name, | |
logStreamName=f'{self.log_stream_name}/{self.task_id}', | |
startFromHead=True | |
) | |
for event in raw_logs.get('events'): | |
logger.info(f'{event.get("message")}') | |
except Exception as error: | |
logger.error(f'There was en error fetching Cloudwatch logs for task {self.task_id}') | |
logger.error(error) | |
def get_hook(self): | |
return AwsHook( | |
aws_conn_id=self.aws_conn_id | |
) | |
def on_kill(self): | |
response = self.client.stop_task( | |
cluster=self.cluster, | |
task=self.arn, | |
reason='Task killed by the user') | |
logger.info('Task killed by the user') | |
logger.debug(pprint.pformat(response)) | |
class DbtPlugin(AirflowPlugin): | |
name = 'dbt_plugin' | |
operators = [DbtOperator] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment