Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion .generator/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,12 @@ def encode(self, obj):
JINJA_ENV.filters["safe_snake_case"] = safe_snake_case
JINJA_ENV.globals["format_data_with_schema"] = format_data_with_schema
JINJA_ENV.globals["format_parameters"] = format_parameters
JINJA_ENV.globals["package"] = "datadog_api_client"

PYTHON_EXAMPLE_J2 = JINJA_ENV.get_template("example.j2")

DATADOG_EXAMPLES_J2 = {
"aws.py": JINJA_ENV.get_template("example_aws.j2")
}

def pytest_bdd_after_scenario(request, feature, scenario):
try:
Expand Down Expand Up @@ -137,6 +140,19 @@ def pytest_bdd_after_scenario(request, feature, scenario):
with output.open("w") as f:
f.write(data)

for file_name, template in DATADOG_EXAMPLES_J2.items():
output = ROOT_PATH / "examples" / "datadog" / file_name
output.parent.mkdir(parents=True, exist_ok=True)

data = template.render(
context=context,
version=version,
scenario=scenario,
operation_spec=operation_spec.spec,
)
with output.open("w") as f:
f.write(data)


def pytest_bdd_apply_tag(tag, function):
"""Register tags as custom markers and skip test for '@skip' ones."""
Expand Down
2 changes: 2 additions & 0 deletions .generator/src/generator/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def cli(specs, output):
"exceptions.py": env.get_template("exceptions.j2"),
"model_utils.py": env.get_template("model_utils.j2"),
"rest.py": env.get_template("rest.j2"),
"delegated_auth.py": env.get_template("delegated_auth.j2"),
"aws.py": env.get_template("aws.j2"),
}

top_package = output / PACKAGE_NAME
Expand Down
74 changes: 62 additions & 12 deletions .generator/src/generator/templates/api_client.j2
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,40 @@ class ApiClient:
return "application/json"
return content_types[0]

def use_delegated_token_auth(self, headers: Dict[str, Any]) -> None:
"""Use delegated token authentication if configured.

:param headers: Header parameters dict to be updated.
:raises: ApiValueError if delegated token authentication fails
"""
from datetime import datetime
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thought: Is this necessary, or can it be rolled up into a more global import?

from {{ package }}.delegated_auth import DelegatedTokenConfig

# Check if we have cached credentials
if not hasattr(self.configuration, '_delegated_token_credentials'):
self.configuration._delegated_token_credentials = None

# Check if we need to get or refresh the token
if (self.configuration._delegated_token_credentials is None or
self.configuration._delegated_token_credentials.is_expired()):

# Create config for the provider
config = DelegatedTokenConfig(
org_uuid=self.configuration.delegated_auth_org_uuid,
provider="aws", # This could be made configurable
provider_auth=self.configuration.delegated_auth_provider
)

# Get new token from provider, passing the API configuration
try:
self.configuration._delegated_token_credentials = self.configuration.delegated_auth_provider.authenticate(config, self.configuration)
except Exception as e:
raise ApiValueError(f"Failed to get delegated token: {str(e)}")

# Set the Authorization header with the delegated token
token = self.configuration._delegated_token_credentials.delegated_token
headers["Authorization"] = f"Bearer {token}"


class ThreadedApiClient(ApiClient):

Expand Down Expand Up @@ -824,18 +858,34 @@ class Endpoint:
if not self.settings["auth"]:
return

for auth in self.settings["auth"]:
auth_setting = self.api_client.configuration.auth_settings().get(auth)
if auth_setting:
if auth_setting["in"] == "header":
if auth_setting["type"] != "http-signature":
if auth_setting["value"] is None:
raise ApiValueError("Invalid authentication token for {}".format(auth_setting["key"]))
headers[auth_setting["key"]] = auth_setting["value"]
elif auth_setting["in"] == "query":
queries.append((auth_setting["key"], auth_setting["value"]))
else:
raise ApiValueError("Authentication token must be in `query` or `header`")
# check if endpoint uses appKeyAuth and if delegated token config is available
has_app_key_auth = "appKeyAuth" in self.settings["auth"]

# Check if delegated auth is configured (using our actual attributes)
has_delegated_auth = (
hasattr(self.api_client.configuration, 'delegated_auth_provider') and
self.api_client.configuration.delegated_auth_provider is not None and
hasattr(self.api_client.configuration, 'delegated_auth_org_uuid') and
self.api_client.configuration.delegated_auth_org_uuid is not None
)

if has_app_key_auth and has_delegated_auth:
# Use delegated token authentication
self.api_client.use_delegated_token_auth(headers)
else:
# Use regular authentication
for auth in self.settings["auth"]:
auth_setting = self.api_client.configuration.auth_settings().get(auth)
if auth_setting:
if auth_setting["in"] == "header":
if auth_setting["type"] != "http-signature":
if auth_setting["value"] is None:
raise ApiValueError("Invalid authentication token for {}".format(auth_setting["key"]))
headers[auth_setting["key"]] = auth_setting["value"]
elif auth_setting["in"] == "query":
queries.append((auth_setting["key"], auth_setting["value"]))
else:
raise ApiValueError("Authentication token must be in `query` or `header`")


def user_agent() -> str:
Expand Down
262 changes: 262 additions & 0 deletions .generator/src/generator/templates/aws.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
{% include "api_info.j2" %}

import base64
import hashlib
import hmac
import json
import os
from datetime import datetime
from typing import Dict, List, Optional
from urllib.parse import quote

from {{ package }}.configuration import Configuration
from {{ package }}.delegated_auth import DelegatedTokenProvider, DelegatedTokenConfig, DelegatedTokenCredentials, get_delegated_token
from {{ package }}.exceptions import ApiValueError


# AWS specific constants
AWS_ACCESS_KEY_ID_NAME = "AWS_ACCESS_KEY_ID"
AWS_SECRET_ACCESS_KEY_NAME = "AWS_SECRET_ACCESS_KEY"
AWS_SESSION_TOKEN_NAME = "AWS_SESSION_TOKEN"

AMZ_DATE_HEADER = "X-Amz-Date"
AMZ_TOKEN_HEADER = "X-Amz-Security-Token"
AMZ_DATE_FORMAT = "%Y%m%d"
AMZ_DATE_TIME_FORMAT = "%Y%m%dT%H%M%SZ"
DEFAULT_REGION = "us-east-1"
DEFAULT_STS_HOST = "sts.amazonaws.com"
REGIONAL_STS_HOST = "sts.{}.amazonaws.com"
SERVICE = "sts"
ALGORITHM = "AWS4-HMAC-SHA256"
AWS4_REQUEST = "aws4_request"
GET_CALLER_IDENTITY_BODY = "Action=GetCallerIdentity&Version=2011-06-15"

# Common Headers
ORG_ID_HEADER = "x-ddog-org-id"
HOST_HEADER = "host"
APPLICATION_FORM = "application/x-www-form-urlencoded; charset=utf-8"

PROVIDER_AWS = "aws"


class AWSCredentials:
"""AWS credentials for authentication."""

def __init__(self, access_key_id: str, secret_access_key: str, session_token: str):
self.access_key_id = access_key_id
self.secret_access_key = secret_access_key
self.session_token = session_token


class SigningData:
"""Data structure for AWS signing information."""

def __init__(self, headers_encoded: str, body_encoded: str, url_encoded: str, method: str):
self.headers_encoded = headers_encoded
self.body_encoded = body_encoded
self.url_encoded = url_encoded
self.method = method


class AWSAuth(DelegatedTokenProvider):
"""AWS authentication provider for delegated tokens."""

def __init__(self, aws_region: Optional[str] = None):
self.aws_region = aws_region

def authenticate(self, config: DelegatedTokenConfig, api_config: Configuration) -> DelegatedTokenCredentials:
"""Authenticate using AWS credentials and return delegated token credentials.

:param config: Delegated token configuration
:param api_config: API client configuration with host and other settings
:return: DelegatedTokenCredentials object
:raises: ApiValueError if authentication fails
"""
# Check org UUID first
if not config or not config.org_uuid:
raise ApiValueError("Missing org UUID in config")

# Get local AWS Credentials
creds = self.get_credentials()

# Use the credentials to generate the signing data
data = self.generate_aws_auth_data(config.org_uuid, creds)

# Generate the auth string passed to the token endpoint
auth_string = f"{data.body_encoded}|{data.headers_encoded}|{data.method}|{data.url_encoded}"

# Pass the api_config to get_delegated_token to use the correct host
auth_response = get_delegated_token(config.org_uuid, auth_string, api_config)
return auth_response

def get_credentials(self) -> AWSCredentials:
"""Get AWS credentials from environment variables.

:return: AWSCredentials object
:raises: ApiValueError if credentials are missing
"""
access_key = os.getenv(AWS_ACCESS_KEY_ID_NAME)
secret_key = os.getenv(AWS_SECRET_ACCESS_KEY_NAME)
session_token = os.getenv(AWS_SESSION_TOKEN_NAME)

if not access_key or not secret_key or not session_token:
raise ApiValueError("Missing AWS credentials. Please set AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_SESSION_TOKEN environment variables.")

return AWSCredentials(
access_key_id=access_key,
secret_access_key=secret_key,
session_token=session_token
)

def _get_connection_parameters(self) -> tuple[str, str, str]:
"""Get connection parameters for AWS STS.

:return: Tuple of (sts_full_url, region, host)
"""
region = self.aws_region or DEFAULT_REGION

if self.aws_region:
host = REGIONAL_STS_HOST.format(region)
else:
host = DEFAULT_STS_HOST

sts_full_url = f"https://{host}"
return sts_full_url, region, host

def generate_aws_auth_data(self, org_uuid: str, creds: AWSCredentials) -> SigningData:
"""Generate AWS authentication data for signing.

:param org_uuid: Organization UUID
:param creds: AWS credentials
:return: SigningData object
:raises: ApiValueError if generation fails
"""
if not org_uuid:
raise ApiValueError("Missing org UUID")

if not creds or not creds.access_key_id or not creds.secret_access_key or not creds.session_token:
raise ApiValueError("Missing AWS credentials")

sts_full_url, region, host = self._get_connection_parameters()

now = datetime.utcnow()

request_body = GET_CALLER_IDENTITY_BODY
payload_hash = hashlib.sha256(request_body.encode('utf-8')).hexdigest()

# Create the headers that factor into the signing algorithm
header_map = {
"Content-Length": [str(len(request_body))],
"Content-Type": [APPLICATION_FORM],
AMZ_DATE_HEADER: [now.strftime(AMZ_DATE_TIME_FORMAT)],
ORG_ID_HEADER: [org_uuid],
AMZ_TOKEN_HEADER: [creds.session_token],
HOST_HEADER: [host],
}

# Create canonical headers
header_arr = []
signed_headers_arr = []

for k, v in header_map.items():
lowered_header_name = k.lower()
header_arr.append(f"{lowered_header_name}:{','.join(v)}")
signed_headers_arr.append(lowered_header_name)

header_arr.sort()
signed_headers_arr.sort()
signed_headers = ";".join(signed_headers_arr)

canonical_request = "\n".join([
"POST",
"/",
"", # No query string
"\n".join(header_arr) + "\n",
signed_headers,
payload_hash,
])

# Create the string to sign
hash_canonical_request = hashlib.sha256(canonical_request.encode('utf-8')).hexdigest()
credential_scope = "/".join([
now.strftime(AMZ_DATE_FORMAT),
region,
SERVICE,
AWS4_REQUEST,
])

string_to_sign = self._make_signature(
now,
credential_scope,
hash_canonical_request,
region,
SERVICE,
creds.secret_access_key,
ALGORITHM,
)

# Create the authorization header
credential = f"{creds.access_key_id}/{credential_scope}"
auth_header = f"{ALGORITHM} Credential={credential}, SignedHeaders={signed_headers}, Signature={string_to_sign}"

header_map["Authorization"] = [auth_header]
header_map["User-Agent"] = [self._get_user_agent()]

headers_json = json.dumps(header_map, separators=(',', ':'))

return SigningData(
headers_encoded=base64.b64encode(headers_json.encode('utf-8')).decode('utf-8'),
body_encoded=base64.b64encode(request_body.encode('utf-8')).decode('utf-8'),
method="POST",
url_encoded=base64.b64encode(sts_full_url.encode('utf-8')).decode('utf-8')
)

def _make_signature(self, t: datetime, credential_scope: str, payload_hash: str,
region: str, service: str, secret_access_key: str, algorithm: str) -> str:
"""Create AWS signature.

:param t: Current datetime
:param credential_scope: Credential scope string
:param payload_hash: Hash of the canonical request
:param region: AWS region
:param service: AWS service name
:param secret_access_key: AWS secret access key
:param algorithm: Signing algorithm
:return: Signature string
"""
# Create the string to sign
string_to_sign = "\n".join([
algorithm,
t.strftime(AMZ_DATE_TIME_FORMAT),
credential_scope,
payload_hash,
])

# Create the signing key
k_date = self._hmac256(t.strftime(AMZ_DATE_FORMAT), f"AWS4{secret_access_key}".encode('utf-8'))
k_region = self._hmac256(region, k_date)
k_service = self._hmac256(service, k_region)
k_signing = self._hmac256(AWS4_REQUEST, k_service)

# Sign the string
signature = self._hmac256(string_to_sign, k_signing)
return signature.hex()

def _hmac256(self, data: str, key: bytes) -> bytes:
"""Create HMAC-SHA256 hash.

:param data: Data to hash
:param key: Key for HMAC
:return: HMAC hash bytes
"""
return hmac.new(key, data.encode('utf-8'), hashlib.sha256).digest()

def _get_user_agent(self) -> str:
"""Get user agent string.

:return: User agent string
"""
import platform
from {{ package }}.version import __version__

return f"datadog-api-client-python/{__version__} (python {platform.python_version()}; os {platform.system()}; arch {platform.machine()})"
Loading