线程(Threads)
Rust 提供了强大的并发编程支持,通过线程(Threads)可以让程序同时执行多个任务。Rust 的所有权系统确保了线程安全。
创建线程
基本线程创建
use std::thread;
use std::time::Duration;
fn main() {
thread::spawn(|| {
for i in 1..10 {
println!("hi number {} from the spawned thread!", i);
thread::sleep(Duration::from_millis(1));
}
});
for i in 1..5 {
println!("hi number {} from the main thread!", i);
thread::sleep(Duration::from_millis(1));
}
}
等待线程完成
use std::thread;
use std::time::Duration;
fn main() {
let handle = thread::spawn(|| {
for i in 1..10 {
println!("hi number {} from the spawned thread!", i);
thread::sleep(Duration::from_millis(1));
}
});
for i in 1..5 {
println!("hi number {} from the main thread!", i);
thread::sleep(Duration::from_millis(1));
}
handle.join().unwrap();
}
线程间数据传递
使用 move 闭包
use std::thread;
fn main() {
let v = vec![1, 2, 3];
let handle = thread::spawn(move || {
println!("Here's a vector: {:?}", v);
});
handle.join().unwrap();
}
返回值
use std::thread;
fn main() {
let handle = thread::spawn(|| {
let mut sum = 0;
for i in 1..=100 {
sum += i;
}
sum
});
let result = handle.join().unwrap();
println!("Sum: {}", result);
}
线程池
虽然标准库没有提供线程池,但我们可以创建一个简单的实现:
use std::sync::{mpsc, Arc, Mutex};
use std::thread;
type Job = Box<dyn FnOnce() + Send + 'static>;
pub struct ThreadPool {
workers: Vec<Worker>,
sender: mpsc::Sender<Job>,
}
impl ThreadPool {
pub fn new(size: usize) -> ThreadPool {
assert!(size > 0);
let (sender, receiver) = mpsc::channel();
let receiver = Arc::new(Mutex::new(receiver));
let mut workers = Vec::with_capacity(size);
for id in 0..size {
workers.push(Worker::new(id, Arc::clone(&receiver)));
}
ThreadPool { workers, sender }
}
pub fn execute<F>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
let job = Box::new(f);
self.sender.send(job).unwrap();
}
}
struct Worker {
id: usize,
thread: thread::JoinHandle<()>,
}
impl Worker {
fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Worker {
let thread = thread::spawn(move || loop {
let job = receiver.lock().unwrap().recv().unwrap();
println!("Worker {} got a job; executing.", id);
job();
});
Worker { id, thread }
}
}
fn main() {
let pool = ThreadPool::new(4);
for i in 0..8 {
pool.execute(move || {
println!("Task {} is running", i);
thread::sleep(std::time::Duration::from_secs(1));
println!("Task {} completed", i);
});
}
thread::sleep(std::time::Duration::from_secs(10));
}
线程局部存储
use std::cell::RefCell;
use std::thread;
thread_local!(static FOO: RefCell<u32> = RefCell::new(1));
fn main() {
FOO.with(|f| {
assert_eq!(*f.borrow(), 1);
*f.borrow_mut() = 2;
});
// 每个线程都有自己的 FOO 副本
let t = thread::spawn(move || {
FOO.with(|f| {
assert_eq!(*f.borrow(), 1);
*f.borrow_mut() = 3;
});
});
// 等待线程完成
t.join().unwrap();
// 主线程的 FOO 值没有改变
FOO.with(|f| {
assert_eq!(*f.borrow(), 2);
});
}
实际应用示例
并行计算
use std::thread;
use std::sync::mpsc;
fn parallel_sum(numbers: Vec<i32>, num_threads: usize) -> i32 {
let chunk_size = numbers.len() / num_threads;
let (tx, rx) = mpsc::channel();
for i in 0..num_threads {
let tx = tx.clone();
let start = i * chunk_size;
let end = if i == num_threads - 1 {
numbers.len()
} else {
(i + 1) * chunk_size
};
let chunk = numbers[start..end].to_vec();
thread::spawn(move || {
let sum: i32 = chunk.iter().sum();
tx.send(sum).unwrap();
});
}
drop(tx); // 关闭发送端
let mut total = 0;
for received in rx {
total += received;
}
total
}
fn main() {
let numbers: Vec<i32> = (1..=1000).collect();
let result = parallel_sum(numbers, 4);
println!("并行计算结果: {}", result);
// 验证结果
let expected: i32 = (1..=1000).sum();
println!("期望结果: {}", expected);
assert_eq!(result, expected);
}
生产者-消费者模式
use std::sync::mpsc;
use std::thread;
use std::time::Duration;
fn main() {
let (tx, rx) = mpsc::channel();
// 生产者线程
let producer = thread::spawn(move || {
for i in 1..=10 {
println!("生产者: 生产项目 {}", i);
tx.send(i).unwrap();
thread::sleep(Duration::from_millis(100));
}
println!("生产者: 完成生产");
});
// 消费者线程
let consumer = thread::spawn(move || {
for received in rx {
println!("消费者: 消费项目 {}", received);
thread::sleep(Duration::from_millis(150));
}
println!("消费者: 完成消费");
});
producer.join().unwrap();
consumer.join().unwrap();
}
多生产者单消费者
use std::sync::mpsc;
use std::thread;
use std::time::Duration;
fn main() {
let (tx, rx) = mpsc::channel();
// 创建多个生产者
for i in 0..3 {
let tx = tx.clone();
thread::spawn(move || {
for j in 1..=5 {
let item = format!("生产者{}-项目{}", i, j);
println!("发送: {}", item);
tx.send(item).unwrap();
thread::sleep(Duration::from_millis(100));
}
});
}
// 关闭原始发送端
drop(tx);
// 消费者
let consumer = thread::spawn(move || {
for received in rx {
println!("接收: {}", received);
thread::sleep(Duration::from_millis(50));
}
});
consumer.join().unwrap();
}
工作窃取队列
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
struct WorkStealingQueue {
queues: Vec<Arc<Mutex<VecDeque<i32>>>>,
}
impl WorkStealingQueue {
fn new(num_workers: usize) -> Self {
let mut queues = Vec::new();
for _ in 0..num_workers {
queues.push(Arc::new(Mutex::new(VecDeque::new())));
}
WorkStealingQueue { queues }
}
fn add_work(&self, worker_id: usize, work: i32) {
self.queues[worker_id].lock().unwrap().push_back(work);
}
fn steal_work(&self, worker_id: usize) -> Option<i32> {
// 首先尝试从自己的队列获取工作
if let Some(work) = self.queues[worker_id].lock().unwrap().pop_front() {
return Some(work);
}
// 如果自己的队列为空,尝试从其他队列窃取工作
for (i, queue) in self.queues.iter().enumerate() {
if i != worker_id {
if let Some(work) = queue.lock().unwrap().pop_back() {
return Some(work);
}
}
}
None
}
}
fn main() {
let num_workers = 3;
let queue = Arc::new(WorkStealingQueue::new(num_workers));
// 添加一些工作到第一个工作者的队列
for i in 1..=10 {
queue.add_work(0, i);
}
let mut handles = vec![];
for worker_id in 0..num_workers {
let queue = Arc::clone(&queue);
let handle = thread::spawn(move || {
loop {
match queue.steal_work(worker_id) {
Some(work) => {
println!("工作者 {} 处理工作 {}", worker_id, work);
thread::sleep(Duration::from_millis(100));
}
None => {
println!("工作者 {} 没有工作,休息一下", worker_id);
thread::sleep(Duration::from_millis(50));
break; // 简化示例,实际中可能会继续等待
}
}
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
线程安全的数据结构
原子计数器
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
fn main() {
let counter = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for _ in 0..10 {
let counter = Arc::clone(&counter);
let handle = thread::spawn(move || {
for _ in 0..1000 {
counter.fetch_add(1, Ordering::SeqCst);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
println!("最终计数: {}", counter.load(Ordering::SeqCst));
}
线程安全的队列
use std::collections::VecDeque;
use std::sync::{Arc, Condvar, Mutex};
use std::thread;
struct ThreadSafeQueue<T> {
queue: Mutex<VecDeque<T>>,
condition: Condvar,
}
impl<T> ThreadSafeQueue<T> {
fn new() -> Self {
ThreadSafeQueue {
queue: Mutex::new(VecDeque::new()),
condition: Condvar::new(),
}
}
fn push(&self, item: T) {
let mut queue = self.queue.lock().unwrap();
queue.push_back(item);
self.condition.notify_one();
}
fn pop(&self) -> T {
let mut queue = self.queue.lock().unwrap();
while queue.is_empty() {
queue = self.condition.wait(queue).unwrap();
}
queue.pop_front().unwrap()
}
fn try_pop(&self) -> Option<T> {
let mut queue = self.queue.lock().unwrap();
queue.pop_front()
}
}
fn main() {
let queue = Arc::new(ThreadSafeQueue::new());
// 生产者线程
let producer_queue = Arc::clone(&queue);
let producer = thread::spawn(move || {
for i in 1..=5 {
producer_queue.push(i);
println!("生产: {}", i);
thread::sleep(std::time::Duration::from_millis(100));
}
});
// 消费者线程
let consumer_queue = Arc::clone(&queue);
let consumer = thread::spawn(move || {
for _ in 1..=5 {
let item = consumer_queue.pop();
println!("消费: {}", item);
}
});
producer.join().unwrap();
consumer.join().unwrap();
}
最佳实践
- 避免共享可变状态:优先使用消息传递
- 使用 Arc 和 Mutex:当必须共享状态时
- 避免死锁:始终以相同顺序获取锁
- 使用原子类型:对于简单的共享数据
- 合理设置线程数量:通常等于 CPU 核心数
use std::sync::{Arc, Mutex};
use std::thread;
// 好的实践示例
fn process_data_parallel(data: Vec<i32>) -> Vec<i32> {
let num_threads = num_cpus::get(); // 需要添加 num_cpus crate
let chunk_size = data.len() / num_threads;
let data = Arc::new(data);
let results = Arc::new(Mutex::new(Vec::new()));
let mut handles = vec![];
for i in 0..num_threads {
let data = Arc::clone(&data);
let results = Arc::clone(&results);
let handle = thread::spawn(move || {
let start = i * chunk_size;
let end = if i == num_threads - 1 {
data.len()
} else {
(i + 1) * chunk_size
};
let mut local_results = Vec::new();
for j in start..end {
local_results.push(data[j] * 2); // 简单的处理
}
let mut results = results.lock().unwrap();
results.extend(local_results);
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let results = results.lock().unwrap();
results.clone()
}
fn main() {
let data: Vec<i32> = (1..=1000).collect();
let processed = process_data_parallel(data);
println!("处理了 {} 个元素", processed.len());
}
线程是 Rust 并发编程的基础,通过合理使用线程可以充分利用多核处理器的性能。在下一节中,我们将学习消息传递,这是 Rust 推荐的并发编程方式。