A Work-Stealing Scheduler
How do you distribute work When you have hundreds or thousands of tasks to execute and a handful of CPU cores? A naïve approach is to use a single queue, but this creates a bottleneck, since every worker must compete for access to that queue.
A work-stealing scheduler solves this problem through decentralization. Each worker maintains a local deque of tasks. Workers execute tasks from one end of their own deque, but if a worker runs out of tasks it can take some from the other end of another worker's deque. This design minimizes contention while providing some load balancing, and appears throughout high-performance computing. Go's runtime scheduler uses is to distribute goroutines across threads, Java's fork/join framework enables parallel divide-and-conquer algorithms, and Tokio (Rust's async runtime) uses it to schedule futures across worker threads.
The Work-Stealing Pattern
A work-stealing system has five parts:
- Each worker has a local deque of tasks.
- Those tasks are independent of each other.
- Workers pop tasks from the private end of their deque.
- Idle workers take tasks from the public end of other workers' deques.
- Running tasks can create new child tasks.
The key idea is asymmetry: the owning worker operates on one end of their deque (usually called the bottom) while other workers (called thieves) steal from its other end (the top). This reduces contention because owners and thieves don't compete for the same task unless the queue is almost empty.
Let's start with the task representation:
@dataclass
class Task:
"""A unit of work to be executed."""
task_id: str
duration: float
parent_id: str | None = None # For nested tasks
def __str__(self):
return f"Task({self.task_id})"
Each task has an ID, a duration to simulate CPU-bound work, and an optional parent task ID for tracking task dependencies.
Each worker maintains a deque. In our simulation, we'll use a simple list-based deque:
class WorkerDeque:
"""Double-ended queue for tasks with stealing support."""
def __init__(self):
self.tasks: list[Task] = []
def push_bottom(self, task: Task):
"""Owner pushes task to bottom (private end)."""
self.tasks.append(task)
def pop_bottom(self) -> Task | None:
"""Owner pops task from bottom."""
return self.tasks.pop() if self.tasks else None
def steal_top(self) -> Task | None:
"""Thief steals task from top (public end)."""
return self.tasks.pop(0) if self.tasks else None
def is_empty(self) -> bool:
"""Check if deque is empty."""
return len(self.tasks) == 0
def size(self) -> int:
"""Return number of tasks."""
return len(self.tasks)
A production system would use something more sophisticated than a simple Python list to manage the deque, but our simulation focuses on the algorithmic pattern rather than low-level synchronization.
A worker executes tasks from its local deque and steals when idle. We start by setting up its members:
and then define its behavior:
async def run(self):
"""Main worker loop: execute local tasks or steal."""
while True:
# Try to get a task from local deque
task = self.deque.pop_bottom()
if task:
# Execute local task
await self.execute_task(task)
else:
# No local work, try stealing
stolen = await self.try_steal()
if stolen:
await self.execute_task(stolen)
else:
# No work available anywhere, wait a bit
await self.timeout(0.1)
As the code above shows, the worker continuously tries to execute tasks. If its local deque is empty, it attempts to steal from other workers. If stealing fails, it waits briefly before trying again.
Executing a task is relatively straightforward:
async def execute_task(self, task: Task):
"""Execute a task."""
self.current_task = task
self.tasks_executed += 1
if self.verbose:
print(
f"[{self.now:.1f}] Worker {self.worker_id}: "
f"Executing {task.task_id} (queue size: {self.deque.size()})"
)
await self.timeout(task.duration)
if self.verbose:
print(f"[{self.now:.1f}] Worker {self.worker_id}: Completed {task.task_id}")
self.current_task = None
Stealing a task from another worker is somewhat more interesting. The most important part is that we randomize the order in which we check the workers in order to spread the load as evenly as possible:
async def try_steal(self) -> Task | None:
"""Try to steal a task from another worker."""
targets = [w for w in self.scheduler.workers if w != self]
if not targets:
return None
random.shuffle(targets)
for target in targets:
task = target.deque.steal_top()
if task:
self.tasks_stolen += 1
if self.verbose:
print(
f"[{self.now:.1f}] Worker {self.worker_id}: "
f"Stole {task.task_id} from Worker {target.worker_id}"
)
return task
return None
The Scheduler
The scheduler coordinates workers and provides task submission:
class WorkStealingScheduler:
"""Scheduler that coordinates work-stealing workers."""
def __init__(
self,
env: Environment,
num_workers: int,
verbose: bool = True,
worker_cls: type = Worker,
):
self.env = env
self.num_workers = num_workers
self.verbose = verbose
self.workers: list = []
self.task_counter = 0
# Create workers
for i in range(num_workers):
worker = worker_cls(env, i, self, verbose)
self.workers.append(worker)
def submit_task(self, duration: float, parent_id: str | None = None) -> Task:
"""Submit a task to a random worker."""
self.task_counter += 1
task = Task(
task_id=f"T{self.task_counter}",
duration=duration,
parent_id=parent_id,
)
worker = random.choice(self.workers)
worker.deque.push_bottom(task)
if self.verbose:
print(
f"[{self.env.now:.1f}] Submitted {task.task_id} "
f"to Worker {worker.worker_id}"
)
return task
We can create a simple simulation with load imbalance to see it in action:
def run_basic_simulation():
"""Basic work-stealing simulation."""
env = Environment()
scheduler = WorkStealingScheduler(env, num_workers=3)
for i in range(10):
scheduler.submit_task(duration=random.uniform(0.5, 2.0))
env.run(until=20)
scheduler.get_statistics()
The output shows workers executing tasks and stealing from each other when they run out of local work. The steal rate shows how much load balancing occurred:
[0.0] Submitted T1 to Worker 2
[0.0] Submitted T2 to Worker 2
[0.0] Submitted T3 to Worker 0
[0.0] Submitted T4 to Worker 1
[0.0] Submitted T5 to Worker 0
[0.0] Submitted T6 to Worker 1
[0.0] Submitted T7 to Worker 0
[0.0] Submitted T8 to Worker 0
[0.0] Submitted T9 to Worker 2
[0.0] Submitted T10 to Worker 1
...more...
=== Statistics ===
Total tasks executed: 10
Total tasks stolen: 0
Steal rate: 0.0%
Worker 0: executed=4, stolen=0, queue=0
Worker 1: executed=3, stolen=0, queue=0
Worker 2: executed=3, stolen=0, queue=0
Nested Task Spawning
A common extension of work-stealing is to support divide-and-conquer algorithms by allowing tasks to spawn subtasks. To explore this, we can create a task generator:
class TaskGenerator(Process):
"""Generates tasks including ones that spawn subtasks."""
def init(self, scheduler: WorkStealingScheduler, num_initial_tasks: int):
self.scheduler = scheduler
self.num_initial_tasks = num_initial_tasks
async def run(self):
"""Generate initial tasks."""
for i in range(self.num_initial_tasks):
self.scheduler.submit_task(duration=random.uniform(1.0, 3.0))
await self.timeout(0.5)
and then create a worker that spawns subtasks with some random probability (in our case, 30%):
class WorkerWithSpawning(Worker):
"""Worker that can spawn child tasks during execution."""
async def execute_task(self, task: Task):
"""Execute task and possibly spawn children."""
self.current_task = task
self.tasks_executed += 1
print(f"[{self.now:.1f}] Worker {self.worker_id}: Executing {task.task_id}")
# Do half the work
await self.timeout(task.duration / 2)
# Randomly spawn child tasks (simulating divide-and-conquer)
if random.random() < 0.3: # 30% chance
num_children = random.randint(1, 3)
for i in range(num_children):
child = Task(
task_id=f"{task.task_id}.{i}",
duration=random.uniform(0.3, 1.0),
parent_id=task.task_id,
)
self.spawn_task(child)
# Finish the work
await self.timeout(task.duration / 2)
print(f"[{self.now:.1f}] Worker {self.worker_id}: Completed {task.task_id}")
self.current_task = None
def spawn_task(self, task: Task):
"""Spawn a new task (called by executing task)."""
self.deque.push_bottom(task)
print(f"[{self.now:.1f}] Worker {self.worker_id}: Spawned {task.task_id}")
The final step is to write a scheduler that creates these workers:
Our simulation looks similar to our first one:
def run_spawning_simulation():
"""Demonstrate nested task spawning."""
env = Environment()
# Create scheduler with spawning workers
scheduler = SchedulerWithSpawning(env, num_workers=4)
# Generate initial tasks
TaskGenerator(env, scheduler, num_initial_tasks=5)
# Run simulation
env.run(until=30)
# Print statistics
scheduler.get_statistics()
Its output shows that spawning helps balance the load even with irregular task creation:
[0.0] Submitted T1 to Worker 1
[0.1] Worker 0: Stole T1 from Worker 1
[0.1] Worker 0: Executing T1
[0.5] Submitted T2 to Worker 1
[0.5] Worker 1: Executing T2
[1.0] Submitted T3 to Worker 1
[1.1] Worker 2: Stole T3 from Worker 1
[1.1] Worker 2: Executing T3
[1.5] Submitted T4 to Worker 3
[1.5] Worker 3: Executing T4
[2.0] Worker 0: Completed T1
[2.0] Submitted T5 to Worker 1
[2.1] Worker 0: Stole T5 from Worker 1
[2.1] Worker 0: Executing T5
[2.4] Worker 2: Completed T3
[3.3] Worker 0: Completed T5
[3.4] Worker 1: Completed T2
[4.5] Worker 3: Completed T4
=== Statistics ===
...more...
=== Statistics ===
Total tasks executed: 5
Total tasks stolen: 3
Steal rate: 60.0%
Worker 0: executed=2, stolen=2, queue=0
Worker 1: executed=1, stolen=0, queue=0
Worker 2: executed=1, stolen=1, queue=0
Worker 3: executed=1, stolen=0, queue=0
Load Balancing Strategies
What effect does target selection strategy have on performance? To find out, we can create a worker that uses adaptive target selection, i.e., that steals tasks from the largest of its peers' queues:
class AdaptiveWorker(Worker):
"""Worker with adaptive target selection."""
def init(self, worker_id: int, scheduler: "WorkStealingScheduler", verbose: bool = True):
super().init(worker_id, scheduler)
self.steal_attempts = 0
self.failed_steals = 0
async def try_steal(self) -> Task | None:
"""Try to steal with adaptive target selection."""
self.steal_attempts += 1
# Try workers with largest queues first
targets = [w for w in self.scheduler.workers if w != self]
targets.sort(key=lambda w: w.deque.size(), reverse=True)
for target in targets:
if target.deque.size() > 0:
task = target.deque.steal_top()
if task:
self.tasks_stolen += 1
print(
f"[{self.now:.1f}] Worker {self.worker_id}: "
f"Stole {task.task_id} from Worker {target.worker_id} "
f"(target queue: {target.deque.size()})"
)
return task
self.failed_steals += 1
return None
Unsurprisingly, this leads to better load balancing:
Initial load distribution:
Worker 0: 12 tasks
Worker 1: 1 tasks
Worker 2: 2 tasks
Worker 3: 0 tasks
[0.0] Worker 0: Executing T12 (queue size: 11)
[0.0] Worker 1: Executing T13 (queue size: 0)
[0.0] Worker 2: Executing T15 (queue size: 1)
[0.0] Worker 3: Stole T1 from Worker 0 (target queue: 10)
[0.0] Worker 3: Executing T1 (queue size: 0)
[1.1] Worker 0: Completed T12
[1.1] Worker 0: Executing T11 (queue size: 9)
[1.4] Worker 3: Completed T1
[1.4] Worker 3: Stole T2 from Worker 0 (target queue: 8)
[1.4] Worker 3: Executing T2 (queue size: 0)
[1.5] Worker 2: Completed T15
[1.5] Worker 2: Executing T14 (queue size: 0)
[1.9] Worker 1: Completed T13
[1.9] Worker 1: Stole T3 from Worker 0 (target queue: 7)
[1.9] Worker 1: Executing T3 (queue size: 0)
...more...
=== Statistics ===
Total tasks executed: 15
Total tasks stolen: 8
Steal rate: 53.3%
Worker 0: executed=4, stolen=0, queue=0
Worker 1: executed=3, stolen=2, queue=0
Worker 2: executed=4, stolen=2, queue=0
Worker 3: executed=4, stolen=4, queue=0
Task Granularity
The granularity of tasks—i.e., how much work is in each one—has a big impact on performance. Many small tasks create lots of scheduling overhead, while a few large tasks cause load imbalance. Using the code we have written so far, we can easily experiment with the effect of changing task size:
=== Performance Analysis ===
Granularity: 0.1s
Total work: 50.0s
Wall time: 13.00s
Speedup: 3.85x
Efficiency: 96.2%
=== Statistics ===
Total tasks executed: 500
Total tasks stolen: 14
Steal rate: 2.8%
Worker 0: executed=125, stolen=0, queue=0
Worker 1: executed=125, stolen=7, queue=0
Worker 2: executed=125, stolen=2, queue=0
Worker 3: executed=125, stolen=5, queue=0
=== Performance Analysis ===
Granularity: 0.5s
Total work: 50.0s
Wall time: 13.00s
Speedup: 3.85x
Efficiency: 96.2%
=== Statistics ===
Total tasks executed: 100
Total tasks stolen: 13
Steal rate: 13.0%
Worker 0: executed=25, stolen=2, queue=0
Worker 1: executed=25, stolen=8, queue=0
Worker 2: executed=25, stolen=3, queue=0
Worker 3: executed=25, stolen=0, queue=0
=== Performance Analysis ===
Granularity: 2.0s
Total work: 50.0s
Wall time: 15.00s
Speedup: 3.33x
Efficiency: 83.3%
=== Statistics ===
Total tasks executed: 25
Total tasks stolen: 1
Steal rate: 4.0%
Worker 0: executed=7, stolen=1, queue=0
Worker 1: executed=6, stolen=0, queue=0
Worker 2: executed=6, stolen=0, queue=0
Worker 3: executed=6, stolen=0, queue=0
Our implementations demonstrate the core concepts of work stealing, but production systems go further. In particular, they try to prevent livelock by limiting how long a worker searches for victims, and use exponential backoff rather than spinning continuously when trying to steal work.
Exercises
FIXME: add exercises.