From f249e4106c88f57cf1af418f7c66f5fb15859f6c Mon Sep 17 00:00:00 2001 From: Juskeerat Anand Date: Thu, 25 Sep 2025 11:05:12 -0400 Subject: [PATCH 01/10] changed template files + generate --- .generator/src/generator/cli.py | 2 + .../src/generator/templates/api_client.j2 | 70 +++- .generator/src/generator/templates/aws.j2 | 258 ++++++++++++ .../src/generator/templates/configuration.j2 | 3 + .../src/generator/templates/delegated_auth.j2 | 145 +++++++ .../src/generator/templates/example_aws.j2 | 73 ++++ examples/datadog/aws.py | 78 ++++ src/datadog_api_client/api_client.py | 75 +++- src/datadog_api_client/aws.py | 277 +++++++++++++ src/datadog_api_client/configuration.py | 3 + src/datadog_api_client/delegated_auth.py | 138 +++++++ tests/client_test.py | 376 ++++++++++++++++++ tests/test_aws.py | 210 ++++++++++ tests/test_delegated_auth.py | 140 +++++++ 14 files changed, 1824 insertions(+), 24 deletions(-) create mode 100644 .generator/src/generator/templates/aws.j2 create mode 100644 .generator/src/generator/templates/delegated_auth.j2 create mode 100644 .generator/src/generator/templates/example_aws.j2 create mode 100644 examples/datadog/aws.py create mode 100644 src/datadog_api_client/aws.py create mode 100644 src/datadog_api_client/delegated_auth.py create mode 100644 tests/client_test.py create mode 100644 tests/test_aws.py create mode 100644 tests/test_delegated_auth.py diff --git a/.generator/src/generator/cli.py b/.generator/src/generator/cli.py index 215d0cc156..982f6a9eac 100644 --- a/.generator/src/generator/cli.py +++ b/.generator/src/generator/cli.py @@ -72,6 +72,8 @@ def cli(specs, output): "exceptions.py": env.get_template("exceptions.j2"), "model_utils.py": env.get_template("model_utils.j2"), "rest.py": env.get_template("rest.j2"), + "delegated_auth.py": env.get_template("delegated_auth.j2"), + "aws.py": env.get_template("aws.j2"), } top_package = output / PACKAGE_NAME diff --git a/.generator/src/generator/templates/api_client.j2 b/.generator/src/generator/templates/api_client.j2 index 77b7b93904..68b5c4203b 100644 --- a/.generator/src/generator/templates/api_client.j2 +++ b/.generator/src/generator/templates/api_client.j2 @@ -454,6 +454,44 @@ class ApiClient: return "application/json" return content_types[0] + def use_delegated_token_auth(self, headers: Dict[str, Any]) -> None: + """Use delegated token authentication if configured. + + :param headers: Header parameters dict to be updated. + :raises: ApiValueError if delegated token authentication fails + """ + if not self.configuration.delegated_token_config: + return + + from {{ package }}.delegated_auth import DelegatedTokenCredentials + from datetime import datetime + + # Get or create delegated token credentials + if not hasattr(self, '_delegated_token_credentials') or self._delegated_token_credentials is None: + self._delegated_token_credentials = self._get_delegated_token() + elif self._delegated_token_credentials.is_expired(): + # Token is expired, get a new one + self._delegated_token_credentials = self._get_delegated_token() + + # Set the Authorization header with the delegated token + headers["Authorization"] = f"Bearer {self._delegated_token_credentials.delegated_token}" + + def _get_delegated_token(self) -> 'DelegatedTokenCredentials': + """Get a new delegated token using the configured provider. + + :return: DelegatedTokenCredentials object + :raises: ApiValueError if token retrieval fails + """ + if not self.configuration.delegated_token_config: + raise ApiValueError("Delegated token configuration is not set") + + try: + return self.configuration.delegated_token_config.provider_auth.authenticate( + self.configuration.delegated_token_config + ) + except Exception as e: + raise ApiValueError(f"Failed to get delegated token: {str(e)}") + class ThreadedApiClient(ApiClient): @@ -824,18 +862,26 @@ class Endpoint: if not self.settings["auth"]: return - for auth in self.settings["auth"]: - auth_setting = self.api_client.configuration.auth_settings().get(auth) - if auth_setting: - if auth_setting["in"] == "header": - if auth_setting["type"] != "http-signature": - if auth_setting["value"] is None: - raise ApiValueError("Invalid authentication token for {}".format(auth_setting["key"])) - headers[auth_setting["key"]] = auth_setting["value"] - elif auth_setting["in"] == "query": - queries.append((auth_setting["key"], auth_setting["value"])) - else: - raise ApiValueError("Authentication token must be in `query` or `header`") + # Check if this endpoint uses appKeyAuth and if delegated token config is available + has_app_key_auth = "appKeyAuth" in self.settings["auth"] + + if has_app_key_auth and self.api_client.configuration.delegated_token_config is not None: + # Use delegated token authentication + self.api_client.use_delegated_token_auth(headers) + else: + # Use regular authentication + for auth in self.settings["auth"]: + auth_setting = self.api_client.configuration.auth_settings().get(auth) + if auth_setting: + if auth_setting["in"] == "header": + if auth_setting["type"] != "http-signature": + if auth_setting["value"] is None: + raise ApiValueError("Invalid authentication token for {}".format(auth_setting["key"])) + headers[auth_setting["key"]] = auth_setting["value"] + elif auth_setting["in"] == "query": + queries.append((auth_setting["key"], auth_setting["value"])) + else: + raise ApiValueError("Authentication token must be in `query` or `header`") def user_agent() -> str: diff --git a/.generator/src/generator/templates/aws.j2 b/.generator/src/generator/templates/aws.j2 new file mode 100644 index 0000000000..853c461153 --- /dev/null +++ b/.generator/src/generator/templates/aws.j2 @@ -0,0 +1,258 @@ +{% include "api_info.j2" %} + +import base64 +import hashlib +import hmac +import json +import os +from datetime import datetime +from typing import Dict, List, Optional +from urllib.parse import quote + +from {{ package }}.delegated_auth import DelegatedTokenProvider, DelegatedTokenConfig, DelegatedTokenCredentials, get_delegated_token +from {{ package }}.exceptions import ApiValueError + + +# AWS specific constants +AWS_ACCESS_KEY_ID_NAME = "AWS_ACCESS_KEY_ID" +AWS_SECRET_ACCESS_KEY_NAME = "AWS_SECRET_ACCESS_KEY" +AWS_SESSION_TOKEN_NAME = "AWS_SESSION_TOKEN" + +AMZ_DATE_HEADER = "X-Amz-Date" +AMZ_TOKEN_HEADER = "X-Amz-Security-Token" +AMZ_DATE_FORMAT = "%Y%m%d" +AMZ_DATE_TIME_FORMAT = "%Y%m%dT%H%M%SZ" +DEFAULT_REGION = "us-east-1" +DEFAULT_STS_HOST = "sts.amazonaws.com" +REGIONAL_STS_HOST = "sts.{}.amazonaws.com" +SERVICE = "sts" +ALGORITHM = "AWS4-HMAC-SHA256" +AWS4_REQUEST = "aws4_request" +GET_CALLER_IDENTITY_BODY = "Action=GetCallerIdentity&Version=2011-06-15" + +# Common Headers +ORG_ID_HEADER = "x-ddog-org-id" +HOST_HEADER = "host" +APPLICATION_FORM = "application/x-www-form-urlencoded; charset=utf-8" + +PROVIDER_AWS = "aws" + + +class AWSCredentials: + """AWS credentials for authentication.""" + + def __init__(self, access_key_id: str, secret_access_key: str, session_token: str): + self.access_key_id = access_key_id + self.secret_access_key = secret_access_key + self.session_token = session_token + + +class SigningData: + """Data structure for AWS signing information.""" + + def __init__(self, headers_encoded: str, body_encoded: str, url_encoded: str, method: str): + self.headers_encoded = headers_encoded + self.body_encoded = body_encoded + self.url_encoded = url_encoded + self.method = method + + +class AWSAuth(DelegatedTokenProvider): + """AWS authentication provider for delegated tokens.""" + + def __init__(self, aws_region: Optional[str] = None): + self.aws_region = aws_region + + def authenticate(self, config: DelegatedTokenConfig) -> DelegatedTokenCredentials: + """Authenticate using AWS credentials and return delegated token credentials. + + :param config: Delegated token configuration + :return: DelegatedTokenCredentials object + :raises: ApiValueError if authentication fails + """ + # Get local AWS Credentials + creds = self.get_credentials() + + if not config or not config.org_uuid: + raise ApiValueError("Missing org UUID in config") + + # Use the credentials to generate the signing data + data = self.generate_aws_auth_data(config.org_uuid, creds) + + # Generate the auth string passed to the token endpoint + auth_string = f"{data.body_encoded}|{data.headers_encoded}|{data.method}|{data.url_encoded}" + + auth_response = get_delegated_token(config.org_uuid, auth_string) + return auth_response + + def get_credentials(self) -> AWSCredentials: + """Get AWS credentials from environment variables. + + :return: AWSCredentials object + :raises: ApiValueError if credentials are missing + """ + access_key = os.getenv(AWS_ACCESS_KEY_ID_NAME) + secret_key = os.getenv(AWS_SECRET_ACCESS_KEY_NAME) + session_token = os.getenv(AWS_SESSION_TOKEN_NAME) + + if not access_key or not secret_key or not session_token: + raise ApiValueError("Missing AWS credentials. Please set AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_SESSION_TOKEN environment variables.") + + return AWSCredentials( + access_key_id=access_key, + secret_access_key=secret_key, + session_token=session_token + ) + + def _get_connection_parameters(self) -> tuple[str, str, str]: + """Get connection parameters for AWS STS. + + :return: Tuple of (sts_full_url, region, host) + """ + region = self.aws_region or DEFAULT_REGION + + if self.aws_region: + host = REGIONAL_STS_HOST.format(region) + else: + host = DEFAULT_STS_HOST + + sts_full_url = f"https://{host}" + return sts_full_url, region, host + + def generate_aws_auth_data(self, org_uuid: str, creds: AWSCredentials) -> SigningData: + """Generate AWS authentication data for signing. + + :param org_uuid: Organization UUID + :param creds: AWS credentials + :return: SigningData object + :raises: ApiValueError if generation fails + """ + if not org_uuid: + raise ApiValueError("Missing org UUID") + + if not creds or not creds.access_key_id or not creds.secret_access_key or not creds.session_token: + raise ApiValueError("Missing AWS credentials") + + sts_full_url, region, host = self._get_connection_parameters() + + now = datetime.utcnow() + + request_body = GET_CALLER_IDENTITY_BODY + payload_hash = hashlib.sha256(request_body.encode('utf-8')).hexdigest() + + # Create the headers that factor into the signing algorithm + header_map = { + "Content-Length": [str(len(request_body))], + "Content-Type": [APPLICATION_FORM], + AMZ_DATE_HEADER: [now.strftime(AMZ_DATE_TIME_FORMAT)], + ORG_ID_HEADER: [org_uuid], + AMZ_TOKEN_HEADER: [creds.session_token], + HOST_HEADER: [host], + } + + # Create canonical headers + header_arr = [] + signed_headers_arr = [] + + for k, v in header_map.items(): + lowered_header_name = k.lower() + header_arr.append(f"{lowered_header_name}:{','.join(v)}") + signed_headers_arr.append(lowered_header_name) + + header_arr.sort() + signed_headers_arr.sort() + signed_headers = ";".join(signed_headers_arr) + + canonical_request = "\n".join([ + "POST", + "/", + "", # No query string + "\n".join(header_arr) + "\n", + signed_headers, + payload_hash, + ]) + + # Create the string to sign + hash_canonical_request = hashlib.sha256(canonical_request.encode('utf-8')).hexdigest() + credential_scope = "/".join([ + now.strftime(AMZ_DATE_FORMAT), + region, + SERVICE, + AWS4_REQUEST, + ]) + + string_to_sign = self._make_signature( + now, + credential_scope, + hash_canonical_request, + region, + SERVICE, + creds.secret_access_key, + ALGORITHM, + ) + + # Create the authorization header + credential = f"{creds.access_key_id}/{credential_scope}" + auth_header = f"{ALGORITHM} Credential={credential}, SignedHeaders={signed_headers}, Signature={string_to_sign}" + + header_map["Authorization"] = [auth_header] + header_map["User-Agent"] = [self._get_user_agent()] + + headers_json = json.dumps(header_map, separators=(',', ':')) + + return SigningData( + headers_encoded=base64.b64encode(headers_json.encode('utf-8')).decode('utf-8'), + body_encoded=base64.b64encode(request_body.encode('utf-8')).decode('utf-8'), + method="POST", + url_encoded=base64.b64encode(sts_full_url.encode('utf-8')).decode('utf-8') + ) + + def _make_signature(self, t: datetime, credential_scope: str, payload_hash: str, + region: str, service: str, secret_access_key: str, algorithm: str) -> str: + """Create AWS signature. + + :param t: Current datetime + :param credential_scope: Credential scope string + :param payload_hash: Hash of the canonical request + :param region: AWS region + :param service: AWS service name + :param secret_access_key: AWS secret access key + :param algorithm: Signing algorithm + :return: Signature string + """ + # Create the string to sign + string_to_sign = "\n".join([ + algorithm, + t.strftime(AMZ_DATE_TIME_FORMAT), + credential_scope, + payload_hash, + ]) + + # Create the signing key + k_date = self._hmac256(t.strftime(AMZ_DATE_FORMAT), f"AWS4{secret_access_key}".encode('utf-8')) + k_region = self._hmac256(region, k_date) + k_service = self._hmac256(service, k_region) + k_signing = self._hmac256(AWS4_REQUEST, k_service) + + # Sign the string + signature = self._hmac256(string_to_sign, k_signing) + return signature.hex() + + def _hmac256(self, data: str, key: bytes) -> bytes: + """Create HMAC-SHA256 hash. + + :param data: Data to hash + :param key: Key for HMAC + :return: HMAC hash bytes + """ + return hmac.new(key, data.encode('utf-8'), hashlib.sha256).digest() + + def _get_user_agent(self) -> str: + """Get user agent string. + + :return: User agent string + """ + import platform + from {{ package }}.version import __version__ + + return f"datadog-api-client-python/{__version__} (python {platform.python_version()}; os {platform.system()}; arch {platform.machine()})" diff --git a/.generator/src/generator/templates/configuration.j2 b/.generator/src/generator/templates/configuration.j2 index c261378471..3fb24de09b 100644 --- a/.generator/src/generator/templates/configuration.j2 +++ b/.generator/src/generator/templates/configuration.j2 @@ -245,6 +245,9 @@ class Configuration: {%- endfor %} }) + # Delegated token configuration + self.delegated_token_config = None + # Load default values from environment if "DD_SITE" in os.environ: self.server_variables["site"] = os.environ["DD_SITE"] diff --git a/.generator/src/generator/templates/delegated_auth.j2 b/.generator/src/generator/templates/delegated_auth.j2 new file mode 100644 index 0000000000..e57d612199 --- /dev/null +++ b/.generator/src/generator/templates/delegated_auth.j2 @@ -0,0 +1,145 @@ +{% include "api_info.j2" %} + +import json +import time +from datetime import datetime, timedelta +from typing import Optional +from urllib.parse import urljoin + +from {{ package }} import rest +from {{ package }}.configuration import Configuration +from {{ package }}.exceptions import ApiValueError + + +TOKEN_URL_ENDPOINT = "/api/v2/delegated-token" +AUTHORIZATION_TYPE = "Delegated" +APPLICATION_JSON = "application/json" + + +class DelegatedTokenCredentials: + """Credentials for delegated token authentication.""" + + def __init__(self, org_uuid: str, delegated_token: str, delegated_proof: str, expiration: datetime): + self.org_uuid = org_uuid + self.delegated_token = delegated_token + self.delegated_proof = delegated_proof + self.expiration = expiration + + def is_expired(self) -> bool: + """Check if the token is expired.""" + return datetime.now() >= self.expiration + + +class DelegatedTokenConfig: + """Configuration for delegated token authentication.""" + + def __init__(self, org_uuid: str, provider: str, provider_auth: 'DelegatedTokenProvider'): + self.org_uuid = org_uuid + self.provider = provider + self.provider_auth = provider_auth + + +class DelegatedTokenProvider: + """Abstract base class for delegated token providers.""" + + def authenticate(self, config: DelegatedTokenConfig) -> DelegatedTokenCredentials: + """Authenticate and return delegated token credentials.""" + raise NotImplementedError("Subclasses must implement authenticate method") + + +def get_delegated_token(org_uuid: str, delegated_auth_proof: str) -> DelegatedTokenCredentials: + """Get a delegated token from the Datadog API. + + :param org_uuid: Organization UUID + :param delegated_auth_proof: Authentication proof string + :return: DelegatedTokenCredentials object + :raises: ApiValueError if the request fails + """ + config = Configuration() + url = get_delegated_token_url(config) + + # Create REST client + rest_client = rest.RESTClientObject(config) + + headers = { + "Content-Type": APPLICATION_JSON, + "Authorization": f"{AUTHORIZATION_TYPE} {delegated_auth_proof}", + "Content-Length": "0" + } + + try: + response = rest_client.request( + method="POST", + url=url, + headers=headers, + body="", + preload_content=True + ) + + if response.status != 200: + raise ApiValueError(f"Failed to get token: {response.status}") + + response_data = response.data.decode('utf-8') + creds = parse_delegated_token_response(response_data, org_uuid, delegated_auth_proof) + return creds + + except Exception as e: + raise ApiValueError(f"Failed to get delegated token: {str(e)}") + + +def parse_delegated_token_response(response_data: str, org_uuid: str, delegated_auth_proof: str) -> DelegatedTokenCredentials: + """Parse the delegated token response. + + :param response_data: JSON response data as string + :param org_uuid: Organization UUID + :param delegated_auth_proof: Authentication proof string + :return: DelegatedTokenCredentials object + :raises: ApiValueError if parsing fails + """ + try: + token_response = json.loads(response_data) + except json.JSONDecodeError as e: + raise ApiValueError(f"Failed to parse token response: {str(e)}") + + # Get attributes from the response + data_response = token_response.get("data") + if not data_response: + raise ApiValueError(f"Failed to get data from response: {token_response}") + + attributes = data_response.get("attributes") + if not attributes: + raise ApiValueError(f"Failed to get attributes from response: {token_response}") + + # Get the access token from the response + token = attributes.get("access_token") + if not token: + raise ApiValueError(f"Failed to get token from response: {token_response}") + + # Get the expiration time from the response + # Default to 15 minutes if the expiration time is not set + expiration_time = datetime.now() + timedelta(minutes=15) + expires_str = attributes.get("expires") + if expires_str: + try: + expiration_int = int(expires_str) + expiration_time = datetime.fromtimestamp(expiration_int) + except (ValueError, TypeError): + # Use default expiration if parsing fails + pass + + return DelegatedTokenCredentials( + org_uuid=org_uuid, + delegated_token=token, + delegated_proof=delegated_auth_proof, + expiration=expiration_time + ) + + +def get_delegated_token_url(config: Configuration) -> str: + """Get the URL for the delegated token endpoint. + + :param config: Configuration object + :return: Full URL for the delegated token endpoint + """ + base_url = config.host + return urljoin(base_url, TOKEN_URL_ENDPOINT) diff --git a/.generator/src/generator/templates/example_aws.j2 b/.generator/src/generator/templates/example_aws.j2 new file mode 100644 index 0000000000..f76247d682 --- /dev/null +++ b/.generator/src/generator/templates/example_aws.j2 @@ -0,0 +1,73 @@ +""" +Example of using AWS authentication with the Datadog API client. + +This example shows how to configure the client to use AWS credentials +for authentication instead of API keys. +""" + +import os +from {{ package }} import ApiClient, Configuration +from {{ package }}.aws import AWSAuth +from {{ package }}.delegated_auth import DelegatedTokenConfig +from {{ package }}.{{ version }}.api.authentication_api import AuthenticationApi + + +def main(): + # Set up AWS credentials in environment variables + # These would typically be set by your AWS environment (EC2 instance role, ECS task role, etc.) + # or explicitly set in your environment: + # os.environ["AWS_ACCESS_KEY_ID"] = "your-access-key-id" + # os.environ["AWS_SECRET_ACCESS_KEY"] = "your-secret-access-key" + # os.environ["AWS_SESSION_TOKEN"] = "your-session-token" + + # Verify AWS credentials are available + if not all([ + os.getenv("AWS_ACCESS_KEY_ID"), + os.getenv("AWS_SECRET_ACCESS_KEY"), + os.getenv("AWS_SESSION_TOKEN") + ]): + print("Error: AWS credentials not found in environment variables.") + print("Please set AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_SESSION_TOKEN") + return + + # Your Datadog organization UUID + # This should be provided by your Datadog administrator + org_uuid = os.getenv("DD_ORG_UUID") + if not org_uuid: + print("Error: DD_ORG_UUID environment variable not set.") + print("Please set your Datadog organization UUID") + return + + # Create AWS authentication provider + # Optionally specify AWS region (defaults to us-east-1) + aws_region = os.getenv("AWS_REGION", "us-east-1") + aws_auth = AWSAuth(aws_region=aws_region) + + # Create delegated token configuration + delegated_config = DelegatedTokenConfig( + org_uuid=org_uuid, + provider="aws", + provider_auth=aws_auth + ) + + # Create configuration and set delegated token config + configuration = Configuration() + configuration.delegated_token_config = delegated_config + + # Create API client with the configuration + with ApiClient(configuration) as api_client: + # Create API instance + api_instance = AuthenticationApi(api_client) + + try: + # Test the authentication by validating credentials + api_response = api_instance.validate() + print("Authentication successful!") + print(f"Valid: {api_response.valid}") + + except Exception as e: + print(f"Authentication failed: {e}") + + +if __name__ == "__main__": + main() diff --git a/examples/datadog/aws.py b/examples/datadog/aws.py new file mode 100644 index 0000000000..88d7800bba --- /dev/null +++ b/examples/datadog/aws.py @@ -0,0 +1,78 @@ +""" +Example of using AWS authentication with the Datadog API client. + +This example shows how to configure the client to use AWS credentials +for authentication instead of API keys. This is ideal for environments +like AWS Ray where you want to use temporary credentials. +""" + +import os +from datadog_api_client import ApiClient, Configuration +from datadog_api_client.aws import AWSAuth +from datadog_api_client.delegated_auth import DelegatedTokenConfig +from datadog_api_client.v1.api.authentication_api import AuthenticationApi + + +def main(): + # Set up AWS credentials in environment variables + # These would typically be set by your AWS environment (EC2 instance role, ECS task role, etc.) + # or explicitly set in your environment: + # os.environ["AWS_ACCESS_KEY_ID"] = "your-access-key-id" + # os.environ["AWS_SECRET_ACCESS_KEY"] = "your-secret-access-key" + # os.environ["AWS_SESSION_TOKEN"] = "your-session-token" + + # Verify AWS credentials are available + if not all([ + os.getenv("AWS_ACCESS_KEY_ID"), + os.getenv("AWS_SECRET_ACCESS_KEY"), + os.getenv("AWS_SESSION_TOKEN") + ]): + print("Error: AWS credentials not found in environment variables.") + print("Please set AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_SESSION_TOKEN") + return + + # Your Datadog organization UUID + # This should be provided by your Datadog administrator + org_uuid = os.getenv("DD_ORG_UUID") + if not org_uuid: + print("Error: DD_ORG_UUID environment variable not set.") + print("Please set your Datadog organization UUID") + return + + # Create AWS authentication provider + # Optionally specify AWS region (defaults to us-east-1) + aws_region = os.getenv("AWS_REGION", "us-east-1") + aws_auth = AWSAuth(aws_region=aws_region) + + # Create delegated token configuration + delegated_config = DelegatedTokenConfig( + org_uuid=org_uuid, + provider="aws", + provider_auth=aws_auth + ) + + # Create configuration and set delegated token config + configuration = Configuration() + configuration.delegated_token_config = delegated_config + + # Optional: Set Datadog site if different from default + site = os.getenv("DD_SITE", "datadoghq.com") + configuration.server_variables["site"] = site + + # Create API client with cloud authentication + with ApiClient(configuration) as api_client: + # Create API instance + api_instance = AuthenticationApi(api_client) + + try: + # Test the authentication by validating the token + response = api_instance.validate() + print("✅ Authentication successful!") + print(f"Valid: {response.valid}") + + except Exception as e: + print(f"❌ Authentication failed: {e}") + + +if __name__ == "__main__": + main() diff --git a/src/datadog_api_client/api_client.py b/src/datadog_api_client/api_client.py index d9794ede9c..326bc2d6b4 100644 --- a/src/datadog_api_client/api_client.py +++ b/src/datadog_api_client/api_client.py @@ -454,6 +454,49 @@ def select_header_content_type(self, content_types: List[str]) -> str: return "application/json" return content_types[0] + def use_delegated_token_auth(self, headers: Dict[str, Any]) -> None: + """Use delegated token authentication if configured. + + :param headers: Header parameters dict to be updated. + :raises: ApiValueError if delegated token authentication fails + """ + if not self.configuration.delegated_token_config: + return + + # Get or create delegated token credentials + if not hasattr(self, "_delegated_token_credentials") or self._delegated_token_credentials is None: + self._delegated_token_credentials = self._get_delegated_token() + elif self._delegated_token_credentials.is_expired(): + # Token is expired, get a new one + self._delegated_token_credentials = self._get_delegated_token() + + # Set the Authorization header with the delegated token + headers["Authorization"] = f"Bearer {self._delegated_token_credentials.delegated_token}" + + def get_delegated_token(self) -> "DelegatedTokenCredentials": + """Get a delegated token using the configured provider (public API). + + :return: DelegatedTokenCredentials object + :raises: ApiValueError if token retrieval fails + """ + return self._get_delegated_token() + + def _get_delegated_token(self) -> "DelegatedTokenCredentials": + """Get a new delegated token using the configured provider. + + :return: DelegatedTokenCredentials object + :raises: ApiValueError if token retrieval fails + """ + if not self.configuration.delegated_token_config: + raise ApiValueError("Delegated token configuration is not set") + + try: + return self.configuration.delegated_token_config.provider_auth.authenticate( + self.configuration.delegated_token_config + ) + except Exception as e: + raise ApiValueError(f"Failed to get delegated token: {str(e)}") + class ThreadedApiClient(ApiClient): _pool = None @@ -822,18 +865,26 @@ def update_params_for_auth(self, headers, queries) -> None: if not self.settings["auth"]: return - for auth in self.settings["auth"]: - auth_setting = self.api_client.configuration.auth_settings().get(auth) - if auth_setting: - if auth_setting["in"] == "header": - if auth_setting["type"] != "http-signature": - if auth_setting["value"] is None: - raise ApiValueError("Invalid authentication token for {}".format(auth_setting["key"])) - headers[auth_setting["key"]] = auth_setting["value"] - elif auth_setting["in"] == "query": - queries.append((auth_setting["key"], auth_setting["value"])) - else: - raise ApiValueError("Authentication token must be in `query` or `header`") + # Check if this endpoint uses appKeyAuth and if delegated token config is available + has_app_key_auth = "appKeyAuth" in self.settings["auth"] + + if has_app_key_auth and self.api_client.configuration.delegated_token_config is not None: + # Use delegated token authentication + self.api_client.use_delegated_token_auth(headers) + else: + # Use regular authentication + for auth in self.settings["auth"]: + auth_setting = self.api_client.configuration.auth_settings().get(auth) + if auth_setting: + if auth_setting["in"] == "header": + if auth_setting["type"] != "http-signature": + if auth_setting["value"] is None: + raise ApiValueError("Invalid authentication token for {}".format(auth_setting["key"])) + headers[auth_setting["key"]] = auth_setting["value"] + elif auth_setting["in"] == "query": + queries.append((auth_setting["key"], auth_setting["value"])) + else: + raise ApiValueError("Authentication token must be in `query` or `header`") def user_agent() -> str: diff --git a/src/datadog_api_client/aws.py b/src/datadog_api_client/aws.py new file mode 100644 index 0000000000..0101f72f9f --- /dev/null +++ b/src/datadog_api_client/aws.py @@ -0,0 +1,277 @@ +# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License. +# This product includes software developed at Datadog (https://www.datadoghq.com/). +# Copyright 2019-Present Datadog, Inc. + +import base64 +import hashlib +import hmac +import json +import os +from datetime import datetime +from typing import Optional + +from datadog_api_client.delegated_auth import ( + DelegatedTokenProvider, + DelegatedTokenConfig, + DelegatedTokenCredentials, + get_delegated_token, +) +from datadog_api_client.exceptions import ApiValueError + + +# AWS specific constants +AWS_ACCESS_KEY_ID_NAME = "AWS_ACCESS_KEY_ID" +AWS_SECRET_ACCESS_KEY_NAME = "AWS_SECRET_ACCESS_KEY" +AWS_SESSION_TOKEN_NAME = "AWS_SESSION_TOKEN" + +AMZ_DATE_HEADER = "X-Amz-Date" +AMZ_TOKEN_HEADER = "X-Amz-Security-Token" +AMZ_DATE_FORMAT = "%Y%m%d" +AMZ_DATE_TIME_FORMAT = "%Y%m%dT%H%M%SZ" +DEFAULT_REGION = "us-east-1" +DEFAULT_STS_HOST = "sts.amazonaws.com" +REGIONAL_STS_HOST = "sts.{}.amazonaws.com" +SERVICE = "sts" +ALGORITHM = "AWS4-HMAC-SHA256" +AWS4_REQUEST = "aws4_request" +GET_CALLER_IDENTITY_BODY = "Action=GetCallerIdentity&Version=2011-06-15" + +# Common Headers +ORG_ID_HEADER = "x-ddog-org-id" +HOST_HEADER = "host" +APPLICATION_FORM = "application/x-www-form-urlencoded; charset=utf-8" + +PROVIDER_AWS = "aws" + + +class AWSCredentials: + """AWS credentials for authentication.""" + + def __init__(self, access_key_id: str, secret_access_key: str, session_token: str): + self.access_key_id = access_key_id + self.secret_access_key = secret_access_key + self.session_token = session_token + + +class SigningData: + """Data structure for AWS signing information.""" + + def __init__(self, headers_encoded: str, body_encoded: str, url_encoded: str, method: str): + self.headers_encoded = headers_encoded + self.body_encoded = body_encoded + self.url_encoded = url_encoded + self.method = method + + +class AWSAuth(DelegatedTokenProvider): + """AWS authentication provider for delegated tokens.""" + + def __init__(self, aws_region: Optional[str] = None): + self.aws_region = aws_region + + def authenticate(self, config: DelegatedTokenConfig) -> DelegatedTokenCredentials: + """Authenticate using AWS credentials and return delegated token credentials. + + :param config: Delegated token configuration + :return: DelegatedTokenCredentials object + :raises: ApiValueError if authentication fails + """ + # Check config first before attempting to get credentials + if not config or not config.org_uuid: + raise ApiValueError("Missing org UUID in config") + + # Get local AWS Credentials + creds = self.get_credentials() + + # Use the credentials to generate the signing data + data = self.generate_aws_auth_data(config.org_uuid, creds) + + # Generate the auth string passed to the token endpoint + auth_string = f"{data.body_encoded}|{data.headers_encoded}|{data.method}|{data.url_encoded}" + + auth_response = get_delegated_token(config.org_uuid, auth_string) + return auth_response + + def get_credentials(self) -> AWSCredentials: + """Get AWS credentials from environment variables. + + :return: AWSCredentials object + :raises: ApiValueError if credentials are missing + """ + access_key = os.getenv(AWS_ACCESS_KEY_ID_NAME) + secret_key = os.getenv(AWS_SECRET_ACCESS_KEY_NAME) + session_token = os.getenv(AWS_SESSION_TOKEN_NAME) + + if not access_key or not secret_key or not session_token: + raise ApiValueError( + "Missing AWS credentials. Please set AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_SESSION_TOKEN environment variables." + ) + + return AWSCredentials(access_key_id=access_key, secret_access_key=secret_key, session_token=session_token) + + def _get_connection_parameters(self) -> tuple[str, str, str]: + """Get connection parameters for AWS STS. + + :return: Tuple of (sts_full_url, region, host) + """ + region = self.aws_region or DEFAULT_REGION + + if self.aws_region: + host = REGIONAL_STS_HOST.format(region) + else: + host = DEFAULT_STS_HOST + + sts_full_url = f"https://{host}" + return sts_full_url, region, host + + def generate_aws_auth_data(self, org_uuid: str, creds: AWSCredentials) -> SigningData: + """Generate AWS authentication data for signing. + + :param org_uuid: Organization UUID + :param creds: AWS credentials + :return: SigningData object + :raises: ApiValueError if generation fails + """ + if not org_uuid: + raise ApiValueError("Missing org UUID") + + if not creds or not creds.access_key_id or not creds.secret_access_key or not creds.session_token: + raise ApiValueError("Missing AWS credentials") + + sts_full_url, region, host = self._get_connection_parameters() + + now = datetime.utcnow() + + request_body = GET_CALLER_IDENTITY_BODY + payload_hash = hashlib.sha256(request_body.encode("utf-8")).hexdigest() + + # Create the headers that factor into the signing algorithm + header_map = { + "Content-Length": [str(len(request_body))], + "Content-Type": [APPLICATION_FORM], + AMZ_DATE_HEADER: [now.strftime(AMZ_DATE_TIME_FORMAT)], + ORG_ID_HEADER: [org_uuid], + AMZ_TOKEN_HEADER: [creds.session_token], + HOST_HEADER: [host], + } + + # Create canonical headers + header_arr = [] + signed_headers_arr = [] + + for k, v in header_map.items(): + lowered_header_name = k.lower() + header_arr.append(f"{lowered_header_name}:{','.join(v)}") + signed_headers_arr.append(lowered_header_name) + + header_arr.sort() + signed_headers_arr.sort() + signed_headers = ";".join(signed_headers_arr) + + canonical_request = "\n".join( + [ + "POST", + "/", + "", # No query string + "\n".join(header_arr) + "\n", + signed_headers, + payload_hash, + ] + ) + + # Create the string to sign + hash_canonical_request = hashlib.sha256(canonical_request.encode("utf-8")).hexdigest() + credential_scope = "/".join( + [ + now.strftime(AMZ_DATE_FORMAT), + region, + SERVICE, + AWS4_REQUEST, + ] + ) + + string_to_sign = self._make_signature( + now, + credential_scope, + hash_canonical_request, + region, + SERVICE, + creds.secret_access_key, + ALGORITHM, + ) + + # Create the authorization header + credential = f"{creds.access_key_id}/{credential_scope}" + auth_header = f"{ALGORITHM} Credential={credential}, SignedHeaders={signed_headers}, Signature={string_to_sign}" + + header_map["Authorization"] = [auth_header] + header_map["User-Agent"] = [self._get_user_agent()] + + headers_json = json.dumps(header_map, separators=(",", ":")) + + return SigningData( + headers_encoded=base64.b64encode(headers_json.encode("utf-8")).decode("utf-8"), + body_encoded=base64.b64encode(request_body.encode("utf-8")).decode("utf-8"), + method="POST", + url_encoded=base64.b64encode(sts_full_url.encode("utf-8")).decode("utf-8"), + ) + + def _make_signature( + self, + t: datetime, + credential_scope: str, + payload_hash: str, + region: str, + service: str, + secret_access_key: str, + algorithm: str, + ) -> str: + """Create AWS signature. + + :param t: Current datetime + :param credential_scope: Credential scope string + :param payload_hash: Hash of the canonical request + :param region: AWS region + :param service: AWS service name + :param secret_access_key: AWS secret access key + :param algorithm: Signing algorithm + :return: Signature string + """ + # Create the string to sign + string_to_sign = "\n".join( + [ + algorithm, + t.strftime(AMZ_DATE_TIME_FORMAT), + credential_scope, + payload_hash, + ] + ) + + # Create the signing key + k_date = self._hmac256(t.strftime(AMZ_DATE_FORMAT), f"AWS4{secret_access_key}".encode("utf-8")) + k_region = self._hmac256(region, k_date) + k_service = self._hmac256(service, k_region) + k_signing = self._hmac256(AWS4_REQUEST, k_service) + + # Sign the string + signature = self._hmac256(string_to_sign, k_signing) + return signature.hex() + + def _hmac256(self, data: str, key: bytes) -> bytes: + """Create HMAC-SHA256 hash. + + :param data: Data to hash + :param key: Key for HMAC + :return: HMAC hash bytes + """ + return hmac.new(key, data.encode("utf-8"), hashlib.sha256).digest() + + def _get_user_agent(self) -> str: + """Get user agent string. + + :return: User agent string + """ + import platform + from datadog_api_client.version import __version__ + + return f"datadog-api-client-python/{__version__} (python {platform.python_version()}; os {platform.system()}; arch {platform.machine()})" diff --git a/src/datadog_api_client/configuration.py b/src/datadog_api_client/configuration.py index d03a705882..f46296add0 100644 --- a/src/datadog_api_client/configuration.py +++ b/src/datadog_api_client/configuration.py @@ -344,6 +344,9 @@ def __init__( } ) + # Delegated token configuration + self.delegated_token_config = None + # Load default values from environment if "DD_SITE" in os.environ: self.server_variables["site"] = os.environ["DD_SITE"] diff --git a/src/datadog_api_client/delegated_auth.py b/src/datadog_api_client/delegated_auth.py new file mode 100644 index 0000000000..33fd2b6deb --- /dev/null +++ b/src/datadog_api_client/delegated_auth.py @@ -0,0 +1,138 @@ +# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License. +# This product includes software developed at Datadog (https://www.datadoghq.com/). +# Copyright 2019-Present Datadog, Inc. + +import json +from datetime import datetime, timedelta +from urllib.parse import urljoin + +from datadog_api_client import rest +from datadog_api_client.configuration import Configuration +from datadog_api_client.exceptions import ApiValueError + + +TOKEN_URL_ENDPOINT = "/api/v2/delegated-token" +AUTHORIZATION_TYPE = "Delegated" +APPLICATION_JSON = "application/json" + + +class DelegatedTokenCredentials: + """Credentials for delegated token authentication.""" + + def __init__(self, org_uuid: str, delegated_token: str, delegated_proof: str, expiration: datetime): + self.org_uuid = org_uuid + self.delegated_token = delegated_token + self.delegated_proof = delegated_proof + self.expiration = expiration + + def is_expired(self) -> bool: + """Check if the token is expired.""" + return datetime.now() >= self.expiration + + +class DelegatedTokenConfig: + """Configuration for delegated token authentication.""" + + def __init__(self, org_uuid: str, provider: str, provider_auth: "DelegatedTokenProvider"): + self.org_uuid = org_uuid + self.provider = provider + self.provider_auth = provider_auth + + +class DelegatedTokenProvider: + """Abstract base class for delegated token providers.""" + + def authenticate(self, config: DelegatedTokenConfig) -> DelegatedTokenCredentials: + """Authenticate and return delegated token credentials.""" + raise NotImplementedError("Subclasses must implement authenticate method") + + +def get_delegated_token(org_uuid: str, delegated_auth_proof: str) -> DelegatedTokenCredentials: + """Get a delegated token from the Datadog API. + + :param org_uuid: Organization UUID + :param delegated_auth_proof: Authentication proof string + :return: DelegatedTokenCredentials object + :raises: ApiValueError if the request fails + """ + config = Configuration() + url = get_delegated_token_url(config) + + # Create REST client + rest_client = rest.RESTClientObject(config) + + headers = { + "Content-Type": APPLICATION_JSON, + "Authorization": f"{AUTHORIZATION_TYPE} {delegated_auth_proof}", + "Content-Length": "0", + } + + try: + response = rest_client.request(method="POST", url=url, headers=headers, body="", preload_content=True) + + if response.status != 200: + raise ApiValueError(f"Failed to get token: {response.status}") + + response_data = response.data.decode("utf-8") + creds = parse_delegated_token_response(response_data, org_uuid, delegated_auth_proof) + return creds + + except Exception as e: + raise ApiValueError(f"Failed to get delegated token: {str(e)}") + + +def parse_delegated_token_response( + response_data: str, org_uuid: str, delegated_auth_proof: str +) -> DelegatedTokenCredentials: + """Parse the delegated token response. + + :param response_data: JSON response data as string + :param org_uuid: Organization UUID + :param delegated_auth_proof: Authentication proof string + :return: DelegatedTokenCredentials object + :raises: ApiValueError if parsing fails + """ + try: + token_response = json.loads(response_data) + except json.JSONDecodeError as e: + raise ApiValueError(f"Failed to parse token response: {str(e)}") + + # Get attributes from the response + data_response = token_response.get("data") + if not data_response: + raise ApiValueError(f"Failed to get data from response: {token_response}") + + attributes = data_response.get("attributes") + if not attributes: + raise ApiValueError(f"Failed to get attributes from response: {token_response}") + + # Get the access token from the response + token = attributes.get("access_token") + if not token: + raise ApiValueError(f"Failed to get token from response: {token_response}") + + # Get the expiration time from the response + # Default to 15 minutes if the expiration time is not set + expiration_time = datetime.now() + timedelta(minutes=15) + expires_str = attributes.get("expires") + if expires_str: + try: + expiration_int = int(expires_str) + expiration_time = datetime.fromtimestamp(expiration_int) + except (ValueError, TypeError): + # Use default expiration if parsing fails + pass + + return DelegatedTokenCredentials( + org_uuid=org_uuid, delegated_token=token, delegated_proof=delegated_auth_proof, expiration=expiration_time + ) + + +def get_delegated_token_url(config: Configuration) -> str: + """Get the URL for the delegated token endpoint. + + :param config: Configuration object + :return: Full URL for the delegated token endpoint + """ + base_url = config.host + return urljoin(base_url, TOKEN_URL_ENDPOINT) diff --git a/tests/client_test.py b/tests/client_test.py new file mode 100644 index 0000000000..77e682a52e --- /dev/null +++ b/tests/client_test.py @@ -0,0 +1,376 @@ +"""Client integration tests for delegated authentication functionality. + +This test module provides comprehensive integration tests for the ApiClient +with delegated token authentication, similar to the Go client tests. +""" + +import pytest +from datetime import datetime, timedelta +from unittest.mock import patch + +from datadog_api_client.api_client import ApiClient +from datadog_api_client.configuration import Configuration +from datadog_api_client.delegated_auth import ( + DelegatedTokenCredentials, + DelegatedTokenConfig, + DelegatedTokenProvider, +) +from datadog_api_client.aws import AWSAuth, PROVIDER_AWS +from datadog_api_client.exceptions import ApiValueError + + +FAKE_TOKEN = "fake-token" +FAKE_ORG_UUID = "1234" +FAKE_PROOF = "proof" + + +class MockDelegatedTokenProvider(DelegatedTokenProvider): + """Mock delegated token provider for testing.""" + + def __init__(self, token=FAKE_TOKEN, org_uuid=FAKE_ORG_UUID, proof=FAKE_PROOF, expiration_minutes=10): + self.token = token + self.org_uuid = org_uuid + self.proof = proof + self.expiration_minutes = expiration_minutes + self.authenticate_calls = [] + + def authenticate(self, config: DelegatedTokenConfig) -> DelegatedTokenCredentials: + """Mock authenticate method.""" + self.authenticate_calls.append(config) + expiration = datetime.now() + timedelta(minutes=self.expiration_minutes) + return DelegatedTokenCredentials( + org_uuid=self.org_uuid, + delegated_token=self.token, + delegated_proof=self.proof, + expiration=expiration + ) + + +class ExpiredMockProvider(MockDelegatedTokenProvider): + """Mock provider that returns expired tokens.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.expiration_minutes = -16 # Expired 16 minutes ago + + +class TestClientDelegatedAuthentication: + """Test client integration with delegated authentication.""" + + def test_delegated_pre_authenticate(self): + """Test delegated token pre-authentication flow.""" + # Create mock provider + mock_provider = MockDelegatedTokenProvider() + + # Create configuration with delegated token config + config = Configuration() + config.delegated_token_config = DelegatedTokenConfig( + org_uuid=FAKE_ORG_UUID, + provider=PROVIDER_AWS, + provider_auth=mock_provider + ) + + # Create API client + api_client = ApiClient(config) + + # Test get_delegated_token method + token = api_client.get_delegated_token() + assert token is not None + assert token.delegated_token == FAKE_TOKEN + assert token.org_uuid == FAKE_ORG_UUID + assert token.delegated_proof == FAKE_PROOF + assert not token.is_expired() + + # Verify provider was called + assert len(mock_provider.authenticate_calls) == 1 + assert mock_provider.authenticate_calls[0].org_uuid == FAKE_ORG_UUID + + def test_delegated_no_pre_authenticate(self): + """Test delegated token authentication without pre-authentication.""" + # Create mock provider + mock_provider = MockDelegatedTokenProvider() + + # Create configuration with delegated token config + config = Configuration() + config.delegated_token_config = DelegatedTokenConfig( + org_uuid=FAKE_ORG_UUID, + provider=PROVIDER_AWS, + provider_auth=mock_provider + ) + + # Create API client + api_client = ApiClient(config) + + # Test use_delegated_token_auth (simulates API call header generation) + headers = {} + api_client.use_delegated_token_auth(headers) + + # Verify headers were set correctly + assert "Authorization" in headers + assert headers["Authorization"] == f"Bearer {FAKE_TOKEN}" + + # Verify provider was called once for token generation + assert len(mock_provider.authenticate_calls) == 1 + + def test_delegated_re_authenticate(self): + """Test delegated token re-authentication when token expires.""" + # Create two mock providers - first with expired token, second with valid token + expired_provider = ExpiredMockProvider() + valid_provider = MockDelegatedTokenProvider() + + # Create configuration + config = Configuration() + config.delegated_token_config = DelegatedTokenConfig( + org_uuid=FAKE_ORG_UUID, + provider=PROVIDER_AWS, + provider_auth=expired_provider + ) + + # Create API client and get initial (expired) token + api_client = ApiClient(config) + + # First call gets expired token + headers = {} + api_client.use_delegated_token_auth(headers) + assert "Authorization" in headers + assert headers["Authorization"] == f"Bearer {FAKE_TOKEN}" + + # Verify expired token was created + assert hasattr(api_client, '_delegated_token_credentials') + assert api_client._delegated_token_credentials.is_expired() + + # Switch to valid provider for re-authentication + config.delegated_token_config.provider_auth = valid_provider + + # Second call should re-authenticate due to expired token + headers2 = {} + api_client.use_delegated_token_auth(headers2) + assert "Authorization" in headers2 + assert headers2["Authorization"] == f"Bearer {FAKE_TOKEN}" + + # Verify new token is not expired + assert not api_client._delegated_token_credentials.is_expired() + + # Verify both providers were called + assert len(expired_provider.authenticate_calls) == 1 + assert len(valid_provider.authenticate_calls) == 1 + + def test_delegated_token_caching(self): + """Test that delegated tokens are cached and reused when not expired.""" + # Create mock provider + mock_provider = MockDelegatedTokenProvider() + + # Create configuration + config = Configuration() + config.delegated_token_config = DelegatedTokenConfig( + org_uuid=FAKE_ORG_UUID, + provider=PROVIDER_AWS, + provider_auth=mock_provider + ) + + # Create API client + api_client = ApiClient(config) + + # Multiple calls to use_delegated_token_auth + headers1 = {} + api_client.use_delegated_token_auth(headers1) + + headers2 = {} + api_client.use_delegated_token_auth(headers2) + + headers3 = {} + api_client.use_delegated_token_auth(headers3) + + # All should have same token + assert headers1["Authorization"] == headers2["Authorization"] == headers3["Authorization"] + + # Provider should only be called once (token cached) + assert len(mock_provider.authenticate_calls) == 1 + + @patch.dict('os.environ', { + 'AWS_ACCESS_KEY_ID': 'fake-access-key-id', + 'AWS_SECRET_ACCESS_KEY': 'fake-secret-access-key', + 'AWS_SESSION_TOKEN': 'fake-session-token' + }) + def test_delegated_with_aws_provider(self): + """Test delegated authentication with real AWS provider (mocked token endpoint).""" + # Create AWS auth provider + aws_auth = AWSAuth() + + # Create configuration + config = Configuration() + config.delegated_token_config = DelegatedTokenConfig( + org_uuid=FAKE_ORG_UUID, + provider=PROVIDER_AWS, + provider_auth=aws_auth + ) + + # Mock the get_delegated_token function called by AWS provider + mock_creds = DelegatedTokenCredentials( + org_uuid=FAKE_ORG_UUID, + delegated_token=FAKE_TOKEN, + delegated_proof=FAKE_PROOF, + expiration=datetime.now() + timedelta(minutes=15) + ) + + with patch('datadog_api_client.aws.get_delegated_token', return_value=mock_creds): + # Create API client + api_client = ApiClient(config) + + # Test token retrieval + token = api_client.get_delegated_token() + assert token.delegated_token == FAKE_TOKEN + assert token.org_uuid == FAKE_ORG_UUID + + # Test header generation + headers = {} + api_client.use_delegated_token_auth(headers) + assert headers["Authorization"] == f"Bearer {FAKE_TOKEN}" + + def test_api_key_authentication_comparison(self): + """Test traditional API key authentication for comparison with delegated auth.""" + # Create configuration with API keys + config = Configuration() + config.api_key = { + "apiKeyAuth": "test-api-key", + "appKeyAuth": "test-app-key" + } + + # Create API client + api_client = ApiClient(config) + + # Test that no delegated auth is used when not configured + headers = {} + api_client.use_delegated_token_auth(headers) + + # Headers should be empty (no delegated token) + assert "Authorization" not in headers + + # This simulates how API keys would be added in actual API calls + # (API key authentication is handled separately in the client) + + def test_delegated_auth_error_handling(self): + """Test error handling in delegated authentication.""" + # Test with no configuration + config = Configuration() + api_client = ApiClient(config) + + with pytest.raises(ApiValueError, match="Delegated token configuration is not set"): + api_client.get_delegated_token() + + # Test with provider that raises an exception + class FailingProvider(DelegatedTokenProvider): + def authenticate(self, config): + raise Exception("Authentication failed") + + config.delegated_token_config = DelegatedTokenConfig( + org_uuid=FAKE_ORG_UUID, + provider=PROVIDER_AWS, + provider_auth=FailingProvider() + ) + + api_client = ApiClient(config) + with pytest.raises(ApiValueError, match="Failed to get delegated token"): + api_client.get_delegated_token() + + def test_multiple_api_clients_isolation(self): + """Test that multiple API clients with different configs work independently.""" + # Create two different mock providers + provider1 = MockDelegatedTokenProvider(token="token1", org_uuid="org1") + provider2 = MockDelegatedTokenProvider(token="token2", org_uuid="org2") + + # Create two configurations + config1 = Configuration() + config1.delegated_token_config = DelegatedTokenConfig( + org_uuid="org1", + provider=PROVIDER_AWS, + provider_auth=provider1 + ) + + config2 = Configuration() + config2.delegated_token_config = DelegatedTokenConfig( + org_uuid="org2", + provider=PROVIDER_AWS, + provider_auth=provider2 + ) + + # Create two API clients + client1 = ApiClient(config1) + client2 = ApiClient(config2) + + # Test that they get different tokens + token1 = client1.get_delegated_token() + token2 = client2.get_delegated_token() + + assert token1.delegated_token == "token1" + assert token1.org_uuid == "org1" + assert token2.delegated_token == "token2" + assert token2.org_uuid == "org2" + + # Test header generation for both + headers1 = {} + client1.use_delegated_token_auth(headers1) + + headers2 = {} + client2.use_delegated_token_auth(headers2) + + assert headers1["Authorization"] == "Bearer token1" + assert headers2["Authorization"] == "Bearer token2" + + +class TestClientDelegatedAuthenticationWithRealMocks: + """Test client integration with more realistic mocking scenarios.""" + + def test_token_refresh_sequence(self): + """Test the complete token refresh sequence over time.""" + # Create a provider that tracks call count + call_count = 0 + + class CountingProvider(DelegatedTokenProvider): + def authenticate(self, config): + nonlocal call_count + call_count += 1 + + # First call returns short-lived token, second call returns long-lived + if call_count == 1: + expiration = datetime.now() + timedelta(seconds=1) # Very short lived + else: + expiration = datetime.now() + timedelta(minutes=30) # Long lived + + return DelegatedTokenCredentials( + org_uuid=config.org_uuid, + delegated_token=f"token-{call_count}", + delegated_proof="proof", + expiration=expiration + ) + + # Setup + config = Configuration() + config.delegated_token_config = DelegatedTokenConfig( + org_uuid=FAKE_ORG_UUID, + provider=PROVIDER_AWS, + provider_auth=CountingProvider() + ) + + api_client = ApiClient(config) + + # First call + headers1 = {} + api_client.use_delegated_token_auth(headers1) + assert headers1["Authorization"] == "Bearer token-1" + assert call_count == 1 + + # Immediate second call should reuse token + headers2 = {} + api_client.use_delegated_token_auth(headers2) + assert headers2["Authorization"] == "Bearer token-1" + assert call_count == 1 # No new call + + # Force expiration by manipulating the cached token + api_client._delegated_token_credentials.expiration = datetime.now() - timedelta(minutes=1) + + # Third call should refresh due to expiration + headers3 = {} + api_client.use_delegated_token_auth(headers3) + assert headers3["Authorization"] == "Bearer token-2" + assert call_count == 2 # New call made diff --git a/tests/test_aws.py b/tests/test_aws.py new file mode 100644 index 0000000000..62d3bf0626 --- /dev/null +++ b/tests/test_aws.py @@ -0,0 +1,210 @@ +"""Tests for AWS authentication functionality.""" + +import os +import pytest +from unittest.mock import Mock, patch + +from datadog_api_client.aws import ( + AWSAuth, + AWSCredentials, + SigningData, + PROVIDER_AWS, +) +from datadog_api_client.delegated_auth import DelegatedTokenConfig +from datadog_api_client.exceptions import ApiValueError + + +class TestAWSCredentials: + """Test AWSCredentials class.""" + + def test_init(self): + """Test initialization of AWSCredentials.""" + access_key = "test-access-key" + secret_key = "test-secret-key" + session_token = "test-session-token" + + creds = AWSCredentials(access_key, secret_key, session_token) + + assert creds.access_key_id == access_key + assert creds.secret_access_key == secret_key + assert creds.session_token == session_token + + +class TestSigningData: + """Test SigningData class.""" + + def test_init(self): + """Test initialization of SigningData.""" + headers_encoded = "encoded-headers" + body_encoded = "encoded-body" + url_encoded = "encoded-url" + method = "POST" + + data = SigningData(headers_encoded, body_encoded, url_encoded, method) + + assert data.headers_encoded == headers_encoded + assert data.body_encoded == body_encoded + assert data.url_encoded == url_encoded + assert data.method == method + + +class TestAWSAuth: + """Test AWSAuth class.""" + + def test_init_default_region(self): + """Test initialization with default region.""" + auth = AWSAuth() + assert auth.aws_region is None + + def test_init_custom_region(self): + """Test initialization with custom region.""" + region = "us-west-2" + auth = AWSAuth(aws_region=region) + assert auth.aws_region == region + + @patch.dict( + os.environ, + { + "AWS_ACCESS_KEY_ID": "test-access-key", + "AWS_SECRET_ACCESS_KEY": "test-secret-key", + "AWS_SESSION_TOKEN": "test-session-token", + }, + ) + def test_get_credentials_success(self): + """Test successful credential retrieval from environment.""" + auth = AWSAuth() + creds = auth.get_credentials() + + assert creds.access_key_id == "test-access-key" + assert creds.secret_access_key == "test-secret-key" + assert creds.session_token == "test-session-token" + + @patch.dict(os.environ, {}, clear=True) + def test_get_credentials_missing_access_key(self): + """Test credential retrieval with missing access key.""" + auth = AWSAuth() + + with pytest.raises(ApiValueError, match="Missing AWS credentials"): + auth.get_credentials() + + @patch.dict(os.environ, {"AWS_ACCESS_KEY_ID": "test-key"}, clear=True) + def test_get_credentials_missing_secret_key(self): + """Test credential retrieval with missing secret key.""" + auth = AWSAuth() + + with pytest.raises(ApiValueError, match="Missing AWS credentials"): + auth.get_credentials() + + @patch.dict( + os.environ, {"AWS_ACCESS_KEY_ID": "test-access-key", "AWS_SECRET_ACCESS_KEY": "test-secret-key"}, clear=True + ) + def test_get_credentials_missing_session_token(self): + """Test credential retrieval with missing session token.""" + auth = AWSAuth() + + with pytest.raises(ApiValueError, match="Missing AWS credentials"): + auth.get_credentials() + + def test_get_connection_parameters_default_region(self): + """Test connection parameters with default region.""" + auth = AWSAuth() + url, region, host = auth._get_connection_parameters() + + assert url == "https://sts.amazonaws.com" + assert region == "us-east-1" + assert host == "sts.amazonaws.com" + + def test_get_connection_parameters_custom_region(self): + """Test connection parameters with custom region.""" + auth = AWSAuth(aws_region="eu-west-1") + url, region, host = auth._get_connection_parameters() + + assert url == "https://sts.eu-west-1.amazonaws.com" + assert region == "eu-west-1" + assert host == "sts.eu-west-1.amazonaws.com" + + @patch.dict( + os.environ, + { + "AWS_ACCESS_KEY_ID": "AKIAIOSFODNN7EXAMPLE", + "AWS_SECRET_ACCESS_KEY": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "AWS_SESSION_TOKEN": "test-session-token", + }, + ) + def test_generate_aws_auth_data(self): + """Test AWS authentication data generation.""" + auth = AWSAuth() + org_uuid = "test-org-uuid" + creds = auth.get_credentials() + + data = auth.generate_aws_auth_data(org_uuid, creds) + + assert isinstance(data, SigningData) + assert data.method == "POST" + assert data.headers_encoded # Should be base64 encoded + assert data.body_encoded # Should be base64 encoded + assert data.url_encoded # Should be base64 encoded + + def test_generate_aws_auth_data_missing_org_uuid(self): + """Test auth data generation with missing org UUID.""" + auth = AWSAuth() + creds = AWSCredentials("key", "secret", "token") + + with pytest.raises(ApiValueError, match="Missing org UUID"): + auth.generate_aws_auth_data("", creds) + + def test_generate_aws_auth_data_missing_credentials(self): + """Test auth data generation with missing credentials.""" + auth = AWSAuth() + + with pytest.raises(ApiValueError, match="Missing AWS credentials"): + auth.generate_aws_auth_data("org-uuid", None) + + def test_hmac256(self): + """Test HMAC-SHA256 generation.""" + auth = AWSAuth() + data = "test-data" + key = b"test-key" + + result = auth._hmac256(data, key) + + assert isinstance(result, bytes) + assert len(result) == 32 # SHA256 produces 32 bytes + + @patch("datadog_api_client.aws.get_delegated_token") + @patch.dict( + os.environ, + { + "AWS_ACCESS_KEY_ID": "test-access-key", + "AWS_SECRET_ACCESS_KEY": "test-secret-key", + "AWS_SESSION_TOKEN": "test-session-token", + }, + ) + def test_authenticate_success(self, mock_get_delegated_token): + """Test successful authentication.""" + # Mock the delegated token response + mock_credentials = Mock() + mock_get_delegated_token.return_value = mock_credentials + + auth = AWSAuth() + config = DelegatedTokenConfig("test-org-uuid", PROVIDER_AWS, auth) + + result = auth.authenticate(config) + + assert result == mock_credentials + mock_get_delegated_token.assert_called_once() + + def test_authenticate_missing_org_uuid(self): + """Test authentication with missing org UUID.""" + auth = AWSAuth() + config = DelegatedTokenConfig("", PROVIDER_AWS, auth) + + with pytest.raises(ApiValueError, match="Missing org UUID in config"): + auth.authenticate(config) + + def test_authenticate_missing_config(self): + """Test authentication with missing config.""" + auth = AWSAuth() + + with pytest.raises(ApiValueError, match="Missing org UUID in config"): + auth.authenticate(None) diff --git a/tests/test_delegated_auth.py b/tests/test_delegated_auth.py new file mode 100644 index 0000000000..a6c4451ab4 --- /dev/null +++ b/tests/test_delegated_auth.py @@ -0,0 +1,140 @@ +"""Tests for delegated authentication functionality.""" + +import json +import pytest +from datetime import datetime, timedelta +from unittest.mock import Mock, patch + +from datadog_api_client.delegated_auth import ( + DelegatedTokenCredentials, + DelegatedTokenConfig, + DelegatedTokenProvider, + get_delegated_token, + parse_delegated_token_response, + get_delegated_token_url, +) +from datadog_api_client.configuration import Configuration +from datadog_api_client.exceptions import ApiValueError + + +class TestDelegatedTokenCredentials: + """Test DelegatedTokenCredentials class.""" + + def test_init(self): + """Test initialization of DelegatedTokenCredentials.""" + org_uuid = "test-org-uuid" + token = "test-token" + proof = "test-proof" + expiration = datetime.now() + timedelta(minutes=15) + + creds = DelegatedTokenCredentials(org_uuid, token, proof, expiration) + + assert creds.org_uuid == org_uuid + assert creds.delegated_token == token + assert creds.delegated_proof == proof + assert creds.expiration == expiration + + def test_is_expired_false(self): + """Test is_expired returns False for non-expired token.""" + expiration = datetime.now() + timedelta(minutes=15) + creds = DelegatedTokenCredentials("org", "token", "proof", expiration) + + assert not creds.is_expired() + + def test_is_expired_true(self): + """Test is_expired returns True for expired token.""" + expiration = datetime.now() - timedelta(minutes=1) + creds = DelegatedTokenCredentials("org", "token", "proof", expiration) + + assert creds.is_expired() + + +class TestDelegatedTokenConfig: + """Test DelegatedTokenConfig class.""" + + def test_init(self): + """Test initialization of DelegatedTokenConfig.""" + org_uuid = "test-org-uuid" + provider = "aws" + provider_auth = Mock(spec=DelegatedTokenProvider) + + config = DelegatedTokenConfig(org_uuid, provider, provider_auth) + + assert config.org_uuid == org_uuid + assert config.provider == provider + assert config.provider_auth == provider_auth + + +class TestGetDelegatedTokenUrl: + """Test get_delegated_token_url function.""" + + def test_get_delegated_token_url(self): + """Test URL construction.""" + config = Configuration() + config.host = "https://api.datadoghq.com" + + url = get_delegated_token_url(config) + + assert url == "https://api.datadoghq.com/api/v2/delegated-token" + + +class TestParseDelegatedTokenResponse: + """Test parse_delegated_token_response function.""" + + def test_parse_valid_response(self): + """Test parsing valid token response.""" + org_uuid = "test-org-uuid" + proof = "test-proof" + token = "test-access-token" + expires = str(int((datetime.now() + timedelta(minutes=15)).timestamp())) + + response_data = json.dumps({"data": {"attributes": {"access_token": token, "expires": expires}}}) + + creds = parse_delegated_token_response(response_data, org_uuid, proof) + + assert creds.org_uuid == org_uuid + assert creds.delegated_token == token + assert creds.delegated_proof == proof + assert isinstance(creds.expiration, datetime) + + def test_parse_invalid_json(self): + """Test parsing invalid JSON response.""" + with pytest.raises(ApiValueError, match="Failed to parse token response"): + parse_delegated_token_response("invalid json", "org", "proof") + + +class TestGetDelegatedToken: + """Test get_delegated_token function.""" + + @patch("datadog_api_client.delegated_auth.rest.RESTClientObject") + def test_get_delegated_token_success(self, mock_rest_client_class): + """Test successful token retrieval.""" + mock_rest_client = Mock() + mock_rest_client_class.return_value = mock_rest_client + + mock_response = Mock() + mock_response.status = 200 + mock_response.data = json.dumps({"data": {"attributes": {"access_token": "test-token"}}}).encode("utf-8") + mock_rest_client.request.return_value = mock_response + + org_uuid = "test-org-uuid" + proof = "test-proof" + + creds = get_delegated_token(org_uuid, proof) + + assert creds.org_uuid == org_uuid + assert creds.delegated_token == "test-token" + assert creds.delegated_proof == proof + + @patch("datadog_api_client.delegated_auth.rest.RESTClientObject") + def test_get_delegated_token_http_error(self, mock_rest_client_class): + """Test token retrieval with HTTP error.""" + mock_rest_client = Mock() + mock_rest_client_class.return_value = mock_rest_client + + mock_response = Mock() + mock_response.status = 401 + mock_rest_client.request.return_value = mock_response + + with pytest.raises(ApiValueError, match="Failed to get token: 401"): + get_delegated_token("org", "proof") From 119ff6346c94711de1f921268b9eda00e20f652d Mon Sep 17 00:00:00 2001 From: juskeeratanand Date: Fri, 26 Sep 2025 15:04:43 -0400 Subject: [PATCH 02/10] rename file --- tests/{test_aws.py => aws_test.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{test_aws.py => aws_test.py} (100%) diff --git a/tests/test_aws.py b/tests/aws_test.py similarity index 100% rename from tests/test_aws.py rename to tests/aws_test.py From 73832a63333a71108fa8d64bc41d416302ba1141 Mon Sep 17 00:00:00 2001 From: juskeeratanand Date: Fri, 26 Sep 2025 15:06:17 -0400 Subject: [PATCH 03/10] rename files to match go client --- tests/{test_delegated_auth.py => delegated_auth_test.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{test_delegated_auth.py => delegated_auth_test.py} (100%) diff --git a/tests/test_delegated_auth.py b/tests/delegated_auth_test.py similarity index 100% rename from tests/test_delegated_auth.py rename to tests/delegated_auth_test.py From dab1f44d8486b4abeb2133259b6c90388d87f081 Mon Sep 17 00:00:00 2001 From: Juskeerat Anand Date: Mon, 29 Sep 2025 13:07:28 -0400 Subject: [PATCH 04/10] fix aws tests --- .generator/conftest.py | 18 +- .generator/src/generator/templates/aws.j2 | 11 +- docs/datadog_api_client.rst | 56 -- examples/datadog/aws.py | 49 +- src/datadog_api_client/api_client.py | 8 - src/datadog_api_client/aws.py | 2 +- tests/client_test.py | 199 +++--- tests/conftest.py | 745 ---------------------- 8 files changed, 131 insertions(+), 957 deletions(-) delete mode 100644 docs/datadog_api_client.rst delete mode 100644 tests/conftest.py diff --git a/.generator/conftest.py b/.generator/conftest.py index 56a1e24057..ada05747ec 100644 --- a/.generator/conftest.py +++ b/.generator/conftest.py @@ -82,9 +82,12 @@ def encode(self, obj): JINJA_ENV.filters["safe_snake_case"] = safe_snake_case JINJA_ENV.globals["format_data_with_schema"] = format_data_with_schema JINJA_ENV.globals["format_parameters"] = format_parameters +JINJA_ENV.globals["package"] = "datadog_api_client" PYTHON_EXAMPLE_J2 = JINJA_ENV.get_template("example.j2") - +DATADOG_EXAMPLES_J2 = { + "aws.py": JINJA_ENV.get_template("example_aws.j2") +} def pytest_bdd_after_scenario(request, feature, scenario): try: @@ -137,6 +140,19 @@ def pytest_bdd_after_scenario(request, feature, scenario): with output.open("w") as f: f.write(data) + for file_name, template in DATADOG_EXAMPLES_J2.items(): + output = ROOT_PATH / "examples" / "datadog" / file_name + output.parent.mkdir(parents=True, exist_ok=True) + + data = template.render( + context=context, + version=version, + scenario=scenario, + operation_spec=operation_spec.spec, + ) + with output.open("w") as f: + f.write(data) + def pytest_bdd_apply_tag(tag, function): """Register tags as custom markers and skip test for '@skip' ones.""" diff --git a/.generator/src/generator/templates/aws.j2 b/.generator/src/generator/templates/aws.j2 index 853c461153..a8518fa468 100644 --- a/.generator/src/generator/templates/aws.j2 +++ b/.generator/src/generator/templates/aws.j2 @@ -65,17 +65,18 @@ class AWSAuth(DelegatedTokenProvider): def authenticate(self, config: DelegatedTokenConfig) -> DelegatedTokenCredentials: """Authenticate using AWS credentials and return delegated token credentials. - + :param config: Delegated token configuration :return: DelegatedTokenCredentials object :raises: ApiValueError if authentication fails """ - # Get local AWS Credentials - creds = self.get_credentials() - + # Check org UUID first if not config or not config.org_uuid: raise ApiValueError("Missing org UUID in config") - + + # Get local AWS Credentials + creds = self.get_credentials() + # Use the credentials to generate the signing data data = self.generate_aws_auth_data(config.org_uuid, creds) diff --git a/docs/datadog_api_client.rst b/docs/datadog_api_client.rst deleted file mode 100644 index 6d1e001200..0000000000 --- a/docs/datadog_api_client.rst +++ /dev/null @@ -1,56 +0,0 @@ -datadog\_api\_client package -============================ - -Subpackages ------------ - -.. toctree:: - :maxdepth: 4 - - datadog_api_client.v1 - datadog_api_client.v2 - -Submodules ----------- - -datadog\_api\_client.api\_client module ---------------------------------------- - -.. automodule:: datadog_api_client.api_client - :members: - :show-inheritance: - -datadog\_api\_client.configuration module ------------------------------------------ - -.. automodule:: datadog_api_client.configuration - :members: - :show-inheritance: - -datadog\_api\_client.exceptions module --------------------------------------- - -.. automodule:: datadog_api_client.exceptions - :members: - :show-inheritance: - -datadog\_api\_client.model\_utils module ----------------------------------------- - -.. automodule:: datadog_api_client.model_utils - :members: - :show-inheritance: - -datadog\_api\_client.rest module --------------------------------- - -.. automodule:: datadog_api_client.rest - :members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: datadog_api_client - :members: - :show-inheritance: diff --git a/examples/datadog/aws.py b/examples/datadog/aws.py index 88d7800bba..010e9e0b55 100644 --- a/examples/datadog/aws.py +++ b/examples/datadog/aws.py @@ -2,15 +2,14 @@ Example of using AWS authentication with the Datadog API client. This example shows how to configure the client to use AWS credentials -for authentication instead of API keys. This is ideal for environments -like AWS Ray where you want to use temporary credentials. +for authentication instead of API keys. """ import os from datadog_api_client import ApiClient, Configuration from datadog_api_client.aws import AWSAuth from datadog_api_client.delegated_auth import DelegatedTokenConfig -from datadog_api_client.v1.api.authentication_api import AuthenticationApi +from datadog_api_client.v2.api.authentication_api import AuthenticationApi def main(): @@ -20,17 +19,13 @@ def main(): # os.environ["AWS_ACCESS_KEY_ID"] = "your-access-key-id" # os.environ["AWS_SECRET_ACCESS_KEY"] = "your-secret-access-key" # os.environ["AWS_SESSION_TOKEN"] = "your-session-token" - + # Verify AWS credentials are available - if not all([ - os.getenv("AWS_ACCESS_KEY_ID"), - os.getenv("AWS_SECRET_ACCESS_KEY"), - os.getenv("AWS_SESSION_TOKEN") - ]): + if not all([os.getenv("AWS_ACCESS_KEY_ID"), os.getenv("AWS_SECRET_ACCESS_KEY"), os.getenv("AWS_SESSION_TOKEN")]): print("Error: AWS credentials not found in environment variables.") print("Please set AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_SESSION_TOKEN") return - + # Your Datadog organization UUID # This should be provided by your Datadog administrator org_uuid = os.getenv("DD_ORG_UUID") @@ -38,40 +33,32 @@ def main(): print("Error: DD_ORG_UUID environment variable not set.") print("Please set your Datadog organization UUID") return - + # Create AWS authentication provider # Optionally specify AWS region (defaults to us-east-1) aws_region = os.getenv("AWS_REGION", "us-east-1") aws_auth = AWSAuth(aws_region=aws_region) - + # Create delegated token configuration - delegated_config = DelegatedTokenConfig( - org_uuid=org_uuid, - provider="aws", - provider_auth=aws_auth - ) - + delegated_config = DelegatedTokenConfig(org_uuid=org_uuid, provider="aws", provider_auth=aws_auth) + # Create configuration and set delegated token config configuration = Configuration() configuration.delegated_token_config = delegated_config - - # Optional: Set Datadog site if different from default - site = os.getenv("DD_SITE", "datadoghq.com") - configuration.server_variables["site"] = site - - # Create API client with cloud authentication + + # Create API client with the configuration with ApiClient(configuration) as api_client: # Create API instance api_instance = AuthenticationApi(api_client) - + try: - # Test the authentication by validating the token - response = api_instance.validate() - print("✅ Authentication successful!") - print(f"Valid: {response.valid}") - + # Test the authentication by validating credentials + api_response = api_instance.validate() + print("Authentication successful!") + print(f"Valid: {api_response.valid}") + except Exception as e: - print(f"❌ Authentication failed: {e}") + print(f"Authentication failed: {e}") if __name__ == "__main__": diff --git a/src/datadog_api_client/api_client.py b/src/datadog_api_client/api_client.py index 326bc2d6b4..3906b8ab1e 100644 --- a/src/datadog_api_client/api_client.py +++ b/src/datadog_api_client/api_client.py @@ -473,14 +473,6 @@ def use_delegated_token_auth(self, headers: Dict[str, Any]) -> None: # Set the Authorization header with the delegated token headers["Authorization"] = f"Bearer {self._delegated_token_credentials.delegated_token}" - def get_delegated_token(self) -> "DelegatedTokenCredentials": - """Get a delegated token using the configured provider (public API). - - :return: DelegatedTokenCredentials object - :raises: ApiValueError if token retrieval fails - """ - return self._get_delegated_token() - def _get_delegated_token(self) -> "DelegatedTokenCredentials": """Get a new delegated token using the configured provider. diff --git a/src/datadog_api_client/aws.py b/src/datadog_api_client/aws.py index 0101f72f9f..abb98810e3 100644 --- a/src/datadog_api_client/aws.py +++ b/src/datadog_api_client/aws.py @@ -76,7 +76,7 @@ def authenticate(self, config: DelegatedTokenConfig) -> DelegatedTokenCredential :return: DelegatedTokenCredentials object :raises: ApiValueError if authentication fails """ - # Check config first before attempting to get credentials + # Check org UUID first if not config or not config.org_uuid: raise ApiValueError("Missing org UUID in config") diff --git a/tests/client_test.py b/tests/client_test.py index 77e682a52e..f56e232cc2 100644 --- a/tests/client_test.py +++ b/tests/client_test.py @@ -26,29 +26,26 @@ class MockDelegatedTokenProvider(DelegatedTokenProvider): """Mock delegated token provider for testing.""" - + def __init__(self, token=FAKE_TOKEN, org_uuid=FAKE_ORG_UUID, proof=FAKE_PROOF, expiration_minutes=10): self.token = token self.org_uuid = org_uuid self.proof = proof self.expiration_minutes = expiration_minutes self.authenticate_calls = [] - + def authenticate(self, config: DelegatedTokenConfig) -> DelegatedTokenCredentials: """Mock authenticate method.""" self.authenticate_calls.append(config) expiration = datetime.now() + timedelta(minutes=self.expiration_minutes) return DelegatedTokenCredentials( - org_uuid=self.org_uuid, - delegated_token=self.token, - delegated_proof=self.proof, - expiration=expiration + org_uuid=self.org_uuid, delegated_token=self.token, delegated_proof=self.proof, expiration=expiration ) class ExpiredMockProvider(MockDelegatedTokenProvider): """Mock provider that returns expired tokens.""" - + def __init__(self, **kwargs): super().__init__(**kwargs) self.expiration_minutes = -16 # Expired 16 minutes ago @@ -61,26 +58,24 @@ def test_delegated_pre_authenticate(self): """Test delegated token pre-authentication flow.""" # Create mock provider mock_provider = MockDelegatedTokenProvider() - + # Create configuration with delegated token config config = Configuration() config.delegated_token_config = DelegatedTokenConfig( - org_uuid=FAKE_ORG_UUID, - provider=PROVIDER_AWS, - provider_auth=mock_provider + org_uuid=FAKE_ORG_UUID, provider=PROVIDER_AWS, provider_auth=mock_provider ) - + # Create API client api_client = ApiClient(config) - - # Test get_delegated_token method - token = api_client.get_delegated_token() + + # Test _get_delegated_token method + token = api_client._get_delegated_token() assert token is not None assert token.delegated_token == FAKE_TOKEN assert token.org_uuid == FAKE_ORG_UUID assert token.delegated_proof == FAKE_PROOF assert not token.is_expired() - + # Verify provider was called assert len(mock_provider.authenticate_calls) == 1 assert mock_provider.authenticate_calls[0].org_uuid == FAKE_ORG_UUID @@ -89,26 +84,24 @@ def test_delegated_no_pre_authenticate(self): """Test delegated token authentication without pre-authentication.""" # Create mock provider mock_provider = MockDelegatedTokenProvider() - + # Create configuration with delegated token config config = Configuration() config.delegated_token_config = DelegatedTokenConfig( - org_uuid=FAKE_ORG_UUID, - provider=PROVIDER_AWS, - provider_auth=mock_provider + org_uuid=FAKE_ORG_UUID, provider=PROVIDER_AWS, provider_auth=mock_provider ) - + # Create API client api_client = ApiClient(config) - + # Test use_delegated_token_auth (simulates API call header generation) headers = {} api_client.use_delegated_token_auth(headers) - + # Verify headers were set correctly assert "Authorization" in headers assert headers["Authorization"] == f"Bearer {FAKE_TOKEN}" - + # Verify provider was called once for token generation assert len(mock_provider.authenticate_calls) == 1 @@ -117,40 +110,38 @@ def test_delegated_re_authenticate(self): # Create two mock providers - first with expired token, second with valid token expired_provider = ExpiredMockProvider() valid_provider = MockDelegatedTokenProvider() - + # Create configuration config = Configuration() config.delegated_token_config = DelegatedTokenConfig( - org_uuid=FAKE_ORG_UUID, - provider=PROVIDER_AWS, - provider_auth=expired_provider + org_uuid=FAKE_ORG_UUID, provider=PROVIDER_AWS, provider_auth=expired_provider ) - + # Create API client and get initial (expired) token api_client = ApiClient(config) - + # First call gets expired token headers = {} api_client.use_delegated_token_auth(headers) assert "Authorization" in headers assert headers["Authorization"] == f"Bearer {FAKE_TOKEN}" - + # Verify expired token was created - assert hasattr(api_client, '_delegated_token_credentials') + assert hasattr(api_client, "_delegated_token_credentials") assert api_client._delegated_token_credentials.is_expired() - + # Switch to valid provider for re-authentication config.delegated_token_config.provider_auth = valid_provider - + # Second call should re-authenticate due to expired token headers2 = {} api_client.use_delegated_token_auth(headers2) assert "Authorization" in headers2 assert headers2["Authorization"] == f"Bearer {FAKE_TOKEN}" - + # Verify new token is not expired assert not api_client._delegated_token_credentials.is_expired() - + # Verify both providers were called assert len(expired_provider.authenticate_calls) == 1 assert len(valid_provider.authenticate_calls) == 1 @@ -159,69 +150,68 @@ def test_delegated_token_caching(self): """Test that delegated tokens are cached and reused when not expired.""" # Create mock provider mock_provider = MockDelegatedTokenProvider() - + # Create configuration config = Configuration() config.delegated_token_config = DelegatedTokenConfig( - org_uuid=FAKE_ORG_UUID, - provider=PROVIDER_AWS, - provider_auth=mock_provider + org_uuid=FAKE_ORG_UUID, provider=PROVIDER_AWS, provider_auth=mock_provider ) - + # Create API client api_client = ApiClient(config) - + # Multiple calls to use_delegated_token_auth headers1 = {} api_client.use_delegated_token_auth(headers1) - + headers2 = {} api_client.use_delegated_token_auth(headers2) - + headers3 = {} api_client.use_delegated_token_auth(headers3) - + # All should have same token assert headers1["Authorization"] == headers2["Authorization"] == headers3["Authorization"] - + # Provider should only be called once (token cached) assert len(mock_provider.authenticate_calls) == 1 - @patch.dict('os.environ', { - 'AWS_ACCESS_KEY_ID': 'fake-access-key-id', - 'AWS_SECRET_ACCESS_KEY': 'fake-secret-access-key', - 'AWS_SESSION_TOKEN': 'fake-session-token' - }) + @patch.dict( + "os.environ", + { + "AWS_ACCESS_KEY_ID": "fake-access-key-id", + "AWS_SECRET_ACCESS_KEY": "fake-secret-access-key", + "AWS_SESSION_TOKEN": "fake-session-token", + }, + ) def test_delegated_with_aws_provider(self): """Test delegated authentication with real AWS provider (mocked token endpoint).""" # Create AWS auth provider aws_auth = AWSAuth() - + # Create configuration config = Configuration() config.delegated_token_config = DelegatedTokenConfig( - org_uuid=FAKE_ORG_UUID, - provider=PROVIDER_AWS, - provider_auth=aws_auth + org_uuid=FAKE_ORG_UUID, provider=PROVIDER_AWS, provider_auth=aws_auth ) - + # Mock the get_delegated_token function called by AWS provider mock_creds = DelegatedTokenCredentials( org_uuid=FAKE_ORG_UUID, delegated_token=FAKE_TOKEN, delegated_proof=FAKE_PROOF, - expiration=datetime.now() + timedelta(minutes=15) + expiration=datetime.now() + timedelta(minutes=15), ) - - with patch('datadog_api_client.aws.get_delegated_token', return_value=mock_creds): + + with patch("datadog_api_client.aws.get_delegated_token", return_value=mock_creds): # Create API client api_client = ApiClient(config) - + # Test token retrieval - token = api_client.get_delegated_token() + token = api_client._get_delegated_token() assert token.delegated_token == FAKE_TOKEN assert token.org_uuid == FAKE_ORG_UUID - + # Test header generation headers = {} api_client.use_delegated_token_auth(headers) @@ -231,144 +221,133 @@ def test_api_key_authentication_comparison(self): """Test traditional API key authentication for comparison with delegated auth.""" # Create configuration with API keys config = Configuration() - config.api_key = { - "apiKeyAuth": "test-api-key", - "appKeyAuth": "test-app-key" - } - + config.api_key = {"apiKeyAuth": "test-api-key", "appKeyAuth": "test-app-key"} + # Create API client api_client = ApiClient(config) - + # Test that no delegated auth is used when not configured headers = {} api_client.use_delegated_token_auth(headers) - + # Headers should be empty (no delegated token) assert "Authorization" not in headers - + # This simulates how API keys would be added in actual API calls # (API key authentication is handled separately in the client) - + def test_delegated_auth_error_handling(self): """Test error handling in delegated authentication.""" # Test with no configuration config = Configuration() api_client = ApiClient(config) - + with pytest.raises(ApiValueError, match="Delegated token configuration is not set"): - api_client.get_delegated_token() - + api_client._get_delegated_token() + # Test with provider that raises an exception class FailingProvider(DelegatedTokenProvider): def authenticate(self, config): raise Exception("Authentication failed") - + config.delegated_token_config = DelegatedTokenConfig( - org_uuid=FAKE_ORG_UUID, - provider=PROVIDER_AWS, - provider_auth=FailingProvider() + org_uuid=FAKE_ORG_UUID, provider=PROVIDER_AWS, provider_auth=FailingProvider() ) - + api_client = ApiClient(config) with pytest.raises(ApiValueError, match="Failed to get delegated token"): - api_client.get_delegated_token() + api_client._get_delegated_token() def test_multiple_api_clients_isolation(self): """Test that multiple API clients with different configs work independently.""" # Create two different mock providers provider1 = MockDelegatedTokenProvider(token="token1", org_uuid="org1") provider2 = MockDelegatedTokenProvider(token="token2", org_uuid="org2") - + # Create two configurations config1 = Configuration() config1.delegated_token_config = DelegatedTokenConfig( - org_uuid="org1", - provider=PROVIDER_AWS, - provider_auth=provider1 + org_uuid="org1", provider=PROVIDER_AWS, provider_auth=provider1 ) - + config2 = Configuration() config2.delegated_token_config = DelegatedTokenConfig( - org_uuid="org2", - provider=PROVIDER_AWS, - provider_auth=provider2 + org_uuid="org2", provider=PROVIDER_AWS, provider_auth=provider2 ) - + # Create two API clients client1 = ApiClient(config1) client2 = ApiClient(config2) - + # Test that they get different tokens - token1 = client1.get_delegated_token() - token2 = client2.get_delegated_token() - + token1 = client1._get_delegated_token() + token2 = client2._get_delegated_token() + assert token1.delegated_token == "token1" assert token1.org_uuid == "org1" assert token2.delegated_token == "token2" assert token2.org_uuid == "org2" - + # Test header generation for both headers1 = {} client1.use_delegated_token_auth(headers1) - + headers2 = {} client2.use_delegated_token_auth(headers2) - + assert headers1["Authorization"] == "Bearer token1" assert headers2["Authorization"] == "Bearer token2" class TestClientDelegatedAuthenticationWithRealMocks: """Test client integration with more realistic mocking scenarios.""" - + def test_token_refresh_sequence(self): """Test the complete token refresh sequence over time.""" # Create a provider that tracks call count call_count = 0 - + class CountingProvider(DelegatedTokenProvider): def authenticate(self, config): nonlocal call_count call_count += 1 - + # First call returns short-lived token, second call returns long-lived if call_count == 1: expiration = datetime.now() + timedelta(seconds=1) # Very short lived else: expiration = datetime.now() + timedelta(minutes=30) # Long lived - + return DelegatedTokenCredentials( org_uuid=config.org_uuid, delegated_token=f"token-{call_count}", delegated_proof="proof", - expiration=expiration + expiration=expiration, ) - + # Setup config = Configuration() config.delegated_token_config = DelegatedTokenConfig( - org_uuid=FAKE_ORG_UUID, - provider=PROVIDER_AWS, - provider_auth=CountingProvider() + org_uuid=FAKE_ORG_UUID, provider=PROVIDER_AWS, provider_auth=CountingProvider() ) - + api_client = ApiClient(config) - + # First call headers1 = {} api_client.use_delegated_token_auth(headers1) assert headers1["Authorization"] == "Bearer token-1" assert call_count == 1 - + # Immediate second call should reuse token headers2 = {} api_client.use_delegated_token_auth(headers2) assert headers2["Authorization"] == "Bearer token-1" assert call_count == 1 # No new call - + # Force expiration by manipulating the cached token api_client._delegated_token_credentials.expiration = datetime.now() - timedelta(minutes=1) - + # Third call should refresh due to expiration headers3 = {} api_client.use_delegated_token_auth(headers3) diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index d0a1727dd1..0000000000 --- a/tests/conftest.py +++ /dev/null @@ -1,745 +0,0 @@ -# coding=utf-8 -"""Define basic fixtures.""" - -import os -import hashlib - -RECORD = os.getenv("RECORD", "false").lower() -SLEEP_AFTER_REQUEST = int(os.getenv("SLEEP_AFTER_REQUEST", "0")) - -# First patch urllib -tracer = None -try: - from ddtrace import patch, tracer - - patch(urllib3=True) - - from pytest import hookimpl - - @hookimpl(hookwrapper=True) - def pytest_terminal_summary(terminalreporter, exitstatus, config): - yield # do normal output - - ci_pipeline_id = os.getenv("GITHUB_RUN_ID", None) - dd_service = os.getenv("DD_SERVICE", None) - if ci_pipeline_id and dd_service: - terminalreporter.ensure_newline() - terminalreporter.section("test reports", purple=True, bold=True) - terminalreporter.line( - "* View test APM traces and detailed time reports on Datadog (can take a few minutes to become available):" - ) - terminalreporter.line( - "* https://app.datadoghq.com/ci/test-runs?query=" - "%40test.service%3A{}%20%40ci.pipeline.id%3A{}&index=citest".format(dd_service, ci_pipeline_id) - ) - -except ImportError: - if os.getenv("CI", "false") == "true" and RECORD == "none": - raise - -import importlib -import functools -import json -import logging -import pathlib -import re -import time -import warnings -from datetime import datetime - -import pytest -from dateutil.relativedelta import relativedelta -from jinja2 import Template, Environment, meta -from pytest_bdd import given, parsers, then, when - -from datadog_api_client import exceptions -from datadog_api_client.api_client import ApiClient -from datadog_api_client.configuration import Configuration -from datadog_api_client.model_utils import OpenApiModel, file_type, data_to_dict - -logging.basicConfig() - -with (pathlib.Path(__file__).parent.parent / ".generator" / "src" / "generator" / "replacement.json").open() as f: - EDGE_CASES = json.load(f) - -PATTERN_ALPHANUM = re.compile(r"[^A-Za-z0-9]+") -PATTERN_DOUBLE_UNDERSCORE = re.compile(r"__+") -PATTERN_LEADING_ALPHA = re.compile(r"(.)([A-Z][a-z]+)") -PATTERN_FOLLOWING_ALPHA = re.compile(r"([a-z0-9])([A-Z])") -PATTERN_WHITESPACE = re.compile(r"\W") -PATTERN_INDEX = re.compile(r"\[([0-9]*)\]") - - -def sleep_after_request(f): - """Sleep after each request.""" - if RECORD == "false" or SLEEP_AFTER_REQUEST <= 0: - return f - - @functools.wraps(f) - def wrapper(*args, **kwargs): - result = f(*args, **kwargs) - time.sleep(SLEEP_AFTER_REQUEST) - return result - - return wrapper - - -def escape_reserved_keyword(word): - """Escape reserved language keywords like openapi generator does it. - - :param word: Word to escape - :return: The escaped word if it was a reserved keyword, the word unchanged otherwise - """ - reserved_keywords = ["from"] - if word in reserved_keywords: - return f"_{word}" - return word - - -def pytest_bdd_after_scenario(request, feature, scenario): - try: - ctx = request.getfixturevalue("context") - except Exception: - return - for undo in reversed(ctx["undo_operations"]): - undo() - - -def pytest_bdd_apply_tag(tag, function): - """Register tags as custom markers and skip test for '@skip' ones.""" - skip_tags = {"skip", "skip-python"} - if RECORD != "none": - # ignore integration-only scenarios if the recording is enabled - skip_tags.add("integration-only") - if RECORD != "false": - skip_tags.add("replay-only") - - if tag in skip_tags: - marker = pytest.mark.skip(reason=f"skipped because of '{tag} in {skip_tags}") - marker(function) - return True - - -def snake_case(value): - for token, replacement in EDGE_CASES.items(): - value = value.replace(token, replacement) - s1 = PATTERN_LEADING_ALPHA.sub(r"\1_\2", value) - s1 = PATTERN_FOLLOWING_ALPHA.sub(r"\1_\2", s1).lower() - s1 = PATTERN_WHITESPACE.sub("_", s1) - s1 = s1.rstrip("_") - return PATTERN_DOUBLE_UNDERSCORE.sub("_", s1) - - -def glom(value, path): - from glom import glom as g - - # replace foo[index].bar by foo.index.bar - path = PATTERN_INDEX.sub(r".\1", path) - if not isinstance(value, dict): - path = ".".join(snake_case(p) for p in path.split(".")) - - # Support top level array indexing - path = re.sub(r"^[.]+", "", path) - - return g(value, path) if path else value - - -def _get_prefix(request): - test_class = request.cls - if test_class: - main = "{}.{}".format(test_class.__name__, request.node.name) - else: - base_name = request.node.__scenario_report__.scenario.name - main = PATTERN_ALPHANUM.sub("_", base_name)[:100] - prefix = "Test-Python" if _disable_recording() else "Test" - return f"{prefix}-{main}" - - -@pytest.fixture -def unique(request, freezed_time): - prefix = _get_prefix(request) - return f"{prefix}-{int(freezed_time.timestamp())}" - - -def relative_time(freezed_time, iso): - time_re = re.compile(r"now( *([+-]) *(\d+)([smhdMy]))?") - - def func(arg): - ret = freezed_time - m = time_re.match(arg) - if m: - if m.group(1): - sign = m.group(2) - num = int(sign + m.group(3)) - unit = m.group(4) - if unit == "s": - ret += relativedelta(seconds=num) - elif unit == "m": - ret += relativedelta(minutes=num) - elif unit == "h": - ret += relativedelta(hours=num) - elif unit == "d": - ret += relativedelta(days=num) - elif unit == "M": - ret += relativedelta(months=num) - elif unit == "y": - ret += relativedelta(years=num) - if iso: - return ret.replace(tzinfo=None) # return datetime object and not string - # NOTE this is not a full ISO 8601 format, but it's enough for our needs - # return ret.strftime('%Y-%m-%dT%H:%M:%S') + ret.strftime('.%f')[:4] + 'Z' - - return int(ret.timestamp()) - return "" - - return func - - -def generate_uuid(freezed_time): - freezed_time_string = str(freezed_time.timestamp()) - return freezed_time_string[:8] + "-0000-0000-0000-" + freezed_time_string[:10] + "00" - - -@pytest.fixture -def context(vcr, unique, freezed_time): - """ - Return a mapping with all defined fixtures, all objects created by `given` steps, - and the undo operations to perform after a test scenario. - """ - unique_hash = hashlib.sha256(unique.encode("utf-8")).hexdigest()[:16] - - # Dirty fix as on_call cassette and API use the `Z` format instead of `+00:00` - is_iso_with_timezone_indicator = "on_call" in unique - - ctx = { - "undo_operations": [], - "unique": unique, - "unique_lower": unique.lower(), - "unique_upper": unique.upper(), - "unique_alnum": PATTERN_ALPHANUM.sub("", unique), - "unique_lower_alnum": PATTERN_ALPHANUM.sub("", unique).lower(), - "unique_upper_alnum": PATTERN_ALPHANUM.sub("", unique).upper(), - "unique_hash": unique_hash, - "timestamp": relative_time(freezed_time, False), - "timeISO": relative_time(freezed_time, True), - "uuid": generate_uuid(freezed_time), - } - - yield ctx - - -@pytest.fixture(scope="session") -def record_mode(request): - """Manage compatibility with DD client libraries.""" - return {"false": "none", "true": "rewrite", "none": "new_episodes"}[RECORD] - - -def _disable_recording(): - """Disable VCR.py integration.""" - return RECORD == "none" - - -@pytest.fixture(scope="session") -def disable_recording(request): - """Disable VCR.py integration. This overrides a pytest-recording fixture.""" - return _disable_recording() - - -@pytest.fixture -def vcr_config(): - config = dict( - filter_headers=( - "DD-API-KEY", - "DD-APPLICATION-KEY", - "User-Agent", - "Accept-Encoding", - ), - match_on=[ - "method", - "scheme", - "host", - "port", - "path", - "query", - "body", - "headers", - ], - ) - if tracer: - from urllib.parse import urlparse - - if hasattr(tracer._writer, "agent_url"): - config["ignore_hosts"] = [urlparse(tracer._writer.agent_url).hostname] - else: - config["ignore_hosts"] = [urlparse(tracer._writer.intake_url).hostname] - - return config - - -@pytest.fixture -def default_cassette_name(default_cassette_name): - return PATTERN_DOUBLE_UNDERSCORE.sub("_", default_cassette_name) - - -@pytest.fixture -def freezed_time(default_cassette_name, record_mode, vcr): - from dateutil import parser - - if record_mode in {"new_episodes", "rewrite"}: - tzinfo = datetime.now().astimezone().tzinfo - freeze_at = datetime.now().replace(tzinfo=tzinfo).isoformat() - if record_mode == "rewrite": - pathlib.Path(vcr._path).parent.mkdir(parents=True, exist_ok=True) - with pathlib.Path(vcr._path).with_suffix(".frozen").open("w+") as f: - f.write(freeze_at) - else: - freeze_file = pathlib.Path(vcr._path).with_suffix(".frozen") - if not freeze_file.exists(): - msg = ( - "Time file '{}' not found: create one setting `RECORD=true` or " "ignore it using `RECORD=none`".format( - freeze_file - ) - ) - raise RuntimeError(msg) - with freeze_file.open("r") as f: - freeze_at = f.readline().strip() - - if not pathlib.Path(vcr._path).exists(): - msg = ( - "Cassette '{}' not found: create one setting `RECORD=true` or " "ignore it using `RECORD=none`".format( - vcr._path - ) - ) - raise RuntimeError(msg) - - return parser.isoparse(freeze_at) - - -def pytest_recording_configure(config, vcr): - from vcr import matchers - from vcr.util import read_body - - is_text_json = matchers._header_checker("text/json") - transformer = matchers._transform_json - - def body(r1, r2): - if is_text_json(r1.headers) and is_text_json(r2.headers): - assert transformer(read_body(r1)) == transformer(read_body(r2)) - else: - matchers.body(r1, r2) - - vcr.matchers["body"] = body - - -@given('a valid "apiKeyAuth" key in the system') -def a_valid_api_key(configuration): - """a valid API key.""" - configuration.api_key["apiKeyAuth"] = os.getenv("DD_TEST_CLIENT_API_KEY", "fake") - - -@given('a valid "appKeyAuth" key in the system') -def a_valid_application_key(configuration): - """a valid Application key.""" - configuration.api_key["appKeyAuth"] = os.getenv("DD_TEST_CLIENT_APP_KEY", "fake") - - -@pytest.fixture(scope="module") -def package_name(api_version): - return "datadog_api_client." + api_version - - -@pytest.fixture(scope="module") -def undo_operations(): - result = {} - for f in pathlib.Path(os.path.dirname(__file__)).rglob("undo.json"): - version = f.parent.parent.name - with f.open() as fp: - data = json.load(fp) - result[version] = {} - for operation_id, settings in data.items(): - undo_settings = settings.get("undo") - undo_settings["base_tag"] = settings.get("tag") - result[version][snake_case(operation_id)] = undo_settings - - return result - - -def build_configuration(): - c = Configuration(return_http_data_only=False, spec_property_naming=True) - c.connection_pool_maxsize = 0 - c.debug = debug = os.getenv("DEBUG") in {"true", "1", "yes", "on"} - c.enable_retry = True - if debug: # enable vcr logs for DEBUG=true - vcr_log = logging.getLogger("vcr") - vcr_log.setLevel(logging.INFO) - if "DD_TEST_SITE" in os.environ: - c.server_index = 2 - c.server_variables["site"] = os.environ["DD_TEST_SITE"] - return c - - -@pytest.fixture -def configuration(): - return build_configuration() - - -@pytest.fixture -def client(configuration): - with ApiClient(configuration) as api_client: - yield api_client - - -def _api_name(value): - value = re.sub(r"[^a-zA-Z0-9]", "", value) - return value + "Api" - - -@given(parsers.parse('an instance of "{name}" API')) -def api(context, package_name, client, name): - """Return an API instance.""" - module_name = snake_case(name) - package = importlib.import_module(f"{package_name}.api.{module_name}_api") - context["api"] = { - "api": getattr(package, _api_name(name))(client), - "package": package_name, - "calls": [], - } - - -@given(parsers.parse('operation "{name}" enabled')) -def operation_enabled(client, name): - """Enable the unstable operation specific in the clause.""" - client.configuration.unstable_operations[snake_case(name)] = True - - -@given(parsers.parse('new "{name}" request')) -def api_request(configuration, context, name): - """Call an endpoint.""" - api = context["api"] - context["api_request"] = { - "api": api["api"], - "request": getattr(api["api"], snake_case(name)), - "args": [], - "kwargs": {}, - "response": (None, None, None), - } - - -@given(parsers.parse("body with value {data}")) -def request_body(context, data): - """Set request body.""" - tpl = Template(data).render(**context) - context["api_request"]["kwargs"]["body"] = tpl - - -@given(parsers.parse('body from file "{path}"')) -def request_body_from_file(context, path, package_name): - """Set request body.""" - version = package_name.split(".")[-1] - with open(os.path.join(os.path.dirname(__file__), version, "features", path)) as f: - data = f.read() - tpl = Template(data).render(**context) - context["api_request"]["kwargs"]["body"] = tpl - - -@given(parsers.parse('request contains "{name}" parameter from "{path}"')) -def request_parameter(context, name, path): - """Set request parameter.""" - context["api_request"]["kwargs"][escape_reserved_keyword(snake_case(name))] = json.dumps(glom(context, path)) - - -@given(parsers.parse('request contains "{name}" parameter with value {value}')) -def request_parameter_with_value(context, name, value): - """Set request parameter.""" - tpl = Template(value).render(**context) - context["api_request"]["kwargs"][escape_reserved_keyword(snake_case(name))] = tpl - - -def assert_no_unparsed(data): - if isinstance(data, list): - for item in data: - assert_no_unparsed(item) - elif isinstance(data, dict): - for item in data.values(): - assert_no_unparsed(item) - elif isinstance(data, OpenApiModel): - assert not data._unparsed - for attr in data._data_store.values(): - assert_no_unparsed(attr) - - -def build_given(version, operation): - @sleep_after_request - def wrapper(context, undo): - name = operation["tag"].replace(" ", "") - module_name = snake_case(operation["tag"]) - operation_name = snake_case(operation["operationId"]) - package_name = f"datadog_api_client.{version}" - - # make sure we have a fresh instance of API client and configuration - configuration = build_configuration() - configuration.api_key["apiKeyAuth"] = os.getenv("DD_TEST_CLIENT_API_KEY", "fake") - configuration.api_key["appKeyAuth"] = os.getenv("DD_TEST_CLIENT_APP_KEY", "fake") - configuration.check_input_type = False - configuration.return_http_data_only = True - - # enable unstable operation - if operation_name in configuration.unstable_operations: - configuration.unstable_operations[operation_name] = True - - package = importlib.import_module(f"{package_name}.api.{module_name}_api") - with ApiClient(configuration) as client: - api = getattr(package, _api_name(name))(client) - operation_method = getattr(api, operation_name) - params_map = getattr(api, f"_{operation_name}_endpoint").params_map - - # perform operation - def build_param(p): - openapi_types = params_map[p["name"]]["openapi_types"] - if "value" in p: - if openapi_types == (file_type,): - filepath = os.path.join( - os.path.dirname(__file__), - version, - "features", - json.loads(Template(p["value"]).render(**context)), - ) - return open(filepath) - return client.deserialize(Template(p["value"]).render(**context), openapi_types, True) - if "source" in p: - return glom(context, p["source"]) - - kwargs = { - escape_reserved_keyword(snake_case(p["name"])): build_param(p) for p in operation.get("parameters", []) - } - result = operation_method(**kwargs) - request_body = kwargs.get("body", "") - - # register undo method - def undo_operation(): - return undo(api, version, operation_name, result, request_body, client=client) - - if tracer: - undo_operation = tracer.wrap(name="undo", resource=operation["step"])(undo_operation) - - context["undo_operations"].append(undo_operation) - - # optional re-shaping - if "source" in operation: - result = glom(result, operation["source"]) - - # store response in fixtures - result_body_json = data_to_dict(result) - context[operation["key"]] = result_body_json - - return wrapper - - -for f in pathlib.Path(os.path.dirname(__file__)).rglob("given.json"): - version = f.parent.parent.name - with f.open() as fp: - for settings in json.load(fp): - given(settings["step"])(build_given(version, settings)) - - -def extract_parameters(kwargs, data, parameter): - if "source" in parameter: - kwargs[parameter["name"]] = glom(data, parameter["source"]) - elif "template" in parameter: - variables = meta.find_undeclared_variables(Environment().parse(parameter["template"])) - ctx = {} - for var in variables: - ctx[var] = glom(data, var) - kwargs[parameter["name"]] = json.loads(Template(parameter["template"]).render(**ctx)) - - -@pytest.fixture -def undo(package_name, undo_operations, client): - """Clean after operation.""" - - def cleanup(api, version, operation_id, response, request, client=client): - operation = undo_operations.get(version, {}).get(operation_id) - if operation_id is None: - raise NotImplementedError((version, operation_id)) - - if operation["type"] is None: - raise NotImplementedError((version, operation_id)) - - if operation["type"] != "unsafe": - return - - # If Undo tag is not the same as the the operation tag. - # For example, Service Accounts use the DisableUser operation to undo, which is part of Users. - if "tag" in operation and operation["base_tag"] != operation["tag"]: - undo_tag = operation["tag"] - undo_name = undo_tag.replace(" ", "") - undo_module_name = snake_case(undo_tag) - undo_package = importlib.import_module(f"{package_name}.api.{undo_module_name}_api") - api = getattr(undo_package, _api_name(undo_name))(client) - - operation_name = snake_case(operation["operationId"]) - method = getattr(api, operation_name) - kwargs = {} - parameters = operation.get("parameters", []) - for parameter in parameters: - if "origin" not in parameter or parameter["origin"] == "response": - extract_parameters(kwargs, response, parameter) - elif parameter["origin"] == "request": - extract_parameters(kwargs, request, parameter) - if operation_name in client.configuration.unstable_operations: - client.configuration.unstable_operations[operation_name] = True - - try: - method(**kwargs) - except exceptions.ApiException as e: - warnings.warn(f"failed undo: {e}") - - yield cleanup - - -@when("the request is sent") -def execute_request(undo, context, client, api_version, request): - """Execute the prepared request.""" - api_request = context["api_request"] - - params_map = getattr(api_request["api"], f'_{api_request["request"].__name__}_endpoint').params_map - for k, v in api_request["kwargs"].items(): - openapi_types = params_map[k]["openapi_types"] - if openapi_types == (file_type,): - filepath = os.path.join(os.path.dirname(__file__), api_version, "features", json.loads(v)) - # We let the GC collects it, this shouldn't be an issue - api_request["kwargs"][k] = open(filepath) - else: - api_request["kwargs"][k] = client.deserialize(v, openapi_types, True) - - try: - response = api_request["request"](*api_request["args"], **api_request["kwargs"]) - # Reserialise the response body to JSON to facilitate test assertions - response_body_json = data_to_dict(response[0]) - api_request["response"] = [response_body_json, response[1], response[2]] - except exceptions.ApiException as e: - # If we have an exception, make a stub response object to use for assertions - # Instead of finding the response class of the method, we use the fact that all - # responses returned have an ordered response of body|status|headers - api_request["response"] = [e.body, e.status, e.headers] - return - - if "skip-validation" not in request.node.__scenario_report__.scenario.tags: - assert_no_unparsed(response[0]) - - api = api_request["api"] - operation_id = api_request["request"].__name__ - response = api_request["response"][0] - request_body = api_request.get("kwargs", {}).get("body", "") - - def undo_operation(): - return undo(api, api_version, operation_id, response, request_body) - - if tracer: - undo_operation = tracer.wrap(name="undo", resource="execute request")(undo_operation) - - context["undo_operations"].append(undo_operation) - - -@when("the request with pagination is sent") -def execute_request_with_pagination(undo, context, client, api_version): - """Execute the prepared paginated request.""" - api_request = context["api_request"] - - params_map = getattr(api_request["api"], f'_{api_request["request"].__name__}_endpoint').params_map - for k, v in api_request["kwargs"].items(): - api_request["kwargs"][k] = client.deserialize(v, params_map[k]["openapi_types"], True) - - kwargs = api_request["kwargs"] - client.configuration.return_http_data_only = True - method = getattr(api_request["api"], f"{api_request['request'].__name__}_with_pagination") - try: - response = list(method(*api_request["args"], **kwargs)) - # Reserialise the response body to JSON to facilitate test assertions - response_body_json = data_to_dict(response) - api_request["response"] = [response_body_json, 200, None] - except exceptions.ApiException as e: - # If we have an exception, make a stub response object to use for assertions - # Instead of finding the response class of the method, we use the fact that all - # responses returned have an ordered response of body|status|headers - api_request["response"] = [e.body, e.status, e.headers] - finally: - client.configuration.return_http_data_only = False - - -@then(parsers.parse("the response status is {status:d} {description}")) -def the_status_is(context, status, description): - """Check the status.""" - assert status == context["api_request"]["response"][1] - - -@then(parsers.parse('the response "{response_path}" is equal to {value}')) -def expect_equal(context, response_path, value): - response_value = glom(context["api_request"]["response"][0], response_path) - test_value = json.loads(Template(value).render(**context)) - assert test_value == response_value - - -@then(parsers.parse('the response "{response_path}" has the same value as "{fixture_path}"')) -def expect_equal_value(context, response_path, fixture_path): - fixture_value = glom(context, fixture_path) - response_value = glom(context["api_request"]["response"][0], response_path) - assert fixture_value == response_value - - -@then(parsers.parse('the response "{response_path}" has length {fixture_length:d}')) -def expect_equal_length(context, response_path, fixture_length): - response_value = glom(context["api_request"]["response"][0], response_path) - assert fixture_length == len(response_value) - - -@then(parsers.parse("the response has {fixture_length:d} items")) -def expect_equal_response_items(context, fixture_length): - response = context["api_request"]["response"][0] - assert fixture_length == len(response) - - -@then(parsers.parse('the response "{response_path}" is false')) -def expect_false(context, response_path): - response_value = glom(context["api_request"]["response"][0], response_path) - assert not response_value - - -@then(parsers.parse('the response "{response_path}" has field "{field}"')) -def expect_response_has_field(context, response_path, field): - """Check that a response path has field.""" - response_value = glom(context["api_request"]["response"][0], response_path) - assert field in response_value - - -@then(parsers.parse('the response "{response_path}" does not have field "{field}"')) -def expect_response_does_not_have_field(context, response_path, field): - """Check that a response path does not have field.""" - response_value = glom(context["api_request"]["response"][0], response_path) - assert field not in response_value - - -@then(parsers.parse('the response "{response_path}" has item with field "{key_path}" with value {value}')) -def expect_array_contains_object(context, response_path, key_path, value): - from glom.core import PathAccessError - - response_value = glom(context["api_request"]["response"][0], response_path) - test_value = json.loads(Template(value).render(**context)) - for response_item in response_value: - try: - response_item_value = glom(response_item, key_path) - if response_item_value == test_value: - return - except PathAccessError: - pass - raise AssertionError(f'could not find key value pair in object array: "{key_path}": "{test_value}"') - - -@then(parsers.parse('the response "{response_path}" array contains value {value}')) -def expect_array_contains_object(context, response_path, value): - response_value = glom(context["api_request"]["response"][0], response_path) - test_value = json.loads(Template(value).render(**context)) - for response_item in response_value: - if response_item == test_value: - return - raise AssertionError(f"could not find value in array: {test_value}") From a5ac2959f5c966e052927bf75db0463642174a12 Mon Sep 17 00:00:00 2001 From: Juskeerat Anand Date: Mon, 29 Sep 2025 13:10:45 -0400 Subject: [PATCH 05/10] fix conftest --- tests/conftest.py | 745 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 745 insertions(+) create mode 100644 tests/conftest.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000..d0a1727dd1 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,745 @@ +# coding=utf-8 +"""Define basic fixtures.""" + +import os +import hashlib + +RECORD = os.getenv("RECORD", "false").lower() +SLEEP_AFTER_REQUEST = int(os.getenv("SLEEP_AFTER_REQUEST", "0")) + +# First patch urllib +tracer = None +try: + from ddtrace import patch, tracer + + patch(urllib3=True) + + from pytest import hookimpl + + @hookimpl(hookwrapper=True) + def pytest_terminal_summary(terminalreporter, exitstatus, config): + yield # do normal output + + ci_pipeline_id = os.getenv("GITHUB_RUN_ID", None) + dd_service = os.getenv("DD_SERVICE", None) + if ci_pipeline_id and dd_service: + terminalreporter.ensure_newline() + terminalreporter.section("test reports", purple=True, bold=True) + terminalreporter.line( + "* View test APM traces and detailed time reports on Datadog (can take a few minutes to become available):" + ) + terminalreporter.line( + "* https://app.datadoghq.com/ci/test-runs?query=" + "%40test.service%3A{}%20%40ci.pipeline.id%3A{}&index=citest".format(dd_service, ci_pipeline_id) + ) + +except ImportError: + if os.getenv("CI", "false") == "true" and RECORD == "none": + raise + +import importlib +import functools +import json +import logging +import pathlib +import re +import time +import warnings +from datetime import datetime + +import pytest +from dateutil.relativedelta import relativedelta +from jinja2 import Template, Environment, meta +from pytest_bdd import given, parsers, then, when + +from datadog_api_client import exceptions +from datadog_api_client.api_client import ApiClient +from datadog_api_client.configuration import Configuration +from datadog_api_client.model_utils import OpenApiModel, file_type, data_to_dict + +logging.basicConfig() + +with (pathlib.Path(__file__).parent.parent / ".generator" / "src" / "generator" / "replacement.json").open() as f: + EDGE_CASES = json.load(f) + +PATTERN_ALPHANUM = re.compile(r"[^A-Za-z0-9]+") +PATTERN_DOUBLE_UNDERSCORE = re.compile(r"__+") +PATTERN_LEADING_ALPHA = re.compile(r"(.)([A-Z][a-z]+)") +PATTERN_FOLLOWING_ALPHA = re.compile(r"([a-z0-9])([A-Z])") +PATTERN_WHITESPACE = re.compile(r"\W") +PATTERN_INDEX = re.compile(r"\[([0-9]*)\]") + + +def sleep_after_request(f): + """Sleep after each request.""" + if RECORD == "false" or SLEEP_AFTER_REQUEST <= 0: + return f + + @functools.wraps(f) + def wrapper(*args, **kwargs): + result = f(*args, **kwargs) + time.sleep(SLEEP_AFTER_REQUEST) + return result + + return wrapper + + +def escape_reserved_keyword(word): + """Escape reserved language keywords like openapi generator does it. + + :param word: Word to escape + :return: The escaped word if it was a reserved keyword, the word unchanged otherwise + """ + reserved_keywords = ["from"] + if word in reserved_keywords: + return f"_{word}" + return word + + +def pytest_bdd_after_scenario(request, feature, scenario): + try: + ctx = request.getfixturevalue("context") + except Exception: + return + for undo in reversed(ctx["undo_operations"]): + undo() + + +def pytest_bdd_apply_tag(tag, function): + """Register tags as custom markers and skip test for '@skip' ones.""" + skip_tags = {"skip", "skip-python"} + if RECORD != "none": + # ignore integration-only scenarios if the recording is enabled + skip_tags.add("integration-only") + if RECORD != "false": + skip_tags.add("replay-only") + + if tag in skip_tags: + marker = pytest.mark.skip(reason=f"skipped because of '{tag} in {skip_tags}") + marker(function) + return True + + +def snake_case(value): + for token, replacement in EDGE_CASES.items(): + value = value.replace(token, replacement) + s1 = PATTERN_LEADING_ALPHA.sub(r"\1_\2", value) + s1 = PATTERN_FOLLOWING_ALPHA.sub(r"\1_\2", s1).lower() + s1 = PATTERN_WHITESPACE.sub("_", s1) + s1 = s1.rstrip("_") + return PATTERN_DOUBLE_UNDERSCORE.sub("_", s1) + + +def glom(value, path): + from glom import glom as g + + # replace foo[index].bar by foo.index.bar + path = PATTERN_INDEX.sub(r".\1", path) + if not isinstance(value, dict): + path = ".".join(snake_case(p) for p in path.split(".")) + + # Support top level array indexing + path = re.sub(r"^[.]+", "", path) + + return g(value, path) if path else value + + +def _get_prefix(request): + test_class = request.cls + if test_class: + main = "{}.{}".format(test_class.__name__, request.node.name) + else: + base_name = request.node.__scenario_report__.scenario.name + main = PATTERN_ALPHANUM.sub("_", base_name)[:100] + prefix = "Test-Python" if _disable_recording() else "Test" + return f"{prefix}-{main}" + + +@pytest.fixture +def unique(request, freezed_time): + prefix = _get_prefix(request) + return f"{prefix}-{int(freezed_time.timestamp())}" + + +def relative_time(freezed_time, iso): + time_re = re.compile(r"now( *([+-]) *(\d+)([smhdMy]))?") + + def func(arg): + ret = freezed_time + m = time_re.match(arg) + if m: + if m.group(1): + sign = m.group(2) + num = int(sign + m.group(3)) + unit = m.group(4) + if unit == "s": + ret += relativedelta(seconds=num) + elif unit == "m": + ret += relativedelta(minutes=num) + elif unit == "h": + ret += relativedelta(hours=num) + elif unit == "d": + ret += relativedelta(days=num) + elif unit == "M": + ret += relativedelta(months=num) + elif unit == "y": + ret += relativedelta(years=num) + if iso: + return ret.replace(tzinfo=None) # return datetime object and not string + # NOTE this is not a full ISO 8601 format, but it's enough for our needs + # return ret.strftime('%Y-%m-%dT%H:%M:%S') + ret.strftime('.%f')[:4] + 'Z' + + return int(ret.timestamp()) + return "" + + return func + + +def generate_uuid(freezed_time): + freezed_time_string = str(freezed_time.timestamp()) + return freezed_time_string[:8] + "-0000-0000-0000-" + freezed_time_string[:10] + "00" + + +@pytest.fixture +def context(vcr, unique, freezed_time): + """ + Return a mapping with all defined fixtures, all objects created by `given` steps, + and the undo operations to perform after a test scenario. + """ + unique_hash = hashlib.sha256(unique.encode("utf-8")).hexdigest()[:16] + + # Dirty fix as on_call cassette and API use the `Z` format instead of `+00:00` + is_iso_with_timezone_indicator = "on_call" in unique + + ctx = { + "undo_operations": [], + "unique": unique, + "unique_lower": unique.lower(), + "unique_upper": unique.upper(), + "unique_alnum": PATTERN_ALPHANUM.sub("", unique), + "unique_lower_alnum": PATTERN_ALPHANUM.sub("", unique).lower(), + "unique_upper_alnum": PATTERN_ALPHANUM.sub("", unique).upper(), + "unique_hash": unique_hash, + "timestamp": relative_time(freezed_time, False), + "timeISO": relative_time(freezed_time, True), + "uuid": generate_uuid(freezed_time), + } + + yield ctx + + +@pytest.fixture(scope="session") +def record_mode(request): + """Manage compatibility with DD client libraries.""" + return {"false": "none", "true": "rewrite", "none": "new_episodes"}[RECORD] + + +def _disable_recording(): + """Disable VCR.py integration.""" + return RECORD == "none" + + +@pytest.fixture(scope="session") +def disable_recording(request): + """Disable VCR.py integration. This overrides a pytest-recording fixture.""" + return _disable_recording() + + +@pytest.fixture +def vcr_config(): + config = dict( + filter_headers=( + "DD-API-KEY", + "DD-APPLICATION-KEY", + "User-Agent", + "Accept-Encoding", + ), + match_on=[ + "method", + "scheme", + "host", + "port", + "path", + "query", + "body", + "headers", + ], + ) + if tracer: + from urllib.parse import urlparse + + if hasattr(tracer._writer, "agent_url"): + config["ignore_hosts"] = [urlparse(tracer._writer.agent_url).hostname] + else: + config["ignore_hosts"] = [urlparse(tracer._writer.intake_url).hostname] + + return config + + +@pytest.fixture +def default_cassette_name(default_cassette_name): + return PATTERN_DOUBLE_UNDERSCORE.sub("_", default_cassette_name) + + +@pytest.fixture +def freezed_time(default_cassette_name, record_mode, vcr): + from dateutil import parser + + if record_mode in {"new_episodes", "rewrite"}: + tzinfo = datetime.now().astimezone().tzinfo + freeze_at = datetime.now().replace(tzinfo=tzinfo).isoformat() + if record_mode == "rewrite": + pathlib.Path(vcr._path).parent.mkdir(parents=True, exist_ok=True) + with pathlib.Path(vcr._path).with_suffix(".frozen").open("w+") as f: + f.write(freeze_at) + else: + freeze_file = pathlib.Path(vcr._path).with_suffix(".frozen") + if not freeze_file.exists(): + msg = ( + "Time file '{}' not found: create one setting `RECORD=true` or " "ignore it using `RECORD=none`".format( + freeze_file + ) + ) + raise RuntimeError(msg) + with freeze_file.open("r") as f: + freeze_at = f.readline().strip() + + if not pathlib.Path(vcr._path).exists(): + msg = ( + "Cassette '{}' not found: create one setting `RECORD=true` or " "ignore it using `RECORD=none`".format( + vcr._path + ) + ) + raise RuntimeError(msg) + + return parser.isoparse(freeze_at) + + +def pytest_recording_configure(config, vcr): + from vcr import matchers + from vcr.util import read_body + + is_text_json = matchers._header_checker("text/json") + transformer = matchers._transform_json + + def body(r1, r2): + if is_text_json(r1.headers) and is_text_json(r2.headers): + assert transformer(read_body(r1)) == transformer(read_body(r2)) + else: + matchers.body(r1, r2) + + vcr.matchers["body"] = body + + +@given('a valid "apiKeyAuth" key in the system') +def a_valid_api_key(configuration): + """a valid API key.""" + configuration.api_key["apiKeyAuth"] = os.getenv("DD_TEST_CLIENT_API_KEY", "fake") + + +@given('a valid "appKeyAuth" key in the system') +def a_valid_application_key(configuration): + """a valid Application key.""" + configuration.api_key["appKeyAuth"] = os.getenv("DD_TEST_CLIENT_APP_KEY", "fake") + + +@pytest.fixture(scope="module") +def package_name(api_version): + return "datadog_api_client." + api_version + + +@pytest.fixture(scope="module") +def undo_operations(): + result = {} + for f in pathlib.Path(os.path.dirname(__file__)).rglob("undo.json"): + version = f.parent.parent.name + with f.open() as fp: + data = json.load(fp) + result[version] = {} + for operation_id, settings in data.items(): + undo_settings = settings.get("undo") + undo_settings["base_tag"] = settings.get("tag") + result[version][snake_case(operation_id)] = undo_settings + + return result + + +def build_configuration(): + c = Configuration(return_http_data_only=False, spec_property_naming=True) + c.connection_pool_maxsize = 0 + c.debug = debug = os.getenv("DEBUG") in {"true", "1", "yes", "on"} + c.enable_retry = True + if debug: # enable vcr logs for DEBUG=true + vcr_log = logging.getLogger("vcr") + vcr_log.setLevel(logging.INFO) + if "DD_TEST_SITE" in os.environ: + c.server_index = 2 + c.server_variables["site"] = os.environ["DD_TEST_SITE"] + return c + + +@pytest.fixture +def configuration(): + return build_configuration() + + +@pytest.fixture +def client(configuration): + with ApiClient(configuration) as api_client: + yield api_client + + +def _api_name(value): + value = re.sub(r"[^a-zA-Z0-9]", "", value) + return value + "Api" + + +@given(parsers.parse('an instance of "{name}" API')) +def api(context, package_name, client, name): + """Return an API instance.""" + module_name = snake_case(name) + package = importlib.import_module(f"{package_name}.api.{module_name}_api") + context["api"] = { + "api": getattr(package, _api_name(name))(client), + "package": package_name, + "calls": [], + } + + +@given(parsers.parse('operation "{name}" enabled')) +def operation_enabled(client, name): + """Enable the unstable operation specific in the clause.""" + client.configuration.unstable_operations[snake_case(name)] = True + + +@given(parsers.parse('new "{name}" request')) +def api_request(configuration, context, name): + """Call an endpoint.""" + api = context["api"] + context["api_request"] = { + "api": api["api"], + "request": getattr(api["api"], snake_case(name)), + "args": [], + "kwargs": {}, + "response": (None, None, None), + } + + +@given(parsers.parse("body with value {data}")) +def request_body(context, data): + """Set request body.""" + tpl = Template(data).render(**context) + context["api_request"]["kwargs"]["body"] = tpl + + +@given(parsers.parse('body from file "{path}"')) +def request_body_from_file(context, path, package_name): + """Set request body.""" + version = package_name.split(".")[-1] + with open(os.path.join(os.path.dirname(__file__), version, "features", path)) as f: + data = f.read() + tpl = Template(data).render(**context) + context["api_request"]["kwargs"]["body"] = tpl + + +@given(parsers.parse('request contains "{name}" parameter from "{path}"')) +def request_parameter(context, name, path): + """Set request parameter.""" + context["api_request"]["kwargs"][escape_reserved_keyword(snake_case(name))] = json.dumps(glom(context, path)) + + +@given(parsers.parse('request contains "{name}" parameter with value {value}')) +def request_parameter_with_value(context, name, value): + """Set request parameter.""" + tpl = Template(value).render(**context) + context["api_request"]["kwargs"][escape_reserved_keyword(snake_case(name))] = tpl + + +def assert_no_unparsed(data): + if isinstance(data, list): + for item in data: + assert_no_unparsed(item) + elif isinstance(data, dict): + for item in data.values(): + assert_no_unparsed(item) + elif isinstance(data, OpenApiModel): + assert not data._unparsed + for attr in data._data_store.values(): + assert_no_unparsed(attr) + + +def build_given(version, operation): + @sleep_after_request + def wrapper(context, undo): + name = operation["tag"].replace(" ", "") + module_name = snake_case(operation["tag"]) + operation_name = snake_case(operation["operationId"]) + package_name = f"datadog_api_client.{version}" + + # make sure we have a fresh instance of API client and configuration + configuration = build_configuration() + configuration.api_key["apiKeyAuth"] = os.getenv("DD_TEST_CLIENT_API_KEY", "fake") + configuration.api_key["appKeyAuth"] = os.getenv("DD_TEST_CLIENT_APP_KEY", "fake") + configuration.check_input_type = False + configuration.return_http_data_only = True + + # enable unstable operation + if operation_name in configuration.unstable_operations: + configuration.unstable_operations[operation_name] = True + + package = importlib.import_module(f"{package_name}.api.{module_name}_api") + with ApiClient(configuration) as client: + api = getattr(package, _api_name(name))(client) + operation_method = getattr(api, operation_name) + params_map = getattr(api, f"_{operation_name}_endpoint").params_map + + # perform operation + def build_param(p): + openapi_types = params_map[p["name"]]["openapi_types"] + if "value" in p: + if openapi_types == (file_type,): + filepath = os.path.join( + os.path.dirname(__file__), + version, + "features", + json.loads(Template(p["value"]).render(**context)), + ) + return open(filepath) + return client.deserialize(Template(p["value"]).render(**context), openapi_types, True) + if "source" in p: + return glom(context, p["source"]) + + kwargs = { + escape_reserved_keyword(snake_case(p["name"])): build_param(p) for p in operation.get("parameters", []) + } + result = operation_method(**kwargs) + request_body = kwargs.get("body", "") + + # register undo method + def undo_operation(): + return undo(api, version, operation_name, result, request_body, client=client) + + if tracer: + undo_operation = tracer.wrap(name="undo", resource=operation["step"])(undo_operation) + + context["undo_operations"].append(undo_operation) + + # optional re-shaping + if "source" in operation: + result = glom(result, operation["source"]) + + # store response in fixtures + result_body_json = data_to_dict(result) + context[operation["key"]] = result_body_json + + return wrapper + + +for f in pathlib.Path(os.path.dirname(__file__)).rglob("given.json"): + version = f.parent.parent.name + with f.open() as fp: + for settings in json.load(fp): + given(settings["step"])(build_given(version, settings)) + + +def extract_parameters(kwargs, data, parameter): + if "source" in parameter: + kwargs[parameter["name"]] = glom(data, parameter["source"]) + elif "template" in parameter: + variables = meta.find_undeclared_variables(Environment().parse(parameter["template"])) + ctx = {} + for var in variables: + ctx[var] = glom(data, var) + kwargs[parameter["name"]] = json.loads(Template(parameter["template"]).render(**ctx)) + + +@pytest.fixture +def undo(package_name, undo_operations, client): + """Clean after operation.""" + + def cleanup(api, version, operation_id, response, request, client=client): + operation = undo_operations.get(version, {}).get(operation_id) + if operation_id is None: + raise NotImplementedError((version, operation_id)) + + if operation["type"] is None: + raise NotImplementedError((version, operation_id)) + + if operation["type"] != "unsafe": + return + + # If Undo tag is not the same as the the operation tag. + # For example, Service Accounts use the DisableUser operation to undo, which is part of Users. + if "tag" in operation and operation["base_tag"] != operation["tag"]: + undo_tag = operation["tag"] + undo_name = undo_tag.replace(" ", "") + undo_module_name = snake_case(undo_tag) + undo_package = importlib.import_module(f"{package_name}.api.{undo_module_name}_api") + api = getattr(undo_package, _api_name(undo_name))(client) + + operation_name = snake_case(operation["operationId"]) + method = getattr(api, operation_name) + kwargs = {} + parameters = operation.get("parameters", []) + for parameter in parameters: + if "origin" not in parameter or parameter["origin"] == "response": + extract_parameters(kwargs, response, parameter) + elif parameter["origin"] == "request": + extract_parameters(kwargs, request, parameter) + if operation_name in client.configuration.unstable_operations: + client.configuration.unstable_operations[operation_name] = True + + try: + method(**kwargs) + except exceptions.ApiException as e: + warnings.warn(f"failed undo: {e}") + + yield cleanup + + +@when("the request is sent") +def execute_request(undo, context, client, api_version, request): + """Execute the prepared request.""" + api_request = context["api_request"] + + params_map = getattr(api_request["api"], f'_{api_request["request"].__name__}_endpoint').params_map + for k, v in api_request["kwargs"].items(): + openapi_types = params_map[k]["openapi_types"] + if openapi_types == (file_type,): + filepath = os.path.join(os.path.dirname(__file__), api_version, "features", json.loads(v)) + # We let the GC collects it, this shouldn't be an issue + api_request["kwargs"][k] = open(filepath) + else: + api_request["kwargs"][k] = client.deserialize(v, openapi_types, True) + + try: + response = api_request["request"](*api_request["args"], **api_request["kwargs"]) + # Reserialise the response body to JSON to facilitate test assertions + response_body_json = data_to_dict(response[0]) + api_request["response"] = [response_body_json, response[1], response[2]] + except exceptions.ApiException as e: + # If we have an exception, make a stub response object to use for assertions + # Instead of finding the response class of the method, we use the fact that all + # responses returned have an ordered response of body|status|headers + api_request["response"] = [e.body, e.status, e.headers] + return + + if "skip-validation" not in request.node.__scenario_report__.scenario.tags: + assert_no_unparsed(response[0]) + + api = api_request["api"] + operation_id = api_request["request"].__name__ + response = api_request["response"][0] + request_body = api_request.get("kwargs", {}).get("body", "") + + def undo_operation(): + return undo(api, api_version, operation_id, response, request_body) + + if tracer: + undo_operation = tracer.wrap(name="undo", resource="execute request")(undo_operation) + + context["undo_operations"].append(undo_operation) + + +@when("the request with pagination is sent") +def execute_request_with_pagination(undo, context, client, api_version): + """Execute the prepared paginated request.""" + api_request = context["api_request"] + + params_map = getattr(api_request["api"], f'_{api_request["request"].__name__}_endpoint').params_map + for k, v in api_request["kwargs"].items(): + api_request["kwargs"][k] = client.deserialize(v, params_map[k]["openapi_types"], True) + + kwargs = api_request["kwargs"] + client.configuration.return_http_data_only = True + method = getattr(api_request["api"], f"{api_request['request'].__name__}_with_pagination") + try: + response = list(method(*api_request["args"], **kwargs)) + # Reserialise the response body to JSON to facilitate test assertions + response_body_json = data_to_dict(response) + api_request["response"] = [response_body_json, 200, None] + except exceptions.ApiException as e: + # If we have an exception, make a stub response object to use for assertions + # Instead of finding the response class of the method, we use the fact that all + # responses returned have an ordered response of body|status|headers + api_request["response"] = [e.body, e.status, e.headers] + finally: + client.configuration.return_http_data_only = False + + +@then(parsers.parse("the response status is {status:d} {description}")) +def the_status_is(context, status, description): + """Check the status.""" + assert status == context["api_request"]["response"][1] + + +@then(parsers.parse('the response "{response_path}" is equal to {value}')) +def expect_equal(context, response_path, value): + response_value = glom(context["api_request"]["response"][0], response_path) + test_value = json.loads(Template(value).render(**context)) + assert test_value == response_value + + +@then(parsers.parse('the response "{response_path}" has the same value as "{fixture_path}"')) +def expect_equal_value(context, response_path, fixture_path): + fixture_value = glom(context, fixture_path) + response_value = glom(context["api_request"]["response"][0], response_path) + assert fixture_value == response_value + + +@then(parsers.parse('the response "{response_path}" has length {fixture_length:d}')) +def expect_equal_length(context, response_path, fixture_length): + response_value = glom(context["api_request"]["response"][0], response_path) + assert fixture_length == len(response_value) + + +@then(parsers.parse("the response has {fixture_length:d} items")) +def expect_equal_response_items(context, fixture_length): + response = context["api_request"]["response"][0] + assert fixture_length == len(response) + + +@then(parsers.parse('the response "{response_path}" is false')) +def expect_false(context, response_path): + response_value = glom(context["api_request"]["response"][0], response_path) + assert not response_value + + +@then(parsers.parse('the response "{response_path}" has field "{field}"')) +def expect_response_has_field(context, response_path, field): + """Check that a response path has field.""" + response_value = glom(context["api_request"]["response"][0], response_path) + assert field in response_value + + +@then(parsers.parse('the response "{response_path}" does not have field "{field}"')) +def expect_response_does_not_have_field(context, response_path, field): + """Check that a response path does not have field.""" + response_value = glom(context["api_request"]["response"][0], response_path) + assert field not in response_value + + +@then(parsers.parse('the response "{response_path}" has item with field "{key_path}" with value {value}')) +def expect_array_contains_object(context, response_path, key_path, value): + from glom.core import PathAccessError + + response_value = glom(context["api_request"]["response"][0], response_path) + test_value = json.loads(Template(value).render(**context)) + for response_item in response_value: + try: + response_item_value = glom(response_item, key_path) + if response_item_value == test_value: + return + except PathAccessError: + pass + raise AssertionError(f'could not find key value pair in object array: "{key_path}": "{test_value}"') + + +@then(parsers.parse('the response "{response_path}" array contains value {value}')) +def expect_array_contains_object(context, response_path, value): + response_value = glom(context["api_request"]["response"][0], response_path) + test_value = json.loads(Template(value).render(**context)) + for response_item in response_value: + if response_item == test_value: + return + raise AssertionError(f"could not find value in array: {test_value}") From f0e61979ad5871ba235102ee3315c12827a0aaf4 Mon Sep 17 00:00:00 2001 From: Juskeerat Anand Date: Mon, 29 Sep 2025 13:12:11 -0400 Subject: [PATCH 06/10] Restore docs/datadog_api_client.rst file --- docs/datadog_api_client.rst | 56 +++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 docs/datadog_api_client.rst diff --git a/docs/datadog_api_client.rst b/docs/datadog_api_client.rst new file mode 100644 index 0000000000..6d1e001200 --- /dev/null +++ b/docs/datadog_api_client.rst @@ -0,0 +1,56 @@ +datadog\_api\_client package +============================ + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + datadog_api_client.v1 + datadog_api_client.v2 + +Submodules +---------- + +datadog\_api\_client.api\_client module +--------------------------------------- + +.. automodule:: datadog_api_client.api_client + :members: + :show-inheritance: + +datadog\_api\_client.configuration module +----------------------------------------- + +.. automodule:: datadog_api_client.configuration + :members: + :show-inheritance: + +datadog\_api\_client.exceptions module +-------------------------------------- + +.. automodule:: datadog_api_client.exceptions + :members: + :show-inheritance: + +datadog\_api\_client.model\_utils module +---------------------------------------- + +.. automodule:: datadog_api_client.model_utils + :members: + :show-inheritance: + +datadog\_api\_client.rest module +-------------------------------- + +.. automodule:: datadog_api_client.rest + :members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: datadog_api_client + :members: + :show-inheritance: From eaa857f126d16021025d88525aefda9a725ba95f Mon Sep 17 00:00:00 2001 From: Juskeerat Anand Date: Mon, 29 Sep 2025 13:18:34 -0400 Subject: [PATCH 07/10] regen --- .generator/src/generator/templates/api_client.j2 | 3 +-- .../src/generator/templates/delegated_auth.j2 | 3 +-- docs/datadog_api_client.rst | 14 ++++++++++++++ src/datadog_api_client/api_client.py | 3 +-- src/datadog_api_client/delegated_auth.py | 3 +-- 5 files changed, 18 insertions(+), 8 deletions(-) diff --git a/.generator/src/generator/templates/api_client.j2 b/.generator/src/generator/templates/api_client.j2 index 68b5c4203b..cf8f285ce3 100644 --- a/.generator/src/generator/templates/api_client.j2 +++ b/.generator/src/generator/templates/api_client.j2 @@ -466,7 +466,6 @@ class ApiClient: from {{ package }}.delegated_auth import DelegatedTokenCredentials from datetime import datetime - # Get or create delegated token credentials if not hasattr(self, '_delegated_token_credentials') or self._delegated_token_credentials is None: self._delegated_token_credentials = self._get_delegated_token() elif self._delegated_token_credentials.is_expired(): @@ -862,7 +861,7 @@ class Endpoint: if not self.settings["auth"]: return - # Check if this endpoint uses appKeyAuth and if delegated token config is available + # check if endpoint uses appKeyAuth and if delegated token config is available has_app_key_auth = "appKeyAuth" in self.settings["auth"] if has_app_key_auth and self.api_client.configuration.delegated_token_config is not None: diff --git a/.generator/src/generator/templates/delegated_auth.j2 b/.generator/src/generator/templates/delegated_auth.j2 index e57d612199..c1cc1a0b7c 100644 --- a/.generator/src/generator/templates/delegated_auth.j2 +++ b/.generator/src/generator/templates/delegated_auth.j2 @@ -115,8 +115,7 @@ def parse_delegated_token_response(response_data: str, org_uuid: str, delegated_ if not token: raise ApiValueError(f"Failed to get token from response: {token_response}") - # Get the expiration time from the response - # Default to 15 minutes if the expiration time is not set + # get expiration time from the response, defualt to 15 min expiration_time = datetime.now() + timedelta(minutes=15) expires_str = attributes.get("expires") if expires_str: diff --git a/docs/datadog_api_client.rst b/docs/datadog_api_client.rst index 6d1e001200..c8c1a345bd 100644 --- a/docs/datadog_api_client.rst +++ b/docs/datadog_api_client.rst @@ -20,6 +20,13 @@ datadog\_api\_client.api\_client module :members: :show-inheritance: +datadog\_api\_client.aws module +------------------------------- + +.. automodule:: datadog_api_client.aws + :members: + :show-inheritance: + datadog\_api\_client.configuration module ----------------------------------------- @@ -27,6 +34,13 @@ datadog\_api\_client.configuration module :members: :show-inheritance: +datadog\_api\_client.delegated\_auth module +------------------------------------------- + +.. automodule:: datadog_api_client.delegated_auth + :members: + :show-inheritance: + datadog\_api\_client.exceptions module -------------------------------------- diff --git a/src/datadog_api_client/api_client.py b/src/datadog_api_client/api_client.py index 3906b8ab1e..1be27a5494 100644 --- a/src/datadog_api_client/api_client.py +++ b/src/datadog_api_client/api_client.py @@ -463,7 +463,6 @@ def use_delegated_token_auth(self, headers: Dict[str, Any]) -> None: if not self.configuration.delegated_token_config: return - # Get or create delegated token credentials if not hasattr(self, "_delegated_token_credentials") or self._delegated_token_credentials is None: self._delegated_token_credentials = self._get_delegated_token() elif self._delegated_token_credentials.is_expired(): @@ -857,7 +856,7 @@ def update_params_for_auth(self, headers, queries) -> None: if not self.settings["auth"]: return - # Check if this endpoint uses appKeyAuth and if delegated token config is available + # check if endpoint uses appKeyAuth and if delegated token config is available has_app_key_auth = "appKeyAuth" in self.settings["auth"] if has_app_key_auth and self.api_client.configuration.delegated_token_config is not None: diff --git a/src/datadog_api_client/delegated_auth.py b/src/datadog_api_client/delegated_auth.py index 33fd2b6deb..4d7bc4540d 100644 --- a/src/datadog_api_client/delegated_auth.py +++ b/src/datadog_api_client/delegated_auth.py @@ -111,8 +111,7 @@ def parse_delegated_token_response( if not token: raise ApiValueError(f"Failed to get token from response: {token_response}") - # Get the expiration time from the response - # Default to 15 minutes if the expiration time is not set + # get expiration time from the response, defualt to 15 min expiration_time = datetime.now() + timedelta(minutes=15) expires_str = attributes.get("expires") if expires_str: From 848f934e8b6da905e82e3dd55c3acc19b9e8e6b0 Mon Sep 17 00:00:00 2001 From: Juskeerat Anand Date: Mon, 29 Sep 2025 14:54:21 -0400 Subject: [PATCH 08/10] print header --- src/datadog_api_client/api_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/datadog_api_client/api_client.py b/src/datadog_api_client/api_client.py index 1be27a5494..2760e64697 100644 --- a/src/datadog_api_client/api_client.py +++ b/src/datadog_api_client/api_client.py @@ -862,6 +862,7 @@ def update_params_for_auth(self, headers, queries) -> None: if has_app_key_auth and self.api_client.configuration.delegated_token_config is not None: # Use delegated token authentication self.api_client.use_delegated_token_auth(headers) + print(f"headers: {headers}") else: # Use regular authentication for auth in self.settings["auth"]: From 19bef193d206bb28acafc67040316bf1749a44d0 Mon Sep 17 00:00:00 2001 From: Juskeerat Anand Date: Mon, 29 Sep 2025 16:10:36 -0400 Subject: [PATCH 09/10] fix headers --- .../src/generator/templates/api_client.j2 | 63 ++++++++++--------- .../src/generator/templates/example_aws.j2 | 33 ++++------ examples/datadog/aws.py | 28 ++++----- src/datadog_api_client/api_client.py | 61 ++++++++++-------- 4 files changed, 94 insertions(+), 91 deletions(-) diff --git a/.generator/src/generator/templates/api_client.j2 b/.generator/src/generator/templates/api_client.j2 index cf8f285ce3..086580d4a7 100644 --- a/.generator/src/generator/templates/api_client.j2 +++ b/.generator/src/generator/templates/api_client.j2 @@ -460,36 +460,33 @@ class ApiClient: :param headers: Header parameters dict to be updated. :raises: ApiValueError if delegated token authentication fails """ - if not self.configuration.delegated_token_config: - return - - from {{ package }}.delegated_auth import DelegatedTokenCredentials from datetime import datetime - - if not hasattr(self, '_delegated_token_credentials') or self._delegated_token_credentials is None: - self._delegated_token_credentials = self._get_delegated_token() - elif self._delegated_token_credentials.is_expired(): - # Token is expired, get a new one - self._delegated_token_credentials = self._get_delegated_token() - - # Set the Authorization header with the delegated token - headers["Authorization"] = f"Bearer {self._delegated_token_credentials.delegated_token}" - - def _get_delegated_token(self) -> 'DelegatedTokenCredentials': - """Get a new delegated token using the configured provider. - - :return: DelegatedTokenCredentials object - :raises: ApiValueError if token retrieval fails - """ - if not self.configuration.delegated_token_config: - raise ApiValueError("Delegated token configuration is not set") - - try: - return self.configuration.delegated_token_config.provider_auth.authenticate( - self.configuration.delegated_token_config + from {{ package }}.delegated_auth import DelegatedTokenConfig + + # Check if we have cached credentials + if not hasattr(self.configuration, '_delegated_token_credentials'): + self.configuration._delegated_token_credentials = None + + # Check if we need to get or refresh the token + if (self.configuration._delegated_token_credentials is None or + self.configuration._delegated_token_credentials.is_expired()): + + # Create config for the provider + config = DelegatedTokenConfig( + org_uuid=self.configuration.delegated_auth_org_uuid, + provider="aws", # This could be made configurable + provider_auth=self.configuration.delegated_auth_provider ) - except Exception as e: - raise ApiValueError(f"Failed to get delegated token: {str(e)}") + + # Get new token from provider + try: + self.configuration._delegated_token_credentials = self.configuration.delegated_auth_provider.authenticate(config) + except Exception as e: + raise ApiValueError(f"Failed to get delegated token: {str(e)}") + + # Set the Authorization header with the delegated token + token = self.configuration._delegated_token_credentials.delegated_token + headers["Authorization"] = f"Bearer {token}" class ThreadedApiClient(ApiClient): @@ -864,7 +861,15 @@ class Endpoint: # check if endpoint uses appKeyAuth and if delegated token config is available has_app_key_auth = "appKeyAuth" in self.settings["auth"] - if has_app_key_auth and self.api_client.configuration.delegated_token_config is not None: + # Check if delegated auth is configured (using our actual attributes) + has_delegated_auth = ( + hasattr(self.api_client.configuration, 'delegated_auth_provider') and + self.api_client.configuration.delegated_auth_provider is not None and + hasattr(self.api_client.configuration, 'delegated_auth_org_uuid') and + self.api_client.configuration.delegated_auth_org_uuid is not None + ) + + if has_app_key_auth and has_delegated_auth: # Use delegated token authentication self.api_client.use_delegated_token_auth(headers) else: diff --git a/.generator/src/generator/templates/example_aws.j2 b/.generator/src/generator/templates/example_aws.j2 index f76247d682..7d127cb3fd 100644 --- a/.generator/src/generator/templates/example_aws.j2 +++ b/.generator/src/generator/templates/example_aws.j2 @@ -8,8 +8,7 @@ for authentication instead of API keys. import os from {{ package }} import ApiClient, Configuration from {{ package }}.aws import AWSAuth -from {{ package }}.delegated_auth import DelegatedTokenConfig -from {{ package }}.{{ version }}.api.authentication_api import AuthenticationApi +from {{ package }}.v2.api.teams_api import TeamsApi def main(): @@ -30,11 +29,10 @@ def main(): print("Please set AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_SESSION_TOKEN") return - # Your Datadog organization UUID - # This should be provided by your Datadog administrator - org_uuid = os.getenv("DD_ORG_UUID") + + org_uuid = os.getenv("DD_TEST_ORG_UUID") if not org_uuid: - print("Error: DD_ORG_UUID environment variable not set.") + print("Error: DD_TEST_ORG_UUID environment variable not set.") print("Please set your Datadog organization UUID") return @@ -43,27 +41,22 @@ def main(): aws_region = os.getenv("AWS_REGION", "us-east-1") aws_auth = AWSAuth(aws_region=aws_region) - # Create delegated token configuration - delegated_config = DelegatedTokenConfig( - org_uuid=org_uuid, - provider="aws", - provider_auth=aws_auth - ) - - # Create configuration and set delegated token config + # Create configuration with AWS authentication configuration = Configuration() - configuration.delegated_token_config = delegated_config + configuration.delegated_auth_provider = aws_auth + configuration.delegated_auth_org_uuid = org_uuid # Create API client with the configuration with ApiClient(configuration) as api_client: - # Create API instance - api_instance = AuthenticationApi(api_client) + # Create API instance - using TeamsApi as an example + api_instance = TeamsApi(api_client) try: - # Test the authentication by validating credentials - api_response = api_instance.validate() + # Test the authentication by listing teams + # The client will automatically use AWS authentication for this call + api_response = api_instance.list_teams() print("Authentication successful!") - print(f"Valid: {api_response.valid}") + print(f"Found {len(api_response.data)} teams") except Exception as e: print(f"Authentication failed: {e}") diff --git a/examples/datadog/aws.py b/examples/datadog/aws.py index 010e9e0b55..b7418b615f 100644 --- a/examples/datadog/aws.py +++ b/examples/datadog/aws.py @@ -8,8 +8,7 @@ import os from datadog_api_client import ApiClient, Configuration from datadog_api_client.aws import AWSAuth -from datadog_api_client.delegated_auth import DelegatedTokenConfig -from datadog_api_client.v2.api.authentication_api import AuthenticationApi +from datadog_api_client.v2.api.teams_api import TeamsApi def main(): @@ -26,11 +25,9 @@ def main(): print("Please set AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and AWS_SESSION_TOKEN") return - # Your Datadog organization UUID - # This should be provided by your Datadog administrator - org_uuid = os.getenv("DD_ORG_UUID") + org_uuid = os.getenv("DD_TEST_ORG_UUID") if not org_uuid: - print("Error: DD_ORG_UUID environment variable not set.") + print("Error: DD_TEST_ORG_UUID environment variable not set.") print("Please set your Datadog organization UUID") return @@ -39,23 +36,22 @@ def main(): aws_region = os.getenv("AWS_REGION", "us-east-1") aws_auth = AWSAuth(aws_region=aws_region) - # Create delegated token configuration - delegated_config = DelegatedTokenConfig(org_uuid=org_uuid, provider="aws", provider_auth=aws_auth) - - # Create configuration and set delegated token config + # Create configuration with AWS authentication configuration = Configuration() - configuration.delegated_token_config = delegated_config + configuration.delegated_auth_provider = aws_auth + configuration.delegated_auth_org_uuid = org_uuid # Create API client with the configuration with ApiClient(configuration) as api_client: - # Create API instance - api_instance = AuthenticationApi(api_client) + # Create API instance - using TeamsApi as an example + api_instance = TeamsApi(api_client) try: - # Test the authentication by validating credentials - api_response = api_instance.validate() + # Test the authentication by listing teams + # The client will automatically use AWS authentication for this call + api_response = api_instance.list_teams() print("Authentication successful!") - print(f"Valid: {api_response.valid}") + print(f"Found {len(api_response.data)} teams") except Exception as e: print(f"Authentication failed: {e}") diff --git a/src/datadog_api_client/api_client.py b/src/datadog_api_client/api_client.py index 2760e64697..9e5a6b7dc4 100644 --- a/src/datadog_api_client/api_client.py +++ b/src/datadog_api_client/api_client.py @@ -460,33 +460,35 @@ def use_delegated_token_auth(self, headers: Dict[str, Any]) -> None: :param headers: Header parameters dict to be updated. :raises: ApiValueError if delegated token authentication fails """ - if not self.configuration.delegated_token_config: - return + from datadog_api_client.delegated_auth import DelegatedTokenConfig + + # Check if we have cached credentials + if not hasattr(self.configuration, "_delegated_token_credentials"): + self.configuration._delegated_token_credentials = None + + # Check if we need to get or refresh the token + if ( + self.configuration._delegated_token_credentials is None + or self.configuration._delegated_token_credentials.is_expired() + ): + # Create config for the provider + config = DelegatedTokenConfig( + org_uuid=self.configuration.delegated_auth_org_uuid, + provider="aws", # This could be made configurable + provider_auth=self.configuration.delegated_auth_provider, + ) - if not hasattr(self, "_delegated_token_credentials") or self._delegated_token_credentials is None: - self._delegated_token_credentials = self._get_delegated_token() - elif self._delegated_token_credentials.is_expired(): - # Token is expired, get a new one - self._delegated_token_credentials = self._get_delegated_token() + # Get new token from provider + try: + self.configuration._delegated_token_credentials = ( + self.configuration.delegated_auth_provider.authenticate(config) + ) + except Exception as e: + raise ApiValueError(f"Failed to get delegated token: {str(e)}") # Set the Authorization header with the delegated token - headers["Authorization"] = f"Bearer {self._delegated_token_credentials.delegated_token}" - - def _get_delegated_token(self) -> "DelegatedTokenCredentials": - """Get a new delegated token using the configured provider. - - :return: DelegatedTokenCredentials object - :raises: ApiValueError if token retrieval fails - """ - if not self.configuration.delegated_token_config: - raise ApiValueError("Delegated token configuration is not set") - - try: - return self.configuration.delegated_token_config.provider_auth.authenticate( - self.configuration.delegated_token_config - ) - except Exception as e: - raise ApiValueError(f"Failed to get delegated token: {str(e)}") + token = self.configuration._delegated_token_credentials.delegated_token + headers["Authorization"] = f"Bearer {token}" class ThreadedApiClient(ApiClient): @@ -859,10 +861,17 @@ def update_params_for_auth(self, headers, queries) -> None: # check if endpoint uses appKeyAuth and if delegated token config is available has_app_key_auth = "appKeyAuth" in self.settings["auth"] - if has_app_key_auth and self.api_client.configuration.delegated_token_config is not None: + # Check if delegated auth is configured (using our actual attributes) + has_delegated_auth = ( + hasattr(self.api_client.configuration, "delegated_auth_provider") + and self.api_client.configuration.delegated_auth_provider is not None + and hasattr(self.api_client.configuration, "delegated_auth_org_uuid") + and self.api_client.configuration.delegated_auth_org_uuid is not None + ) + + if has_app_key_auth and has_delegated_auth: # Use delegated token authentication self.api_client.use_delegated_token_auth(headers) - print(f"headers: {headers}") else: # Use regular authentication for auth in self.settings["auth"]: From 6d3ae95d7c75fdd0a4a6ad457933693155612920 Mon Sep 17 00:00:00 2001 From: Juskeerat Anand Date: Tue, 30 Sep 2025 14:25:55 -0400 Subject: [PATCH 10/10] fix config propogation --- .generator/src/generator/templates/api_client.j2 | 4 ++-- .generator/src/generator/templates/aws.j2 | 7 +++++-- .../src/generator/templates/delegated_auth.j2 | 13 +++++++++---- .generator/src/generator/templates/example_aws.j2 | 2 +- examples/datadog/aws.py | 2 +- src/datadog_api_client/api_client.py | 4 ++-- src/datadog_api_client/aws.py | 7 +++++-- src/datadog_api_client/delegated_auth.py | 13 +++++++++---- 8 files changed, 34 insertions(+), 18 deletions(-) diff --git a/.generator/src/generator/templates/api_client.j2 b/.generator/src/generator/templates/api_client.j2 index 086580d4a7..877d0781ed 100644 --- a/.generator/src/generator/templates/api_client.j2 +++ b/.generator/src/generator/templates/api_client.j2 @@ -478,9 +478,9 @@ class ApiClient: provider_auth=self.configuration.delegated_auth_provider ) - # Get new token from provider + # Get new token from provider, passing the API configuration try: - self.configuration._delegated_token_credentials = self.configuration.delegated_auth_provider.authenticate(config) + self.configuration._delegated_token_credentials = self.configuration.delegated_auth_provider.authenticate(config, self.configuration) except Exception as e: raise ApiValueError(f"Failed to get delegated token: {str(e)}") diff --git a/.generator/src/generator/templates/aws.j2 b/.generator/src/generator/templates/aws.j2 index a8518fa468..579f0e971d 100644 --- a/.generator/src/generator/templates/aws.j2 +++ b/.generator/src/generator/templates/aws.j2 @@ -9,6 +9,7 @@ from datetime import datetime from typing import Dict, List, Optional from urllib.parse import quote +from {{ package }}.configuration import Configuration from {{ package }}.delegated_auth import DelegatedTokenProvider, DelegatedTokenConfig, DelegatedTokenCredentials, get_delegated_token from {{ package }}.exceptions import ApiValueError @@ -63,10 +64,11 @@ class AWSAuth(DelegatedTokenProvider): def __init__(self, aws_region: Optional[str] = None): self.aws_region = aws_region - def authenticate(self, config: DelegatedTokenConfig) -> DelegatedTokenCredentials: + def authenticate(self, config: DelegatedTokenConfig, api_config: Configuration) -> DelegatedTokenCredentials: """Authenticate using AWS credentials and return delegated token credentials. :param config: Delegated token configuration + :param api_config: API client configuration with host and other settings :return: DelegatedTokenCredentials object :raises: ApiValueError if authentication fails """ @@ -83,7 +85,8 @@ class AWSAuth(DelegatedTokenProvider): # Generate the auth string passed to the token endpoint auth_string = f"{data.body_encoded}|{data.headers_encoded}|{data.method}|{data.url_encoded}" - auth_response = get_delegated_token(config.org_uuid, auth_string) + # Pass the api_config to get_delegated_token to use the correct host + auth_response = get_delegated_token(config.org_uuid, auth_string, api_config) return auth_response def get_credentials(self) -> AWSCredentials: diff --git a/.generator/src/generator/templates/delegated_auth.j2 b/.generator/src/generator/templates/delegated_auth.j2 index c1cc1a0b7c..3ea48bd072 100644 --- a/.generator/src/generator/templates/delegated_auth.j2 +++ b/.generator/src/generator/templates/delegated_auth.j2 @@ -42,20 +42,25 @@ class DelegatedTokenConfig: class DelegatedTokenProvider: """Abstract base class for delegated token providers.""" - def authenticate(self, config: DelegatedTokenConfig) -> DelegatedTokenCredentials: - """Authenticate and return delegated token credentials.""" + def authenticate(self, config: DelegatedTokenConfig, api_config: Configuration) -> DelegatedTokenCredentials: + """Authenticate and return delegated token credentials. + + :param config: Delegated token configuration + :param api_config: API client configuration with host and other settings + :return: DelegatedTokenCredentials object + """ raise NotImplementedError("Subclasses must implement authenticate method") -def get_delegated_token(org_uuid: str, delegated_auth_proof: str) -> DelegatedTokenCredentials: +def get_delegated_token(org_uuid: str, delegated_auth_proof: str, config: Configuration) -> DelegatedTokenCredentials: """Get a delegated token from the Datadog API. :param org_uuid: Organization UUID :param delegated_auth_proof: Authentication proof string + :param config: Configuration object with host and other settings :return: DelegatedTokenCredentials object :raises: ApiValueError if the request fails """ - config = Configuration() url = get_delegated_token_url(config) # Create REST client diff --git a/.generator/src/generator/templates/example_aws.j2 b/.generator/src/generator/templates/example_aws.j2 index 7d127cb3fd..4e4dc7eb98 100644 --- a/.generator/src/generator/templates/example_aws.j2 +++ b/.generator/src/generator/templates/example_aws.j2 @@ -42,7 +42,7 @@ def main(): aws_auth = AWSAuth(aws_region=aws_region) # Create configuration with AWS authentication - configuration = Configuration() + configuration = Configuration(host="https://dd.datad0g.com") configuration.delegated_auth_provider = aws_auth configuration.delegated_auth_org_uuid = org_uuid diff --git a/examples/datadog/aws.py b/examples/datadog/aws.py index b7418b615f..20d967319a 100644 --- a/examples/datadog/aws.py +++ b/examples/datadog/aws.py @@ -37,7 +37,7 @@ def main(): aws_auth = AWSAuth(aws_region=aws_region) # Create configuration with AWS authentication - configuration = Configuration() + configuration = Configuration(host="https://dd.datad0g.com") configuration.delegated_auth_provider = aws_auth configuration.delegated_auth_org_uuid = org_uuid diff --git a/src/datadog_api_client/api_client.py b/src/datadog_api_client/api_client.py index 9e5a6b7dc4..76f1db40d3 100644 --- a/src/datadog_api_client/api_client.py +++ b/src/datadog_api_client/api_client.py @@ -478,10 +478,10 @@ def use_delegated_token_auth(self, headers: Dict[str, Any]) -> None: provider_auth=self.configuration.delegated_auth_provider, ) - # Get new token from provider + # Get new token from provider, passing the API configuration try: self.configuration._delegated_token_credentials = ( - self.configuration.delegated_auth_provider.authenticate(config) + self.configuration.delegated_auth_provider.authenticate(config, self.configuration) ) except Exception as e: raise ApiValueError(f"Failed to get delegated token: {str(e)}") diff --git a/src/datadog_api_client/aws.py b/src/datadog_api_client/aws.py index abb98810e3..48087e6c83 100644 --- a/src/datadog_api_client/aws.py +++ b/src/datadog_api_client/aws.py @@ -10,6 +10,7 @@ from datetime import datetime from typing import Optional +from datadog_api_client.configuration import Configuration from datadog_api_client.delegated_auth import ( DelegatedTokenProvider, DelegatedTokenConfig, @@ -69,10 +70,11 @@ class AWSAuth(DelegatedTokenProvider): def __init__(self, aws_region: Optional[str] = None): self.aws_region = aws_region - def authenticate(self, config: DelegatedTokenConfig) -> DelegatedTokenCredentials: + def authenticate(self, config: DelegatedTokenConfig, api_config: Configuration) -> DelegatedTokenCredentials: """Authenticate using AWS credentials and return delegated token credentials. :param config: Delegated token configuration + :param api_config: API client configuration with host and other settings :return: DelegatedTokenCredentials object :raises: ApiValueError if authentication fails """ @@ -89,7 +91,8 @@ def authenticate(self, config: DelegatedTokenConfig) -> DelegatedTokenCredential # Generate the auth string passed to the token endpoint auth_string = f"{data.body_encoded}|{data.headers_encoded}|{data.method}|{data.url_encoded}" - auth_response = get_delegated_token(config.org_uuid, auth_string) + # Pass the api_config to get_delegated_token to use the correct host + auth_response = get_delegated_token(config.org_uuid, auth_string, api_config) return auth_response def get_credentials(self) -> AWSCredentials: diff --git a/src/datadog_api_client/delegated_auth.py b/src/datadog_api_client/delegated_auth.py index 4d7bc4540d..3ba4652b62 100644 --- a/src/datadog_api_client/delegated_auth.py +++ b/src/datadog_api_client/delegated_auth.py @@ -42,20 +42,25 @@ def __init__(self, org_uuid: str, provider: str, provider_auth: "DelegatedTokenP class DelegatedTokenProvider: """Abstract base class for delegated token providers.""" - def authenticate(self, config: DelegatedTokenConfig) -> DelegatedTokenCredentials: - """Authenticate and return delegated token credentials.""" + def authenticate(self, config: DelegatedTokenConfig, api_config: Configuration) -> DelegatedTokenCredentials: + """Authenticate and return delegated token credentials. + + :param config: Delegated token configuration + :param api_config: API client configuration with host and other settings + :return: DelegatedTokenCredentials object + """ raise NotImplementedError("Subclasses must implement authenticate method") -def get_delegated_token(org_uuid: str, delegated_auth_proof: str) -> DelegatedTokenCredentials: +def get_delegated_token(org_uuid: str, delegated_auth_proof: str, config: Configuration) -> DelegatedTokenCredentials: """Get a delegated token from the Datadog API. :param org_uuid: Organization UUID :param delegated_auth_proof: Authentication proof string + :param config: Configuration object with host and other settings :return: DelegatedTokenCredentials object :raises: ApiValueError if the request fails """ - config = Configuration() url = get_delegated_token_url(config) # Create REST client