Cancelling hung-up requests in uvicorn

Let’s say I have a FastAPI app. What happens to my when someone calls it up, asks a tough question, and then hangs up? By default, the app continues processing the request until it completes. This isn’t always desireable, especially in cases where the request is non-mutative (e.g. GET/HEAD/OPTIONS).

It turns out that it’s fairly straightforward to solve this problem, at least for ASGI apps hosted in uvicorn, which is true of many FastAPI users. asyncio’s BaseProtocol already provides a connection_lost() handler, so we can just hook into it. The only tricky bit is figuring out which tasks to kill, which can be done by intercepting the set.add calls that uvicorn uses to track tasks across the entire server.

Because asyncio implements cancellation as a subclass of BaseException, you can also catch that exception and take the opportunity to cancel any downstream API calls you may have made. This is an exercise left to the reader.

import asyncio
import sys
from uuid import uuid4
from weakref import WeakSet

import uvicorn
from fastapi import FastAPI
from uvicorn.protocols.http.auto import AutoHTTPProtocol


# Uvicorn uses a single set to track tasks for all of its connections.
# We want to cancel only the ones associated with this connection.
# So we build a two layer set to intercept self.tasks.add() in the Protocol.
class TwoLayerSet:
    def __init__(self, underlying: set):
        self._underlying = underlying
        self._overlying = WeakSet()

    def add(self, item):
        self._underlying.add(item)
        self._overlying.add(item)

    def __getattr__(self, attr):
        return getattr(self._underlying, attr)


class RequestCancellingProtocol(AutoHTTPProtocol):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.tasks = TwoLayerSet(self.tasks)

    def connection_lost(self, *args, **kwargs):
        super().connection_lost(*args, **kwargs)
        # `self.tasks._overlying` should always have 0-1 items, due to how HTTP
        # pipelining works.
        for task in self.tasks._overlying:
            task.cancel()


app = FastAPI()

# This mostly exists to prevent ugly stack traces from reaching stderr
class CancellationMiddleware:
    def __init__(self, app):
        self._app = app

    async def __call__(self, scope, receive, send):
        unique_id = uuid4()
        try:
            if scope["type"] == "http":
                print(f"Request handler {unique_id} began", file=sys.stderr)
            await self._app(scope, receive, send)
            if scope["type"] == "http":
                print(f"Request handler {unique_id} finished", file=sys.stderr)
        except asyncio.CancelledError:
            # I'd like to catch a more specific error, but it seems like there's no
            # straightforward way to subclass CancelledError when a task is cancelled.
            print(f"Request handler {unique_id} was cancelled", file=sys.stderr)


@app.get("/")
async def hello_world():
    await asyncio.sleep(4)
    return "Thank you for waiting. Hello world."


async def main():
    app_ = CancellationMiddleware(app)
    config = uvicorn.Config(app_, http=RequestCancellingProtocol)
    server = uvicorn.Server(config)
    await server.serve()


if __name__ == "__main__":
    raise SystemExit(asyncio.run(main()))