并发编程
Rust 的并发编程模型基于所有权系统,提供了内存安全的并发编程能力。Rust 通过类型系统在编译时防止数据竞争,让并发编程变得更加安全。
线程基础
创建线程
rust
use std::thread;
use std::time::Duration;
fn main() {
// 创建新线程
let handle = thread::spawn(|| {
for i in 1..10 {
println!("hi number {} from the spawned thread!", i);
thread::sleep(Duration::from_millis(1));
}
});
// 主线程执行
for i in 1..5 {
println!("hi number {} from the main thread!", i);
thread::sleep(Duration::from_millis(1));
}
// 等待子线程完成
handle.join().unwrap();
}
线程间数据传递
rust
use std::thread;
fn main() {
let v = vec![1, 2, 3];
// 使用 move 将所有权转移到线程
let handle = thread::spawn(move || {
println!("Here's a vector: {:?}", v);
});
// println!("v: {:?}", v); // 错误!v 已被移动
handle.join().unwrap();
}
消息传递
基本通道
rust
use std::sync::mpsc;
use std::thread;
fn main() {
// 创建通道
let (tx, rx) = mpsc::channel();
thread::spawn(move || {
let val = String::from("hi");
tx.send(val).unwrap();
// println!("val is {}", val); // 错误!val 已被移动
});
// 接收消息
let received = rx.recv().unwrap();
println!("Got: {}", received);
}
发送多个值
rust
use std::sync::mpsc;
use std::thread;
use std::time::Duration;
fn main() {
let (tx, rx) = mpsc::channel();
thread::spawn(move || {
let vals = vec![
String::from("hi"),
String::from("from"),
String::from("the"),
String::from("thread"),
];
for val in vals {
tx.send(val).unwrap();
thread::sleep(Duration::from_secs(1));
}
});
// 接收多个消息
for received in rx {
println!("Got: {}", received);
}
}
多个发送者
rust
use std::sync::mpsc;
use std::thread;
use std::time::Duration;
fn main() {
let (tx, rx) = mpsc::channel();
// 克隆发送者
let tx1 = tx.clone();
thread::spawn(move || {
let vals = vec![
String::from("hi"),
String::from("from"),
String::from("the"),
String::from("thread"),
];
for val in vals {
tx1.send(val).unwrap();
thread::sleep(Duration::from_secs(1));
}
});
thread::spawn(move || {
let vals = vec![
String::from("more"),
String::from("messages"),
String::from("for"),
String::from("you"),
];
for val in vals {
tx.send(val).unwrap();
thread::sleep(Duration::from_secs(1));
}
});
for received in rx {
println!("Got: {}", received);
}
}
共享状态并发
Mutex 互斥锁
rust
use std::sync::Mutex;
fn main() {
let m = Mutex::new(5);
{
let mut num = m.lock().unwrap();
*num = 6;
} // 锁在这里被释放
println!("m = {:?}", m);
}
多线程共享 Mutex
rust
use std::sync::{Arc, Mutex};
use std::thread;
fn main() {
// 使用 Arc 在多个线程间共享 Mutex
let counter = Arc::new(Mutex::new(0));
let mut handles = vec![];
for _ in 0..10 {
let counter = Arc::clone(&counter);
let handle = thread::spawn(move || {
let mut num = counter.lock().unwrap();
*num += 1;
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
println!("Result: {}", *counter.lock().unwrap());
}
RwLock 读写锁
rust
use std::sync::{Arc, RwLock};
use std::thread;
use std::time::Duration;
fn main() {
let data = Arc::new(RwLock::new(vec![1, 2, 3]));
let mut handles = vec![];
// 多个读者
for i in 0..3 {
let data = Arc::clone(&data);
let handle = thread::spawn(move || {
let reader = data.read().unwrap();
println!("Reader {}: {:?}", i, *reader);
thread::sleep(Duration::from_millis(100));
});
handles.push(handle);
}
// 一个写者
let data_writer = Arc::clone(&data);
let writer_handle = thread::spawn(move || {
thread::sleep(Duration::from_millis(50));
let mut writer = data_writer.write().unwrap();
writer.push(4);
println!("Writer: added 4");
});
handles.push(writer_handle);
for handle in handles {
handle.join().unwrap();
}
println!("Final data: {:?}", *data.read().unwrap());
}
原子类型
基本原子操作
rust
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
fn main() {
let counter = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for _ in 0..10 {
let counter = Arc::clone(&counter);
let handle = thread::spawn(move || {
for _ in 0..1000 {
counter.fetch_add(1, Ordering::SeqCst);
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
println!("Result: {}", counter.load(Ordering::SeqCst));
}
不同的内存排序
rust
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
fn main() {
let data = Arc::new(AtomicUsize::new(0));
let flag = Arc::new(AtomicBool::new(false));
let data_clone = Arc::clone(&data);
let flag_clone = Arc::clone(&flag);
// 写线程
let writer = thread::spawn(move || {
data_clone.store(42, Ordering::Relaxed);
flag_clone.store(true, Ordering::Release); // Release 语义
});
// 读线程
let reader = thread::spawn(move || {
while !flag.load(Ordering::Acquire) { // Acquire 语义
thread::yield_now();
}
let value = data.load(Ordering::Relaxed);
println!("Read value: {}", value);
});
writer.join().unwrap();
reader.join().unwrap();
}
线程池
简单线程池实现
rust
use std::sync::{mpsc, Arc, Mutex};
use std::thread;
type Job = Box<dyn FnOnce() + Send + 'static>;
pub struct ThreadPool {
workers: Vec<Worker>,
sender: mpsc::Sender<Job>,
}
impl ThreadPool {
pub fn new(size: usize) -> ThreadPool {
assert!(size > 0);
let (sender, receiver) = mpsc::channel();
let receiver = Arc::new(Mutex::new(receiver));
let mut workers = Vec::with_capacity(size);
for id in 0..size {
workers.push(Worker::new(id, Arc::clone(&receiver)));
}
ThreadPool { workers, sender }
}
pub fn execute<F>(&self, f: F)
where
F: FnOnce() + Send + 'static,
{
let job = Box::new(f);
self.sender.send(job).unwrap();
}
}
struct Worker {
id: usize,
thread: thread::JoinHandle<()>,
}
impl Worker {
fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Worker {
let thread = thread::spawn(move || loop {
let job = receiver.lock().unwrap().recv().unwrap();
println!("Worker {} got a job; executing.", id);
job();
});
Worker { id, thread }
}
}
fn main() {
let pool = ThreadPool::new(4);
for i in 0..8 {
pool.execute(move || {
println!("Task {} is running", i);
thread::sleep(std::time::Duration::from_secs(1));
println!("Task {} completed", i);
});
}
thread::sleep(std::time::Duration::from_secs(10));
}
同步原语
Barrier 屏障
rust
use std::sync::{Arc, Barrier};
use std::thread;
fn main() {
let mut handles = Vec::with_capacity(10);
let barrier = Arc::new(Barrier::new(10));
for i in 0..10 {
let c = Arc::clone(&barrier);
handles.push(thread::spawn(move || {
println!("Thread {} is working...", i);
thread::sleep(std::time::Duration::from_millis(i * 100));
println!("Thread {} finished work, waiting for others", i);
c.wait();
println!("Thread {} passed the barrier!", i);
}));
}
for handle in handles {
handle.join().unwrap();
}
}
Condvar 条件变量
rust
use std::sync::{Arc, Condvar, Mutex};
use std::thread;
use std::time::Duration;
fn main() {
let pair = Arc::new((Mutex::new(false), Condvar::new()));
let pair2 = Arc::clone(&pair);
// 等待线程
thread::spawn(move || {
let (lock, cvar) = &*pair2;
let mut started = lock.lock().unwrap();
while !*started {
println!("Waiting for condition...");
started = cvar.wait(started).unwrap();
}
println!("Condition met! Proceeding...");
});
// 通知线程
thread::sleep(Duration::from_secs(2));
let (lock, cvar) = &*pair;
let mut started = lock.lock().unwrap();
*started = true;
cvar.notify_one();
thread::sleep(Duration::from_secs(1));
}
Once 一次性初始化
rust
use std::sync::Once;
use std::thread;
static INIT: Once = Once::new();
static mut VAL: usize = 0;
fn get_value() -> usize {
unsafe {
INIT.call_once(|| {
VAL = expensive_computation();
});
VAL
}
}
fn expensive_computation() -> usize {
println!("Performing expensive computation...");
thread::sleep(std::time::Duration::from_secs(1));
42
}
fn main() {
let mut handles = vec![];
for i in 0..5 {
let handle = thread::spawn(move || {
println!("Thread {} got value: {}", i, get_value());
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
实际应用示例
并发下载器
rust
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
struct DownloadManager {
downloads: Arc<Mutex<Vec<Download>>>,
max_concurrent: usize,
}
#[derive(Debug, Clone)]
struct Download {
id: usize,
url: String,
status: DownloadStatus,
}
#[derive(Debug, Clone)]
enum DownloadStatus {
Pending,
InProgress,
Completed,
Failed,
}
impl DownloadManager {
fn new(max_concurrent: usize) -> Self {
DownloadManager {
downloads: Arc::new(Mutex::new(Vec::new())),
max_concurrent,
}
}
fn add_download(&self, url: String) -> usize {
let mut downloads = self.downloads.lock().unwrap();
let id = downloads.len();
downloads.push(Download {
id,
url,
status: DownloadStatus::Pending,
});
id
}
fn start_downloads(&self) {
let downloads = Arc::clone(&self.downloads);
for _ in 0..self.max_concurrent {
let downloads_clone = Arc::clone(&downloads);
thread::spawn(move || {
loop {
let download = {
let mut downloads = downloads_clone.lock().unwrap();
downloads.iter_mut()
.find(|d| matches!(d.status, DownloadStatus::Pending))
.map(|d| {
d.status = DownloadStatus::InProgress;
d.clone()
})
};
if let Some(download) = download {
println!("Starting download {}: {}", download.id, download.url);
// 模拟下载
thread::sleep(Duration::from_secs(2));
let mut downloads = downloads_clone.lock().unwrap();
if let Some(d) = downloads.iter_mut().find(|d| d.id == download.id) {
d.status = DownloadStatus::Completed;
println!("Completed download {}", download.id);
}
} else {
thread::sleep(Duration::from_millis(100));
}
}
});
}
}
fn get_status(&self) -> Vec<Download> {
self.downloads.lock().unwrap().clone()
}
}
fn main() {
let manager = DownloadManager::new(3);
// 添加下载任务
for i in 0..10 {
manager.add_download(format!("https://example.com/file{}.zip", i));
}
// 开始下载
manager.start_downloads();
// 监控进度
for _ in 0..15 {
thread::sleep(Duration::from_secs(1));
let downloads = manager.get_status();
let completed = downloads.iter().filter(|d| matches!(d.status, DownloadStatus::Completed)).count();
let in_progress = downloads.iter().filter(|d| matches!(d.status, DownloadStatus::InProgress)).count();
let pending = downloads.iter().filter(|d| matches!(d.status, DownloadStatus::Pending)).count();
println!("Progress: {} completed, {} in progress, {} pending", completed, in_progress, pending);
if completed == downloads.len() {
break;
}
}
}
生产者-消费者模式
rust
use std::sync::{Arc, Condvar, Mutex};
use std::thread;
use std::time::Duration;
use std::collections::VecDeque;
struct Buffer<T> {
queue: Mutex<VecDeque<T>>,
not_empty: Condvar,
not_full: Condvar,
capacity: usize,
}
impl<T> Buffer<T> {
fn new(capacity: usize) -> Self {
Buffer {
queue: Mutex::new(VecDeque::new()),
not_empty: Condvar::new(),
not_full: Condvar::new(),
capacity,
}
}
fn put(&self, item: T) {
let mut queue = self.queue.lock().unwrap();
while queue.len() == self.capacity {
queue = self.not_full.wait(queue).unwrap();
}
queue.push_back(item);
self.not_empty.notify_one();
}
fn take(&self) -> T {
let mut queue = self.queue.lock().unwrap();
while queue.is_empty() {
queue = self.not_empty.wait(queue).unwrap();
}
let item = queue.pop_front().unwrap();
self.not_full.notify_one();
item
}
}
fn main() {
let buffer = Arc::new(Buffer::new(5));
let mut handles = vec![];
// 生产者
for i in 0..3 {
let buffer = Arc::clone(&buffer);
let handle = thread::spawn(move || {
for j in 0..5 {
let item = format!("Producer {} - Item {}", i, j);
println!("Producing: {}", item);
buffer.put(item);
thread::sleep(Duration::from_millis(100));
}
});
handles.push(handle);
}
// 消费者
for i in 0..2 {
let buffer = Arc::clone(&buffer);
let handle = thread::spawn(move || {
for _ in 0..7 {
let item = buffer.take();
println!("Consumer {} consumed: {}", i, item);
thread::sleep(Duration::from_millis(150));
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
性能考虑
避免过度同步
rust
use std::sync::{Arc, Mutex};
use std::thread;
fn main() {
let data = Arc::new(Mutex::new(0));
let mut handles = vec![];
// 不好的做法:频繁加锁
for _ in 0..10 {
let data = Arc::clone(&data);
let handle = thread::spawn(move || {
for _ in 0..1000 {
let mut num = data.lock().unwrap();
*num += 1;
// 锁持有时间过长
}
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
println!("Bad approach result: {}", *data.lock().unwrap());
// 更好的做法:批量操作
let data2 = Arc::new(Mutex::new(0));
let mut handles2 = vec![];
for _ in 0..10 {
let data = Arc::clone(&data2);
let handle = thread::spawn(move || {
let mut local_sum = 0;
for _ in 0..1000 {
local_sum += 1;
}
let mut num = data.lock().unwrap();
*num += local_sum;
});
handles2.push(handle);
}
for handle in handles2 {
handle.join().unwrap();
}
println!("Better approach result: {}", *data2.lock().unwrap());
}
练习
练习 1:并发计数器
实现一个线程安全的计数器,支持多线程同时增减。
练习 2:任务队列
创建一个任务队列系统,支持多个工作线程处理任务。
练习 3:缓存系统
实现一个线程安全的 LRU 缓存。
练习 4:并发哈希表
设计一个支持并发读写的哈希表。
下一步
掌握了并发编程后,您可以继续学习:
并发编程是构建高性能 Rust 应用的重要技能!