"""Rate-limiting primitives for kiln-generated FastAPI projects.
This module's runtime dependency on ``slowapi`` and ``limits`` is
gated behind the ``rate-limit`` extra. Install with::
pip install 'kiln-generator[rate-limit]'
# or: uv add 'kiln-generator[rate-limit]'
The pieces:
* :class:`RateLimitBucketMixin` -- a SQLAlchemy mixin supplying the
three columns every counter row needs (``key``, ``hits``,
``expires_at``). Same idiom as :class:`fsh_lib.files.FileMixin`:
the consumer subclasses it on their own model so they own the
table and we own the columns.
* :class:`PostgresStorage` -- a ``limits``-compatible synchronous
storage backend backed by a small dedicated SQLAlchemy engine.
``slowapi``'s enforcement path calls ``limiter.hit(...)``
synchronously (not awaited) so an async storage cannot satisfy
it; we use a separate sync engine targeting the same Postgres
database the rest of the app talks to.
* :func:`build_limiter` -- factory that constructs a slowapi
``slowapi.Limiter`` and wires our :class:`PostgresStorage`
in as its backing store, swapping out the placeholder
``memory://`` storage slowapi creates internally.
* :func:`default_key_func` -- the per-request rate-limit key
callable used by default (client IP).
"""
from __future__ import annotations
import datetime as _dt
from typing import TYPE_CHECKING, Any
from limits.storage import Storage
from limits.strategies import STRATEGIES
from slowapi import Limiter
from sqlalchemy import (
BigInteger,
DateTime,
String,
case,
create_engine,
delete,
select,
text,
)
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Mapped, mapped_column, sessionmaker
if TYPE_CHECKING:
from collections.abc import Callable, Iterable
from fastapi import Request
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session
[docs]
class RateLimitBucketMixin:
"""SQLAlchemy mixin supplying the columns of a rate-limit bucket.
Subclass on a regular SQLAlchemy ``Base`` to carry the storage
columns:
.. code-block:: python
from fsh_lib.rate_limit import RateLimitBucketMixin
class RateLimitBucket(Base, RateLimitBucketMixin):
__tablename__ = "rate_limit_buckets"
Unlike :class:`fsh_lib.files.FileMixin`, the natural primary key
here is :attr:`key` itself (the limit identifier produced by the
``key_func`` plus the limit string). Declaring it
``primary_key=True`` means the consumer doesn't need to bring
their own PK plugin to use this mixin.
The consumer is responsible for migrating the table; ``be``
doesn't generate Alembic migrations.
"""
key: Mapped[str] = mapped_column(String(512), primary_key=True)
"""Rate-limit key. ``slowapi`` builds this from the route, the
``key_func`` output, and the limit string."""
hits: Mapped[int] = mapped_column(BigInteger, default=0, nullable=False)
"""Counter value for the current window."""
expires_at: Mapped[_dt.datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
index=True,
)
"""When the current window ends. Rows with
``expires_at < now()`` are stale and reset on the next hit."""
[docs]
def default_key_func(request: Request) -> str:
"""Default rate-limit key: client IP, falling back to ``unknown``.
Used when :attr:`~be.config.schema.RateLimitConfig.key_func` is
not configured. Behind a trusted proxy you almost certainly
want to point ``key_func`` at a function that reads
``X-Forwarded-For`` instead -- this default deliberately
refuses to trust any header.
"""
if request.client is None:
return "unknown"
return request.client.host
[docs]
class PostgresStorage(Storage):
"""``limits``-compatible storage backed by a sync Postgres engine.
slowapi's enforcement path is synchronous (``limiter.hit(...)``
is *not* awaited), so an async storage cannot satisfy it. This
class uses a dedicated synchronous SQLAlchemy engine pointed at
the same Postgres database as the rest of the app -- separate
connection pool, same data.
The counter row is upserted with Postgres ``INSERT ... ON
CONFLICT DO UPDATE``: a hit on a fresh window inserts a row;
a hit on an active window increments :attr:`~RateLimitBucketMixin.hits`;
a hit on an expired window resets :attr:`~RateLimitBucketMixin.hits`
to ``amount`` and shifts :attr:`~RateLimitBucketMixin.expires_at`
forward.
"""
STORAGE_SCHEME = ["postgres-rate-limit"] # noqa: RUF012
"""URI scheme this storage registers under. Not used for
instantiation -- :func:`build_limiter` constructs the storage
directly and patches it onto the slowapi limiter -- but
``limits`` requires the attribute on every storage subclass.
Declared as an instance attribute (not :data:`~typing.ClassVar`)
to mirror the base ``Storage`` class -- ``limits`` annotates it
as such and a ``ClassVar`` override would conflict at type-check
time.
"""
def __init__(
self,
*,
model: type[RateLimitBucketMixin],
session_maker: Callable[[], Session],
) -> None:
"""Build a Postgres-backed storage.
Args:
model: The consumer's bucket model class (must mix in
:class:`RateLimitBucketMixin`).
session_maker: Zero-arg callable returning a sync
SQLAlchemy ``Session`` (typically a configured
``sessionmaker``).
"""
# ``Storage.__init__`` accepts a URI-like string for the
# registry-driven path; we never go through that path so any
# value is fine, including ``None``.
super().__init__(uri=None, wrap_exceptions=False)
self.model = model
self._session_maker = session_maker
@property
def base_exceptions(self) -> Any:
"""Exception class(es) ``limits`` should treat as storage failures.
The ``limits`` base class types this as
``type[Exception] | tuple[type[Exception], ...]``; we narrow
to a single class and annotate ``Any`` here to keep autodoc
from cross-referencing ``type`` (which collides with three
unrelated ``type:`` discriminator fields in the be schema).
Returns :class:`~sqlalchemy.exc.SQLAlchemyError`.
"""
return SQLAlchemyError
def _now(self) -> _dt.datetime:
return _dt.datetime.now(tz=_dt.UTC)
[docs]
def incr(self, key: str, expiry: int, amount: int = 1) -> int:
"""Increment *key* by *amount*, opening a fresh window when stale.
Args:
key: Rate-limit key.
expiry: Window duration in seconds.
amount: Increment step (defaults to 1).
Returns:
The new counter value after the increment.
"""
now = self._now()
new_expiry = now + _dt.timedelta(seconds=expiry)
cls = self.model
with self._session_maker() as session:
stmt = (
pg_insert(cls)
.values(
key=key,
hits=amount,
expires_at=new_expiry,
)
.on_conflict_do_update(
index_elements=["key"],
set_={
# Reset the counter (and shift the window
# forward) when the existing row is stale;
# otherwise just add to it.
"hits": _case_when_expired(cls, now, amount),
"expires_at": _case_when_expired_expiry(
cls, now, new_expiry
),
},
)
.returning(cls.hits)
)
result = session.execute(stmt).scalar_one()
session.commit()
return int(result)
[docs]
def get(self, key: str) -> int:
"""Return the current counter value for *key* (0 when stale)."""
now = self._now()
cls = self.model
with self._session_maker() as session:
stmt = select(cls.hits, cls.expires_at).where(cls.key == key)
row = session.execute(stmt).first()
if row is None:
return 0
hits, expires_at = row
if expires_at < now:
return 0
return int(hits)
[docs]
def get_expiry(self, key: str) -> float:
"""Return the window expiry for *key* as a UNIX timestamp.
``limits`` treats a value in the past as "no active window".
"""
cls = self.model
with self._session_maker() as session:
stmt = select(cls.expires_at).where(cls.key == key)
expires_at = session.execute(stmt).scalar_one_or_none()
if expires_at is None:
return self._now().timestamp()
return expires_at.timestamp()
[docs]
def check(self) -> bool:
"""Return whether the storage is reachable.
``limits`` calls this opportunistically when a previous call
raised; we keep it cheap by issuing ``SELECT 1`` rather than
touching the bucket table.
"""
try:
with self._session_maker() as session:
session.execute(text("SELECT 1"))
except SQLAlchemyError:
return False
return True
[docs]
def reset(self) -> int | None:
"""Delete every counter row. Returns the number deleted."""
cls = self.model
with self._session_maker() as session:
result = session.execute(delete(cls))
session.commit()
# ``execute`` returns a CursorResult for DML statements,
# whose ``rowcount`` is the number of affected rows. The
# static return type is the broader ``Result`` so we
# access it via getattr to avoid a type-check false
# positive without sacrificing the runtime behaviour.
rowcount: int | None = getattr(result, "rowcount", None)
return rowcount
[docs]
def clear(self, key: str) -> None:
"""Delete the counter row for *key* (no-op when absent)."""
cls = self.model
with self._session_maker() as session:
session.execute(delete(cls).where(cls.key == key))
session.commit()
def _case_when_expired(
cls: type[RateLimitBucketMixin],
now: _dt.datetime,
amount: int,
) -> Any:
"""Return a CASE: ``amount`` when stale, else ``hits + amount``."""
return case(
(cls.expires_at < now, amount),
else_=cls.hits + amount,
)
def _case_when_expired_expiry(
cls: type[RateLimitBucketMixin],
now: _dt.datetime,
new_expiry: _dt.datetime,
) -> Any:
"""Return a CASE that shifts ``expires_at`` forward only when stale."""
return case(
(cls.expires_at < now, new_expiry),
else_=cls.expires_at,
)
[docs]
def build_limiter(
*,
model: type[RateLimitBucketMixin],
sync_url: str,
key_func: Callable[[Request], str] | None = None,
default_limits: Iterable[str] = (),
headers_enabled: bool = True,
engine: Engine | None = None,
) -> Limiter:
"""Build a slowapi ``slowapi.Limiter`` backed by Postgres.
The returned limiter has its ``_storage`` and ``_limiter``
fields swapped out for our :class:`PostgresStorage` -- slowapi
constructs a placeholder ``memory://`` storage internally
because its public API only takes a URI, and we replace it
rather than going through URI dispatch (the storage needs
Python objects -- the bucket model and a sessionmaker -- that
don't round-trip through a URI).
Args:
model: The consumer's bucket model class (must mix in
:class:`RateLimitBucketMixin`).
sync_url: A *synchronous* Postgres DSN for the
rate-limit storage. The app's main async DSN
(``postgresql+asyncpg://...``) is fine to reuse with
the ``+asyncpg`` driver tag stripped.
key_func: Per-request key callable. Defaults to
:func:`default_key_func` (client IP).
default_limits: Iterable of limit strings applied to every
route that doesn't have its own ``@limiter.limit(...)``.
headers_enabled: Whether slowapi emits ``X-RateLimit-*``
response headers.
engine: Pre-built sync engine. Optional escape hatch for
tests / custom pools; production callers leave it
``None`` and let the helper build one from *sync_url*.
Returns:
A configured slowapi ``slowapi.Limiter``.
"""
if engine is None:
engine = create_engine(sync_url, future=True, pool_pre_ping=True)
session_maker = sessionmaker(engine, expire_on_commit=False)
storage = PostgresStorage(model=model, session_maker=session_maker)
limiter = Limiter(
key_func=key_func or default_key_func,
default_limits=list(default_limits),
headers_enabled=headers_enabled,
# Placeholder; we replace ``_storage`` below. slowapi's
# ``__init__`` insists on building one storage up front.
storage_uri="memory://",
)
# Swap in the real storage and rebuild the strategy that wraps
# it. ``fixed-window`` matches slowapi's default strategy.
limiter._storage = storage # noqa: SLF001
limiter._limiter = STRATEGIES["fixed-window"](storage) # noqa: SLF001
return limiter