A Work-Stealing Scheduler

How do you distribute work When you have hundreds or thousands of tasks to execute and a handful of CPU cores? A naïve approach is to use a single queue, but this creates a bottleneck, since every worker must compete for access to that queue.

A work-stealing scheduler solves this problem through decentralization. Each worker maintains a local deque of tasks. Workers execute tasks from one end of their own deque, but if a worker runs out of tasks it can take some from the other end of another worker's deque. This design minimizes contention while providing some load balancing, and appears throughout high-performance computing. Go's runtime scheduler uses is to distribute goroutines across threads, Java's fork/join framework enables parallel divide-and-conquer algorithms, and Tokio (Rust's async runtime) uses it to schedule futures across worker threads.

The Work-Stealing Pattern

A work-stealing system has five parts:

  1. Each worker has a local deque of tasks.
  2. Those tasks are independent of each other.
  3. Workers pop tasks from the private end of their deque.
  4. Idle workers take tasks from the public end of other workers' deques.
  5. Running tasks can create new child tasks.

The key idea is asymmetry: the owning worker operates on one end of their deque (usually called the bottom) while other workers (called thieves) steal from its other end (the top). This reduces contention because owners and thieves don't compete for the same task unless the queue is almost empty.

Let's start with the task representation:

@dataclass
class Task:
    """A unit of work to be executed."""

    task_id: str
    duration: float
    parent_id: str | None = None  # For nested tasks

    def __str__(self):
        return f"Task({self.task_id})"

Each task has an ID, a duration to simulate CPU-bound work, and an optional parent task ID for tracking task dependencies.

Each worker maintains a deque. In our simulation, we'll use a simple list-based deque:

class WorkerDeque:
    """Double-ended queue for tasks with stealing support."""

    def __init__(self):
        self.tasks: list[Task] = []

    def push_bottom(self, task: Task):
        """Owner pushes task to bottom (private end)."""
        self.tasks.append(task)

    def pop_bottom(self) -> Task | None:
        """Owner pops task from bottom."""
        return self.tasks.pop() if self.tasks else None

    def steal_top(self) -> Task | None:
        """Thief steals task from top (public end)."""
        return self.tasks.pop(0) if self.tasks else None

    def is_empty(self) -> bool:
        """Check if deque is empty."""
        return len(self.tasks) == 0

    def size(self) -> int:
        """Return number of tasks."""
        return len(self.tasks)

A production system would use something more sophisticated than a simple Python list to manage the deque, but our simulation focuses on the algorithmic pattern rather than low-level synchronization.

A worker executes tasks from its local deque and steals when idle. We start by setting up its members:


and then define its behavior:

    async def run(self):
        """Main worker loop: execute local tasks or steal."""
        while True:
            # Try to get a task from local deque
            task = self.deque.pop_bottom()

            if task:
                # Execute local task
                await self.execute_task(task)
            else:
                # No local work, try stealing
                stolen = await self.try_steal()
                if stolen:
                    await self.execute_task(stolen)
                else:
                    # No work available anywhere, wait a bit
                    await self.timeout(0.1)

As the code above shows, the worker continuously tries to execute tasks. If its local deque is empty, it attempts to steal from other workers. If stealing fails, it waits briefly before trying again.

Executing a task is relatively straightforward:

    async def execute_task(self, task: Task):
        """Execute a task."""
        self.current_task = task
        self.tasks_executed += 1
        if self.verbose:
            print(
                f"[{self.now:.1f}] Worker {self.worker_id}: "
                f"Executing {task.task_id} (queue size: {self.deque.size()})"
            )

        await self.timeout(task.duration)

        if self.verbose:
            print(f"[{self.now:.1f}] Worker {self.worker_id}: Completed {task.task_id}")
        self.current_task = None

Stealing a task from another worker is somewhat more interesting. The most important part is that we randomize the order in which we check the workers in order to spread the load as evenly as possible:

    async def try_steal(self) -> Task | None:
        """Try to steal a task from another worker."""
        targets = [w for w in self.scheduler.workers if w != self]
        if not targets:
            return None

        random.shuffle(targets)

        for target in targets:
            task = target.deque.steal_top()
            if task:
                self.tasks_stolen += 1
                if self.verbose:
                    print(
                        f"[{self.now:.1f}] Worker {self.worker_id}: "
                        f"Stole {task.task_id} from Worker {target.worker_id}"
                    )
                return task

        return None

The Scheduler

The scheduler coordinates workers and provides task submission:

class WorkStealingScheduler:
    """Scheduler that coordinates work-stealing workers."""

    def __init__(
        self,
        env: Environment,
        num_workers: int,
        verbose: bool = True,
        worker_cls: type = Worker,
    ):
        self.env = env
        self.num_workers = num_workers
        self.verbose = verbose
        self.workers: list = []
        self.task_counter = 0

        # Create workers
        for i in range(num_workers):
            worker = worker_cls(env, i, self, verbose)
            self.workers.append(worker)

    def submit_task(self, duration: float, parent_id: str | None = None) -> Task:
        """Submit a task to a random worker."""
        self.task_counter += 1
        task = Task(
            task_id=f"T{self.task_counter}",
            duration=duration,
            parent_id=parent_id,
        )

        worker = random.choice(self.workers)
        worker.deque.push_bottom(task)

        if self.verbose:
            print(
                f"[{self.env.now:.1f}] Submitted {task.task_id} "
                f"to Worker {worker.worker_id}"
            )

        return task

We can create a simple simulation with load imbalance to see it in action:

def run_basic_simulation():
    """Basic work-stealing simulation."""
    env = Environment()

    scheduler = WorkStealingScheduler(env, num_workers=3)

    for i in range(10):
        scheduler.submit_task(duration=random.uniform(0.5, 2.0))

    env.run(until=20)

    scheduler.get_statistics()

The output shows workers executing tasks and stealing from each other when they run out of local work. The steal rate shows how much load balancing occurred:

[0.0] Submitted T1 to Worker 2
[0.0] Submitted T2 to Worker 2
[0.0] Submitted T3 to Worker 0
[0.0] Submitted T4 to Worker 1
[0.0] Submitted T5 to Worker 0
[0.0] Submitted T6 to Worker 1
[0.0] Submitted T7 to Worker 0
[0.0] Submitted T8 to Worker 0
[0.0] Submitted T9 to Worker 2
[0.0] Submitted T10 to Worker 1
...more...
=== Statistics ===
Total tasks executed: 10
Total tasks stolen: 0
Steal rate: 0.0%
Worker 0: executed=4, stolen=0, queue=0
Worker 1: executed=3, stolen=0, queue=0
Worker 2: executed=3, stolen=0, queue=0

Nested Task Spawning

A common extension of work-stealing is to support divide-and-conquer algorithms by allowing tasks to spawn subtasks. To explore this, we can create a task generator:

class TaskGenerator(Process):
    """Generates tasks including ones that spawn subtasks."""

    def init(self, scheduler: WorkStealingScheduler, num_initial_tasks: int):
        self.scheduler = scheduler
        self.num_initial_tasks = num_initial_tasks

    async def run(self):
        """Generate initial tasks."""
        for i in range(self.num_initial_tasks):
            self.scheduler.submit_task(duration=random.uniform(1.0, 3.0))
            await self.timeout(0.5)

and then create a worker that spawns subtasks with some random probability (in our case, 30%):

class WorkerWithSpawning(Worker):
    """Worker that can spawn child tasks during execution."""

    async def execute_task(self, task: Task):
        """Execute task and possibly spawn children."""
        self.current_task = task
        self.tasks_executed += 1

        print(f"[{self.now:.1f}] Worker {self.worker_id}: Executing {task.task_id}")

        # Do half the work
        await self.timeout(task.duration / 2)

        # Randomly spawn child tasks (simulating divide-and-conquer)
        if random.random() < 0.3:  # 30% chance
            num_children = random.randint(1, 3)
            for i in range(num_children):
                child = Task(
                    task_id=f"{task.task_id}.{i}",
                    duration=random.uniform(0.3, 1.0),
                    parent_id=task.task_id,
                )
                self.spawn_task(child)

        # Finish the work
        await self.timeout(task.duration / 2)

        print(f"[{self.now:.1f}] Worker {self.worker_id}: Completed {task.task_id}")

        self.current_task = None

    def spawn_task(self, task: Task):
        """Spawn a new task (called by executing task)."""
        self.deque.push_bottom(task)
        print(f"[{self.now:.1f}] Worker {self.worker_id}: Spawned {task.task_id}")

The final step is to write a scheduler that creates these workers:


Our simulation looks similar to our first one:

def run_spawning_simulation():
    """Demonstrate nested task spawning."""
    env = Environment()

    # Create scheduler with spawning workers
    scheduler = SchedulerWithSpawning(env, num_workers=4)

    # Generate initial tasks
    TaskGenerator(env, scheduler, num_initial_tasks=5)

    # Run simulation
    env.run(until=30)

    # Print statistics
    scheduler.get_statistics()

Its output shows that spawning helps balance the load even with irregular task creation:

[0.0] Submitted T1 to Worker 1
[0.1] Worker 0: Stole T1 from Worker 1
[0.1] Worker 0: Executing T1
[0.5] Submitted T2 to Worker 1
[0.5] Worker 1: Executing T2
[1.0] Submitted T3 to Worker 1
[1.1] Worker 2: Stole T3 from Worker 1
[1.1] Worker 2: Executing T3
[1.5] Submitted T4 to Worker 3
[1.5] Worker 3: Executing T4
[2.0] Worker 0: Completed T1
[2.0] Submitted T5 to Worker 1
[2.1] Worker 0: Stole T5 from Worker 1
[2.1] Worker 0: Executing T5
[2.4] Worker 2: Completed T3
[3.3] Worker 0: Completed T5
[3.4] Worker 1: Completed T2
[4.5] Worker 3: Completed T4

=== Statistics ===
...more...
=== Statistics ===
Total tasks executed: 5
Total tasks stolen: 3
Steal rate: 60.0%
Worker 0: executed=2, stolen=2, queue=0
Worker 1: executed=1, stolen=0, queue=0
Worker 2: executed=1, stolen=1, queue=0
Worker 3: executed=1, stolen=0, queue=0

Load Balancing Strategies

What effect does target selection strategy have on performance? To find out, we can create a worker that uses adaptive target selection, i.e., that steals tasks from the largest of its peers' queues:

class AdaptiveWorker(Worker):
    """Worker with adaptive target selection."""

    def init(self, worker_id: int, scheduler: "WorkStealingScheduler", verbose: bool = True):
        super().init(worker_id, scheduler)
        self.steal_attempts = 0
        self.failed_steals = 0

    async def try_steal(self) -> Task | None:
        """Try to steal with adaptive target selection."""
        self.steal_attempts += 1

        # Try workers with largest queues first
        targets = [w for w in self.scheduler.workers if w != self]
        targets.sort(key=lambda w: w.deque.size(), reverse=True)

        for target in targets:
            if target.deque.size() > 0:
                task = target.deque.steal_top()
                if task:
                    self.tasks_stolen += 1
                    print(
                        f"[{self.now:.1f}] Worker {self.worker_id}: "
                        f"Stole {task.task_id} from Worker {target.worker_id} "
                        f"(target queue: {target.deque.size()})"
                    )
                    return task

        self.failed_steals += 1
        return None

Unsurprisingly, this leads to better load balancing:

Initial load distribution:
  Worker 0: 12 tasks
  Worker 1: 1 tasks
  Worker 2: 2 tasks
  Worker 3: 0 tasks
[0.0] Worker 0: Executing T12 (queue size: 11)
[0.0] Worker 1: Executing T13 (queue size: 0)
[0.0] Worker 2: Executing T15 (queue size: 1)
[0.0] Worker 3: Stole T1 from Worker 0 (target queue: 10)
[0.0] Worker 3: Executing T1 (queue size: 0)
[1.1] Worker 0: Completed T12
[1.1] Worker 0: Executing T11 (queue size: 9)
[1.4] Worker 3: Completed T1
[1.4] Worker 3: Stole T2 from Worker 0 (target queue: 8)
[1.4] Worker 3: Executing T2 (queue size: 0)
[1.5] Worker 2: Completed T15
[1.5] Worker 2: Executing T14 (queue size: 0)
[1.9] Worker 1: Completed T13
[1.9] Worker 1: Stole T3 from Worker 0 (target queue: 7)
[1.9] Worker 1: Executing T3 (queue size: 0)
...more...
=== Statistics ===
Total tasks executed: 15
Total tasks stolen: 8
Steal rate: 53.3%
Worker 0: executed=4, stolen=0, queue=0
Worker 1: executed=3, stolen=2, queue=0
Worker 2: executed=4, stolen=2, queue=0
Worker 3: executed=4, stolen=4, queue=0

Task Granularity

The granularity of tasks—i.e., how much work is in each one—has a big impact on performance. Many small tasks create lots of scheduling overhead, while a few large tasks cause load imbalance. Using the code we have written so far, we can easily experiment with the effect of changing task size:

=== Performance Analysis ===
Granularity: 0.1s
Total work: 50.0s
Wall time: 13.00s
Speedup: 3.85x
Efficiency: 96.2%

=== Statistics ===
Total tasks executed: 500
Total tasks stolen: 14
Steal rate: 2.8%
Worker 0: executed=125, stolen=0, queue=0
Worker 1: executed=125, stolen=7, queue=0
Worker 2: executed=125, stolen=2, queue=0
Worker 3: executed=125, stolen=5, queue=0

=== Performance Analysis ===
Granularity: 0.5s
Total work: 50.0s
Wall time: 13.00s
Speedup: 3.85x
Efficiency: 96.2%

=== Statistics ===
Total tasks executed: 100
Total tasks stolen: 13
Steal rate: 13.0%
Worker 0: executed=25, stolen=2, queue=0
Worker 1: executed=25, stolen=8, queue=0
Worker 2: executed=25, stolen=3, queue=0
Worker 3: executed=25, stolen=0, queue=0

=== Performance Analysis ===
Granularity: 2.0s
Total work: 50.0s
Wall time: 15.00s
Speedup: 3.33x
Efficiency: 83.3%

=== Statistics ===
Total tasks executed: 25
Total tasks stolen: 1
Steal rate: 4.0%
Worker 0: executed=7, stolen=1, queue=0
Worker 1: executed=6, stolen=0, queue=0
Worker 2: executed=6, stolen=0, queue=0
Worker 3: executed=6, stolen=0, queue=0

Our implementations demonstrate the core concepts of work stealing, but production systems go further. In particular, they try to prevent livelock by limiting how long a worker searches for victims, and use exponential backoff rather than spinning continuously when trying to steal work.

Exercises

FIXME: add exercises.