Source code for fsh_lib.auth

"""JWT auth primitives for kiln-generated FastAPI projects.

A session is a Pydantic model dumped into JWT claims.  Tokens
travel over one or both of two *sources*:

* ``"bearer"`` -- ``Authorization`` header; API clients.
* ``"cookie"`` -- ``httpOnly`` cookie; browser frontends (out of
  reach of JS so XSS can't steal it).

The signing secret lives in an env var (caller-named, typically
``JWT_SECRET``) so generated source never embeds a key.
"""

# NOTE: ``session_auth`` and the transport ``extract_dep`` helpers
# below build inner functions annotated with
# ``Annotated[..., Depends(<closure-local>)]``.  pydantic's
# ``TypeAdapter`` calls ``typing.get_type_hints`` against those
# inner functions when FastAPI builds the OpenAPI schema; closure
# locals aren't in ``__globals__``, so any stringified annotation
# (PEP 563) fails to resolve and 500s the schema build.  PEP 749's
# default deferred-but-lazy evaluation in 3.14 keeps annotations as
# real objects, preserving the closure scope -- but only as long as
# nothing forces them back to strings.  ``collections.abc`` must
# therefore be imported at runtime (not under ``TYPE_CHECKING``) so
# the same closure-local resolution can find ``Awaitable``,
# ``Callable``, and ``Sequence`` at request time.

import datetime
import os
from collections.abc import (  # noqa: TC003 -- runtime, see NOTE above
    Awaitable,
    Callable,
    Sequence,
)
from typing import Annotated, Any, Literal, Protocol

import jwt
from fastapi import Cookie, Depends, HTTPException, Response, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel

DEFAULT_TOKEN_TTL = datetime.timedelta(minutes=30)
"""Default ``exp`` stamped on tokens when the caller doesn't set one."""

Source = Literal["bearer", "cookie"]
SameSite = Literal["lax", "strict", "none"]


[docs] class LoginResponse(BaseModel): """OAuth2-shaped login body for the bearer case.""" access_token: str token_type: Literal["bearer"] = "bearer" # noqa: S105 -- not a secret
[docs] class OkResponse(BaseModel): """Minimal ack body for cookie-only login and every logout.""" ok: Literal[True] = True
[docs] class SessionStore(Protocol): """Hook pair for server-side session state (deny-list, sessions, ...). Turns the stateless-JWT flow stateful. The store receives the full session model so it can key on whatever identity claim the consumer picks (typically ``jti``); ``fsh_lib.auth`` stays agnostic. Both methods are async so the store can hit a database. """
[docs] async def is_revoked(self, session: BaseModel) -> bool: """Return ``True`` to reject the request with HTTP 401.""" ...
[docs] async def revoke(self, session: BaseModel) -> None: """Mark *session* dead. Must be idempotent.""" ...
def _unauthorized() -> HTTPException: """401 for missing or invalid tokens.""" return HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated", headers={"WWW-Authenticate": "Bearer"}, ) def _revoked() -> HTTPException: """401 for JWT-valid tokens the :class:`SessionStore` rejected.""" return HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Session revoked", headers={"WWW-Authenticate": "Bearer"}, ) class _Transport: """One way a JWT rides the request/response pair. Subclasses register themselves against a :data:`Source` value via :data:`_TRANSPORTS`; adding a third source (e.g. a header carrying an API key) means writing a subclass and dropping an entry in that dict -- no changes to the public functions. """ @classmethod def from_config(cls, **kwargs: Any) -> _Transport: """Build an instance from the loose config kwargs.""" raise NotImplementedError def extract_dep(self) -> Callable[..., Awaitable[str | None]]: raise NotImplementedError def emit( self, response: Response, token: str, ttl: datetime.timedelta, ) -> LoginResponse | None: """Write the token to this transport on login. Returning a :class:`LoginResponse` makes it the response body (the bearer case); returning ``None`` means the transport lives in headers only (the cookie case). """ raise NotImplementedError def clear(self, response: Response) -> None: raise NotImplementedError class _BearerTransport(_Transport): """``Authorization: Bearer`` header. ``token_url`` only surfaces to OpenAPI via :class:`OAuth2PasswordBearer`; runtime extraction reads the header. ``issue_session`` / ``clear_session`` don't call :meth:`extract_dep` so they pass ``None``. """ def __init__(self, token_url: str | None) -> None: self._oauth = ( OAuth2PasswordBearer(tokenUrl=token_url, auto_error=False) if token_url is not None else None ) @classmethod def from_config(cls, **kwargs: Any) -> _BearerTransport: return cls(kwargs.get("token_url")) def extract_dep(self) -> Callable[..., Awaitable[str | None]]: if self._oauth is None: # pragma: no cover -- session_auth pre-guards msg = "token_url is required for bearer extraction" raise ValueError(msg) oauth = self._oauth async def _extract( bearer: Annotated[str | None, Depends(oauth)] = None, ) -> str | None: return bearer return _extract def emit( self, response: Response, # noqa: ARG002 token: str, ttl: datetime.timedelta, # noqa: ARG002 ) -> LoginResponse | None: return LoginResponse(access_token=token) def clear(self, response: Response) -> None: # noqa: ARG002 # Bearer logout is client-side -- clients discard the token. return class _CookieTransport(_Transport): """``httpOnly`` cookie. ``secure`` / ``samesite`` must match between :meth:`emit` and :meth:`clear` -- browsers refuse to overwrite an existing cookie when either attribute differs. """ def __init__( self, name: str, *, secure: bool = True, samesite: SameSite = "lax", ) -> None: self._name = name self._secure = secure self._samesite = samesite @classmethod def from_config(cls, **kwargs: Any) -> _CookieTransport: name = kwargs.get("cookie_name") if name is None: msg = "cookie_name is required when 'cookie' is in sources" raise ValueError(msg) return cls( name, secure=kwargs.get("cookie_secure", True), samesite=kwargs.get("cookie_samesite", "lax"), ) def extract_dep(self) -> Callable[..., Awaitable[str | None]]: name = self._name async def _extract( cookie: Annotated[str | None, Cookie(alias=name)] = None, ) -> str | None: return cookie return _extract def emit( self, response: Response, token: str, ttl: datetime.timedelta, ) -> LoginResponse | None: response.set_cookie( key=self._name, value=token, max_age=int(ttl.total_seconds()), httponly=True, secure=self._secure, samesite=self._samesite, ) return None def clear(self, response: Response) -> None: response.delete_cookie( key=self._name, httponly=True, secure=self._secure, samesite=self._samesite, ) _TRANSPORTS: dict[Source, type[_Transport]] = { "bearer": _BearerTransport, "cookie": _CookieTransport, } async def _no_token() -> str | None: """Stand-in extractor for a source that isn't configured. Lets :func:`session_auth` expose a uniform ``(bearer, cookie)`` signature regardless of which sources are actually in use. FastAPI doesn't add a security scheme for a plain ``Depends(_no_token)``, so OpenAPI still advertises only the configured sources. """ return None def _build_transports( sources: Sequence[Source], **config: Any, ) -> dict[Source, _Transport]: """Build transport instances keyed on source name. Dispatches through :data:`_TRANSPORTS` so each subclass owns its own config-extraction rules via :meth:`_Transport.from_config`. ``session_auth`` still guards the bearer-needs-``token_url`` case separately because ``issue_session`` / ``clear_session`` build bearer transports without one (they don't call :meth:`extract_dep`). """ unknown = [src for src in sources if src not in _TRANSPORTS] if unknown: msg = f"unknown source(s): {sorted(set(unknown))}" raise ValueError(msg) if not sources: msg = f"sources must contain at least one of {sorted(_TRANSPORTS)}" raise ValueError(msg) return {src: _TRANSPORTS[src].from_config(**config) for src in sources}
[docs] def encode_jwt( payload: dict[str, Any], *, secret_env: str, algorithm: str, ttl: datetime.timedelta = DEFAULT_TOKEN_TTL, ) -> str: """Sign *payload* as a JWT; stamps ``exp`` if absent. Never mutates.""" claims = dict(payload) claims.setdefault( "exp", datetime.datetime.now(tz=datetime.UTC) + ttl, ) return jwt.encode(claims, os.environ[secret_env], algorithm=algorithm)
[docs] def decode_jwt( token: str, *, secret_env: str, algorithm: str, ) -> dict[str, Any]: """Decode *token* and return its claims, or raise HTTP 401.""" if not (secret := os.environ.get(secret_env)): raise _unauthorized() try: return jwt.decode(token, secret, algorithms=[algorithm]) except jwt.InvalidTokenError as exc: raise _unauthorized() from exc
[docs] def session_auth[SessionT: BaseModel]( schema: type[SessionT], sources: Sequence[Source], *, secret_env: str, algorithm: str, token_url: str | None = None, cookie_name: str | None = None, store: SessionStore | None = None, ) -> Callable[..., Awaitable[SessionT]]: """Build a FastAPI dep that yields a validated *schema* instance. The returned callable takes one parameter per supported transport; configured sources plug in their real extractors, unconfigured ones get a no-token shim (returns ``None``). The first non-``None`` token wins. Claims parse through :meth:`~pydantic.BaseModel.model_validate` so handlers get the full model, not a raw dict. *store*, when supplied, turns every authenticated request into a deny-list check -- avoids a wrapper dep on the consumer side. """ transports = _build_transports( sources, token_url=token_url, cookie_name=cookie_name, ) if "bearer" in sources and token_url is None: msg = "token_url is required when 'bearer' is in sources" raise ValueError(msg) bearer_transport = transports.get("bearer") cookie_transport = transports.get("cookie") bearer_ext = ( bearer_transport.extract_dep() if bearer_transport else _no_token ) cookie_ext = ( cookie_transport.extract_dep() if cookie_transport else _no_token ) async def resolve(token: str | None) -> SessionT: if token is None: raise _unauthorized() claims = decode_jwt(token, secret_env=secret_env, algorithm=algorithm) session = schema.model_validate(claims) if store is not None and await store.is_revoked(session): raise _revoked() return session async def get_session( bearer: Annotated[str | None, Depends(bearer_ext)] = None, cookie: Annotated[str | None, Depends(cookie_ext)] = None, ) -> SessionT: return await resolve(bearer or cookie) return get_session
[docs] def issue_session( response: Response, session: BaseModel | None, *, sources: Sequence[Source], secret_env: str, algorithm: str, ttl: datetime.timedelta = DEFAULT_TOKEN_TTL, cookie_name: str | None = None, cookie_secure: bool = True, cookie_samesite: SameSite = "lax", ) -> LoginResponse | OkResponse: """Mint a JWT and emit it to every configured transport. Upstream validate_login returns a None session in the case of no user, a password that failed validation, etc. """ if session is None: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials", headers={"WWW-Authenticate": "Bearer"}, ) transports = _build_transports( sources, cookie_name=cookie_name, cookie_secure=cookie_secure, cookie_samesite=cookie_samesite, ) token = encode_jwt( session.model_dump(mode="json"), secret_env=secret_env, algorithm=algorithm, ttl=ttl, ) body: LoginResponse | OkResponse = OkResponse() for transport in transports.values(): if emitted := transport.emit(response, token, ttl): body = emitted return body
[docs] def clear_session( response: Response, *, sources: Sequence[Source], cookie_name: str | None = None, cookie_secure: bool = True, cookie_samesite: SameSite = "lax", ) -> OkResponse: """Delete the session cookie if configured; ack for bearer. ``cookie_secure`` and ``cookie_samesite`` must match the values :func:`issue_session` used -- browsers refuse to overwrite an existing cookie when either attribute differs. """ transports = _build_transports( sources, cookie_name=cookie_name, cookie_secure=cookie_secure, cookie_samesite=cookie_samesite, ) for transport in transports.values(): transport.clear(response) return OkResponse()