并发实战:任务调度器
并发实战:任务调度器
想象一个快递分拣中心:传送带不断送来包裹(任务),多个分拣员(goroutine)同时工作,有人负责扫描、有人负责分类、有人负责装车。每个环节通过流水线(channel)传递包裹,调度员(scheduler)监控全局进度,一旦某个环节出问题就及时处理。这就是并发任务调度器的运作方式——多个工人协同完成大量任务,既高效又可靠。
本节课我们将综合前面学到的 goroutine、channel、select、sync 包和 context.Context,从零开发一个功能完整的并发任务调度器。
项目需求
我们需要构建一个任务调度器,满足以下要求:
- 并发执行:支持多个 worker 同时处理任务
- 任务提交:外部可以动态提交任务
- 超时控制:单个任务执行超时后自动取消
- 失败重试:任务失败后自动重试,最多重试 N 次
- 结果收集:并发安全地收集所有任务的执行结果
- 优雅关闭:接收到取消信号后,等待正在执行的任务完成再退出
系统设计
整体架构分为三层:
┌─────────────┐ ┌──────────────┐ ┌─────────────┐
│ 任务提交 │────▶│ 调度引擎 │────▶│ Worker 池 │
│ (Producer) │ │ (Scheduler) │ │ (Workers) │
└─────────────┘ └──────────────┘ └─────────────┘
│ │
▼ ▼
┌──────────────┐ ┌─────────────┐
│ 结果收集 │◀────│ 任务执行 │
│ (Collector) │ │ (Executor) │
└──────────────┘ └─────────────┘
核心组件:
| 组件 | 职责 | 实现方式 |
|---|---|---|
Task |
表示一个可执行的任务 | 结构体 + 函数类型 |
Result |
保存任务执行结果 | 结构体 + channel |
Worker |
执行具体的任务 | goroutine |
Scheduler |
协调任务分发与执行 | channel + select |
ResultCollector |
并发安全地收集结果 | sync.Mutex |
完整代码实现
1. 定义数据结构
package main
import (
"context"
"fmt"
"math/rand"
"sync"
"time"
)
// Task 表示一个待执行的任务
type Task struct {
ID int // 任务唯一标识
Payload string // 任务数据
Execute func(ctx context.Context) (string, error) // 任务执行函数
}
// Result 表示任务的执行结果
type Result struct {
TaskID int // 对应的任务 ID
Output string // 执行成功时的输出
Error error // 执行失败时的错误
Attempts int // 实际尝试次数
}
// ResultCollector 并发安全地收集任务结果
type ResultCollector struct {
mu sync.Mutex
results []Result
}
// Add 添加一个结果(并发安全)
func (rc *ResultCollector) Add(r Result) {
rc.mu.Lock()
defer rc.mu.Unlock()
rc.results = append(rc.results, r)
}
// All 返回所有已收集的结果
func (rc *ResultCollector) All() []Result {
rc.mu.Lock()
defer rc.mu.Unlock()
// 返回副本,避免外部修改
out := make([]Result, len(rc.results))
copy(out, rc.results)
return out
}
2. 实现 Worker
// Worker 从任务通道接收任务并执行,结果发送到结果通道
func Worker(
ctx context.Context,
id int,
tasks <-chan Task,
results chan<- Result,
maxRetries int,
wg *sync.WaitGroup,
) {
defer wg.Done()
for task := range tasks {
// 检查是否已取消
select {
case <-ctx.Done():
fmt.Printf("[Worker %d] 上下文已取消,退出\n", id)
return
default:
}
fmt.Printf("[Worker %d] 开始执行任务 #%d\n", id, task.ID)
var output string
var err error
attempts := 0
// 带重试的执行逻辑
for attempts < maxRetries {
attempts++
// 为每个任务创建带超时的子上下文(最多 2 秒)
taskCtx, taskCancel := context.WithTimeout(ctx, 2*time.Second)
output, err = task.Execute(taskCtx)
taskCancel()
if err == nil {
break // 成功,跳出重试循环
}
fmt.Printf("[Worker %d] 任务 #%d 第 %d 次失败: %v\n",
id, task.ID, attempts, err)
if attempts < maxRetries {
// 等待一段时间再重试,同时监听取消信号
select {
case <-ctx.Done():
fmt.Printf("[Worker %d] 重试等待中收到取消信号\n", id)
results <- Result{TaskID: task.ID, Error: ctx.Err(), Attempts: attempts}
return
case <-time.After(time.Duration(attempts) * 500 * time.Millisecond):
// 指数退避:第1次等500ms,第2次等1s,第3次等1.5s ...
}
}
}
results <- Result{
TaskID: task.ID,
Output: output,
Error: err,
Attempts: attempts,
}
if err != nil {
fmt.Printf("[Worker %d] 任务 #%d 最终失败(尝试 %d 次)\n", id, task.ID, attempts)
} else {
fmt.Printf("[Worker %d] 任务 #%d 完成\n", id, task.ID)
}
}
fmt.Printf("[Worker %d] 任务通道已关闭,退出\n", id)
}
3. 实现调度器
// Scheduler 是任务调度引擎
type Scheduler struct {
workerCount int // worker 数量
maxRetries int // 最大重试次数
taskTimeout time.Duration // 单任务超时
}
// NewScheduler 创建一个新的调度器
func NewScheduler(workerCount, maxRetries int, taskTimeout time.Duration) *Scheduler {
return &Scheduler{
workerCount: workerCount,
maxRetries: maxRetries,
taskTimeout: taskTimeout,
}
}
// Run 启动调度器,处理给定的任务列表,返回所有结果
func (s *Scheduler) Run(ctx context.Context, taskList []Task) []Result {
// 创建通道
tasks := make(chan Task, len(taskList))
results := make(chan Result, len(taskList))
// 结果收集器
collector := &ResultCollector{}
// 启动结果收集 goroutine
var collectorWg sync.WaitGroup
collectorWg.Add(1)
go func() {
defer collectorWg.Done()
for r := range results {
collector.Add(r)
}
}()
// 启动 worker 池
var workerWg sync.WaitGroup
for i := 1; i <= s.workerCount; i++ {
workerWg.Add(1)
go Worker(ctx, i, tasks, results, s.maxRetries, &workerWg)
}
// 提交所有任务
go func() {
for _, task := range taskList {
select {
case tasks <- task:
fmt.Printf("[调度器] 已提交任务 #%d\n", task.ID)
case <-ctx.Done():
fmt.Println("[调度器] 上下文取消,停止提交任务")
break
}
}
close(tasks) // 关闭任务通道,通知 worker 不再有新任务
}()
// 等待所有 worker 完成
workerWg.Wait()
close(results) // 关闭结果通道
// 等待收集器处理完所有结果
collectorWg.Wait()
return collector.All()
}
4. 模拟任务与主函数
// createDemoTasks 创建一组模拟任务,部分会失败、部分会超时
func createDemoTasks(count int) []Task {
tasks := make([]Task, count)
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
for i := 0; i < count; i++ {
id := i + 1
behavior := rng.Intn(4) // 0-3,决定任务行为
switch behavior {
case 0:
// 正常完成
tasks[i] = Task{
ID: id,
Payload: fmt.Sprintf("正常任务-%d", id),
Execute: func(ctx context.Context) (string, error) {
time.Sleep(time.Duration(100+rng.Intn(400)) * time.Millisecond)
return fmt.Sprintf("任务 #%d 成功", id), nil
},
}
case 1:
// 偶尔失败(第1次失败,第2次成功)
failCount := 0
var mu sync.Mutex
tasks[i] = Task{
ID: id,
Payload: fmt.Sprintf("不稳定任务-%d", id),
Execute: func(ctx context.Context) (string, error) {
mu.Lock()
failCount++
current := failCount
mu.Unlock()
time.Sleep(time.Duration(50+rng.Intn(200)) * time.Millisecond)
if current <= 1 {
return nil, fmt.Errorf("任务 #%d 模拟失败", id)
}
return fmt.Sprintf("任务 #%d 第 %d 次尝试成功", id, current), nil
},
}
case 2:
// 执行时间过长(会触发超时)
tasks[i] = Task{
ID: id,
Payload: fmt.Sprintf("慢任务-%d", id),
Execute: func(ctx context.Context) (string, error) {
select {
case <-time.After(5 * time.Second): // 远超 2 秒超时
return fmt.Sprintf("任务 #%d 完成", id), nil
case <-ctx.Done():
return "", fmt.Errorf("任务 #%d 被取消: %w", id, ctx.Err())
}
},
}
default:
// 始终失败
tasks[i] = Task{
ID: id,
Payload: fmt.Sprintf("失败任务-%d", id),
Execute: func(ctx context.Context) (string, error) {
time.Sleep(time.Duration(50+rng.Intn(150)) * time.Millisecond)
return nil, fmt.Errorf("任务 #%d 不可恢复的错误", id)
},
}
}
}
return tasks
}
func main() {
fmt.Println("=== 并发任务调度器 ===")
fmt.Println()
// 创建支持取消的上下文
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// 监听操作系统中断信号(Ctrl+C),触发优雅关闭
go func() {
// 模拟:8 秒后发送取消信号(实际项目中可用 signal.Notify)
time.Sleep(8 * time.Second)
fmt.Println("\n[主程序] 收到取消信号,开始优雅关闭...")
cancel()
}()
// 创建调度器:3 个 worker,最多重试 3 次,单任务超时 2 秒
scheduler := NewScheduler(3, 3, 2*time.Second)
// 生成 10 个模拟任务
taskList := createDemoTasks(10)
// 执行调度
fmt.Printf("[主程序] 提交 %d 个任务,启动 %d 个 worker\n\n", len(taskList), 3)
start := time.Now()
results := scheduler.Run(ctx, taskList)
elapsed := time.Since(start)
// 输出结果统计
fmt.Println("\n=============================")
fmt.Println(" 执行结果汇总 ")
fmt.Println("=============================")
successCount := 0
failCount := 0
cancelCount := 0
for _, r := range results {
status := "✓"
detail := r.Output
if r.Error != nil {
if ctx.Err() != nil {
status = "⊘"
detail = "已取消"
cancelCount++
} else {
status = "✗"
detail = r.Error.Error()
failCount++
}
} else {
successCount++
}
fmt.Printf(" %s 任务 #%02d | 尝试: %d 次 | %s\n",
status, r.TaskID, r.Attempts, detail)
}
fmt.Println("-----------------------------")
fmt.Printf(" 总计: %d | 成功: %d | 失败: %d | 取消: %d\n",
len(results), successCount, failCount, cancelCount)
fmt.Printf(" 耗时: %v\n", elapsed)
}
5. 运行效果
go run main.go
输出示例(每次运行结果随机):
=== 并发任务调度器 ===
[主程序] 提交 10 个任务,启动 3 个 worker
[调度器] 已提交任务 #1
[调度器] 已提交任务 #2
[调度器] 已提交任务 #3
[调度器] 已提交任务 #4
[调度器] 已提交任务 #5
[调度器] 已提交任务 #6
[调度器] 已提交任务 #7
[调度器] 已提交任务 #8
[调度器] 已提交任务 #9
[调度器] 已提交任务 #10
[Worker 1] 开始执行任务 #1
[Worker 2] 开始执行任务 #2
[Worker 3] 开始执行任务 #3
[Worker 1] 任务 #1 完成
[Worker 1] 开始执行任务 #4
[Worker 2] 任务 #2 第 1 次失败: 任务 #2 模拟失败
[Worker 3] 任务 #3 完成
[Worker 3] 开始执行任务 #5
[Worker 2] 开始执行任务 #6
[Worker 2] 任务 #6 第 1 次失败: 任务 #6 不可恢复的错误
[Worker 3] 任务 #5 完成
[Worker 3] 开始执行任务 #7
[Worker 2] 任务 #6 第 2 次失败: 任务 #6 不可恢复的错误
[Worker 1] 任务 #4 第 1 次失败: 任务 #4 模拟失败
[Worker 2] 任务 #6 第 3 次失败: 任务 #6 不可恢复的错误
[Worker 2] 任务 #6 最终失败(尝试 3 次)
[Worker 2] 开始执行任务 #8
...
[Worker 1] 任务通道已关闭,退出
[Worker 2] 任务通道已关闭,退出
[Worker 3] 任务通道已关闭,退出
=============================
执行结果汇总
=============================
✓ 任务 #01 | 尝试: 1 次 | 任务 #1 成功
⊘ 任务 #02 | 尝试: 2 次 | 已取消
✓ 任务 #03 | 尝试: 1 次 | 任务 #3 成功
⊘ 任务 #04 | 尝试: 2 次 | 已取消
✓ 任务 #05 | 尝试: 1 次 | 任务 #5 成功
✗ 任务 #06 | 尝试: 3 次 | 任务 #6 不可恢复的错误
✓ 任务 #07 | 尝试: 1 次 | 任务 #7 成功
⊘ 任务 #08 | 尝试: 1 次 | 已取消
✓ 任务 #09 | 尝试: 1 次 | 任务 #9 成功
⊘ 任务 #10 | 尝试: 1 次 | 已取消
-----------------------------
总计: 10 | 成功: 5 | 失败: 1 | 取消: 4
耗时: 8.012s
代码解析
核心并发模式
1. Fan-out / Fan-in(扇出/扇入)
这是本调度器的核心模式:
┌─ Worker 1 ─┐
Tasks ─────┼─ Worker 2 ─┼───── Results
└─ Worker 3 ─┘
- 扇出(Fan-out):多个 worker 从同一个
taskschannel 读取,自动实现任务分发 - 扇入(Fan-in):所有 worker 将结果写入同一个
resultschannel,汇聚结果
2. context.Context 级联取消
根 Context (main)
└─ WithCancel
├─ Worker 1 的 WithTimeout
│ └─ taskCtx(单任务超时 2s)
├─ Worker 2 的 WithTimeout
│ └─ taskCtx(单任务超时 2s)
└─ Worker 3 的 WithTimeout
└─ taskCtx(单任务超时 2s)
当主程序调用 cancel() 时,所有子 context 都会收到取消信号,实现级联取消。
3. WaitGroup 同步等待
var wg sync.WaitGroup
wg.Add(1) // 每启动一个 worker,计数 +1
go func() {
defer wg.Done() // worker 退出时计数 -1
// ...
}()
wg.Wait() // 阻塞直到计数归零
4. select 多路复用
Worker 中同时监听多个信号源:
select {
case <-ctx.Done(): // 上下文取消
return
case <-time.After(d): // 重试退避
// 继续重试
}
重试退避策略
本例采用线性退避:每次重试等待时间 = attempts × 500ms。生产环境中推荐指数退避:
backoff := time.Duration(1<<uint(attempts)) * 100 * time.Millisecond
// 第1次: 200ms, 第2次: 400ms, 第3次: 800ms
还可以加入随机抖动(jitter)防止多个任务同时重试造成"惊群效应"。
❓ 常见问题
1. 为什么 Worker 用 range tasks 而不是直接检查 ctx.Done()?
range tasks 在 channel 关闭后自动退出循环,是标准的生产者-消费者模式。ctx.Done() 用于在等待或重试期间响应取消信号。两者配合使用:
- channel 关闭 → 正常退出(所有任务已提交完毕)
- context 取消 → 提前退出(外部要求终止)
如果只用 ctx.Done(),需要额外处理 channel 关闭的情况;只用 range 则无法响应取消。
2. results channel 的缓冲区大小怎么确定?
本例中 make(chan Result, len(taskList)) 使用任务总数作为缓冲区大小,确保所有结果都能写入而不阻塞。如果任务数量非常大(百万级),可以:
- 使用无缓冲 channel + 独立的收集 goroutine
- 分批处理,每批结束后重置 channel
- 使用
sync.Map或带锁的切片替代 channel
3. 如何实现真正的指数退避?
func exponentialBackoff(attempt int) time.Duration {
base := 100 * time.Millisecond
max := 10 * time.Second
backoff := base * time.Duration(1<<uint(attempt))
if backoff > max {
backoff = max
}
// 加入随机抖动:±25%
jitter := time.Duration(rand.Int63n(int64(backoff) / 2))
return backoff - backoff/4 + jitter
}
4. 如何限制同时执行的任务数量(并发度控制)?
本例通过 worker 数量自然限制了并发度。另一种方式是使用信号量:
sem := make(chan struct{}, 10) // 最多 10 个并发
for _, task := range taskList {
sem <- struct{}{} // 获取许可
go func(t Task) {
defer func() { <-sem }() // 释放许可
t.Execute(ctx)
}(task)
}
📖 小节
本节课我们综合运用了 Go 并发编程的核心技术:
| 技术 | 应用场景 | 关键代码 |
|---|---|---|
| goroutine | Worker 并行执行任务 | go Worker(...) |
| channel | 任务分发与结果收集 | tasks <-chan Task |
| select | 多路复用:取消、超时、退避 | select { case <-ctx.Done(): ... } |
| sync.WaitGroup | 等待所有 worker 完成 | wg.Wait() |
| sync.Mutex | 并发安全的结果收集 | rc.mu.Lock() |
| context.Context | 级联取消与超时控制 | context.WithTimeout(ctx, 2*time.Second) |
设计原则回顾:
- 使用 channel 传递数据,不要通过共享内存通信
- 使用 context 传播取消信号,贯穿整个调用链
- WaitGroup 用于等待一组 goroutine 完成
- Mutex 用于保护共享状态(本例中的结果收集器)
- 关闭 channel 作为广播信号,通知所有消费者
📝 作业
作业 1:添加优先级队列
修改调度器,支持任务优先级(高/中/低)。高优先级任务应优先被 worker 处理。
提示:可以使用多个 channel 或实现一个优先级队列结构。
作业 2:实现速率限制
为调度器添加速率限制功能,例如每秒最多执行 5 个任务(令牌桶算法)。
提示:使用 time.Ticker 或第三方库 golang.org/x/time/rate。
limiter := rate.NewLimiter(5, 1) // 每秒 5 个,突发 1 个
for task := range tasks {
limiter.Wait(ctx) // 等待获取令牌
task.Execute(ctx)
}
作业 3:添加任务进度回调
为调度器添加一个 OnProgress 回调函数,在每个任务完成时调用,报告当前进度(已完成/总数)。
type Scheduler struct {
// ... 其他字段
OnProgress func(completed, total int)
}



