Achieving immortality with Python workers

Let’s say I run a gunicorn server with a typical ASGI app, like so:

$ gunicorn -w 4 -k uvicorn_worker.UvicornWorker app:app

If I send it a request and view its activity under strace, I see this:

[pid 160604] <... epoll_wait resumed>[{events=EPOLLIN, data={u32=5, u64=139977279143941}}], 2, 20000) = 1
[pid 160603] <... epoll_wait resumed>[{events=EPOLLIN, data={u32=5, u64=139977279143941}}], 2, 20000) = 1
[pid 160602] <... epoll_wait resumed>[{events=EPOLLIN, data={u32=5, u64=139977279143941}}], 2, 20000) = 1
[pid 160601] <... epoll_wait resumed>[{events=EPOLLIN, data={u32=5, u64=139977279143941}}], 2, 20000) = 1
[pid 160604] accept4(5,  <unfinished ...>
[pid 160603] accept4(5,  <unfinished ...>
[pid 160602] accept4(5,  <unfinished ...>
[pid 160601] accept4(5,  <unfinished ...>
[pid 160604] <... accept4 resumed>{sa_family=AF_INET, sin_port=htons(43136), sin_addr=inet_addr("127.0.0.1")}, [16], SOCK_CLOEXEC) = 12
[pid 160603] <... accept4 resumed>0x7ffe449e2d50, [16], SOCK_CLOEXEC) = -1 EAGAIN (Resource temporarily unavailable)
[pid 160602] <... accept4 resumed>0x7ffe449e2d50, [16], SOCK_CLOEXEC) = -1 EAGAIN (Resource temporarily unavailable)
[pid 160601] <... accept4 resumed>0x7ffe449e2d50, [16], SOCK_CLOEXEC) = -1 EAGAIN (Resource temporarily unavailable)

All four of the workers were woken by the OS when the request arrived, and three of them went home empty-handed. We did 4x the necessary work to accept this inbound request, and the waste gets worse with each worker we add to the mix. uWSGI has a writeup of how they solve this problem using the --thunder-lock (a “thundering herd” lock). But is it really necessary to create a highly-contended lock for all of the processes to bang against? It turns out the answer is no. Since most people already put some kind of webserver (e.g. nginx) in front of their Python app, we can have that webserver do least-connections load-balancing to multiple Python workers. This approach works nicely with Unix domain sockets, which are more performant, and easier to work with.

Pre-fork initialization

What if our our app takes some effort to boot up, or it needs to preload some data from over the network? Maybe (as a trivial example), someone put a time.sleep() into it.

import time
time.sleep(10)

async def app(scope, receive, send):
    assert scope['type'] == 'http'
    await send({
        'type': 'http.response.start',
        'status': 200,
        'headers': [(b'content-type', b'text/plain')],
    })
    await send({
        'type': 'http.response.body',
        'body': b"Hello world!",
    })

If you run your app with either of the below invocations, you’re going to be hitting that delay on each worker bootup.

$ uvicorn --workers 4 app:app
$ gunicorn -w 4 -k uvicorn_worker.UvicornWorker app:app

In such a situation, your workers are going to become precious to you, since the loss of a worker results in the loss of serving capacity (they have high TTFR, in other words). This regression is solved somewhat nicely by the --preload option to gunicorn’s CLI, or in uvicorn by initializing your app, forking, and then calling uvicorn.Server.serve(). uWSGI does pre-fork initialization by default, but sadly doesn’t appear to support ASGI out of the box.

When we do pre-fork initialization, the initial process from which we fork request-handling workers is sometimes called a “zygote” (1, 2, 3). Zygoting has the following advantages:

  1. faults within a worker are extremely easy to recover from, since you can just launch a new one.
  2. once the zygote finishes its own initialization, it can call gc.freeze to mark all of the objects that exist as of that moment as ineligible for garbage collection. This can significantly cut down on the cost of garbage collection, since such objects are no longer scanned at all.
  3. memory pages in a forked process are shared between parent and child, until one of them performs a write. This can produce substantial savings, assuming a large quantity of shared immutable data. Unfortunately this is mostly untrue in CPython.

Garbage collection

In CPython, an object can be garbage collected in one of three ways:

  1. by its reference count falling to zero
  2. by the tracing garbage collector
  3. by the death of the process it resides within

For reference, on my laptop:

The time required to fork is roughly proportional to size of the process’s total address space, as measured in pages. In the case of a process that launched from a zygote which had done gc.freeze(), the GC-tracked objects will be only a minority of the process’s total address space in the request-handling (child) process. To give a worked example, let’s assume that a process is initialized with 2GiB of data from its parent process, and has accumulated 512MiB of heap objects that must be scanned by the tracing GC. In that case, the size threshold where forking is cheaper than GC is 509 bytes2. If the average GC-tracked object is larger than 509 bytes, then it is cheaper to GC; otherwise it is cheaper to throw away the worker and fork a new one. Some have called this approach the ultimate in garbage collection.

As a very casual benchmark, I’m observing that the average GC-tracked object size in a Python heap is about 230 bytes. But it’s going to be heavily workload dependent, and you can see for yourself in your own program quite easily:

from collections import defaultdict
import gc
import sys

gc.collect()
gc_objects = gc.get_objects()
total = 0
counts = defaultdict(int)
for obj in gc_objects:
    size = sys.getsizeof(obj)
    counts[size] += 1
    total += size

print("GC-tracked objects")
print("  count: {:,}".format(len(gc_objects)))
print("  total size: {:,} bytes".format(total))
if gc_objects:
    print("  average: {}".format(total / len(gc_objects)))
print("  distribution: {}".format(sorted(counts.items())))

Immortality

I mentioned earlier that there is RAM savings to be had: if neither the child nor the parent process mutates a given page, then the operating system will allow them to share the same copy of a that page. But if either of them writes, it will trigger a page fault, and the sharing will end. The gc.freeze() documentation says this:

This requires both avoiding creation of freed “holes” in memory pages in the parent process and ensuring that GC collections in child processes won’t touch the gc_refs counter of long-lived objects originating in the parent process.

Unfortunately the docs are wrong: while gc.freeze() is somewhat good at preventing anything from writing to the gc_refs field3, it doesn’t change the behavior of Py_INCREF or Py_DECREF, which modify the ob_refcnt (reference count) field. If you have CPython objects which are smaller than a single page (which is typically the case), then simply reading them will result in writes to something approaching 100% of the underlying memory pages.

Instagram’s fork of CPython solves this problem by marking objects as “immortal”, which causes them to be ignored by those functions. The net effect is that the bytes-in-memory never change unless your own code writes to them, and of course they are ineligible for any kind of GC. This program will OOM under unmodified CPython, but will succeed when run on their fork:

#!/usr/bin/env -S systemd-run --user --scope -p MemoryMax=1200M python
import gc
import os
import sys
from typing import Self
from dataclasses import dataclass

@dataclass(frozen=True)
class Node:
    cdr: Self | None

BIG_DATA = None
for i in range(12 * 1024 * 1024):
    BIG_DATA = Node(BIG_DATA)

gc.freeze()
try:
    gc.immortalize_heap()
except AttributeError:
    print(
        "couldn't immortalize heap: will consume extra memory",
        file=sys.stderr
    )
os.fork()

# As you can see, we never mutate any part of this linked list, but stock
# CPython will mutate its reference counts, which is enough to cause double
# the memory usage.
head = BIG_DATA
while head is not None:
    head = head.cdr

Putting it all together

The synthesis of all of the above is a Python program that launches and monitors its children, with each of those children getting its own unix socket to listen on.

# Copyright 2024 Josh Snyder
# Licensed under https://www.apache.org/licenses/LICENSE-2.0.txt
from functools import partial
from pathlib import Path
from typing import Callable, Iterator, Iterable
import asyncio
import contextlib
import dataclasses
import gc
import os
import shutil
import signal
import socket
import sys
import tempfile
import time
import uvicorn


# -- the webapp --

async def app(scope, receive, send):
    """A basic ASGI app."""

    if scope["type"] == "lifespan":
        while True:
            message = await receive()
            if message["type"] == "lifespan.startup":
                await send({"type": "lifespan.startup.complete"})
            elif message["type"] == "lifespan.shutdown":
                await send({"type": "lifespan.shutdown.complete"})

    if scope["type"] == "http":
        await send(
            {
                "type": "http.response.start",
                "status": 200,
                "headers": [(b"content-type", b"text/plain")],
            }
        )
        await send(
            {
                "type": "http.response.body",
                "body": b"Hello world!\n",
            }
        )
        return

    raise NotImplementedError("scope.type", scope["type"])


uvicorn_server = uvicorn.Server(uvicorn.Config(app, lifespan="on"))

# -- basic utilities --

# All of the other signals get the default disposition
HANDLED_SIGNALS = set([signal.SIGCHLD, signal.SIGHUP])


@contextlib.contextmanager
def mask_signals(sigset):
    old_mask = signal.pthread_sigmask(signal.SIG_BLOCK, sigset)
    try:
        yield
    finally:
        signal.pthread_sigmask(signal.SIG_SETMASK, old_mask)


@contextlib.contextmanager
def tempdir_for_pid(pid: int) -> Iterator[Path]:
    """A tempdir that only cleans up when it hasn't been forked"""
    path = Path(tempfile.mkdtemp())
    try:
        yield path
    finally:
        if os.getpid() == pid:
            shutil.rmtree(path)

# -- management of child processes --

@dataclasses.dataclass(frozen=True)
class ChildDispatch:
    """Encapsulates a task for the child to run."""

    # The file descriptors that the child should keep open
    fds_to_keep: tuple[int, ...]
    # What to actually run
    call: Callable[[], int]


def reap_zombies(pids: dict[int, ChildDispatch]) -> ChildDispatch | None:
    """If any of our child proceses dies (becomes a zombie), this will reap
    them and restart the appropriate callable."""

    while True:
        # Wait for child processes. There must be at least one, but we
        # don't know how many there will be.
        pid, status, rusage = os.wait3(os.WNOHANG)
        if pid == 0:
            break

        try:
            dispatch = pids.pop(pid)
        except KeyError:
            # We reaped a process that we didn't launch
            continue

        dispatch = dataclasses.replace(
            dispatch,
            # Change the started_at kwarg on the previous partial
            call=partial(dispatch.call, started_at=time.perf_counter()),
        )

        # Relaunch any children that died
        pid = os.fork()
        if pid == 0:
            return dispatch

        pids[pid] = dispatch
        print(f"Spawned {pid}", file=sys.stderr)

    return None


def poll_once(pids: dict[int, ChildDispatch]) -> ChildDispatch | None:
    # We could also use signalfd and a selectors.poll/epoll object
    seen_signal = signal.sigwait(HANDLED_SIGNALS)

    if seen_signal == signal.SIGHUP:
        for pid in pids:
            os.kill(pid, signal.SIGHUP)

    if seen_signal == signal.SIGCHLD:
        # The return is needed for the case where reap_zombies decides that it
        # needs to fork.
        return reap_zombies(pids)

    return None


def bind_sockets(path: Path, count: int) -> list[int]:
    ret = []
    for i in range(count):
        name = str(path / f"{i:03x}")
        s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        try:
            s.bind(name)
            # We start listening immediately, but nobody will accept() until
            # we fork into processes.
            s.listen(1)

            ret.append(s.detach())
        finally:
            s.close()

    return ret



def fork_children(
    sockets: list[int], self_pidfd: int
) -> ChildDispatch | dict[int, ChildDispatch]:
    """Launches child processes.

    If we are now in a child process, we return a ChildDispatch.
    Otherwise we return a dict[pid, ChildDispatch] in case a child needs to
    be respawned."""

    ret = dict()
    for sock in sockets:
        # the format we return is:
        # (file descriptors to keep, callable to run)
        child_ret = ChildDispatch(
            (self_pidfd, sock),
            partial(
                child_process, self_pidfd, sock, started_at=time.perf_counter()
            ),
        )
        pid = os.fork()
        if pid == 0:
            return child_ret

        ret[pid] = child_ret

    return ret


def spawn_loop(stack: contextlib.ExitStack) -> ChildDispatch:
    # Mask the normal delivery of signals
    stack.enter_context(mask_signals(HANDLED_SIGNALS))
    my_pid = os.getpid()

    # Set up a pidfd. We'll give this pidfd to our children, and they'll use
    # it to monitor us (the parent process)
    self_pidfd = os.pidfd_open(my_pid)

    path = stack.enter_context(tempdir_for_pid(my_pid))
    cpu_count = os.cpu_count()
    assert cpu_count
    sockets = bind_sockets(path, cpu_count)
    fork_result = fork_children(sockets, self_pidfd)
    if isinstance(fork_result, ChildDispatch):
        return fork_result

    while True:
        ret = poll_once(fork_result)
        if ret is not None:
            return ret

async def run_uvicorn(listen_fd: int, started_at: float) -> None:
    # Our actual activity. This runs in the child (worker) processes.
    sock = socket.socket(fileno=listen_fd)
    elapsed = time.perf_counter() - started_at
    name = sock.getsockname()
    print(
        f"PID {os.getpid()} fork took {elapsed} seconds.\n"
        f"Query me at `curl --unix-socket \"{name}\" http://localhost`",
        file=sys.stderr
    )
    await uvicorn_server.serve([sock])


async def watch_parent(parent_fd: int) -> None:
    """Watches an fd `parent_fd`, and exits when it reports readiness.

    The intent of this function is to prevent us from outlasting our parent
    process.
    """
    loop = asyncio.get_running_loop()
    future = loop.create_future()
    loop.add_reader(parent_fd, future.set_result, None)
    try:
        await future
    finally:
        loop.remove_reader(parent_fd)


async def async_child_process(
    parent_fd: int, listen_fd: int, *, started_at: float
) -> int:
    loop = asyncio.get_running_loop()
    tasks = [
        loop.create_task(watch_parent(parent_fd)),
        loop.create_task(run_uvicorn(listen_fd, started_at)),
    ]
    await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
    return 0


def child_process(parent_fd: int, listen_fd: int, *, started_at: float) -> int:
    # The child process's contract is to:
    # 1) monitor the parent_fd for readiness, which indicates the parent has
    # exited.
    # 2) handle connections arriving to listen_fd.
    try:
        return asyncio.run(
            async_child_process(parent_fd, listen_fd, started_at=started_at)
        )
    except KeyboardInterrupt:
        pass

    # This will skip any atexit handlers
    signal.signal(signal.SIGINT, signal.SIG_DFL)
    os.kill(os.getpid(), signal.SIGINT)
    return 0


def close_fds(keep: Iterable[int] = ()) -> None:
    """Closes all of the FDs not needed to do our job."""
    # This is necessary because we inherit all of the parent's FDs, but we
    # only want to be responsible for the ones it has assigned us.

    keep = iter(sorted(keep))
    from_fd = 3
    for to_fd in keep:
        # os.closerange is exclusive of its second argument, matching range()
        if to_fd > from_fd:
            os.closerange(from_fd, to_fd)
        from_fd = to_fd + 1
    else:
        os.closerange(from_fd, 0)


def wait_for_work() -> Callable[[], int]:
    # Runs in the main process and waits for it to fork into a child process.
    # It'll return a callback when it does, which we execute as our child
    # process.
    try:
        with contextlib.ExitStack() as stack:
            dispatch = spawn_loop(stack)
            if dispatch is None:
                # No work to do. We're done!
                raise SystemExit

        close_fds(dispatch.fds_to_keep)
        return dispatch.call
    except KeyboardInterrupt:
        raise


def main() -> int:
    # It's not sufficient to prevent de-CoWing, but it's slightly worthwhile to
    # freeze the existing heap.
    gc.freeze()
    # By exiting all the way out to the outermost stack frame, we ensure that
    # none of the parent's code (within `wait_for_work()`) is seen in stack
    # traces of the child, which makes them simpler to understand.
    callback = wait_for_work()
    return callback()


if __name__ == "__main__":
    raise SystemExit(main())

  1. MIMALLOC_ALLOW_LARGE_OS_PAGES=1 PYTHONMALLOC=mimalloc. Fun fact: getrusage() does not report huge page consumption as part of RSS. 

  2. \( \scriptsize 512 \, \mathrm{MiB} \times \frac{36 \, \mathrm{nanoseconds}}{\mathrm{object}} \div \scriptsize( 2 \, \mathrm{GiB \, forkable \, address \, space} \times \frac{19 \, \mathrm{ms}}{\mathrm{GiB}} \scriptsize) \) 

  3. During GC, the gc_prev pointer, which constitutes half of the PyGC_Head intrusive doubly-linked list node, is converted into an integer field to store (and manipulate) the object’s inbound reference count. When it is serving this role, the list becomes a singly linked list, and gc_prev is also called gc_refs