diff --git a/warehouse/locale/messages.pot b/warehouse/locale/messages.pot index fe2fe3cefa88..e396b50bed4e 100644 --- a/warehouse/locale/messages.pot +++ b/warehouse/locale/messages.pot @@ -393,7 +393,7 @@ msgstr "" msgid "Select project" msgstr "" -#: warehouse/manage/forms.py:507 warehouse/oidc/forms/_core.py:24 +#: warehouse/manage/forms.py:507 warehouse/oidc/forms/_core.py:36 #: warehouse/oidc/forms/gitlab.py:47 msgid "Specify project name" msgstr "" @@ -676,45 +676,45 @@ msgstr "" msgid "Expired invitation for '${username}' deleted." msgstr "" -#: warehouse/oidc/forms/_core.py:26 warehouse/oidc/forms/_core.py:37 +#: warehouse/oidc/forms/_core.py:38 warehouse/oidc/forms/_core.py:49 #: warehouse/oidc/forms/gitlab.py:50 warehouse/oidc/forms/gitlab.py:54 msgid "Invalid project name" msgstr "" -#: warehouse/oidc/forms/_core.py:55 +#: warehouse/oidc/forms/_core.py:68 #, python-brace-format msgid "" "This project already exists: use the project's publishing settings here to create a Trusted Publisher for it." msgstr "" -#: warehouse/oidc/forms/_core.py:64 +#: warehouse/oidc/forms/_core.py:77 msgid "This project already exists." msgstr "" -#: warehouse/oidc/forms/_core.py:69 +#: warehouse/oidc/forms/_core.py:82 msgid "This project name isn't allowed" msgstr "" -#: warehouse/oidc/forms/_core.py:73 +#: warehouse/oidc/forms/_core.py:86 msgid "This project name is too similar to an existing project" msgstr "" -#: warehouse/oidc/forms/_core.py:78 +#: warehouse/oidc/forms/_core.py:91 msgid "" "This project name isn't allowed (conflict with the Python standard " "library module name)" msgstr "" -#: warehouse/oidc/forms/_core.py:106 warehouse/oidc/forms/_core.py:117 +#: warehouse/oidc/forms/_core.py:119 warehouse/oidc/forms/_core.py:130 msgid "Specify a publisher ID" msgstr "" -#: warehouse/oidc/forms/_core.py:107 warehouse/oidc/forms/_core.py:118 +#: warehouse/oidc/forms/_core.py:120 warehouse/oidc/forms/_core.py:131 msgid "Publisher must be specified by ID" msgstr "" -#: warehouse/oidc/forms/_core.py:123 +#: warehouse/oidc/forms/_core.py:136 msgid "Specify an environment name" msgstr "" diff --git a/warehouse/macaroons/services.py b/warehouse/macaroons/services.py index 52e80ba442ef..fe7763c02ba8 100644 --- a/warehouse/macaroons/services.py +++ b/warehouse/macaroons/services.py @@ -140,14 +140,14 @@ def verify(self, raw_macaroon: str, request, context, permission) -> bool: def create_macaroon( self, - location, - description, - scopes, + location: str, + description: str, + scopes: list[caveats.Caveat], *, - user_id=None, - oidc_publisher_id=None, - additional=None, - ): + user_id: uuid.UUID | None = None, + oidc_publisher_id: str | None = None, + additional: dict[str, typing.Any] | None = None, + ) -> tuple[str, Macaroon]: """ Returns a tuple of a new raw (serialized) macaroon and its DB model. The description provided is not embedded into the macaroon, only stored diff --git a/warehouse/oidc/__init__.py b/warehouse/oidc/__init__.py index 0f1135966808..6d06917f3754 100644 --- a/warehouse/oidc/__init__.py +++ b/warehouse/oidc/__init__.py @@ -1,4 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import typing from celery.schedules import crontab @@ -12,8 +15,11 @@ GOOGLE_OIDC_ISSUER_URL, ) +if typing.TYPE_CHECKING: + from pyramid.config import Configurator + -def includeme(config): +def includeme(config: Configurator) -> None: oidc_publisher_service_class = config.maybe_dotted( config.registry.settings["oidc.backend"] ) diff --git a/warehouse/oidc/forms/_core.py b/warehouse/oidc/forms/_core.py index ba57efebc0eb..ca76cd2864f9 100644 --- a/warehouse/oidc/forms/_core.py +++ b/warehouse/oidc/forms/_core.py @@ -1,5 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import typing + import markupsafe import structlog import wtforms @@ -15,10 +19,18 @@ ) from warehouse.utils.project import PROJECT_NAME_RE +if typing.TYPE_CHECKING: + from warehouse.accounts.models import User + log = structlog.get_logger() class PendingPublisherMixin: + # Attributes that must be provided by subclasses + _user: User + _check_project_name: typing.Callable[[str], None] + _route_url: typing.Callable[..., str] + project_name = wtforms.StringField( validators=[ wtforms.validators.InputRequired(message=_("Specify project name")), @@ -28,7 +40,7 @@ class PendingPublisherMixin: ] ) - def validate_project_name(self, field): + def validate_project_name(self, field: wtforms.Field) -> None: project_name = field.data try: @@ -39,7 +51,8 @@ def validate_project_name(self, field): # If the user owns the existing project, the error message includes a # link to the project settings that the user can modify. if self._user in e.existing_project.owners: - url_params = {name: value for name, value in self.data.items() if value} + # Mixin doesn't inherit from wtforms.Form but composed classes do + url_params = {name: value for name, value in self.data.items() if value} # type: ignore[attr-defined] # noqa: E501 url_params["provider"] = {self.provider} url = self._route_url( "manage.project.settings.publishing", diff --git a/warehouse/oidc/forms/activestate.py b/warehouse/oidc/forms/activestate.py index 9486360b2631..adb07a7cbdce 100644 --- a/warehouse/oidc/forms/activestate.py +++ b/warehouse/oidc/forms/activestate.py @@ -31,14 +31,14 @@ class GqlResponse(TypedDict): errors: list[dict[str, Any]] -def _no_double_dashes(form, field): +def _no_double_dashes(_form: wtforms.Form, field: wtforms.Field) -> None: if _DOUBLE_DASHES.search(field.data): raise wtforms.validators.ValidationError( _("Double dashes are not allowed in the name") ) -def _no_leading_or_trailing_dashes(form, field): +def _no_leading_or_trailing_dashes(_form: wtforms.Form, field: wtforms.Field) -> None: if field.data.startswith("-") or field.data.endswith("-"): raise wtforms.validators.ValidationError( _("Leading or trailing dashes are not allowed in the name") @@ -150,7 +150,7 @@ def process_org_response(response: GqlResponse) -> None: _GRAPHQL_GET_ORGANIZATION, {"orgname": org_url_name}, process_org_response ) - def validate_organization(self, field): + def validate_organization(self, field: wtforms.Field) -> None: self._lookup_organization(field.data) def _lookup_actor(self, actor: str) -> UserResponse: @@ -170,7 +170,7 @@ def process_actor_response(response: GqlResponse) -> UserResponse: _GRAPHQL_GET_ACTOR, {"username": actor}, process_actor_response ) - def validate_actor(self, field): + def validate_actor(self, field: wtforms.Field) -> None: actor = field.data actor_info = self._lookup_actor(actor) diff --git a/warehouse/oidc/forms/github.py b/warehouse/oidc/forms/github.py index ffafb27c408f..0572ed1096e6 100644 --- a/warehouse/oidc/forms/github.py +++ b/warehouse/oidc/forms/github.py @@ -45,16 +45,16 @@ class GitHubPublisherBase(wtforms.Form): # https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment environment = wtforms.StringField(validators=[wtforms.validators.Optional()]) - def __init__(self, *args, api_token, **kwargs): + def __init__(self, *args, api_token: str, **kwargs): super().__init__(*args, **kwargs) self._api_token = api_token - def _headers_auth(self): + def _headers_auth(self) -> dict[str, str]: if not self._api_token: return {} return {"Authorization": f"token {self._api_token}"} - def _lookup_owner(self, owner): + def _lookup_owner(self, owner: str) -> dict[str, str | int]: # To actually validate the owner, we ask GitHub's API about them. # We can't do this for the repository, since it might be private. try: @@ -113,7 +113,7 @@ def _lookup_owner(self, owner): return response.json() - def validate_owner(self, field): + def validate_owner(self, field: wtforms.Field) -> None: owner = field.data # We pre-filter owners with a regex, to avoid loading GitHub's API @@ -129,7 +129,7 @@ def validate_owner(self, field): self.normalized_owner = owner_info["login"] self.owner_id = owner_info["id"] - def validate_workflow_filename(self, field): + def validate_workflow_filename(self, field: wtforms.Field) -> None: workflow_filename = field.data if not ( @@ -144,7 +144,7 @@ def validate_workflow_filename(self, field): _("Workflow filename must be a filename only, without directories") ) - def validate_environment(self, field): + def validate_environment(self, field: wtforms.Field) -> None: environment = field.data if not environment: @@ -174,7 +174,7 @@ def validate_environment(self, field): ) @property - def normalized_environment(self): + def normalized_environment(self) -> str: # The only normalization is due to case-insensitivity. # # NOTE: We explicitly do not compare `self.environment.data` to None, diff --git a/warehouse/oidc/forms/gitlab.py b/warehouse/oidc/forms/gitlab.py index 492d39baadf8..00b100cf5a51 100644 --- a/warehouse/oidc/forms/gitlab.py +++ b/warehouse/oidc/forms/gitlab.py @@ -76,7 +76,7 @@ class GitLabPublisherBase(wtforms.Form): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def validate_workflow_filepath(self, field): + def validate_workflow_filepath(self, field: wtforms.Field) -> None: workflow_filepath = field.data if not ( @@ -91,7 +91,7 @@ def validate_workflow_filepath(self, field): ) @property - def normalized_environment(self): + def normalized_environment(self) -> str: # NOTE: We explicitly do not compare `self.environment.data` to None, # since it might also be falsey via an empty string (or might be # only whitespace, which we also treat as a None case). diff --git a/warehouse/oidc/interfaces.py b/warehouse/oidc/interfaces.py index 103c70a7283a..260e32ab7b31 100644 --- a/warehouse/oidc/interfaces.py +++ b/warehouse/oidc/interfaces.py @@ -9,14 +9,14 @@ from warehouse.rate_limiting.interfaces import RateLimiterException if TYPE_CHECKING: - from warehouse.oidc.models import PendingOIDCPublisher + from warehouse.oidc.models import OIDCPublisher, PendingOIDCPublisher from warehouse.packaging.models import Project -SignedClaims = NewType("SignedClaims", dict[str, Any]) +SignedClaims = NewType("SignedClaims", dict[str, Any]) # TODO: narrow this down class IOIDCPublisherService(Interface): - def verify_jwt_signature(unverified_token: str): + def verify_jwt_signature(unverified_token: str) -> SignedClaims | None: """ Verify the given JWT's signature, returning its signed claims if valid. If the signature is invalid, `None` is returned. @@ -26,7 +26,9 @@ def verify_jwt_signature(unverified_token: str): """ pass - def find_publisher(signed_claims: SignedClaims, *, pending: bool = False): + def find_publisher( + signed_claims: SignedClaims, *, pending: bool = False + ) -> OIDCPublisher | PendingOIDCPublisher | None: """ Given a mapping of signed claims produced by `verify_jwt_signature`, attempt to find and return either a `OIDCPublisher` or `PendingOIDCPublisher` @@ -38,7 +40,7 @@ def find_publisher(signed_claims: SignedClaims, *, pending: bool = False): def reify_pending_publisher( pending_publisher: PendingOIDCPublisher, project: Project - ): + ) -> OIDCPublisher: """ Reify the given pending `PendingOIDCPublisher` into an `OIDCPublisher`, adding it to the given project (presumed newly created) in the process. diff --git a/warehouse/oidc/models/_core.py b/warehouse/oidc/models/_core.py index 6ae8ab43bc7e..de7472f3da70 100644 --- a/warehouse/oidc/models/_core.py +++ b/warehouse/oidc/models/_core.py @@ -4,12 +4,13 @@ from collections.abc import Callable from typing import TYPE_CHECKING, Any, Self, TypedDict, TypeVar, Unpack +from uuid import UUID import rfc3986 import sentry_sdk from sqlalchemy import ForeignKey, Index, String, UniqueConstraint, func, orm -from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.dialects.postgresql import UUID as PG_UUID from sqlalchemy.orm import Mapped, mapped_column from warehouse import db @@ -19,6 +20,7 @@ if TYPE_CHECKING: from pypi_attestations import Publisher + from sqlalchemy.orm import Session from warehouse.accounts.models import User from warehouse.macaroons.models import Macaroon @@ -77,7 +79,7 @@ def wrapper( def check_existing_jti( _ground_truth, - signed_claim, + signed_claim: str, _all_signed_claims, **kwargs: Unpack[CheckNamedArguments], ) -> bool: @@ -99,14 +101,14 @@ class OIDCPublisherProjectAssociation(db.Model): __tablename__ = "oidc_publisher_project_association" __table_args__ = (UniqueConstraint("oidc_publisher_id", "project_id"),) - oidc_publisher_id = mapped_column( - UUID(as_uuid=True), + oidc_publisher_id: Mapped[UUID] = mapped_column( + PG_UUID(as_uuid=True), ForeignKey("oidc_publishers.id"), nullable=False, primary_key=True, ) - project_id = mapped_column( - UUID(as_uuid=True), + project_id: Mapped[UUID] = mapped_column( + PG_UUID(as_uuid=True), ForeignKey("projects.id", onupdate="CASCADE", ondelete="CASCADE"), nullable=False, primary_key=True, @@ -122,7 +124,7 @@ class OIDCPublisherMixin: # Each hierarchy of OIDC publishers (both `OIDCPublisher` and # `PendingOIDCPublisher`) use a `discriminator` column for model # polymorphism, but the two are not mutually polymorphic at the DB level. - discriminator = mapped_column(String) + discriminator: Mapped[str | None] = mapped_column(String) # A map of claim names to "check" functions, each of which # has the signature `check(ground-truth, signed-claim, all-signed-claims) -> bool`. @@ -160,7 +162,7 @@ class OIDCPublisherMixin: # the most optional constraints satisfied. # @classmethod - def lookup_by_claims(cls, session, signed_claims: SignedClaims) -> Self: + def lookup_by_claims(cls, session: Session, signed_claims: SignedClaims) -> Self: raise NotImplementedError @classmethod @@ -224,7 +226,7 @@ def check_claims_existence(cls, signed_claims: SignedClaims) -> None: def verify_claims( self, signed_claims: SignedClaims, publisher_service: OIDCPublisherService - ): + ) -> bool: """ Given a JWT that has been successfully decoded (checked for a valid signature and basic claims), verify it against the more specific @@ -339,7 +341,7 @@ def verify_url(self, url: str) -> bool: url=url, ) - def exists(self, session) -> bool: # pragma: no cover + def exists(self, session: Session) -> bool: # pragma: no cover """ Check if the publisher exists in the database """ @@ -380,9 +382,9 @@ class PendingOIDCPublisher(OIDCPublisherMixin, db.Model): __tablename__ = "pending_oidc_publishers" - project_name = mapped_column(String, nullable=False) - added_by_id = mapped_column( - UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, index=True + project_name: Mapped[str] = mapped_column(String, nullable=False) + added_by_id: Mapped[UUID] = mapped_column( + PG_UUID(as_uuid=True), ForeignKey("users.id"), nullable=False, index=True ) added_by: Mapped[User] = orm.relationship(back_populates="pending_oidc_publishers") @@ -397,7 +399,7 @@ class PendingOIDCPublisher(OIDCPublisherMixin, db.Model): "polymorphic_on": OIDCPublisherMixin.discriminator, } - def reify(self, session): # pragma: no cover + def reify(self, session: Session) -> OIDCPublisher: # pragma: no cover """ Return an equivalent "normal" OIDC publisher model for this pending publisher, deleting the pending publisher in the process. diff --git a/warehouse/oidc/models/activestate.py b/warehouse/oidc/models/activestate.py index 650568517bc4..6e091cc3422a 100644 --- a/warehouse/oidc/models/activestate.py +++ b/warehouse/oidc/models/activestate.py @@ -1,12 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import typing import urllib from typing import Any, Self +from uuid import UUID from sqlalchemy import ForeignKey, String, UniqueConstraint, and_, exists -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import Query, mapped_column +from sqlalchemy.dialects.postgresql import UUID as PG_UUID +from sqlalchemy.orm import Mapped, Query, mapped_column import warehouse.oidc.models._core as oidccore @@ -18,6 +22,9 @@ PendingOIDCPublisher, ) +if typing.TYPE_CHECKING: + from sqlalchemy.orm import Session + ACTIVESTATE_OIDC_ISSUER_URL = "https://platform.activestate.com/api/v1/oauth/oidc" _ACTIVESTATE_URL = "https://platform.activestate.com" @@ -53,13 +60,13 @@ class ActiveStatePublisherMixin: Common functionality for both pending and concrete ActiveState OIDC publishers. """ - organization = mapped_column(String, nullable=False) - activestate_project_name = mapped_column(String, nullable=False) - actor = mapped_column(String, nullable=False) + organization: Mapped[str] = mapped_column(String, nullable=False) + activestate_project_name: Mapped[str] = mapped_column(String, nullable=False) + actor: Mapped[str] = mapped_column(String, nullable=False) # 'actor' (The ActiveState platform username) is obtained from the user # while configuring the publisher We'll make an api call to ActiveState to # get the 'actor_id' - actor_id = mapped_column(String, nullable=False) + actor_id: Mapped[str] = mapped_column(String, nullable=False) __required_verifiable_claims__: dict[str, CheckClaimCallable[Any]] = { "sub": _check_sub, @@ -106,13 +113,13 @@ def publisher_base_url(self) -> str: def publisher_url(self, claims: SignedClaims | None = None) -> str: return self.publisher_base_url - def stored_claims(self, claims=None): + def stored_claims(self, claims: SignedClaims | None = None) -> dict: return {} def __str__(self) -> str: return self.publisher_url() - def exists(self, session) -> bool: + def exists(self, session: Session) -> bool: return session.query( exists().where( and_( @@ -135,7 +142,7 @@ def admin_details(self) -> list[tuple[str, str]]: ] @classmethod - def lookup_by_claims(cls, session, signed_claims: SignedClaims) -> Self: + def lookup_by_claims(cls, session: Session, signed_claims: SignedClaims) -> Self: query: Query = Query(cls).filter_by( organization=signed_claims["organization"], activestate_project_name=signed_claims["project"], @@ -161,8 +168,8 @@ class ActiveStatePublisher(ActiveStatePublisherMixin, OIDCPublisher): ), ) - id = mapped_column( - UUID(as_uuid=True), ForeignKey(OIDCPublisher.id), primary_key=True + id: Mapped[UUID] = mapped_column( + PG_UUID(as_uuid=True), ForeignKey(OIDCPublisher.id), primary_key=True ) @@ -178,11 +185,11 @@ class PendingActiveStatePublisher(ActiveStatePublisherMixin, PendingOIDCPublishe ), ) - id = mapped_column( - UUID(as_uuid=True), ForeignKey(PendingOIDCPublisher.id), primary_key=True + id: Mapped[UUID] = mapped_column( + PG_UUID(as_uuid=True), ForeignKey(PendingOIDCPublisher.id), primary_key=True ) - def reify(self, session): + def reify(self, session: Session) -> ActiveStatePublisher: """ Returns a `ActiveStatePublisher` for this `PendingActiveStatePublisher`, deleting the `PendingActiveStatePublisher` in the process. diff --git a/warehouse/oidc/models/github.py b/warehouse/oidc/models/github.py index 0f984e32918b..121e024604a6 100644 --- a/warehouse/oidc/models/github.py +++ b/warehouse/oidc/models/github.py @@ -1,14 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import re +import typing from typing import Any, Self +from uuid import UUID from more_itertools import first_true from pypi_attestations import GitHubPublisher as GitHubIdentity, Publisher from sqlalchemy import ForeignKey, String, UniqueConstraint, and_, exists -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import Query, mapped_column +from sqlalchemy.dialects.postgresql import UUID as PG_UUID +from sqlalchemy.orm import Mapped, Query, mapped_column from warehouse.oidc.errors import InvalidPublisherError from warehouse.oidc.interfaces import SignedClaims @@ -21,6 +25,9 @@ ) from warehouse.oidc.urls import verify_url_from_reference +if typing.TYPE_CHECKING: + from sqlalchemy.orm import Session + GITHUB_OIDC_ISSUER_URL = "https://token.actions.githubusercontent.com" # This expression matches the workflow filename component of a GitHub @@ -48,7 +55,9 @@ def _extract_workflow_filename(workflow_ref: str) -> str | None: return None -def _check_repository(ground_truth, signed_claim, _all_signed_claims, **_kwargs): +def _check_repository( + ground_truth: str, signed_claim: str, _all_signed_claims: SignedClaims, **_kwargs +) -> bool: # Defensive: GitHub should never give us an empty repository claim. if not signed_claim: return False @@ -57,7 +66,9 @@ def _check_repository(ground_truth, signed_claim, _all_signed_claims, **_kwargs) return signed_claim.lower() == ground_truth.lower() -def _check_job_workflow_ref(ground_truth, signed_claim, all_signed_claims, **_kwargs): +def _check_job_workflow_ref( + ground_truth: str, signed_claim: str, all_signed_claims: SignedClaims, **_kwargs +) -> bool: # We expect a string formatted as follows: # OWNER/REPO/.github/workflows/WORKFLOW.yml@REF # where REF is the value of either the `ref` or `sha` claims. @@ -88,7 +99,12 @@ def _check_job_workflow_ref(ground_truth, signed_claim, all_signed_claims, **_kw return True -def _check_environment(ground_truth, signed_claim, _all_signed_claims, **_kwargs): +def _check_environment( + ground_truth: str, + signed_claim: str | None, + _all_signed_claims: SignedClaims, + **_kwargs, +) -> bool: # When there is an environment, we expect a case-insensitive string. # https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment # For tokens that are generated outside of an environment, the claim will @@ -110,7 +126,9 @@ def _check_environment(ground_truth, signed_claim, _all_signed_claims, **_kwargs return ground_truth.lower() == signed_claim.lower() -def _check_sub(ground_truth, signed_claim, _all_signed_claims, **_kwargs): +def _check_sub( + ground_truth: str, signed_claim: str, _all_signed_claims: SignedClaims, **_kwargs +) -> bool: # We expect a string formatted as follows: # repo:ORG/REPO[:OPTIONAL-STUFF] # where :OPTIONAL-STUFF is a concatenation of other job context @@ -139,11 +157,11 @@ class GitHubPublisherMixin: Common functionality for both pending and concrete GitHub OIDC publishers. """ - repository_name = mapped_column(String, nullable=False) - repository_owner = mapped_column(String, nullable=False) - repository_owner_id = mapped_column(String, nullable=False) - workflow_filename = mapped_column(String, nullable=False) - environment = mapped_column(String, nullable=False) + repository_name: Mapped[str] = mapped_column(String, nullable=False) + repository_owner: Mapped[str] = mapped_column(String, nullable=False) + repository_owner_id: Mapped[str] = mapped_column(String, nullable=False) + workflow_filename: Mapped[str] = mapped_column(String, nullable=False) + environment: Mapped[str] = mapped_column(String, nullable=False) __required_verifiable_claims__: dict[str, CheckClaimCallable[Any]] = { "sub": _check_sub, @@ -204,7 +222,7 @@ def _get_publisher_for_environment( return None @classmethod - def lookup_by_claims(cls, session, signed_claims: SignedClaims) -> Self: + def lookup_by_claims(cls, session: Session, signed_claims: SignedClaims) -> Self: repository = signed_claims["repository"] repository_owner, repository_name = repository.split("/", 1) job_workflow_ref = signed_claims["job_workflow_ref"] @@ -229,27 +247,27 @@ def lookup_by_claims(cls, session, signed_claims: SignedClaims) -> Self: raise InvalidPublisherError("Publisher with matching claims was not found") @property - def _workflow_slug(self): + def _workflow_slug(self) -> str: return f".github/workflows/{self.workflow_filename}" @property - def publisher_name(self): + def publisher_name(self) -> str: return "GitHub" @property - def repository(self): + def repository(self) -> str: return f"{self.repository_owner}/{self.repository_name}" @property - def job_workflow_ref(self): + def job_workflow_ref(self) -> str: return f"{self.repository}/{self._workflow_slug}" @property - def sub(self): + def sub(self) -> str: return f"repo:{self.repository}" @property - def publisher_base_url(self): + def publisher_base_url(self) -> str: return f"https://github.com/{self.repository}" @property @@ -257,7 +275,7 @@ def jti(self) -> str: """Placeholder value for JTI.""" return "placeholder" - def publisher_url(self, claims=None): + def publisher_url(self, claims: SignedClaims | None = None) -> str: base = self.publisher_base_url sha = claims.get("sha") if claims else None @@ -273,14 +291,14 @@ def attestation_identity(self) -> Publisher | None: environment=self.environment if self.environment else None, ) - def stored_claims(self, claims=None): - claims = claims if claims else {} - return {"ref": claims.get("ref"), "sha": claims.get("sha")} + def stored_claims(self, claims: SignedClaims | None = None) -> dict: + claims_obj = claims if claims else {} + return {"ref": claims_obj.get("ref"), "sha": claims_obj.get("sha")} - def __str__(self): + def __str__(self) -> str: return self.workflow_filename - def exists(self, session) -> bool: + def exists(self, session: Session) -> bool: return session.query( exists().where( and_( @@ -318,11 +336,11 @@ class GitHubPublisher(GitHubPublisherMixin, OIDCPublisher): ), ) - id = mapped_column( - UUID(as_uuid=True), ForeignKey(OIDCPublisher.id), primary_key=True + id: Mapped[UUID] = mapped_column( + PG_UUID(as_uuid=True), ForeignKey(OIDCPublisher.id), primary_key=True ) - def verify_url(self, url: str): + def verify_url(self, url: str) -> bool: """ Verify a given URL against this GitHub's publisher information @@ -375,11 +393,11 @@ class PendingGitHubPublisher(GitHubPublisherMixin, PendingOIDCPublisher): ), ) - id = mapped_column( - UUID(as_uuid=True), ForeignKey(PendingOIDCPublisher.id), primary_key=True + id: Mapped[UUID] = mapped_column( + PG_UUID(as_uuid=True), ForeignKey(PendingOIDCPublisher.id), primary_key=True ) - def reify(self, session): + def reify(self, session: Session) -> GitHubPublisher: """ Returns a `GitHubPublisher` for this `PendingGitHubPublisher`, deleting the `PendingGitHubPublisher` in the process. diff --git a/warehouse/oidc/models/gitlab.py b/warehouse/oidc/models/gitlab.py index 446d0138026c..16383a037db9 100644 --- a/warehouse/oidc/models/gitlab.py +++ b/warehouse/oidc/models/gitlab.py @@ -1,14 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import re +import typing from typing import Any, Self +from uuid import UUID from more_itertools import first_true from pypi_attestations import GitLabPublisher as GitLabIdentity, Publisher from sqlalchemy import ForeignKey, String, UniqueConstraint, and_, exists -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import Query, mapped_column +from sqlalchemy.dialects.postgresql import UUID as PG_UUID +from sqlalchemy.orm import Mapped, Query, mapped_column from warehouse.oidc.errors import InvalidPublisherError from warehouse.oidc.interfaces import SignedClaims @@ -20,6 +24,9 @@ ) from warehouse.oidc.urls import verify_url_from_reference +if typing.TYPE_CHECKING: + from sqlalchemy.orm import Session + GITLAB_OIDC_ISSUER_URL = "https://gitlab.com" # This expression matches the workflow filepath component of a GitLab @@ -51,7 +58,13 @@ def _extract_workflow_filepath(ci_config_ref_uri: str) -> str | None: return None -def _check_project_path(ground_truth, signed_claim, _all_signed_claims, **_kwargs): +def _check_project_path( + ground_truth: str, + signed_claim: str | None, + _all_signed_claims: SignedClaims, + **_kwargs, +) -> bool: + # Defensive: GitLab should never give us an empty project_path claim. if not signed_claim: return False @@ -60,7 +73,12 @@ def _check_project_path(ground_truth, signed_claim, _all_signed_claims, **_kwarg return signed_claim.lower() == ground_truth.lower() -def _check_ci_config_ref_uri(ground_truth, signed_claim, all_signed_claims, **_kwargs): +def _check_ci_config_ref_uri( + ground_truth: str, + signed_claim: str | None, + all_signed_claims: SignedClaims, + **_kwargs, +) -> bool: # We expect a string formatted as follows: # gitlab.com/OWNER/REPO//WORKFLOW_PATH/WORKFLOW_FILE.yml@REF # where REF is the value of the `ref_path` claim. @@ -86,7 +104,12 @@ def _check_ci_config_ref_uri(ground_truth, signed_claim, all_signed_claims, **_k return True -def _check_environment(ground_truth, signed_claim, _all_signed_claims, **_kwargs): +def _check_environment( + ground_truth: str, + signed_claim: str | None, + _all_signed_claims: SignedClaims, + **_kwargs, +) -> bool: # When there is an environment, we expect a string. # For tokens that are generated outside of an environment, the claim will # be missing. @@ -105,7 +128,12 @@ def _check_environment(ground_truth, signed_claim, _all_signed_claims, **_kwargs return ground_truth == signed_claim -def _check_sub(ground_truth, signed_claim, _all_signed_claims, **_kwargs): +def _check_sub( + ground_truth: str, + signed_claim: str | None, + _all_signed_claims: SignedClaims, + **_kwargs, +) -> bool: # We expect a string formatted as follows: # project_path:NAMESPACE/PROJECT[:OPTIONAL-STUFF] # where :OPTIONAL-STUFF is a concatenation of other job context @@ -134,10 +162,10 @@ class GitLabPublisherMixin: Common functionality for both pending and concrete GitLab OIDC publishers. """ - namespace = mapped_column(String, nullable=False) - project = mapped_column(String, nullable=False) - workflow_filepath = mapped_column(String, nullable=False) - environment = mapped_column(String, nullable=False) + namespace: Mapped[str] = mapped_column(String, nullable=False) + project: Mapped[str] = mapped_column(String, nullable=False) + workflow_filepath: Mapped[str] = mapped_column(String, nullable=False) + environment: Mapped[str] = mapped_column(String, nullable=False) __required_verifiable_claims__: dict[str, CheckClaimCallable[Any]] = { "sub": _check_sub, @@ -204,7 +232,7 @@ def _get_publisher_for_environment( return None @classmethod - def lookup_by_claims(cls, session, signed_claims: SignedClaims) -> Self: + def lookup_by_claims(cls, session: Session, signed_claims: SignedClaims) -> Self: project_path = signed_claims["project_path"] ci_config_ref_uri = signed_claims["ci_config_ref_uri"] namespace, project = project_path.rsplit("/", 1) @@ -227,23 +255,23 @@ def lookup_by_claims(cls, session, signed_claims: SignedClaims) -> Self: raise InvalidPublisherError("Publisher with matching claims was not found") @property - def project_path(self): + def project_path(self) -> str: return f"{self.namespace}/{self.project}" @property - def sub(self): + def sub(self) -> str: return f"project_path:{self.project_path}" @property - def ci_config_ref_uri(self): + def ci_config_ref_uri(self) -> str: return f"gitlab.com/{self.project_path}//{self.workflow_filepath}" @property - def publisher_name(self): + def publisher_name(self) -> str: return "GitLab" @property - def publisher_base_url(self): + def publisher_base_url(self) -> str: return f"https://gitlab.com/{self.project_path}" @property @@ -251,7 +279,7 @@ def jti(self) -> str: """Placeholder value for JTI.""" return "placeholder" - def publisher_url(self, claims=None): + def publisher_url(self, claims: SignedClaims | None = None) -> str | None: base = self.publisher_base_url return f"{base}/commit/{claims['sha']}" if claims else base @@ -263,14 +291,14 @@ def attestation_identity(self) -> Publisher | None: environment=self.environment if self.environment else None, ) - def stored_claims(self, claims=None): - claims = claims if claims else {} - return {"ref_path": claims.get("ref_path"), "sha": claims.get("sha")} + def stored_claims(self, claims: SignedClaims | None = None) -> dict: + claims_obj = claims if claims else {} + return {"ref_path": claims_obj.get("ref_path"), "sha": claims_obj.get("sha")} - def __str__(self): + def __str__(self) -> str: return self.workflow_filepath - def exists(self, session) -> bool: + def exists(self, session: Session) -> bool: return session.query( exists().where( and_( @@ -307,11 +335,11 @@ class GitLabPublisher(GitLabPublisherMixin, OIDCPublisher): ), ) - id = mapped_column( - UUID(as_uuid=True), ForeignKey(OIDCPublisher.id), primary_key=True + id: Mapped[UUID] = mapped_column( + PG_UUID(as_uuid=True), ForeignKey(OIDCPublisher.id), primary_key=True ) - def verify_url(self, url: str): + def verify_url(self, url: str) -> bool: """ Verify a given URL against this GitLab's publisher information @@ -382,11 +410,11 @@ class PendingGitLabPublisher(GitLabPublisherMixin, PendingOIDCPublisher): ), ) - id = mapped_column( - UUID(as_uuid=True), ForeignKey(PendingOIDCPublisher.id), primary_key=True + id: Mapped[UUID] = mapped_column( + PG_UUID(as_uuid=True), ForeignKey(PendingOIDCPublisher.id), primary_key=True ) - def reify(self, session): + def reify(self, session: Session) -> GitLabPublisher: """ Returns a `GitLabPublisher` for this `PendingGitLabPublisher`, deleting the `PendingGitLabPublisher` in the process. diff --git a/warehouse/oidc/models/google.py b/warehouse/oidc/models/google.py index f5b55b44e92b..f3ee5847f0c8 100644 --- a/warehouse/oidc/models/google.py +++ b/warehouse/oidc/models/google.py @@ -1,12 +1,17 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import typing + from typing import Any, Self +from uuid import UUID from more_itertools import first_true from pypi_attestations import GooglePublisher as GoogleIdentity, Publisher from sqlalchemy import ForeignKey, String, UniqueConstraint, and_, exists -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import Query, mapped_column +from sqlalchemy.dialects.postgresql import UUID as PG_UUID +from sqlalchemy.orm import Mapped, Query, mapped_column from warehouse.oidc.errors import InvalidPublisherError from warehouse.oidc.interfaces import SignedClaims @@ -18,6 +23,10 @@ check_claim_invariant, ) +if typing.TYPE_CHECKING: + from sqlalchemy.orm import Session + + GOOGLE_OIDC_ISSUER_URL = "https://accounts.google.com" @@ -46,8 +55,8 @@ class GooglePublisherMixin: providers. """ - email = mapped_column(String, nullable=False) - sub = mapped_column(String, nullable=True) + email: Mapped[str] = mapped_column(String, nullable=False) + sub: Mapped[str] = mapped_column(String, nullable=True) __required_verifiable_claims__: dict[str, CheckClaimCallable[Any]] = { "email": check_claim_binary(str.__eq__), @@ -61,7 +70,7 @@ class GooglePublisherMixin: __unchecked_claims__ = {"azp", "google"} @classmethod - def lookup_by_claims(cls, session, signed_claims: SignedClaims) -> Self: + def lookup_by_claims(cls, session: Session, signed_claims: SignedClaims) -> Self: query: Query = Query(cls).filter_by(email=signed_claims["email"]) publishers = query.with_session(session).all() @@ -77,33 +86,33 @@ def lookup_by_claims(cls, session, signed_claims: SignedClaims) -> Self: raise InvalidPublisherError("Publisher with matching claims was not found") @property - def publisher_name(self): + def publisher_name(self) -> str: return "Google" @property - def publisher_base_url(self): + def publisher_base_url(self) -> None: return None - def publisher_url(self, claims=None): + def publisher_url(self, claims: SignedClaims | None = None) -> None: return None @property def attestation_identity(self) -> Publisher | None: return GoogleIdentity(email=self.email) - def stored_claims(self, claims=None): + def stored_claims(self, claims: SignedClaims | None = None) -> dict: return {} @property - def email_verified(self): + def email_verified(self) -> bool: # We don't consider a claim set valid unless `email_verified` is true; # no other states are possible. return True - def __str__(self): + def __str__(self) -> str: return self.email - def exists(self, session) -> bool: + def exists(self, session: Session) -> bool: return session.query( exists().where( and_( @@ -135,8 +144,8 @@ class GooglePublisher(GooglePublisherMixin, OIDCPublisher): ), ) - id = mapped_column( - UUID(as_uuid=True), ForeignKey(OIDCPublisher.id), primary_key=True + id: Mapped[UUID] = mapped_column( + PG_UUID(as_uuid=True), ForeignKey(OIDCPublisher.id), primary_key=True ) @@ -151,11 +160,11 @@ class PendingGooglePublisher(GooglePublisherMixin, PendingOIDCPublisher): ), ) - id = mapped_column( - UUID(as_uuid=True), ForeignKey(PendingOIDCPublisher.id), primary_key=True + id: Mapped[UUID] = mapped_column( + PG_UUID(as_uuid=True), ForeignKey(PendingOIDCPublisher.id), primary_key=True ) - def reify(self, session): + def reify(self, session: Session) -> GooglePublisher: """ Returns a `GooglePublisher` for this `PendingGooglePublisher`, deleting the `PendingGooglePublisher` in the process. diff --git a/warehouse/oidc/services.py b/warehouse/oidc/services.py index e46b985e4e19..69694124e552 100644 --- a/warehouse/oidc/services.py +++ b/warehouse/oidc/services.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + import json +import typing import warnings import jwt @@ -17,10 +20,24 @@ from warehouse.oidc.utils import find_publisher_by_issuer from warehouse.utils.exceptions import InsecureOIDCPublisherWarning +if typing.TYPE_CHECKING: + from pyramid.request import Request + from sqlalchemy.orm import Session + + from warehouse.packaging import Project + @implementer(IOIDCPublisherService) class NullOIDCPublisherService: - def __init__(self, session, publisher, issuer_url, audience, cache_url, metrics): + def __init__( + self, + session: Session, + publisher: str, + issuer_url: str, + audience: str, + cache_url: str, + metrics: IMetricsService, + ): warnings.warn( "NullOIDCPublisherService is intended only for use in development, " "you should not use it in production due to the lack of actual " @@ -64,7 +81,9 @@ def find_publisher( self.db, self.issuer_url, signed_claims, pending=pending ) - def reify_pending_publisher(self, pending_publisher, project): + def reify_pending_publisher( + self, pending_publisher: PendingOIDCPublisher, project: Project + ) -> OIDCPublisher: new_publisher = pending_publisher.reify(self.db) project.oidc_publishers.append(new_publisher) return new_publisher @@ -86,7 +105,15 @@ def store_jwt_identifier(self, jti: str, expiration: int) -> None: @implementer(IOIDCPublisherService) class OIDCPublisherService: - def __init__(self, session, publisher, issuer_url, audience, cache_url, metrics): + def __init__( + self, + session: Session, + publisher: str, + issuer_url: str, + audience: str, + cache_url: str, + metrics: IMetricsService, + ): self.db = session self.publisher = publisher self.issuer_url = issuer_url @@ -97,7 +124,7 @@ def __init__(self, session, publisher, issuer_url, audience, cache_url, metrics) self._publisher_jwk_key = f"/warehouse/oidc/jwks/{self.publisher}" self._publisher_timeout_key = f"{self._publisher_jwk_key}/timeout" - def _store_keyset(self, keys): + def _store_keyset(self, keys: dict) -> None: """ Store the given keyset for the given publisher, setting the timeout key in the process. @@ -107,7 +134,7 @@ def _store_keyset(self, keys): r.set(self._publisher_jwk_key, json.dumps(keys)) r.setex(self._publisher_timeout_key, 60, "placeholder") - def _get_keyset(self): + def _get_keyset(self) -> tuple[dict[str, dict], bool]: """ Return the cached keyset for the given publisher, or an empty keyset if no keys are currently cached. @@ -121,7 +148,7 @@ def _get_keyset(self): else: return ({}, timeout) - def _refresh_keyset(self): + def _refresh_keyset(self) -> dict[str, dict]: """ Attempt to refresh the keyset from the OIDC publisher, assuming no timeout is in effect. @@ -196,7 +223,7 @@ def _refresh_keyset(self): return keys - def _get_key(self, key_id): + def _get_key(self, key_id: str) -> jwt.PyJWK | None: """ Return a JWK for the given key ID, or None if the key can't be found in this publisher's keyset. @@ -328,19 +355,26 @@ def find_publisher( ) raise e - def reify_pending_publisher(self, pending_publisher, project) -> OIDCPublisher: + def reify_pending_publisher( + self, pending_publisher: PendingOIDCPublisher, project: Project + ) -> OIDCPublisher: new_publisher = pending_publisher.reify(self.db) project.oidc_publishers.append(new_publisher) return new_publisher class OIDCPublisherServiceFactory: - def __init__(self, publisher, issuer_url, service_class=OIDCPublisherService): + def __init__( + self, + publisher: str, + issuer_url: str, + service_class=OIDCPublisherService, # TODO: Unclear how to correctly type this + ): self.publisher = publisher self.issuer_url = issuer_url self.service_class = service_class - def __call__(self, _context, request): + def __call__(self, _context, request: Request) -> OIDCPublisherService: cache_url = request.registry.settings["oidc.jwk_cache_url"] audience = request.registry.settings["warehouse.oidc.audience"] metrics = request.find_service(IMetricsService, context=None) @@ -354,7 +388,7 @@ def __call__(self, _context, request): metrics, ) - def __eq__(self, other): + def __eq__(self, other) -> bool: if not isinstance(other, OIDCPublisherServiceFactory): return NotImplemented diff --git a/warehouse/oidc/utils.py b/warehouse/oidc/utils.py index 0a73d73dfa39..8dd02e5168c0 100644 --- a/warehouse/oidc/utils.py +++ b/warehouse/oidc/utils.py @@ -2,6 +2,8 @@ from __future__ import annotations +import typing + from dataclasses import dataclass from pyramid.authorization import Authenticated @@ -26,6 +28,10 @@ PendingOIDCPublisher, ) +if typing.TYPE_CHECKING: + from sqlalchemy.orm import Session + + OIDC_ISSUER_SERVICE_NAMES = { GITHUB_OIDC_ISSUER_URL: "github", GITLAB_OIDC_ISSUER_URL: "gitlab", @@ -61,7 +67,11 @@ def find_publisher_by_issuer( - session, issuer_url: str, signed_claims: SignedClaims, *, pending: bool = False + session: Session, + issuer_url: str, + signed_claims: SignedClaims, + *, + pending: bool = False, ) -> OIDCPublisher | PendingOIDCPublisher: """ Given an OIDC issuer URL and a dictionary of claims that have been verified diff --git a/warehouse/packaging/tasks.py b/warehouse/packaging/tasks.py index 0240847bc6de..397845688d96 100644 --- a/warehouse/packaging/tasks.py +++ b/warehouse/packaging/tasks.py @@ -298,7 +298,7 @@ def update_release_description(_task, request, release_id): retry_jitter=False, max_retries=5, ) -def update_bigquery_release_files(task, request, dist_metadata): +def update_bigquery_release_files(task, request, dist_metadata) -> None: """ Adds release file metadata to public BigQuery database """ @@ -317,7 +317,7 @@ def update_bigquery_release_files(task, request, dist_metadata): # Using the schema to populate the data allows us to automatically # set the values to their respective fields rather than assigning # values individually - json_rows: dict = {} + json_row: dict = {} for sch in table_schema: field_data = dist_metadata.get(sch.name, None) @@ -330,7 +330,7 @@ def update_bigquery_release_files(task, request, dist_metadata): field_data = None if field_data is None and sch.mode == "REPEATED": - json_rows[sch.name] = [] + json_row[sch.name] = [] elif field_data and sch.mode == "REPEATED": # Currently, some of the metadata fields such as # the 'platform' tag are incorrectly classified as a @@ -339,12 +339,12 @@ def update_bigquery_release_files(task, request, dist_metadata): # This extra check can be removed once # https://github.com/pypi/warehouse/issues/8257 is fixed if isinstance(field_data, str): - json_rows[sch.name] = [field_data] + json_row[sch.name] = [field_data] else: - json_rows[sch.name] = list(field_data) + json_row[sch.name] = list(field_data) else: - json_rows[sch.name] = field_data - json_rows = [json_rows] + json_row[sch.name] = field_data + json_rows = [json_row] bq.insert_rows_json( table=table_name, json_rows=json_rows, timeout=5.0, retry=None diff --git a/warehouse/tasks.py b/warehouse/tasks.py index bad2e6e51bc9..d44ca0730822 100644 --- a/warehouse/tasks.py +++ b/warehouse/tasks.py @@ -24,6 +24,7 @@ from warehouse.metrics import IMetricsService if typing.TYPE_CHECKING: + from pyramid.config import Configurator from pyramid.request import Request # We need to trick Celery into supporting rediss:// URLs which is how redis-py @@ -251,10 +252,10 @@ def add_task(): config.action(None, add_task, order=100) -def includeme(config): +def includeme(config: Configurator) -> None: s = config.registry.settings - broker_transport_options: dict[str, str | dict] = {} + broker_transport_options: dict[str, str | int | dict] = {} broker_url = s["celery.broker_redis_url"] diff --git a/warehouse/utils/enum.py b/warehouse/utils/enum.py index 1e91f7ebe4d6..faa4de1d84e9 100644 --- a/warehouse/utils/enum.py +++ b/warehouse/utils/enum.py @@ -2,6 +2,8 @@ import enum +from typing import Self + class StrLabelEnum(str, enum.Enum): """Base class for Enum with string value and display label.""" @@ -9,7 +11,7 @@ class StrLabelEnum(str, enum.Enum): label: str # Name = "value", _("Label") - def __new__(cls, value: str, label: str): + def __new__(cls, value: str, label: str) -> Self: obj = str.__new__(cls, value) obj._value_ = value obj.label = label