Cancelling hung-up requests in uvicorn

Let’s say I have a FastAPI app. What happens when someone calls it up, asks a tough question, and then hangs up? By default with a server like uvicorn, the app continues processing the request until it completes. This isn’t always desirable, especially in cases where the request is non-mutative (e.g. GET/HEAD/OPTIONS).

It turns out that it’s fairly straightforward to solve this problem, at least for ASGI apps hosted in uvicorn. asyncio’s BaseProtocol already provides a connection_lost() handler, so we can just hook into it. The only tricky bit is figuring out which tasks to kill, which can be done by intercepting the set.add calls that uvicorn uses to track tasks across the entire server.

Because asyncio implements cancellation as a subclass of BaseException, you can also catch that exception and take the opportunity to cancel any downstream API calls you may have made. This is an exercise left to the reader.

#!/usr/bin/env python3.12
# Copyright 2024 Josh Snyder
# Licensed under https://www.apache.org/licenses/LICENSE-2.0.txt
import asyncio
import sys
from uuid import uuid4
from weakref import WeakSet

import uvicorn  # 0.34.0
from fastapi import FastAPI  # 0.115.6
from uvicorn.protocols.http.auto import AutoHTTPProtocol


# Uvicorn uses a single set to track tasks for all of its connections.
# We want to cancel only the ones associated with this connection.
# So we build a two layer set to intercept self.tasks.add() in the Protocol.
class TwoLayerSet:
    def __init__(self, underlying: set):
        self._underlying = underlying
        self._overlying = WeakSet()

    def add(self, item):
        self._underlying.add(item)
        self._overlying.add(item)

    def __getattr__(self, attr):
        return getattr(self._underlying, attr)


class RequestCancellingProtocol(AutoHTTPProtocol):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.tasks = TwoLayerSet(self.tasks)

    def connection_lost(self, *args, **kwargs):
        super().connection_lost(*args, **kwargs)
        # `self.tasks._overlying` should always have 0-1 items, due to how HTTP
        # pipelining works.
        for task in self.tasks._overlying:
            task.cancel()


app = FastAPI()

# This mostly exists to prevent ugly stack traces from reaching stderr
class CancellationMiddleware:
    def __init__(self, app):
        self._app = app

    async def __call__(self, scope, receive, send):
        unique_id = uuid4()
        try:
            if scope["type"] == "http":
                print(f"Request handler {unique_id} began", file=sys.stderr)
            await self._app(scope, receive, send)
            if scope["type"] == "http":
                print(f"Request handler {unique_id} finished", file=sys.stderr)
        except asyncio.CancelledError:
            # I'd like to catch a more specific error, but it seems like there's
            # no straightforward way to subclass CancelledError when a task is
            # cancelled.
            print(f"Request handler {unique_id} was cancelled", file=sys.stderr)


@app.get("/")
async def hello_world():
    await asyncio.sleep(4)
    return "Thank you for waiting. Hello world."


async def main():
    app_ = CancellationMiddleware(app)
    config = uvicorn.Config(app_, http=RequestCancellingProtocol)
    server = uvicorn.Server(config)
    await server.serve()


if __name__ == "__main__":
    raise SystemExit(asyncio.run(main()))