diff --git a/.generator/conftest.py b/.generator/conftest.py index 56a1e24057..ada05747ec 100644 --- a/.generator/conftest.py +++ b/.generator/conftest.py @@ -82,9 +82,12 @@ def encode(self, obj): JINJA_ENV.filters["safe_snake_case"] = safe_snake_case JINJA_ENV.globals["format_data_with_schema"] = format_data_with_schema JINJA_ENV.globals["format_parameters"] = format_parameters +JINJA_ENV.globals["package"] = "datadog_api_client" PYTHON_EXAMPLE_J2 = JINJA_ENV.get_template("example.j2") - +DATADOG_EXAMPLES_J2 = { + "aws.py": JINJA_ENV.get_template("example_aws.j2") +} def pytest_bdd_after_scenario(request, feature, scenario): try: @@ -137,6 +140,19 @@ def pytest_bdd_after_scenario(request, feature, scenario): with output.open("w") as f: f.write(data) + for file_name, template in DATADOG_EXAMPLES_J2.items(): + output = ROOT_PATH / "examples" / "datadog" / file_name + output.parent.mkdir(parents=True, exist_ok=True) + + data = template.render( + context=context, + version=version, + scenario=scenario, + operation_spec=operation_spec.spec, + ) + with output.open("w") as f: + f.write(data) + def pytest_bdd_apply_tag(tag, function): """Register tags as custom markers and skip test for '@skip' ones.""" diff --git a/.generator/src/generator/cli.py b/.generator/src/generator/cli.py index 215d0cc156..982f6a9eac 100644 --- a/.generator/src/generator/cli.py +++ b/.generator/src/generator/cli.py @@ -72,6 +72,8 @@ def cli(specs, output): "exceptions.py": env.get_template("exceptions.j2"), "model_utils.py": env.get_template("model_utils.j2"), "rest.py": env.get_template("rest.j2"), + "delegated_auth.py": env.get_template("delegated_auth.j2"), + "aws.py": env.get_template("aws.j2"), } top_package = output / PACKAGE_NAME diff --git a/.generator/src/generator/templates/api_client.j2 b/.generator/src/generator/templates/api_client.j2 index 77b7b93904..877d0781ed 100644 --- a/.generator/src/generator/templates/api_client.j2 +++ b/.generator/src/generator/templates/api_client.j2 @@ -454,6 +454,40 @@ class ApiClient: return "application/json" return content_types[0] + def use_delegated_token_auth(self, headers: Dict[str, Any]) -> None: + """Use delegated token authentication if configured. + + :param headers: Header parameters dict to be updated. + :raises: ApiValueError if delegated token authentication fails + """ + from datetime import datetime + from {{ package }}.delegated_auth import DelegatedTokenConfig + + # Check if we have cached credentials + if not hasattr(self.configuration, '_delegated_token_credentials'): + self.configuration._delegated_token_credentials = None + + # Check if we need to get or refresh the token + if (self.configuration._delegated_token_credentials is None or + self.configuration._delegated_token_credentials.is_expired()): + + # Create config for the provider + config = DelegatedTokenConfig( + org_uuid=self.configuration.delegated_auth_org_uuid, + provider="aws", # This could be made configurable + provider_auth=self.configuration.delegated_auth_provider + ) + + # Get new token from provider, passing the API configuration + try: + self.configuration._delegated_token_credentials = self.configuration.delegated_auth_provider.authenticate(config, self.configuration) + except Exception as e: + raise ApiValueError(f"Failed to get delegated token: {str(e)}") + + # Set the Authorization header with the delegated token + token = self.configuration._delegated_token_credentials.delegated_token + headers["Authorization"] = f"Bearer {token}" + class ThreadedApiClient(ApiClient): @@ -824,18 +858,34 @@ class Endpoint: if not self.settings["auth"]: return - for auth in self.settings["auth"]: - auth_setting = self.api_client.configuration.auth_settings().get(auth) - if auth_setting: - if auth_setting["in"] == "header": - if auth_setting["type"] != "http-signature": - if auth_setting["value"] is None: - raise ApiValueError("Invalid authentication token for {}".format(auth_setting["key"])) - headers[auth_setting["key"]] = auth_setting["value"] - elif auth_setting["in"] == "query": - queries.append((auth_setting["key"], auth_setting["value"])) - else: - raise ApiValueError("Authentication token must be in `query` or `header`") + # check if endpoint uses appKeyAuth and if delegated token config is available + has_app_key_auth = "appKeyAuth" in self.settings["auth"] + + # Check if delegated auth is configured (using our actual attributes) + has_delegated_auth = ( + hasattr(self.api_client.configuration, 'delegated_auth_provider') and + self.api_client.configuration.delegated_auth_provider is not None and + hasattr(self.api_client.configuration, 'delegated_auth_org_uuid') and + self.api_client.configuration.delegated_auth_org_uuid is not None + ) + + if has_app_key_auth and has_delegated_auth: + # Use delegated token authentication + self.api_client.use_delegated_token_auth(headers) + else: + # Use regular authentication + for auth in self.settings["auth"]: + auth_setting = self.api_client.configuration.auth_settings().get(auth) + if auth_setting: + if auth_setting["in"] == "header": + if auth_setting["type"] != "http-signature": + if auth_setting["value"] is None: + raise ApiValueError("Invalid authentication token for {}".format(auth_setting["key"])) + headers[auth_setting["key"]] = auth_setting["value"] + elif auth_setting["in"] == "query": + queries.append((auth_setting["key"], auth_setting["value"])) + else: + raise ApiValueError("Authentication token must be in `query` or `header`") def user_agent() -> str: diff --git a/.generator/src/generator/templates/aws.j2 b/.generator/src/generator/templates/aws.j2 new file mode 100644 index 0000000000..579f0e971d --- /dev/null +++ b/.generator/src/generator/templates/aws.j2 @@ -0,0 +1,262 @@ +{% include "api_info.j2" %} + +import base64 +import hashlib +import hmac +import json +import os +from datetime import datetime +from typing import Dict, List, Optional +from urllib.parse import quote + +from {{ package }}.configuration import Configuration +from {{ package }}.delegated_auth import DelegatedTokenProvider, DelegatedTokenConfig, DelegatedTokenCredentials, get_delegated_token +from {{ package }}.exceptions import ApiValueError + + +# AWS specific constants +AWS_ACCESS_KEY_ID_NAME = "AWS_ACCESS_KEY_ID" +AWS_SECRET_ACCESS_KEY_NAME = "AWS_SECRET_ACCESS_KEY" +AWS_SESSION_TOKEN_NAME = "AWS_SESSION_TOKEN" + +AMZ_DATE_HEADER = "X-Amz-Date" +AMZ_TOKEN_HEADER = "X-Amz-Security-Token" +AMZ_DATE_FORMAT = "%Y%m%d" +AMZ_DATE_TIME_FORMAT = "%Y%m%dT%H%M%SZ" +DEFAULT_REGION = "us-east-1" +DEFAULT_STS_HOST = "sts.amazonaws.com" +REGIONAL_STS_HOST = "sts.{}.amazonaws.com" +SERVICE = "sts" +ALGORITHM = "AWS4-HMAC-SHA256" +AWS4_REQUEST = "aws4_request" +GET_CALLER_IDENTITY_BODY = "Action=GetCallerIdentity&Version=2011-06-15" + +# Common Headers +ORG_ID_HEADER = "x-ddog-org-id" +HOST_HEADER = "host" +APPLICATION_FORM = "application/x-www-form-urlencoded; charset=utf-8" + +PROVIDER_AWS = "aws" + + +class AWSCredentials: + """AWS credentials for authentication.""" + + def __init__(self, access_key_id: str, secret_access_key: str, session_token: str): + self.access_key_id = access_key_id + self.secret_access_key = secret_access_key + self.session_token = session_token + + +class SigningData: + """Data structure for AWS signing information.""" + + def __init__(self, headers_encoded: str, body_encoded: str, url_encoded: str, method: str): + self.headers_encoded = headers_encoded + self.body_encoded = body_encoded + self.url_encoded = url_encoded + self.method = method + + +class AWSAuth(DelegatedTokenProvider): + """AWS authentication provider for delegated tokens.""" + + def __init__(self, aws_region: Optional[str] = None): + self.aws_region = aws_region + + def authenticate(self, config: DelegatedTokenConfig, api_config: Configuration) -> DelegatedTokenCredentials: + """Authenticate using AWS credentials and return delegated token credentials. + + :param config: Delegated token configuration + :param api_config: API client configuration with host and other settings + :return: DelegatedTokenCredentials object + :raises: ApiValueError if authentication fails + """ + # Check org UUID first + if not config or not config.org_uuid: + raise ApiValueError("Missing org UUID in config") + + # Get local AWS Credentials + creds = self.get_credentials() + + # Use the credentials to generate the signing data + data = self.generate_aws_auth_data(config.org_uuid, creds) + + # Generate the auth string passed to the token endpoint + auth_string = f"{data.body_encoded}|{data.headers_encoded}|{data.method}|{data.url_encoded}" + + # Pass the api_config to get_delegated_token to use the correct host + auth_response = get_delegated_token(config.org_uuid, auth_string, api_config) + return auth_response + + def get_credentials(self) -> AWSCredentials: + """Get AWS credentials from environment variables. + + :return: AWSCredentials object + :raises: ApiValueError if credentials are missing + """ + access_key = os.getenv(AWS_ACCESS_KEY_ID_NAME) + secret_key = os.getenv(AWS_SECRET_ACCESS_KEY_NAME) + session_token = os.getenv(AWS_SESSION_TOKEN_NAME) + + if not access_key or not secret_key or not session_token: + raise ApiValueError("Missing AWS credentials. Please set AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_SESSION_TOKEN environment variables.") + + return AWSCredentials( + access_key_id=access_key, + secret_access_key=secret_key, + session_token=session_token + ) + + def _get_connection_parameters(self) -> tuple[str, str, str]: + """Get connection parameters for AWS STS. + + :return: Tuple of (sts_full_url, region, host) + """ + region = self.aws_region or DEFAULT_REGION + + if self.aws_region: + host = REGIONAL_STS_HOST.format(region) + else: + host = DEFAULT_STS_HOST + + sts_full_url = f"https://{host}" + return sts_full_url, region, host + + def generate_aws_auth_data(self, org_uuid: str, creds: AWSCredentials) -> SigningData: + """Generate AWS authentication data for signing. + + :param org_uuid: Organization UUID + :param creds: AWS credentials + :return: SigningData object + :raises: ApiValueError if generation fails + """ + if not org_uuid: + raise ApiValueError("Missing org UUID") + + if not creds or not creds.access_key_id or not creds.secret_access_key or not creds.session_token: + raise ApiValueError("Missing AWS credentials") + + sts_full_url, region, host = self._get_connection_parameters() + + now = datetime.utcnow() + + request_body = GET_CALLER_IDENTITY_BODY + payload_hash = hashlib.sha256(request_body.encode('utf-8')).hexdigest() + + # Create the headers that factor into the signing algorithm + header_map = { + "Content-Length": [str(len(request_body))], + "Content-Type": [APPLICATION_FORM], + AMZ_DATE_HEADER: [now.strftime(AMZ_DATE_TIME_FORMAT)], + ORG_ID_HEADER: [org_uuid], + AMZ_TOKEN_HEADER: [creds.session_token], + HOST_HEADER: [host], + } + + # Create canonical headers + header_arr = [] + signed_headers_arr = [] + + for k, v in header_map.items(): + lowered_header_name = k.lower() + header_arr.append(f"{lowered_header_name}:{','.join(v)}") + signed_headers_arr.append(lowered_header_name) + + header_arr.sort() + signed_headers_arr.sort() + signed_headers = ";".join(signed_headers_arr) + + canonical_request = "\n".join([ + "POST", + "/", + "", # No query string + "\n".join(header_arr) + "\n", + signed_headers, + payload_hash, + ]) + + # Create the string to sign + hash_canonical_request = hashlib.sha256(canonical_request.encode('utf-8')).hexdigest() + credential_scope = "/".join([ + now.strftime(AMZ_DATE_FORMAT), + region, + SERVICE, + AWS4_REQUEST, + ]) + + string_to_sign = self._make_signature( + now, + credential_scope, + hash_canonical_request, + region, + SERVICE, + creds.secret_access_key, + ALGORITHM, + ) + + # Create the authorization header + credential = f"{creds.access_key_id}/{credential_scope}" + auth_header = f"{ALGORITHM} Credential={credential}, SignedHeaders={signed_headers}, Signature={string_to_sign}" + + header_map["Authorization"] = [auth_header] + header_map["User-Agent"] = [self._get_user_agent()] + + headers_json = json.dumps(header_map, separators=(',', ':')) + + return SigningData( + headers_encoded=base64.b64encode(headers_json.encode('utf-8')).decode('utf-8'), + body_encoded=base64.b64encode(request_body.encode('utf-8')).decode('utf-8'), + method="POST", + url_encoded=base64.b64encode(sts_full_url.encode('utf-8')).decode('utf-8') + ) + + def _make_signature(self, t: datetime, credential_scope: str, payload_hash: str, + region: str, service: str, secret_access_key: str, algorithm: str) -> str: + """Create AWS signature. + + :param t: Current datetime + :param credential_scope: Credential scope string + :param payload_hash: Hash of the canonical request + :param region: AWS region + :param service: AWS service name + :param secret_access_key: AWS secret access key + :param algorithm: Signing algorithm + :return: Signature string + """ + # Create the string to sign + string_to_sign = "\n".join([ + algorithm, + t.strftime(AMZ_DATE_TIME_FORMAT), + credential_scope, + payload_hash, + ]) + + # Create the signing key + k_date = self._hmac256(t.strftime(AMZ_DATE_FORMAT), f"AWS4{secret_access_key}".encode('utf-8')) + k_region = self._hmac256(region, k_date) + k_service = self._hmac256(service, k_region) + k_signing = self._hmac256(AWS4_REQUEST, k_service) + + # Sign the string + signature = self._hmac256(string_to_sign, k_signing) + return signature.hex() + + def _hmac256(self, data: str, key: bytes) -> bytes: + """Create HMAC-SHA256 hash. + + :param data: Data to hash + :param key: Key for HMAC + :return: HMAC hash bytes + """ + return hmac.new(key, data.encode('utf-8'), hashlib.sha256).digest() + + def _get_user_agent(self) -> str: + """Get user agent string. + + :return: User agent string + """ + import platform + from {{ package }}.version import __version__ + + return f"datadog-api-client-python/{__version__} (python {platform.python_version()}; os {platform.system()}; arch {platform.machine()})" diff --git a/.generator/src/generator/templates/configuration.j2 b/.generator/src/generator/templates/configuration.j2 index c261378471..3fb24de09b 100644 --- a/.generator/src/generator/templates/configuration.j2 +++ b/.generator/src/generator/templates/configuration.j2 @@ -245,6 +245,9 @@ class Configuration: {%- endfor %} }) + # Delegated token configuration + self.delegated_token_config = None + # Load default values from environment if "DD_SITE" in os.environ: self.server_variables["site"] = os.environ["DD_SITE"] diff --git a/.generator/src/generator/templates/delegated_auth.j2 b/.generator/src/generator/templates/delegated_auth.j2 new file mode 100644 index 0000000000..3ea48bd072 --- /dev/null +++ b/.generator/src/generator/templates/delegated_auth.j2 @@ -0,0 +1,149 @@ +{% include "api_info.j2" %} + +import json +import time +from datetime import datetime, timedelta +from typing import Optional +from urllib.parse import urljoin + +from {{ package }} import rest +from {{ package }}.configuration import Configuration +from {{ package }}.exceptions import ApiValueError + + +TOKEN_URL_ENDPOINT = "/api/v2/delegated-token" +AUTHORIZATION_TYPE = "Delegated" +APPLICATION_JSON = "application/json" + + +class DelegatedTokenCredentials: + """Credentials for delegated token authentication.""" + + def __init__(self, org_uuid: str, delegated_token: str, delegated_proof: str, expiration: datetime): + self.org_uuid = org_uuid + self.delegated_token = delegated_token + self.delegated_proof = delegated_proof + self.expiration = expiration + + def is_expired(self) -> bool: + """Check if the token is expired.""" + return datetime.now() >= self.expiration + + +class DelegatedTokenConfig: + """Configuration for delegated token authentication.""" + + def __init__(self, org_uuid: str, provider: str, provider_auth: 'DelegatedTokenProvider'): + self.org_uuid = org_uuid + self.provider = provider + self.provider_auth = provider_auth + + +class DelegatedTokenProvider: + """Abstract base class for delegated token providers.""" + + def authenticate(self, config: DelegatedTokenConfig, api_config: Configuration) -> DelegatedTokenCredentials: + """Authenticate and return delegated token credentials. + + :param config: Delegated token configuration + :param api_config: API client configuration with host and other settings + :return: DelegatedTokenCredentials object + """ + raise NotImplementedError("Subclasses must implement authenticate method") + + +def get_delegated_token(org_uuid: str, delegated_auth_proof: str, config: Configuration) -> DelegatedTokenCredentials: + """Get a delegated token from the Datadog API. + + :param org_uuid: Organization UUID + :param delegated_auth_proof: Authentication proof string + :param config: Configuration object with host and other settings + :return: DelegatedTokenCredentials object + :raises: ApiValueError if the request fails + """ + url = get_delegated_token_url(config) + + # Create REST client + rest_client = rest.RESTClientObject(config) + + headers = { + "Content-Type": APPLICATION_JSON, + "Authorization": f"{AUTHORIZATION_TYPE} {delegated_auth_proof}", + "Content-Length": "0" + } + + try: + response = rest_client.request( + method="POST", + url=url, + headers=headers, + body="", + preload_content=True + ) + + if response.status != 200: + raise ApiValueError(f"Failed to get token: {response.status}") + + response_data = response.data.decode('utf-8') + creds = parse_delegated_token_response(response_data, org_uuid, delegated_auth_proof) + return creds + + except Exception as e: + raise ApiValueError(f"Failed to get delegated token: {str(e)}") + + +def parse_delegated_token_response(response_data: str, org_uuid: str, delegated_auth_proof: str) -> DelegatedTokenCredentials: + """Parse the delegated token response. + + :param response_data: JSON response data as string + :param org_uuid: Organization UUID + :param delegated_auth_proof: Authentication proof string + :return: DelegatedTokenCredentials object + :raises: ApiValueError if parsing fails + """ + try: + token_response = json.loads(response_data) + except json.JSONDecodeError as e: + raise ApiValueError(f"Failed to parse token response: {str(e)}") + + # Get attributes from the response + data_response = token_response.get("data") + if not data_response: + raise ApiValueError(f"Failed to get data from response: {token_response}") + + attributes = data_response.get("attributes") + if not attributes: + raise ApiValueError(f"Failed to get attributes from response: {token_response}") + + # Get the access token from the response + token = attributes.get("access_token") + if not token: + raise ApiValueError(f"Failed to get token from response: {token_response}") + + # get expiration time from the response, defualt to 15 min + expiration_time = datetime.now() + timedelta(minutes=15) + expires_str = attributes.get("expires") + if expires_str: + try: + expiration_int = int(expires_str) + expiration_time = datetime.fromtimestamp(expiration_int) + except (ValueError, TypeError): + # Use default expiration if parsing fails + pass + + return DelegatedTokenCredentials( + org_uuid=org_uuid, + delegated_token=token, + delegated_proof=delegated_auth_proof, + expiration=expiration_time + ) + + +def get_delegated_token_url(config: Configuration) -> str: + """Get the URL for the delegated token endpoint. + + :param config: Configuration object + :return: Full URL for the delegated token endpoint + """ + base_url = config.host + return urljoin(base_url, TOKEN_URL_ENDPOINT) diff --git a/.generator/src/generator/templates/example_aws.j2 b/.generator/src/generator/templates/example_aws.j2 new file mode 100644 index 0000000000..4e4dc7eb98 --- /dev/null +++ b/.generator/src/generator/templates/example_aws.j2 @@ -0,0 +1,66 @@ +""" +Example of using AWS authentication with the Datadog API client. + +This example shows how to configure the client to use AWS credentials +for authentication instead of API keys. +""" + +import os +from {{ package }} import ApiClient, Configuration +from {{ package }}.aws import AWSAuth +from {{ package }}.v2.api.teams_api import TeamsApi + + +def main(): + # Set up AWS credentials in environment variables + # These would typically be set by your AWS environment (EC2 instance role, ECS task role, etc.) + # or explicitly set in your environment: + # os.environ["AWS_ACCESS_KEY_ID"] = "your-access-key-id" + # os.environ["AWS_SECRET_ACCESS_KEY"] = "your-secret-access-key" + # os.environ["AWS_SESSION_TOKEN"] = "your-session-token" + + # Verify AWS credentials are available + if not all([ + os.getenv("AWS_ACCESS_KEY_ID"), + os.getenv("AWS_SECRET_ACCESS_KEY"), + os.getenv("AWS_SESSION_TOKEN") + ]): + print("Error: AWS credentials not found in environment variables.") + print("Please set AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_SESSION_TOKEN") + return + + + org_uuid = os.getenv("DD_TEST_ORG_UUID") + if not org_uuid: + print("Error: DD_TEST_ORG_UUID environment variable not set.") + print("Please set your Datadog organization UUID") + return + + # Create AWS authentication provider + # Optionally specify AWS region (defaults to us-east-1) + aws_region = os.getenv("AWS_REGION", "us-east-1") + aws_auth = AWSAuth(aws_region=aws_region) + + # Create configuration with AWS authentication + configuration = Configuration(host="https://dd.datad0g.com") + configuration.delegated_auth_provider = aws_auth + configuration.delegated_auth_org_uuid = org_uuid + + # Create API client with the configuration + with ApiClient(configuration) as api_client: + # Create API instance - using TeamsApi as an example + api_instance = TeamsApi(api_client) + + try: + # Test the authentication by listing teams + # The client will automatically use AWS authentication for this call + api_response = api_instance.list_teams() + print("Authentication successful!") + print(f"Found {len(api_response.data)} teams") + + except Exception as e: + print(f"Authentication failed: {e}") + + +if __name__ == "__main__": + main() diff --git a/docs/datadog_api_client.rst b/docs/datadog_api_client.rst index 6d1e001200..c8c1a345bd 100644 --- a/docs/datadog_api_client.rst +++ b/docs/datadog_api_client.rst @@ -20,6 +20,13 @@ datadog\_api\_client.api\_client module :members: :show-inheritance: +datadog\_api\_client.aws module +------------------------------- + +.. automodule:: datadog_api_client.aws + :members: + :show-inheritance: + datadog\_api\_client.configuration module ----------------------------------------- @@ -27,6 +34,13 @@ datadog\_api\_client.configuration module :members: :show-inheritance: +datadog\_api\_client.delegated\_auth module +------------------------------------------- + +.. automodule:: datadog_api_client.delegated_auth + :members: + :show-inheritance: + datadog\_api\_client.exceptions module -------------------------------------- diff --git a/examples/datadog/aws.py b/examples/datadog/aws.py new file mode 100644 index 0000000000..20d967319a --- /dev/null +++ b/examples/datadog/aws.py @@ -0,0 +1,61 @@ +""" +Example of using AWS authentication with the Datadog API client. + +This example shows how to configure the client to use AWS credentials +for authentication instead of API keys. +""" + +import os +from datadog_api_client import ApiClient, Configuration +from datadog_api_client.aws import AWSAuth +from datadog_api_client.v2.api.teams_api import TeamsApi + + +def main(): + # Set up AWS credentials in environment variables + # These would typically be set by your AWS environment (EC2 instance role, ECS task role, etc.) + # or explicitly set in your environment: + # os.environ["AWS_ACCESS_KEY_ID"] = "your-access-key-id" + # os.environ["AWS_SECRET_ACCESS_KEY"] = "your-secret-access-key" + # os.environ["AWS_SESSION_TOKEN"] = "your-session-token" + + # Verify AWS credentials are available + if not all([os.getenv("AWS_ACCESS_KEY_ID"), os.getenv("AWS_SECRET_ACCESS_KEY"), os.getenv("AWS_SESSION_TOKEN")]): + print("Error: AWS credentials not found in environment variables.") + print("Please set AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_SESSION_TOKEN") + return + + org_uuid = os.getenv("DD_TEST_ORG_UUID") + if not org_uuid: + print("Error: DD_TEST_ORG_UUID environment variable not set.") + print("Please set your Datadog organization UUID") + return + + # Create AWS authentication provider + # Optionally specify AWS region (defaults to us-east-1) + aws_region = os.getenv("AWS_REGION", "us-east-1") + aws_auth = AWSAuth(aws_region=aws_region) + + # Create configuration with AWS authentication + configuration = Configuration(host="https://dd.datad0g.com") + configuration.delegated_auth_provider = aws_auth + configuration.delegated_auth_org_uuid = org_uuid + + # Create API client with the configuration + with ApiClient(configuration) as api_client: + # Create API instance - using TeamsApi as an example + api_instance = TeamsApi(api_client) + + try: + # Test the authentication by listing teams + # The client will automatically use AWS authentication for this call + api_response = api_instance.list_teams() + print("Authentication successful!") + print(f"Found {len(api_response.data)} teams") + + except Exception as e: + print(f"Authentication failed: {e}") + + +if __name__ == "__main__": + main() diff --git a/src/datadog_api_client/api_client.py b/src/datadog_api_client/api_client.py index d9794ede9c..76f1db40d3 100644 --- a/src/datadog_api_client/api_client.py +++ b/src/datadog_api_client/api_client.py @@ -454,6 +454,42 @@ def select_header_content_type(self, content_types: List[str]) -> str: return "application/json" return content_types[0] + def use_delegated_token_auth(self, headers: Dict[str, Any]) -> None: + """Use delegated token authentication if configured. + + :param headers: Header parameters dict to be updated. + :raises: ApiValueError if delegated token authentication fails + """ + from datadog_api_client.delegated_auth import DelegatedTokenConfig + + # Check if we have cached credentials + if not hasattr(self.configuration, "_delegated_token_credentials"): + self.configuration._delegated_token_credentials = None + + # Check if we need to get or refresh the token + if ( + self.configuration._delegated_token_credentials is None + or self.configuration._delegated_token_credentials.is_expired() + ): + # Create config for the provider + config = DelegatedTokenConfig( + org_uuid=self.configuration.delegated_auth_org_uuid, + provider="aws", # This could be made configurable + provider_auth=self.configuration.delegated_auth_provider, + ) + + # Get new token from provider, passing the API configuration + try: + self.configuration._delegated_token_credentials = ( + self.configuration.delegated_auth_provider.authenticate(config, self.configuration) + ) + except Exception as e: + raise ApiValueError(f"Failed to get delegated token: {str(e)}") + + # Set the Authorization header with the delegated token + token = self.configuration._delegated_token_credentials.delegated_token + headers["Authorization"] = f"Bearer {token}" + class ThreadedApiClient(ApiClient): _pool = None @@ -822,18 +858,34 @@ def update_params_for_auth(self, headers, queries) -> None: if not self.settings["auth"]: return - for auth in self.settings["auth"]: - auth_setting = self.api_client.configuration.auth_settings().get(auth) - if auth_setting: - if auth_setting["in"] == "header": - if auth_setting["type"] != "http-signature": - if auth_setting["value"] is None: - raise ApiValueError("Invalid authentication token for {}".format(auth_setting["key"])) - headers[auth_setting["key"]] = auth_setting["value"] - elif auth_setting["in"] == "query": - queries.append((auth_setting["key"], auth_setting["value"])) - else: - raise ApiValueError("Authentication token must be in `query` or `header`") + # check if endpoint uses appKeyAuth and if delegated token config is available + has_app_key_auth = "appKeyAuth" in self.settings["auth"] + + # Check if delegated auth is configured (using our actual attributes) + has_delegated_auth = ( + hasattr(self.api_client.configuration, "delegated_auth_provider") + and self.api_client.configuration.delegated_auth_provider is not None + and hasattr(self.api_client.configuration, "delegated_auth_org_uuid") + and self.api_client.configuration.delegated_auth_org_uuid is not None + ) + + if has_app_key_auth and has_delegated_auth: + # Use delegated token authentication + self.api_client.use_delegated_token_auth(headers) + else: + # Use regular authentication + for auth in self.settings["auth"]: + auth_setting = self.api_client.configuration.auth_settings().get(auth) + if auth_setting: + if auth_setting["in"] == "header": + if auth_setting["type"] != "http-signature": + if auth_setting["value"] is None: + raise ApiValueError("Invalid authentication token for {}".format(auth_setting["key"])) + headers[auth_setting["key"]] = auth_setting["value"] + elif auth_setting["in"] == "query": + queries.append((auth_setting["key"], auth_setting["value"])) + else: + raise ApiValueError("Authentication token must be in `query` or `header`") def user_agent() -> str: diff --git a/src/datadog_api_client/aws.py b/src/datadog_api_client/aws.py new file mode 100644 index 0000000000..48087e6c83 --- /dev/null +++ b/src/datadog_api_client/aws.py @@ -0,0 +1,280 @@ +# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License. +# This product includes software developed at Datadog (https://www.datadoghq.com/). +# Copyright 2019-Present Datadog, Inc. + +import base64 +import hashlib +import hmac +import json +import os +from datetime import datetime +from typing import Optional + +from datadog_api_client.configuration import Configuration +from datadog_api_client.delegated_auth import ( + DelegatedTokenProvider, + DelegatedTokenConfig, + DelegatedTokenCredentials, + get_delegated_token, +) +from datadog_api_client.exceptions import ApiValueError + + +# AWS specific constants +AWS_ACCESS_KEY_ID_NAME = "AWS_ACCESS_KEY_ID" +AWS_SECRET_ACCESS_KEY_NAME = "AWS_SECRET_ACCESS_KEY" +AWS_SESSION_TOKEN_NAME = "AWS_SESSION_TOKEN" + +AMZ_DATE_HEADER = "X-Amz-Date" +AMZ_TOKEN_HEADER = "X-Amz-Security-Token" +AMZ_DATE_FORMAT = "%Y%m%d" +AMZ_DATE_TIME_FORMAT = "%Y%m%dT%H%M%SZ" +DEFAULT_REGION = "us-east-1" +DEFAULT_STS_HOST = "sts.amazonaws.com" +REGIONAL_STS_HOST = "sts.{}.amazonaws.com" +SERVICE = "sts" +ALGORITHM = "AWS4-HMAC-SHA256" +AWS4_REQUEST = "aws4_request" +GET_CALLER_IDENTITY_BODY = "Action=GetCallerIdentity&Version=2011-06-15" + +# Common Headers +ORG_ID_HEADER = "x-ddog-org-id" +HOST_HEADER = "host" +APPLICATION_FORM = "application/x-www-form-urlencoded; charset=utf-8" + +PROVIDER_AWS = "aws" + + +class AWSCredentials: + """AWS credentials for authentication.""" + + def __init__(self, access_key_id: str, secret_access_key: str, session_token: str): + self.access_key_id = access_key_id + self.secret_access_key = secret_access_key + self.session_token = session_token + + +class SigningData: + """Data structure for AWS signing information.""" + + def __init__(self, headers_encoded: str, body_encoded: str, url_encoded: str, method: str): + self.headers_encoded = headers_encoded + self.body_encoded = body_encoded + self.url_encoded = url_encoded + self.method = method + + +class AWSAuth(DelegatedTokenProvider): + """AWS authentication provider for delegated tokens.""" + + def __init__(self, aws_region: Optional[str] = None): + self.aws_region = aws_region + + def authenticate(self, config: DelegatedTokenConfig, api_config: Configuration) -> DelegatedTokenCredentials: + """Authenticate using AWS credentials and return delegated token credentials. + + :param config: Delegated token configuration + :param api_config: API client configuration with host and other settings + :return: DelegatedTokenCredentials object + :raises: ApiValueError if authentication fails + """ + # Check org UUID first + if not config or not config.org_uuid: + raise ApiValueError("Missing org UUID in config") + + # Get local AWS Credentials + creds = self.get_credentials() + + # Use the credentials to generate the signing data + data = self.generate_aws_auth_data(config.org_uuid, creds) + + # Generate the auth string passed to the token endpoint + auth_string = f"{data.body_encoded}|{data.headers_encoded}|{data.method}|{data.url_encoded}" + + # Pass the api_config to get_delegated_token to use the correct host + auth_response = get_delegated_token(config.org_uuid, auth_string, api_config) + return auth_response + + def get_credentials(self) -> AWSCredentials: + """Get AWS credentials from environment variables. + + :return: AWSCredentials object + :raises: ApiValueError if credentials are missing + """ + access_key = os.getenv(AWS_ACCESS_KEY_ID_NAME) + secret_key = os.getenv(AWS_SECRET_ACCESS_KEY_NAME) + session_token = os.getenv(AWS_SESSION_TOKEN_NAME) + + if not access_key or not secret_key or not session_token: + raise ApiValueError( + "Missing AWS credentials. Please set AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_SESSION_TOKEN environment variables." + ) + + return AWSCredentials(access_key_id=access_key, secret_access_key=secret_key, session_token=session_token) + + def _get_connection_parameters(self) -> tuple[str, str, str]: + """Get connection parameters for AWS STS. + + :return: Tuple of (sts_full_url, region, host) + """ + region = self.aws_region or DEFAULT_REGION + + if self.aws_region: + host = REGIONAL_STS_HOST.format(region) + else: + host = DEFAULT_STS_HOST + + sts_full_url = f"https://{host}" + return sts_full_url, region, host + + def generate_aws_auth_data(self, org_uuid: str, creds: AWSCredentials) -> SigningData: + """Generate AWS authentication data for signing. + + :param org_uuid: Organization UUID + :param creds: AWS credentials + :return: SigningData object + :raises: ApiValueError if generation fails + """ + if not org_uuid: + raise ApiValueError("Missing org UUID") + + if not creds or not creds.access_key_id or not creds.secret_access_key or not creds.session_token: + raise ApiValueError("Missing AWS credentials") + + sts_full_url, region, host = self._get_connection_parameters() + + now = datetime.utcnow() + + request_body = GET_CALLER_IDENTITY_BODY + payload_hash = hashlib.sha256(request_body.encode("utf-8")).hexdigest() + + # Create the headers that factor into the signing algorithm + header_map = { + "Content-Length": [str(len(request_body))], + "Content-Type": [APPLICATION_FORM], + AMZ_DATE_HEADER: [now.strftime(AMZ_DATE_TIME_FORMAT)], + ORG_ID_HEADER: [org_uuid], + AMZ_TOKEN_HEADER: [creds.session_token], + HOST_HEADER: [host], + } + + # Create canonical headers + header_arr = [] + signed_headers_arr = [] + + for k, v in header_map.items(): + lowered_header_name = k.lower() + header_arr.append(f"{lowered_header_name}:{','.join(v)}") + signed_headers_arr.append(lowered_header_name) + + header_arr.sort() + signed_headers_arr.sort() + signed_headers = ";".join(signed_headers_arr) + + canonical_request = "\n".join( + [ + "POST", + "/", + "", # No query string + "\n".join(header_arr) + "\n", + signed_headers, + payload_hash, + ] + ) + + # Create the string to sign + hash_canonical_request = hashlib.sha256(canonical_request.encode("utf-8")).hexdigest() + credential_scope = "/".join( + [ + now.strftime(AMZ_DATE_FORMAT), + region, + SERVICE, + AWS4_REQUEST, + ] + ) + + string_to_sign = self._make_signature( + now, + credential_scope, + hash_canonical_request, + region, + SERVICE, + creds.secret_access_key, + ALGORITHM, + ) + + # Create the authorization header + credential = f"{creds.access_key_id}/{credential_scope}" + auth_header = f"{ALGORITHM} Credential={credential}, SignedHeaders={signed_headers}, Signature={string_to_sign}" + + header_map["Authorization"] = [auth_header] + header_map["User-Agent"] = [self._get_user_agent()] + + headers_json = json.dumps(header_map, separators=(",", ":")) + + return SigningData( + headers_encoded=base64.b64encode(headers_json.encode("utf-8")).decode("utf-8"), + body_encoded=base64.b64encode(request_body.encode("utf-8")).decode("utf-8"), + method="POST", + url_encoded=base64.b64encode(sts_full_url.encode("utf-8")).decode("utf-8"), + ) + + def _make_signature( + self, + t: datetime, + credential_scope: str, + payload_hash: str, + region: str, + service: str, + secret_access_key: str, + algorithm: str, + ) -> str: + """Create AWS signature. + + :param t: Current datetime + :param credential_scope: Credential scope string + :param payload_hash: Hash of the canonical request + :param region: AWS region + :param service: AWS service name + :param secret_access_key: AWS secret access key + :param algorithm: Signing algorithm + :return: Signature string + """ + # Create the string to sign + string_to_sign = "\n".join( + [ + algorithm, + t.strftime(AMZ_DATE_TIME_FORMAT), + credential_scope, + payload_hash, + ] + ) + + # Create the signing key + k_date = self._hmac256(t.strftime(AMZ_DATE_FORMAT), f"AWS4{secret_access_key}".encode("utf-8")) + k_region = self._hmac256(region, k_date) + k_service = self._hmac256(service, k_region) + k_signing = self._hmac256(AWS4_REQUEST, k_service) + + # Sign the string + signature = self._hmac256(string_to_sign, k_signing) + return signature.hex() + + def _hmac256(self, data: str, key: bytes) -> bytes: + """Create HMAC-SHA256 hash. + + :param data: Data to hash + :param key: Key for HMAC + :return: HMAC hash bytes + """ + return hmac.new(key, data.encode("utf-8"), hashlib.sha256).digest() + + def _get_user_agent(self) -> str: + """Get user agent string. + + :return: User agent string + """ + import platform + from datadog_api_client.version import __version__ + + return f"datadog-api-client-python/{__version__} (python {platform.python_version()}; os {platform.system()}; arch {platform.machine()})" diff --git a/src/datadog_api_client/configuration.py b/src/datadog_api_client/configuration.py index d03a705882..f46296add0 100644 --- a/src/datadog_api_client/configuration.py +++ b/src/datadog_api_client/configuration.py @@ -344,6 +344,9 @@ def __init__( } ) + # Delegated token configuration + self.delegated_token_config = None + # Load default values from environment if "DD_SITE" in os.environ: self.server_variables["site"] = os.environ["DD_SITE"] diff --git a/src/datadog_api_client/delegated_auth.py b/src/datadog_api_client/delegated_auth.py new file mode 100644 index 0000000000..3ba4652b62 --- /dev/null +++ b/src/datadog_api_client/delegated_auth.py @@ -0,0 +1,142 @@ +# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License. +# This product includes software developed at Datadog (https://www.datadoghq.com/). +# Copyright 2019-Present Datadog, Inc. + +import json +from datetime import datetime, timedelta +from urllib.parse import urljoin + +from datadog_api_client import rest +from datadog_api_client.configuration import Configuration +from datadog_api_client.exceptions import ApiValueError + + +TOKEN_URL_ENDPOINT = "/api/v2/delegated-token" +AUTHORIZATION_TYPE = "Delegated" +APPLICATION_JSON = "application/json" + + +class DelegatedTokenCredentials: + """Credentials for delegated token authentication.""" + + def __init__(self, org_uuid: str, delegated_token: str, delegated_proof: str, expiration: datetime): + self.org_uuid = org_uuid + self.delegated_token = delegated_token + self.delegated_proof = delegated_proof + self.expiration = expiration + + def is_expired(self) -> bool: + """Check if the token is expired.""" + return datetime.now() >= self.expiration + + +class DelegatedTokenConfig: + """Configuration for delegated token authentication.""" + + def __init__(self, org_uuid: str, provider: str, provider_auth: "DelegatedTokenProvider"): + self.org_uuid = org_uuid + self.provider = provider + self.provider_auth = provider_auth + + +class DelegatedTokenProvider: + """Abstract base class for delegated token providers.""" + + def authenticate(self, config: DelegatedTokenConfig, api_config: Configuration) -> DelegatedTokenCredentials: + """Authenticate and return delegated token credentials. + + :param config: Delegated token configuration + :param api_config: API client configuration with host and other settings + :return: DelegatedTokenCredentials object + """ + raise NotImplementedError("Subclasses must implement authenticate method") + + +def get_delegated_token(org_uuid: str, delegated_auth_proof: str, config: Configuration) -> DelegatedTokenCredentials: + """Get a delegated token from the Datadog API. + + :param org_uuid: Organization UUID + :param delegated_auth_proof: Authentication proof string + :param config: Configuration object with host and other settings + :return: DelegatedTokenCredentials object + :raises: ApiValueError if the request fails + """ + url = get_delegated_token_url(config) + + # Create REST client + rest_client = rest.RESTClientObject(config) + + headers = { + "Content-Type": APPLICATION_JSON, + "Authorization": f"{AUTHORIZATION_TYPE} {delegated_auth_proof}", + "Content-Length": "0", + } + + try: + response = rest_client.request(method="POST", url=url, headers=headers, body="", preload_content=True) + + if response.status != 200: + raise ApiValueError(f"Failed to get token: {response.status}") + + response_data = response.data.decode("utf-8") + creds = parse_delegated_token_response(response_data, org_uuid, delegated_auth_proof) + return creds + + except Exception as e: + raise ApiValueError(f"Failed to get delegated token: {str(e)}") + + +def parse_delegated_token_response( + response_data: str, org_uuid: str, delegated_auth_proof: str +) -> DelegatedTokenCredentials: + """Parse the delegated token response. + + :param response_data: JSON response data as string + :param org_uuid: Organization UUID + :param delegated_auth_proof: Authentication proof string + :return: DelegatedTokenCredentials object + :raises: ApiValueError if parsing fails + """ + try: + token_response = json.loads(response_data) + except json.JSONDecodeError as e: + raise ApiValueError(f"Failed to parse token response: {str(e)}") + + # Get attributes from the response + data_response = token_response.get("data") + if not data_response: + raise ApiValueError(f"Failed to get data from response: {token_response}") + + attributes = data_response.get("attributes") + if not attributes: + raise ApiValueError(f"Failed to get attributes from response: {token_response}") + + # Get the access token from the response + token = attributes.get("access_token") + if not token: + raise ApiValueError(f"Failed to get token from response: {token_response}") + + # get expiration time from the response, defualt to 15 min + expiration_time = datetime.now() + timedelta(minutes=15) + expires_str = attributes.get("expires") + if expires_str: + try: + expiration_int = int(expires_str) + expiration_time = datetime.fromtimestamp(expiration_int) + except (ValueError, TypeError): + # Use default expiration if parsing fails + pass + + return DelegatedTokenCredentials( + org_uuid=org_uuid, delegated_token=token, delegated_proof=delegated_auth_proof, expiration=expiration_time + ) + + +def get_delegated_token_url(config: Configuration) -> str: + """Get the URL for the delegated token endpoint. + + :param config: Configuration object + :return: Full URL for the delegated token endpoint + """ + base_url = config.host + return urljoin(base_url, TOKEN_URL_ENDPOINT) diff --git a/tests/aws_test.py b/tests/aws_test.py new file mode 100644 index 0000000000..62d3bf0626 --- /dev/null +++ b/tests/aws_test.py @@ -0,0 +1,210 @@ +"""Tests for AWS authentication functionality.""" + +import os +import pytest +from unittest.mock import Mock, patch + +from datadog_api_client.aws import ( + AWSAuth, + AWSCredentials, + SigningData, + PROVIDER_AWS, +) +from datadog_api_client.delegated_auth import DelegatedTokenConfig +from datadog_api_client.exceptions import ApiValueError + + +class TestAWSCredentials: + """Test AWSCredentials class.""" + + def test_init(self): + """Test initialization of AWSCredentials.""" + access_key = "test-access-key" + secret_key = "test-secret-key" + session_token = "test-session-token" + + creds = AWSCredentials(access_key, secret_key, session_token) + + assert creds.access_key_id == access_key + assert creds.secret_access_key == secret_key + assert creds.session_token == session_token + + +class TestSigningData: + """Test SigningData class.""" + + def test_init(self): + """Test initialization of SigningData.""" + headers_encoded = "encoded-headers" + body_encoded = "encoded-body" + url_encoded = "encoded-url" + method = "POST" + + data = SigningData(headers_encoded, body_encoded, url_encoded, method) + + assert data.headers_encoded == headers_encoded + assert data.body_encoded == body_encoded + assert data.url_encoded == url_encoded + assert data.method == method + + +class TestAWSAuth: + """Test AWSAuth class.""" + + def test_init_default_region(self): + """Test initialization with default region.""" + auth = AWSAuth() + assert auth.aws_region is None + + def test_init_custom_region(self): + """Test initialization with custom region.""" + region = "us-west-2" + auth = AWSAuth(aws_region=region) + assert auth.aws_region == region + + @patch.dict( + os.environ, + { + "AWS_ACCESS_KEY_ID": "test-access-key", + "AWS_SECRET_ACCESS_KEY": "test-secret-key", + "AWS_SESSION_TOKEN": "test-session-token", + }, + ) + def test_get_credentials_success(self): + """Test successful credential retrieval from environment.""" + auth = AWSAuth() + creds = auth.get_credentials() + + assert creds.access_key_id == "test-access-key" + assert creds.secret_access_key == "test-secret-key" + assert creds.session_token == "test-session-token" + + @patch.dict(os.environ, {}, clear=True) + def test_get_credentials_missing_access_key(self): + """Test credential retrieval with missing access key.""" + auth = AWSAuth() + + with pytest.raises(ApiValueError, match="Missing AWS credentials"): + auth.get_credentials() + + @patch.dict(os.environ, {"AWS_ACCESS_KEY_ID": "test-key"}, clear=True) + def test_get_credentials_missing_secret_key(self): + """Test credential retrieval with missing secret key.""" + auth = AWSAuth() + + with pytest.raises(ApiValueError, match="Missing AWS credentials"): + auth.get_credentials() + + @patch.dict( + os.environ, {"AWS_ACCESS_KEY_ID": "test-access-key", "AWS_SECRET_ACCESS_KEY": "test-secret-key"}, clear=True + ) + def test_get_credentials_missing_session_token(self): + """Test credential retrieval with missing session token.""" + auth = AWSAuth() + + with pytest.raises(ApiValueError, match="Missing AWS credentials"): + auth.get_credentials() + + def test_get_connection_parameters_default_region(self): + """Test connection parameters with default region.""" + auth = AWSAuth() + url, region, host = auth._get_connection_parameters() + + assert url == "https://sts.amazonaws.com" + assert region == "us-east-1" + assert host == "sts.amazonaws.com" + + def test_get_connection_parameters_custom_region(self): + """Test connection parameters with custom region.""" + auth = AWSAuth(aws_region="eu-west-1") + url, region, host = auth._get_connection_parameters() + + assert url == "https://sts.eu-west-1.amazonaws.com" + assert region == "eu-west-1" + assert host == "sts.eu-west-1.amazonaws.com" + + @patch.dict( + os.environ, + { + "AWS_ACCESS_KEY_ID": "AKIAIOSFODNN7EXAMPLE", + "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "AWS_SESSION_TOKEN": "test-session-token", + }, + ) + def test_generate_aws_auth_data(self): + """Test AWS authentication data generation.""" + auth = AWSAuth() + org_uuid = "test-org-uuid" + creds = auth.get_credentials() + + data = auth.generate_aws_auth_data(org_uuid, creds) + + assert isinstance(data, SigningData) + assert data.method == "POST" + assert data.headers_encoded # Should be base64 encoded + assert data.body_encoded # Should be base64 encoded + assert data.url_encoded # Should be base64 encoded + + def test_generate_aws_auth_data_missing_org_uuid(self): + """Test auth data generation with missing org UUID.""" + auth = AWSAuth() + creds = AWSCredentials("key", "secret", "token") + + with pytest.raises(ApiValueError, match="Missing org UUID"): + auth.generate_aws_auth_data("", creds) + + def test_generate_aws_auth_data_missing_credentials(self): + """Test auth data generation with missing credentials.""" + auth = AWSAuth() + + with pytest.raises(ApiValueError, match="Missing AWS credentials"): + auth.generate_aws_auth_data("org-uuid", None) + + def test_hmac256(self): + """Test HMAC-SHA256 generation.""" + auth = AWSAuth() + data = "test-data" + key = b"test-key" + + result = auth._hmac256(data, key) + + assert isinstance(result, bytes) + assert len(result) == 32 # SHA256 produces 32 bytes + + @patch("datadog_api_client.aws.get_delegated_token") + @patch.dict( + os.environ, + { + "AWS_ACCESS_KEY_ID": "test-access-key", + "AWS_SECRET_ACCESS_KEY": "test-secret-key", + "AWS_SESSION_TOKEN": "test-session-token", + }, + ) + def test_authenticate_success(self, mock_get_delegated_token): + """Test successful authentication.""" + # Mock the delegated token response + mock_credentials = Mock() + mock_get_delegated_token.return_value = mock_credentials + + auth = AWSAuth() + config = DelegatedTokenConfig("test-org-uuid", PROVIDER_AWS, auth) + + result = auth.authenticate(config) + + assert result == mock_credentials + mock_get_delegated_token.assert_called_once() + + def test_authenticate_missing_org_uuid(self): + """Test authentication with missing org UUID.""" + auth = AWSAuth() + config = DelegatedTokenConfig("", PROVIDER_AWS, auth) + + with pytest.raises(ApiValueError, match="Missing org UUID in config"): + auth.authenticate(config) + + def test_authenticate_missing_config(self): + """Test authentication with missing config.""" + auth = AWSAuth() + + with pytest.raises(ApiValueError, match="Missing org UUID in config"): + auth.authenticate(None) diff --git a/tests/client_test.py b/tests/client_test.py new file mode 100644 index 0000000000..f56e232cc2 --- /dev/null +++ b/tests/client_test.py @@ -0,0 +1,355 @@ +"""Client integration tests for delegated authentication functionality. + +This test module provides comprehensive integration tests for the ApiClient +with delegated token authentication, similar to the Go client tests. +""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import patch + +from datadog_api_client.api_client import ApiClient +from datadog_api_client.configuration import Configuration +from datadog_api_client.delegated_auth import ( + DelegatedTokenCredentials, + DelegatedTokenConfig, + DelegatedTokenProvider, +) +from datadog_api_client.aws import AWSAuth, PROVIDER_AWS +from datadog_api_client.exceptions import ApiValueError + + +FAKE_TOKEN = "fake-token" +FAKE_ORG_UUID = "1234" +FAKE_PROOF = "proof" + + +class MockDelegatedTokenProvider(DelegatedTokenProvider): + """Mock delegated token provider for testing.""" + + def __init__(self, token=FAKE_TOKEN, org_uuid=FAKE_ORG_UUID, proof=FAKE_PROOF, expiration_minutes=10): + self.token = token + self.org_uuid = org_uuid + self.proof = proof + self.expiration_minutes = expiration_minutes + self.authenticate_calls = [] + + def authenticate(self, config: DelegatedTokenConfig) -> DelegatedTokenCredentials: + """Mock authenticate method.""" + self.authenticate_calls.append(config) + expiration = datetime.now() + timedelta(minutes=self.expiration_minutes) + return DelegatedTokenCredentials( + org_uuid=self.org_uuid, delegated_token=self.token, delegated_proof=self.proof, expiration=expiration + ) + + +class ExpiredMockProvider(MockDelegatedTokenProvider): + """Mock provider that returns expired tokens.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.expiration_minutes = -16 # Expired 16 minutes ago + + +class TestClientDelegatedAuthentication: + """Test client integration with delegated authentication.""" + + def test_delegated_pre_authenticate(self): + """Test delegated token pre-authentication flow.""" + # Create mock provider + mock_provider = MockDelegatedTokenProvider() + + # Create configuration with delegated token config + config = Configuration() + config.delegated_token_config = DelegatedTokenConfig( + org_uuid=FAKE_ORG_UUID, provider=PROVIDER_AWS, provider_auth=mock_provider + ) + + # Create API client + api_client = ApiClient(config) + + # Test _get_delegated_token method + token = api_client._get_delegated_token() + assert token is not None + assert token.delegated_token == FAKE_TOKEN + assert token.org_uuid == FAKE_ORG_UUID + assert token.delegated_proof == FAKE_PROOF + assert not token.is_expired() + + # Verify provider was called + assert len(mock_provider.authenticate_calls) == 1 + assert mock_provider.authenticate_calls[0].org_uuid == FAKE_ORG_UUID + + def test_delegated_no_pre_authenticate(self): + """Test delegated token authentication without pre-authentication.""" + # Create mock provider + mock_provider = MockDelegatedTokenProvider() + + # Create configuration with delegated token config + config = Configuration() + config.delegated_token_config = DelegatedTokenConfig( + org_uuid=FAKE_ORG_UUID, provider=PROVIDER_AWS, provider_auth=mock_provider + ) + + # Create API client + api_client = ApiClient(config) + + # Test use_delegated_token_auth (simulates API call header generation) + headers = {} + api_client.use_delegated_token_auth(headers) + + # Verify headers were set correctly + assert "Authorization" in headers + assert headers["Authorization"] == f"Bearer {FAKE_TOKEN}" + + # Verify provider was called once for token generation + assert len(mock_provider.authenticate_calls) == 1 + + def test_delegated_re_authenticate(self): + """Test delegated token re-authentication when token expires.""" + # Create two mock providers - first with expired token, second with valid token + expired_provider = ExpiredMockProvider() + valid_provider = MockDelegatedTokenProvider() + + # Create configuration + config = Configuration() + config.delegated_token_config = DelegatedTokenConfig( + org_uuid=FAKE_ORG_UUID, provider=PROVIDER_AWS, provider_auth=expired_provider + ) + + # Create API client and get initial (expired) token + api_client = ApiClient(config) + + # First call gets expired token + headers = {} + api_client.use_delegated_token_auth(headers) + assert "Authorization" in headers + assert headers["Authorization"] == f"Bearer {FAKE_TOKEN}" + + # Verify expired token was created + assert hasattr(api_client, "_delegated_token_credentials") + assert api_client._delegated_token_credentials.is_expired() + + # Switch to valid provider for re-authentication + config.delegated_token_config.provider_auth = valid_provider + + # Second call should re-authenticate due to expired token + headers2 = {} + api_client.use_delegated_token_auth(headers2) + assert "Authorization" in headers2 + assert headers2["Authorization"] == f"Bearer {FAKE_TOKEN}" + + # Verify new token is not expired + assert not api_client._delegated_token_credentials.is_expired() + + # Verify both providers were called + assert len(expired_provider.authenticate_calls) == 1 + assert len(valid_provider.authenticate_calls) == 1 + + def test_delegated_token_caching(self): + """Test that delegated tokens are cached and reused when not expired.""" + # Create mock provider + mock_provider = MockDelegatedTokenProvider() + + # Create configuration + config = Configuration() + config.delegated_token_config = DelegatedTokenConfig( + org_uuid=FAKE_ORG_UUID, provider=PROVIDER_AWS, provider_auth=mock_provider + ) + + # Create API client + api_client = ApiClient(config) + + # Multiple calls to use_delegated_token_auth + headers1 = {} + api_client.use_delegated_token_auth(headers1) + + headers2 = {} + api_client.use_delegated_token_auth(headers2) + + headers3 = {} + api_client.use_delegated_token_auth(headers3) + + # All should have same token + assert headers1["Authorization"] == headers2["Authorization"] == headers3["Authorization"] + + # Provider should only be called once (token cached) + assert len(mock_provider.authenticate_calls) == 1 + + @patch.dict( + "os.environ", + { + "AWS_ACCESS_KEY_ID": "fake-access-key-id", + "AWS_SECRET_ACCESS_KEY": "fake-secret-access-key", + "AWS_SESSION_TOKEN": "fake-session-token", + }, + ) + def test_delegated_with_aws_provider(self): + """Test delegated authentication with real AWS provider (mocked token endpoint).""" + # Create AWS auth provider + aws_auth = AWSAuth() + + # Create configuration + config = Configuration() + config.delegated_token_config = DelegatedTokenConfig( + org_uuid=FAKE_ORG_UUID, provider=PROVIDER_AWS, provider_auth=aws_auth + ) + + # Mock the get_delegated_token function called by AWS provider + mock_creds = DelegatedTokenCredentials( + org_uuid=FAKE_ORG_UUID, + delegated_token=FAKE_TOKEN, + delegated_proof=FAKE_PROOF, + expiration=datetime.now() + timedelta(minutes=15), + ) + + with patch("datadog_api_client.aws.get_delegated_token", return_value=mock_creds): + # Create API client + api_client = ApiClient(config) + + # Test token retrieval + token = api_client._get_delegated_token() + assert token.delegated_token == FAKE_TOKEN + assert token.org_uuid == FAKE_ORG_UUID + + # Test header generation + headers = {} + api_client.use_delegated_token_auth(headers) + assert headers["Authorization"] == f"Bearer {FAKE_TOKEN}" + + def test_api_key_authentication_comparison(self): + """Test traditional API key authentication for comparison with delegated auth.""" + # Create configuration with API keys + config = Configuration() + config.api_key = {"apiKeyAuth": "test-api-key", "appKeyAuth": "test-app-key"} + + # Create API client + api_client = ApiClient(config) + + # Test that no delegated auth is used when not configured + headers = {} + api_client.use_delegated_token_auth(headers) + + # Headers should be empty (no delegated token) + assert "Authorization" not in headers + + # This simulates how API keys would be added in actual API calls + # (API key authentication is handled separately in the client) + + def test_delegated_auth_error_handling(self): + """Test error handling in delegated authentication.""" + # Test with no configuration + config = Configuration() + api_client = ApiClient(config) + + with pytest.raises(ApiValueError, match="Delegated token configuration is not set"): + api_client._get_delegated_token() + + # Test with provider that raises an exception + class FailingProvider(DelegatedTokenProvider): + def authenticate(self, config): + raise Exception("Authentication failed") + + config.delegated_token_config = DelegatedTokenConfig( + org_uuid=FAKE_ORG_UUID, provider=PROVIDER_AWS, provider_auth=FailingProvider() + ) + + api_client = ApiClient(config) + with pytest.raises(ApiValueError, match="Failed to get delegated token"): + api_client._get_delegated_token() + + def test_multiple_api_clients_isolation(self): + """Test that multiple API clients with different configs work independently.""" + # Create two different mock providers + provider1 = MockDelegatedTokenProvider(token="token1", org_uuid="org1") + provider2 = MockDelegatedTokenProvider(token="token2", org_uuid="org2") + + # Create two configurations + config1 = Configuration() + config1.delegated_token_config = DelegatedTokenConfig( + org_uuid="org1", provider=PROVIDER_AWS, provider_auth=provider1 + ) + + config2 = Configuration() + config2.delegated_token_config = DelegatedTokenConfig( + org_uuid="org2", provider=PROVIDER_AWS, provider_auth=provider2 + ) + + # Create two API clients + client1 = ApiClient(config1) + client2 = ApiClient(config2) + + # Test that they get different tokens + token1 = client1._get_delegated_token() + token2 = client2._get_delegated_token() + + assert token1.delegated_token == "token1" + assert token1.org_uuid == "org1" + assert token2.delegated_token == "token2" + assert token2.org_uuid == "org2" + + # Test header generation for both + headers1 = {} + client1.use_delegated_token_auth(headers1) + + headers2 = {} + client2.use_delegated_token_auth(headers2) + + assert headers1["Authorization"] == "Bearer token1" + assert headers2["Authorization"] == "Bearer token2" + + +class TestClientDelegatedAuthenticationWithRealMocks: + """Test client integration with more realistic mocking scenarios.""" + + def test_token_refresh_sequence(self): + """Test the complete token refresh sequence over time.""" + # Create a provider that tracks call count + call_count = 0 + + class CountingProvider(DelegatedTokenProvider): + def authenticate(self, config): + nonlocal call_count + call_count += 1 + + # First call returns short-lived token, second call returns long-lived + if call_count == 1: + expiration = datetime.now() + timedelta(seconds=1) # Very short lived + else: + expiration = datetime.now() + timedelta(minutes=30) # Long lived + + return DelegatedTokenCredentials( + org_uuid=config.org_uuid, + delegated_token=f"token-{call_count}", + delegated_proof="proof", + expiration=expiration, + ) + + # Setup + config = Configuration() + config.delegated_token_config = DelegatedTokenConfig( + org_uuid=FAKE_ORG_UUID, provider=PROVIDER_AWS, provider_auth=CountingProvider() + ) + + api_client = ApiClient(config) + + # First call + headers1 = {} + api_client.use_delegated_token_auth(headers1) + assert headers1["Authorization"] == "Bearer token-1" + assert call_count == 1 + + # Immediate second call should reuse token + headers2 = {} + api_client.use_delegated_token_auth(headers2) + assert headers2["Authorization"] == "Bearer token-1" + assert call_count == 1 # No new call + + # Force expiration by manipulating the cached token + api_client._delegated_token_credentials.expiration = datetime.now() - timedelta(minutes=1) + + # Third call should refresh due to expiration + headers3 = {} + api_client.use_delegated_token_auth(headers3) + assert headers3["Authorization"] == "Bearer token-2" + assert call_count == 2 # New call made diff --git a/tests/delegated_auth_test.py b/tests/delegated_auth_test.py new file mode 100644 index 0000000000..a6c4451ab4 --- /dev/null +++ b/tests/delegated_auth_test.py @@ -0,0 +1,140 @@ +"""Tests for delegated authentication functionality.""" + +import json +import pytest +from datetime import datetime, timedelta +from unittest.mock import Mock, patch + +from datadog_api_client.delegated_auth import ( + DelegatedTokenCredentials, + DelegatedTokenConfig, + DelegatedTokenProvider, + get_delegated_token, + parse_delegated_token_response, + get_delegated_token_url, +) +from datadog_api_client.configuration import Configuration +from datadog_api_client.exceptions import ApiValueError + + +class TestDelegatedTokenCredentials: + """Test DelegatedTokenCredentials class.""" + + def test_init(self): + """Test initialization of DelegatedTokenCredentials.""" + org_uuid = "test-org-uuid" + token = "test-token" + proof = "test-proof" + expiration = datetime.now() + timedelta(minutes=15) + + creds = DelegatedTokenCredentials(org_uuid, token, proof, expiration) + + assert creds.org_uuid == org_uuid + assert creds.delegated_token == token + assert creds.delegated_proof == proof + assert creds.expiration == expiration + + def test_is_expired_false(self): + """Test is_expired returns False for non-expired token.""" + expiration = datetime.now() + timedelta(minutes=15) + creds = DelegatedTokenCredentials("org", "token", "proof", expiration) + + assert not creds.is_expired() + + def test_is_expired_true(self): + """Test is_expired returns True for expired token.""" + expiration = datetime.now() - timedelta(minutes=1) + creds = DelegatedTokenCredentials("org", "token", "proof", expiration) + + assert creds.is_expired() + + +class TestDelegatedTokenConfig: + """Test DelegatedTokenConfig class.""" + + def test_init(self): + """Test initialization of DelegatedTokenConfig.""" + org_uuid = "test-org-uuid" + provider = "aws" + provider_auth = Mock(spec=DelegatedTokenProvider) + + config = DelegatedTokenConfig(org_uuid, provider, provider_auth) + + assert config.org_uuid == org_uuid + assert config.provider == provider + assert config.provider_auth == provider_auth + + +class TestGetDelegatedTokenUrl: + """Test get_delegated_token_url function.""" + + def test_get_delegated_token_url(self): + """Test URL construction.""" + config = Configuration() + config.host = "https://api.datadoghq.com" + + url = get_delegated_token_url(config) + + assert url == "https://api.datadoghq.com/api/v2/delegated-token" + + +class TestParseDelegatedTokenResponse: + """Test parse_delegated_token_response function.""" + + def test_parse_valid_response(self): + """Test parsing valid token response.""" + org_uuid = "test-org-uuid" + proof = "test-proof" + token = "test-access-token" + expires = str(int((datetime.now() + timedelta(minutes=15)).timestamp())) + + response_data = json.dumps({"data": {"attributes": {"access_token": token, "expires": expires}}}) + + creds = parse_delegated_token_response(response_data, org_uuid, proof) + + assert creds.org_uuid == org_uuid + assert creds.delegated_token == token + assert creds.delegated_proof == proof + assert isinstance(creds.expiration, datetime) + + def test_parse_invalid_json(self): + """Test parsing invalid JSON response.""" + with pytest.raises(ApiValueError, match="Failed to parse token response"): + parse_delegated_token_response("invalid json", "org", "proof") + + +class TestGetDelegatedToken: + """Test get_delegated_token function.""" + + @patch("datadog_api_client.delegated_auth.rest.RESTClientObject") + def test_get_delegated_token_success(self, mock_rest_client_class): + """Test successful token retrieval.""" + mock_rest_client = Mock() + mock_rest_client_class.return_value = mock_rest_client + + mock_response = Mock() + mock_response.status = 200 + mock_response.data = json.dumps({"data": {"attributes": {"access_token": "test-token"}}}).encode("utf-8") + mock_rest_client.request.return_value = mock_response + + org_uuid = "test-org-uuid" + proof = "test-proof" + + creds = get_delegated_token(org_uuid, proof) + + assert creds.org_uuid == org_uuid + assert creds.delegated_token == "test-token" + assert creds.delegated_proof == proof + + @patch("datadog_api_client.delegated_auth.rest.RESTClientObject") + def test_get_delegated_token_http_error(self, mock_rest_client_class): + """Test token retrieval with HTTP error.""" + mock_rest_client = Mock() + mock_rest_client_class.return_value = mock_rest_client + + mock_response = Mock() + mock_response.status = 401 + mock_rest_client.request.return_value = mock_response + + with pytest.raises(ApiValueError, match="Failed to get token: 401"): + get_delegated_token("org", "proof")