共享状态并发(Shared State Concurrency)
虽然消息传递是 Rust 推荐的并发方式,但有时共享状态(Shared State)是更合适的选择。Rust 提供了多种同步原语来安全地共享状态。
Mutex (互斥锁)
Mutex 是 "mutual exclusion" 的缩写,它确保在任意时刻,只允许一个线程访问某些数据。
基本 Mutex 使用
use std::sync::{Arc, Mutex};
use std::thread;
fn main() {
let counter = Arc::new(Mutex::new(0));
let mut handles = vec![];
for _ in 0..10 {
let counter = Arc::clone(&counter);
let handle = thread::spawn(move || {
let mut num = counter.lock().unwrap();
*num += 1;
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
println!("Result: {}", *counter.lock().unwrap());
}
处理锁中毒
use std::sync::{Arc, Mutex};
use std::thread;
fn main() {
let data = Arc::new(Mutex::new(0));
let data_clone = Arc::clone(&data);
let handle = thread::spawn(move || {
let mut num = data_clone.lock().unwrap();
*num += 1;
panic!("线程 panic!"); // 这会导致锁中毒
});
// 等待线程完成(会 panic)
let _ = handle.join();
// 尝试获取锁
match data.lock() {
Ok(num) => println!("值: {}", *num),
Err(poisoned) => {
println!("锁被中毒了,但我们可以恢复数据");
let num = poisoned.into_inner();
println!("恢复的值: {}", *num);
}
}
}
RwLock (读写锁)
RwLock 允许多个读者或一个写者,适合读多写少的场景。
use std::sync::{Arc, RwLock};
use std::thread;
use std::time::Duration;
fn main() {
let data = Arc::new(RwLock::new(5));
// 创建多个读者线程
let mut handles = vec![];
for i in 0..5 {
let data = Arc::clone(&data);
let handle = thread::spawn(move || {
let num = data.read().unwrap();
println!("读者 {} 读取到: {}", i, *num);
thread::sleep(Duration::from_millis(100));
});
handles.push(handle);
}
// 创建一个写者线程
let data_writer = Arc::clone(&data);
let writer_handle = thread::spawn(move || {
thread::sleep(Duration::from_millis(50));
let mut num = data_writer.write().unwrap();
*num += 1;
println!("写者更新值为: {}", *num);
});
for handle in handles {
handle.join().unwrap();
}
writer_handle.join().unwrap();
println!("最终值: {}", *data.read().unwrap());
}
原子类型
对于简单的数据类型,可以使用原子操作,这比锁更高效。
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
fn main() {
let counter = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for _ in 0..10 {
let counter = Arc::clone(&counter);
let handle = thread::spawn(move || {
for _ in 0..1000 {
counter.fetch_add(1, Ordering::SeqCst);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
println!("Result: {}", counter.load(Ordering::SeqCst));
}
原子操作的内存顺序
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
fn main() {
let data = Arc::new(AtomicUsize::new(0));
let flag = Arc::new(AtomicBool::new(false));
let data_clone = Arc::clone(&data);
let flag_clone = Arc::clone(&flag);
// 写者线程
let writer = thread::spawn(move || {
data_clone.store(42, Ordering::Relaxed);
flag_clone.store(true, Ordering::Release); // Release 语义
});
// 读者线程
let reader = thread::spawn(move || {
while !flag.load(Ordering::Acquire) { // Acquire 语义
thread::yield_now();
}
let value = data.load(Ordering::Relaxed);
println!("读取到的值: {}", value);
});
writer.join().unwrap();
reader.join().unwrap();
}
Condvar (条件变量)
条件变量用于线程间的协调,允许线程等待某个条件成立。
use std::sync::{Arc, Condvar, Mutex};
use std::thread;
use std::time::Duration;
fn main() {
let pair = Arc::new((Mutex::new(false), Condvar::new()));
let pair2 = Arc::clone(&pair);
// 等待线程
thread::spawn(move || {
let (lock, cvar) = &*pair2;
let mut started = lock.lock().unwrap();
while !*started {
println!("等待条件...");
started = cvar.wait(started).unwrap();
}
println!("条件满足,继续执行!");
});
// 主线程等待一段时间后设置条件
thread::sleep(Duration::from_millis(1000));
let (lock, cvar) = &*pair;
let mut started = lock.lock().unwrap();
*started = true;
cvar.notify_one();
thread::sleep(Duration::from_millis(100)); // 等待打印完成
}
生产者-消费者模式
use std::collections::VecDeque;
use std::sync::{Arc, Condvar, Mutex};
use std::thread;
use std::time::Duration;
struct Buffer<T> {
queue: Mutex<VecDeque<T>>,
not_empty: Condvar,
not_full: Condvar,
capacity: usize,
}
impl<T> Buffer<T> {
fn new(capacity: usize) -> Self {
Buffer {
queue: Mutex::new(VecDeque::new()),
not_empty: Condvar::new(),
not_full: Condvar::new(),
capacity,
}
}
fn put(&self, item: T) {
let mut queue = self.queue.lock().unwrap();
while queue.len() == self.capacity {
queue = self.not_full.wait(queue).unwrap();
}
queue.push_back(item);
self.not_empty.notify_one();
}
fn take(&self) -> T {
let mut queue = self.queue.lock().unwrap();
while queue.is_empty() {
queue = self.not_empty.wait(queue).unwrap();
}
let item = queue.pop_front().unwrap();
self.not_full.notify_one();
item
}
}
fn main() {
let buffer = Arc::new(Buffer::new(5));
// 生产者线程
let producer_buffer = Arc::clone(&buffer);
let producer = thread::spawn(move || {
for i in 1..=10 {
producer_buffer.put(i);
println!("生产: {}", i);
thread::sleep(Duration::from_millis(100));
}
});
// 消费者线程
let consumer_buffer = Arc::clone(&buffer);
let consumer = thread::spawn(move || {
for _ in 1..=10 {
let item = consumer_buffer.take();
println!("消费: {}", item);
thread::sleep(Duration::from_millis(150));
}
});
producer.join().unwrap();
consumer.join().unwrap();
}
Barrier (屏障)
Barrier 用于同步多个线程,让它们在某个点等待,直到所有线程都到达。
use std::sync::{Arc, Barrier};
use std::thread;
use std::time::Duration;
fn main() {
let n = 5;
let barrier = Arc::new(Barrier::new(n));
let mut handles = vec![];
for i in 0..n {
let barrier = Arc::clone(&barrier);
let handle = thread::spawn(move || {
println!("线程 {} 开始工作", i);
// 模拟不同的工作时间
thread::sleep(Duration::from_millis(100 * i as u64));
println!("线程 {} 完成工作,等待其他线程", i);
barrier.wait();
println!("线程 {} 继续执行", i);
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
Once (一次性初始化)
Once 确保某个操作只执行一次,常用于单例模式。
use std::sync::Once;
static INIT: Once = Once::new();
static mut GLOBAL_DATA: Option<String> = None;
fn get_global_data() -> &'static str {
unsafe {
INIT.call_once(|| {
GLOBAL_DATA = Some("初始化的全局数据".to_string());
});
GLOBAL_DATA.as_ref().unwrap()
}
}
fn main() {
let data1 = get_global_data();
let data2 = get_global_data();
println!("数据1: {}", data1);
println!("数据2: {}", data2);
println!("地址相同: {}", std::ptr::eq(data1, data2));
}
实际应用示例
线程安全的计数器
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::{Duration, Instant};
struct ThreadSafeCounter {
value: Arc<Mutex<usize>>,
}
impl ThreadSafeCounter {
fn new() -> Self {
ThreadSafeCounter {
value: Arc::new(Mutex::new(0)),
}
}
fn increment(&self) {
let mut value = self.value.lock().unwrap();
*value += 1;
}
fn get(&self) -> usize {
*self.value.lock().unwrap()
}
fn add(&self, n: usize) {
let mut value = self.value.lock().unwrap();
*value += n;
}
}
impl Clone for ThreadSafeCounter {
fn clone(&self) -> Self {
ThreadSafeCounter {
value: Arc::clone(&self.value),
}
}
}
fn main() {
let counter = ThreadSafeCounter::new();
let mut handles = vec![];
let start = Instant::now();
// 创建多个线程同时操作计数器
for i in 0..10 {
let counter = counter.clone();
let handle = thread::spawn(move || {
for _ in 0..1000 {
counter.increment();
}
println!("线程 {} 完成", i);
});
handles.push(handle);
}
// 等待所有线程完成
for handle in handles {
handle.join().unwrap();
}
let duration = start.elapsed();
println!("最终计数: {}", counter.get());
println!("耗时: {:?}", duration);
}
缓存系统
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::thread;
use std::time::Duration;
struct Cache<K, V> {
data: Arc<RwLock<HashMap<K, V>>>,
}
impl<K, V> Cache<K, V>
where
K: Clone + Eq + std::hash::Hash,
V: Clone,
{
fn new() -> Self {
Cache {
data: Arc::new(RwLock::new(HashMap::new())),
}
}
fn get(&self, key: &K) -> Option<V> {
let data = self.data.read().unwrap();
data.get(key).cloned()
}
fn insert(&self, key: K, value: V) {
let mut data = self.data.write().unwrap();
data.insert(key, value);
}
fn remove(&self, key: &K) -> Option<V> {
let mut data = self.data.write().unwrap();
data.remove(key)
}
fn len(&self) -> usize {
let data = self.data.read().unwrap();
data.len()
}
}
impl<K, V> Clone for Cache<K, V> {
fn clone(&self) -> Self {
Cache {
data: Arc::clone(&self.data),
}
}
}
fn main() {
let cache = Cache::new();
let mut handles = vec![];
// 写者线程
for i in 0..3 {
let cache = cache.clone();
let handle = thread::spawn(move || {
for j in 0..5 {
let key = format!("key_{}_{}", i, j);
let value = format!("value_{}_{}", i, j);
cache.insert(key.clone(), value);
println!("插入: {}", key);
thread::sleep(Duration::from_millis(10));
}
});
handles.push(handle);
}
// 读者线程
for i in 0..5 {
let cache = cache.clone();
let handle = thread::spawn(move || {
thread::sleep(Duration::from_millis(50)); // 等待一些数据被插入
for j in 0..3 {
let key = format!("key_{}_{}", j, 0);
if let Some(value) = cache.get(&key) {
println!("读者 {} 读取 {}: {}", i, key, value);
}
thread::sleep(Duration::from_millis(20));
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
println!("缓存大小: {}", cache.len());
}
工作窃取队列
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
struct WorkStealingQueue<T> {
queues: Vec<Arc<Mutex<VecDeque<T>>>>,
}
impl<T> WorkStealingQueue<T>
where
T: Send + 'static,
{
fn new(num_workers: usize) -> Self {
let mut queues = Vec::new();
for _ in 0..num_workers {
queues.push(Arc::new(Mutex::new(VecDeque::new())));
}
WorkStealingQueue { queues }
}
fn push(&self, worker_id: usize, item: T) {
self.queues[worker_id].lock().unwrap().push_back(item);
}
fn pop(&self, worker_id: usize) -> Option<T> {
self.queues[worker_id].lock().unwrap().pop_front()
}
fn steal(&self, worker_id: usize) -> Option<T> {
// 尝试从其他队列窃取工作
for (i, queue) in self.queues.iter().enumerate() {
if i != worker_id {
if let Ok(mut q) = queue.try_lock() {
if let Some(item) = q.pop_back() {
return Some(item);
}
}
}
}
None
}
}
fn main() {
let num_workers = 3;
let queue = Arc::new(WorkStealingQueue::new(num_workers));
// 向第一个工作者添加所有工作
for i in 1..=20 {
queue.push(0, i);
}
let mut handles = vec![];
for worker_id in 0..num_workers {
let queue = Arc::clone(&queue);
let handle = thread::spawn(move || {
let mut processed = 0;
loop {
// 首先尝试从自己的队列获取工作
if let Some(work) = queue.pop(worker_id) {
println!("工作者 {} 处理自己的工作: {}", worker_id, work);
processed += 1;
thread::sleep(Duration::from_millis(100));
} else if let Some(work) = queue.steal(worker_id) {
// 如果自己的队列为空,尝试窃取
println!("工作者 {} 窃取工作: {}", worker_id, work);
processed += 1;
thread::sleep(Duration::from_millis(100));
} else {
// 没有工作可做,稍等一下再试
thread::sleep(Duration::from_millis(50));
// 简化示例:如果连续几次都没有工作,就退出
if processed > 0 {
break;
}
}
}
println!("工作者 {} 完成,处理了 {} 个任务", worker_id, processed);
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
最佳实践
- 优先考虑消息传递:共享状态应该是最后的选择
- 选择合适的同步原语:
- 简单计数器:原子类型
- 读多写少:RwLock
- 复杂状态:Mutex
- 线程协调:Condvar, Barrier
- 避免死锁:
- 始终以相同顺序获取锁
- 尽快释放锁
- 考虑使用 try_lock
- 处理锁中毒:合理处理 panic 导致的锁中毒
- 性能考虑:
- 减少锁的粒度
- 使用无锁数据结构
- 考虑使用原子操作
use std::sync::{Arc, Mutex};
use std::thread;
// 好的实践示例:细粒度锁
struct BankAccount {
balance: Arc<Mutex<f64>>,
id: u32,
}
impl BankAccount {
fn new(id: u32, initial_balance: f64) -> Self {
BankAccount {
balance: Arc::new(Mutex::new(initial_balance)),
id,
}
}
fn transfer(&self, to: &BankAccount, amount: f64) -> Result<(), String> {
// 避免死锁:总是按 ID 顺序获取锁
let (first, second) = if self.id < to.id {
(&self.balance, &to.balance)
} else {
(&to.balance, &self.balance)
};
let mut first_balance = first.lock().unwrap();
let mut second_balance = second.lock().unwrap();
let (from_balance, to_balance) = if self.id < to.id {
(&mut *first_balance, &mut *second_balance)
} else {
(&mut *second_balance, &mut *first_balance)
};
if *from_balance >= amount {
*from_balance -= amount;
*to_balance += amount;
Ok(())
} else {
Err("余额不足".to_string())
}
}
fn get_balance(&self) -> f64 {
*self.balance.lock().unwrap()
}
}
fn main() {
let account1 = Arc::new(BankAccount::new(1, 1000.0));
let account2 = Arc::new(BankAccount::new(2, 500.0));
let mut handles = vec![];
// 多个线程同时进行转账
for i in 0..5 {
let acc1 = Arc::clone(&account1);
let acc2 = Arc::clone(&account2);
let handle = thread::spawn(move || {
if i % 2 == 0 {
acc1.transfer(&acc2, 10.0).unwrap();
} else {
acc2.transfer(&acc1, 5.0).unwrap();
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
println!("账户1余额: {}", account1.get_balance());
println!("账户2余额: {}", account2.get_balance());
}
共享状态并发虽然复杂,但在某些场景下是必要的。通过合理使用 Rust 提供的同步原语,可以编写出安全高效的并发程序。