import asyncio
import inspect
import os
from pathlib import Path
__all__ = ["API"]
import uvicorn
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.errors import ServerErrorMiddleware
from starlette.middleware.exceptions import ExceptionMiddleware
from starlette.middleware.gzip import GZipMiddleware
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
from starlette.middleware.sessions import SessionMiddleware
from starlette.middleware.trustedhost import TrustedHostMiddleware
from . import status_codes
from .background import BackgroundQueue
from .formats import get_formats
from .models import Request, Response
from .routes import Router
from .staticfiles import StaticFiles
from .statics import DEFAULT_CORS_PARAMS, DEFAULT_OPENAPI_THEME, DEFAULT_SECRET_KEY
from .templates import Templates
[docs]
class API:
"""The primary web-service class.
:param static_dir: The directory to use for static files. Will be created for you if it doesn't already exist.
:param templates_dir: The directory to use for templates. Will be created for you if it doesn't already exist.
:param auto_escape: If ``True``, HTML and XML templates will automatically be escaped.
:param enable_hsts: If ``True``, send all responses to HTTPS URLs.
:param gzip: If ``True`` (the default), compress responses with GZip.
:param openapi_theme: OpenAPI documentation theme, must be one of ``elements``, ``rapidoc``, ``redoc``, ``swagger_ui``
""" # noqa: E501
status_codes = status_codes
def __init__(
self,
*,
debug=False,
title=None,
version=None,
description=None,
terms_of_service=None,
contact=None,
license=None, # noqa: A002
openapi=None,
openapi_route="/schema.yml",
static_dir="static",
static_route="/static",
templates_dir="templates",
auto_escape=True,
secret_key=DEFAULT_SECRET_KEY,
enable_hsts=False,
docs_route=None,
cors=False,
cors_params=DEFAULT_CORS_PARAMS,
allowed_hosts=None,
openapi_theme=DEFAULT_OPENAPI_THEME,
lifespan=None,
gzip=True,
request_id=False,
enable_logging=False,
):
"""Create a new Responder API instance.
:param debug: If ``True``, enable debug mode with verbose error pages.
:param title: The title of the API, used in OpenAPI documentation.
:param version: The version string for the API (e.g. ``"1.0"``).
:param description: A longer description of the API for OpenAPI docs.
:param terms_of_service: URL to the API's terms of service.
:param contact: Contact information dict for the API (``name``, ``url``, ``email``).
:param license: License information dict (``name``, ``url``).
:param openapi: The OpenAPI version string (e.g. ``"3.0.2"``). Enables OpenAPI schema generation.
:param openapi_route: The URL path for the OpenAPI schema (default ``"/schema.yml"``).
:param static_dir: Directory for static files. Set to ``None`` to disable. Created automatically if missing.
:param static_route: URL prefix for serving static files (default ``"/static"``).
:param templates_dir: Directory for Jinja2 templates (default ``"templates"``).
:param auto_escape: If ``True``, auto-escape HTML/XML in templates.
:param secret_key: Secret key for signing cookie-based sessions. **Always set this in production.**
:param enable_hsts: If ``True``, redirect all HTTP requests to HTTPS.
:param docs_route: URL path for interactive API docs (e.g. ``"/docs"``). Enables OpenAPI if not already set.
:param cors: If ``True``, enable CORS middleware.
:param cors_params: Dict of CORS configuration (``allow_origins``, ``allow_methods``, etc.).
:param allowed_hosts: List of allowed hostnames (e.g. ``["example.com"]``). Defaults to ``["*"]``.
:param openapi_theme: Documentation UI theme: ``"swagger_ui"``, ``"redoc"``, ``"rapidoc"``, or ``"elements"``.
:param lifespan: An async context manager for startup/shutdown logic.
:param gzip: If ``True`` (the default), compress responses with GZip.
:param request_id: If ``True``, add ``X-Request-ID`` headers to all responses.
:param enable_logging: If ``True``, enable structured logging with per-request context (request ID, method, path, client IP).
""" # noqa: E501
self.background = BackgroundQueue()
self.secret_key = secret_key
self.router = Router(lifespan=lifespan)
if static_dir is not None:
if static_route is None:
static_route = ""
static_dir = Path(static_dir).resolve()
self.static_dir = static_dir
self.static_route = static_route
self.hsts_enabled = enable_hsts
self.cors = cors
self.cors_params = cors_params
self.debug = debug
if not allowed_hosts:
allowed_hosts = ["*"]
self.allowed_hosts = allowed_hosts
if self.static_dir is not None:
self.static_dir.mkdir(parents=True, exist_ok=True)
self.mount(self.static_route, self.static_app)
self.formats = get_formats()
self._session = None
self.default_endpoint = None
self.app = ExceptionMiddleware(self.router, debug=debug)
if gzip:
self.add_middleware(GZipMiddleware)
if self.hsts_enabled:
self.add_middleware(HTTPSRedirectMiddleware)
self.add_middleware(TrustedHostMiddleware, allowed_hosts=self.allowed_hosts)
if self.cors:
self.add_middleware(CORSMiddleware, **self.cors_params)
self.add_middleware(ServerErrorMiddleware, debug=debug)
self.add_middleware(SessionMiddleware, secret_key=self.secret_key)
if openapi or docs_route:
try:
from .ext.openapi import OpenAPISchema
except ImportError as ex:
raise ImportError(
"The dependencies for the OpenAPI extension are not installed. "
"Install them using: pip install responder"
) from ex
self.openapi = OpenAPISchema(
app=self,
title=title,
version=version,
openapi=openapi,
docs_route=docs_route,
description=description,
terms_of_service=terms_of_service,
contact=contact,
license=license,
openapi_route=openapi_route,
static_route=static_route,
openapi_theme=openapi_theme,
)
self.templates = Templates(directory=templates_dir)
if request_id and not enable_logging:
import uuid as _uuid
def _add_request_id(req, resp):
rid = req.headers.get("X-Request-ID", str(_uuid.uuid4()))
resp.headers["X-Request-ID"] = rid
self.router.after_request(_add_request_id)
if enable_logging:
import logging as _logging
from .ext.logging import LoggingMiddleware, get_logger, setup_logging
log_level = _logging.DEBUG if debug else _logging.INFO
setup_logging(level=log_level)
self.add_middleware(LoggingMiddleware)
self.log = get_logger("responder.app")
else:
import logging as _logging
self.log = _logging.getLogger("responder.app")
@property
def requests(self):
"""A test client connected to the ASGI app. Lazily initialized."""
return self.session()
@property
def static_app(self):
"""The Starlette ``StaticFiles`` application for serving static assets."""
if not hasattr(self, "_static_app"):
assert self.static_dir is not None
self._static_app = StaticFiles(directory=self.static_dir)
return self._static_app
[docs]
def before_request(self, websocket=False):
"""Register a function to run before every request.
If the hook sets ``resp.status_code``, the route handler is skipped
and the response is sent immediately (short-circuiting).
:param websocket: If ``True``, register as a WebSocket before-request hook instead of HTTP.
Usage::
@api.before_request()
def check_auth(req, resp):
if "Authorization" not in req.headers:
resp.status_code = 401
resp.media = {"error": "unauthorized"}
""" # noqa: E501
def decorator(f):
self.router.before_request(f, websocket=websocket)
return f
return decorator
[docs]
def after_request(self):
"""Register a function to run after every request.
Usage::
@api.after_request()
def add_request_id(req, resp):
resp.headers["X-Request-ID"] = str(uuid.uuid4())
"""
def decorator(f):
self.router.after_request(f)
return f
return decorator
[docs]
def add_middleware(self, middleware_cls, **middleware_config):
"""Add ASGI middleware to the application.
Middleware wraps the entire application and can inspect or modify
every request and response. Middleware is applied in reverse order —
the last middleware added runs first.
:param middleware_cls: A Starlette-compatible middleware class.
:param middleware_config: Keyword arguments passed to the middleware constructor.
Usage::
from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware
api.add_middleware(HTTPSRedirectMiddleware)
"""
self.app = middleware_cls(self.app, **middleware_config)
[docs]
def exception_handler(self, exception_cls):
"""Register a handler for a specific exception type.
Usage::
@api.exception_handler(ValueError)
async def handle_value_error(req, resp, exc):
resp.status_code = 400
resp.media = {"error": str(exc)}
"""
def decorator(func):
async def _handler(request, exc):
from starlette.responses import Response as StarletteResp
req = Request(request.scope, request.receive, formats=get_formats())
resp = Response(req=req, formats=get_formats())
if inspect.iscoroutinefunction(func):
await func(req, resp, exc)
else:
func(req, resp, exc)
if resp.status_code is None:
resp.status_code = 500
body, headers = await resp.body
return StarletteResp(
content=body, status_code=resp.status_code, headers=headers
)
# Register with the ExceptionMiddleware
self.router._exception_handlers = getattr(
self.router, "_exception_handlers", {}
)
self.router._exception_handlers[exception_cls] = _handler
# Also register on the ASGI app chain
from starlette.middleware.exceptions import ExceptionMiddleware as EM
app = self.app
while app is not None:
if isinstance(app, EM):
app.add_exception_handler(exception_cls, _handler)
break
app = getattr(app, "app", None)
return func
return decorator
[docs]
def schema(self, name, **options):
"""
Decorator for creating new routes around function and class definitions.
Usage::
from marshmallow import Schema, fields
@api.schema("Pet")
class PetSchema(Schema):
name = fields.Str()
"""
def decorator(f):
self.openapi.add_schema(name=name, schema=f, **options)
return f
return decorator
[docs]
def path_matches_route(self, path):
"""Given a path portion of a URL, tests that it matches against any registered route.
:param path: The path portion of a URL, to test all known routes against.
""" # noqa: E501 (Line too long)
for route in self.router.routes:
match, _ = route.matches(path)
if match:
return route
return None
[docs]
def add_route(
self,
route=None,
endpoint=None,
*,
default=False,
static=True,
check_existing=True,
websocket=False,
before_request=False,
methods=None,
):
"""Adds a route to the API.
:param route: A string representation of the route.
:param endpoint: The endpoint for the route -- can be a callable, or a class.
:param default: If ``True``, all unknown requests will route to this view.
:param static: If ``True``, and no endpoint was passed, render "static/index.html".
Also, it will become a default route.
:param methods: Optional list of HTTP methods (e.g. ``["GET", "POST"]``).
""" # noqa: E501
if static:
assert self.static_dir is not None
if not endpoint:
endpoint = self._static_response
default = True
self.router.add_route(
route,
endpoint,
default=default,
websocket=websocket,
before_request=before_request,
check_existing=check_existing,
methods=methods,
)
async def _static_response(self, req, resp):
assert self.static_dir is not None
index = (self.static_dir / "index.html").resolve()
if index.exists():
resp.html = index.read_text()
else:
resp.status_code = status_codes.HTTP_404 # type: ignore[attr-defined]
resp.text = "Not found."
[docs]
def redirect(
self,
resp,
location,
*,
set_text=True,
status_code=status_codes.HTTP_301, # type: ignore[attr-defined]
):
"""
Redirects a given response to a given location.
:param resp: The Response to mutate.
:param location: The location of the redirect.
:param set_text: If ``True``, sets the Redirect body content automatically.
:param status_code: an `API.status_codes` attribute, or an integer,
representing the HTTP status code of the redirect.
"""
resp.redirect(location, set_text=set_text, status_code=status_code)
[docs]
def on_event(self, event_type: str, **args):
"""Decorator for registering functions or coroutines to run at certain events
Supported events: startup, shutdown
Usage::
@api.on_event('startup')
async def open_database_connection_pool():
...
@api.on_event('shutdown')
async def close_database_connection_pool():
...
"""
def decorator(func):
self.add_event_handler(event_type, func, **args)
return func
return decorator
[docs]
def add_event_handler(self, event_type, handler):
"""Adds an event handler to the API.
:param event_type: A string in ("startup", "shutdown")
:param handler: The function to run. Can be either a function or a coroutine.
"""
self.router.add_event_handler(event_type, handler)
[docs]
def route(self, route=None, *, request_model=None, response_model=None, **options):
"""Decorator for creating new routes around function and class definitions.
Usage::
@api.route("/hello")
def hello(req, resp):
resp.text = "hello, world!"
With Pydantic models for OpenAPI documentation::
from pydantic import BaseModel
class ItemIn(BaseModel):
name: str
price: float
class ItemOut(BaseModel):
id: int
name: str
price: float
@api.route("/items", methods=["POST"],
request_model=ItemIn, response_model=ItemOut)
async def create_item(req, resp):
data = await req.media()
resp.media = {"id": 1, **data}
"""
def decorator(f):
if request_model is not None:
f._request_model = request_model
if hasattr(self, "openapi"):
self.openapi.add_schema(
request_model.__name__, request_model, check_existing=False
)
if response_model is not None:
f._response_model = response_model
if hasattr(self, "openapi"):
self.openapi.add_schema(
response_model.__name__, response_model, check_existing=False
)
self.add_route(route, f, **options)
return f
return decorator
[docs]
def graphql(self, route="/graphql", *, schema):
"""Mount a GraphQL API at the given route.
Usage::
import graphene
class Query(graphene.ObjectType):
hello = graphene.String(name=graphene.String(default_value="stranger"))
def resolve_hello(self, info, name):
return f"Hello {name}"
api.graphql("/graphql", schema=graphene.Schema(query=Query))
:param route: The URL path for the GraphQL endpoint.
:param schema: A Graphene schema instance.
"""
from .ext.graphql import GraphQLView
self.add_route(route, GraphQLView(api=self, schema=schema))
[docs]
def mount(self, route, app):
"""Mounts an WSGI / ASGI application at a given route.
:param route: String representation of the route to be used
(shouldn't be parameterized).
:param app: The other WSGI / ASGI app.
"""
self.router.apps.update({route: app})
[docs]
def session(self, base_url="http://;"):
"""Testing HTTP client. Returns a Starlette TestClient instance,
able to send HTTP requests to the Responder application.
:param base_url: The base URL for the test client.
"""
if self._session is None:
from starlette.testclient import TestClient
self._session = TestClient(self, base_url=base_url)
return self._session
[docs]
def url_for(self, endpoint, **params):
"""Given an endpoint, returns a rendered URL for its route.
:param endpoint: The route endpoint you're searching for.
:param params: Data to pass into the URL generator (for parameterized URLs).
"""
return self.router.url_for(endpoint, **params)
[docs]
def template(self, filename, *args, **kwargs):
r"""Render a Jinja2 template file with the provided values.
:param filename: The filename of the jinja2 template, in ``templates_dir``.
:param \*args: Data to pass into the template.
:param \*\*kwargs: Data to pass into the template.
"""
return self.templates.render(filename, *args, **kwargs)
[docs]
def template_string(self, source, *args, **kwargs):
r"""Render a Jinja2 template string with the provided values.
:param source: The template to use, a Jinja2 template string.
:param \*args: Data to pass into the template.
:param \*\*kwargs: Data to pass into the template.
"""
return self.templates.render_string(source, *args, **kwargs)
[docs]
def serve(self, *, address=None, port=None, debug=False, **options):
"""
Run the application with uvicorn.
If the ``PORT`` environment variable is set, requests will be served on that port
automatically to all known hosts.
:param address: The address to bind to.
:param port: The port to bind to. If none is provided, one will be selected at random.
:param debug: Whether to run application in debug mode.
:param options: Additional keyword arguments to send to ``uvicorn.run()``.
""" # noqa: E501
if "PORT" in os.environ:
if address is None:
address = "0.0.0.0" # noqa: S104
port = int(os.environ["PORT"])
if address is None:
address = "127.0.0.1"
if port is None:
port = 5042
if debug:
options["log_level"] = "debug"
uvicorn.run(self, host=address, port=port, **options)
[docs]
def run(self, **kwargs):
"""Run the application. Shorthand for :meth:`serve` that inherits the ``debug`` setting.
:param kwargs: Keyword arguments passed through to :meth:`serve`.
""" # noqa: E501
if "debug" not in kwargs:
kwargs.update({"debug": self.debug})
self.serve(**kwargs)
[docs]
def group(self, prefix):
"""Create a route group with a shared URL prefix.
Usage::
v1 = api.group("/v1")
@v1.route("/users")
def list_users(req, resp):
resp.media = []
@v1.route("/users/{id:int}")
def get_user(req, resp, *, id):
resp.media = {"id": id}
"""
return RouteGroup(api=self, prefix=prefix)
async def __call__(self, scope, receive, send):
await self.app(scope, receive, send)
[docs]
class RouteGroup:
"""A group of routes with a shared URL prefix."""
def __init__(self, api, prefix):
self.api = api
self.prefix = prefix.rstrip("/")
def route(self, route=None, **options):
full_route = f"{self.prefix}{route}"
return self.api.route(full_route, **options)
def before_request(self, **kwargs):
return self.api.before_request(**kwargs)