等待多个并发事件

Java 并发 API 提供了一个类 CountDownLatch,它可以使多个线程等待直到一组操作完成。实例化这个类需要一个整型参数,该参数代表了线程希望等待的操作的个数。当一个线程希望等待这些操作执行完成时,可以调用 CountDownLatch 对象的 await(), 这个方法会将调用线程休眠,直到所等待的操作全部结束。而当一个操作结束时,应该调用 CountDownLatch 对象的 countDown()方法,该方法会将 CounDownLatch 对象 内部的属性计数器的值减 1,表示一个操作的完成。当 CountDownLatch 对象内部的计数器的值为 0 时,表示所有操作都完成了。这时,CountDownLatch 对象将唤起所有因 调用其 await()方法而休眠的线程。

源码分析

构造方法


public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}


构造方法需要传入一个count,也就是初始次数。

Sync类


private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;

// 构造方法,初始化同步状态为 count
Sync(int count) {
setState(count);
}

// 获取当前同步状态的值,即计数器的值
int getCount() {
return getState();
}

// 尝试获取共享锁,成功返回 1,失败返回 -1
protected int tryAcquireShared(int acquires) {
// 如果同步状态为 0,表示所有线程都已经完成工作,等待线程可以获取锁并继续执行
return (getState() == 0) ? 1 : -1;
}

// 释放共享锁,将同步状态减 1,并在计数器变为 0 时唤醒等待线程
protected boolean tryReleaseShared(int releases) {
for (;;) {
int c = getState();
// 如果计数器已经为 0,无法继续减少,返回 false
if (c == 0)
return false;
int nextc = c - 1;
// 使用 compareAndSetState 方法原子性地修改同步状态的值
if (compareAndSetState(c, nextc))
// 如果计数器的值已经为 0,返回 true,唤醒所有等待线程
return nextc == 0;
}
}
}


Sync 类的实现是 CountDownLatch 类实现线程同步的关键所在,它通过管理计数器来实现线程的等待和唤醒。 Sync类继承了 AbstractQueuedSynchronizer 类,通过重写 tryAcquireShared 和 tryReleaseShared 方法来实现线程同步。

Sync 类中的 setState 方法用于初始化同步状态,getState 方法用于获取同步状态的值,即计数器的值。tryAcquireShared 方法用于尝试获取共享锁。tryReleaseShared 方法用于释放共享锁,将计数器的值减 1,并在计数器变为 0 时唤醒所有等待线程。

在 tryReleaseShared 方法中,使用 compareAndSetState 方法原子性地修改同步状态的值,保证计数器的减少是线程安全的。compareAndSetState 方法是 AbstractQueuedSynchronizer 类的一个原子性操作,用于判断当前同步状态的值是否等于预期值,如果相等则将同步状态的值设置为新值,返回 true;否则返回 false。

await()方法

public void await() throws InterruptedException {
// 调用 Sync 类的 acquireSharedInterruptibly 方法等待共享锁
sync.acquireSharedInterruptibly(1);
}

await() 方法是 CountDownLatch 类的一个阻塞方法,用于让当前线程等待计数器的值变为 0,才能继续往下执行。

在方法实现中,await() 方法调用 Sync 类的 acquireSharedInterruptibly 方法来等待共享锁。这个方法会尝试获取共享锁,如果共享锁不可用,则当前线程会被阻塞。

acquireSharedInterruptibly 方法会不断地调用 tryAcquireShared 方法来尝试获取共享锁,直到获取成功或者线程被中断。如果线程被中断,则会抛出 InterruptedException 异常。

在这段代码中,方法调用的参数为 1,表示要获取 1 个共享锁,这也就意味着线程需要等待计数器的值变为 0 才能继续执行。

acquireSharedInterruptibly(int arg)方法

public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
// 如果线程已被中断,抛出 InterruptedException 异常
if (Thread.interrupted())
throw new InterruptedException();
// 如果获取共享锁失败,则阻塞当前线程并等待共享锁可用
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}

acquireSharedInterruptibly() 方法是 AQS(AbstractQueuedSynchronizer)类中的一个阻塞方法,用于在获取共享锁时阻塞当前线程,并等待共享锁可用。如果线程被中断,则会抛出 InterruptedException 异常。

参数 arg 表示要获取的共享锁数量。

tryAcquireShared() 方法尝试获取共享锁。如果获取失败,则调用 doAcquireSharedInterruptibly() 方法阻塞当前线程,并等待共享锁可用。

doAcquireSharedInterruptibly() 方法会调用 acquireSharedInterruptibly() 方法中定义的 tryAcquireShared() 方法来尝试获取共享锁。如果获取失败,则会把当前线程加入到等待共享锁的队列中,然后阻塞当前线程并等待共享锁可用。如果线程被中断,则会抛出 InterruptedException 异常。

doAcquireSharedInterruptibly(int arg)方法

/**
* Acquires in shared interruptible mode.
* @param arg the acquire argument
*/
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
// 将当前线程封装成 Node 节点并加入等待队列
final Node node = addWaiter(Node.SHARED);
// 标志变量,用于判断等待过程是否出现异常
boolean failed = true;
try {
// 循环尝试获取共享锁
for (;;) {
// 获取当前节点的前驱节点
final Node p = node.predecessor();
// 如果前驱节点是头节点,说明当前节点可以尝试获取共享锁
if (p == head) {
// 尝试获取共享锁
int r = tryAcquireShared(arg);
// 如果获取成功,设置当前节点为头节点并唤醒后继节点
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
// 如果当前节点无法获取共享锁,则进行线程等待
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
// 如果线程被中断,抛出 InterruptedException 异常
throw new InterruptedException();
}
} finally {
// 如果等待过程出现异常,将当前节点从等待队列中移除
if (failed)
cancelAcquire(node);
}
}

doAcquireSharedInterruptibly() 方法是 AQS(AbstractQueuedSynchronizer)类中的一个私有方法,用于实现在共享模式下的可中断等待。如果共享锁不可用,则当前线程会被阻塞,并等待共享锁可用。

首先会将当前线程封装成 Node 节点并加入等待队列中。然后,在一个无限循环中尝试获取共享锁。

循环中,先获取当前节点的前驱节点,如果前驱节点是头节点,则说明当前节点可以尝试获取共享锁。调用 tryAcquireShared() 方法尝试获取共享锁,如果获取成功,则将当前节点设置为头节点并唤醒后继节点,然后返回。

如果当前节点无法获取共享锁,则调用 shouldParkAfterFailedAcquire() 方法判断是否需要将当前线程加入到等待队列中,如果需要,则调用 parkAndCheckInterrupt() 方法阻塞当前线程并等待共享锁可用。如果线程被中断,则会抛出 InterruptedException 异常。

如果等待过程出现异常,会将当前节点从等待队列中移除。

参数 arg 表示要获取的共享锁数量。

countDown()方法

// java.util.concurrent.CountDownLatch.countDown()
public void countDown() {
// 调用AQS的释放共享锁方法
sync.releaseShared(1);
}
// java.util.concurrent.locks.AbstractQueuedSynchronizer.releaseShared()
public final boolean releaseShared(int arg) {
// 尝试释放共享锁,如果成功了,就唤醒排队的线程
if (tryReleaseShared(arg)) {
doReleaseShared();
return true;
}
return false;
}

countDown()方法,会释放一个共享锁,也就是count的次数会减1。

根据上面Sync的源码,我们知道,tryReleaseShared()每次会把count的次数减1,当其减为0的时候返回true,这时候才会唤醒等待的线程。

使用案例

这里我们模拟一个视频会议场景,等待5个议会者全部到达后,开始视频会议,大致流程如下。

【线程同步工具】CountDownLatch源码解析_共享锁

Videoconference类

/**
* 这个类实现了视频会议的控制器
*
* 它使用CountDownLatch来控制所有参与者的到达。
*/
public class Videoconference implements Runnable{

/**
* 这个类使用CountDownLatch来控制所有参与者的到达。
*/
private final CountDownLatch controller;

/**
* 类的构造函数。初始化CountDownLatch。
* @param number 视频会议中的参与者数
*/
public Videoconference(int number) {
controller=new CountDownLatch(number);
}

/**
* 每个参与者加入视频会议时都会调用此方法
* @param name 参与者的名称
*/
public void arrive(String name){
System.out.printf("%s 已到达。\n",name);
// 这个方法使用countDown方法来递减CountDownLatch的内部计数器
controller.countDown();
System.out.printf("VideoConference: 还需要等待 %d 个参与者。\n",controller.getCount());
}

/**
* 这是视频会议控制器的主方法。它等待所有参与者,然后开始会议。
*/
@Override
public void run() {
System.out.printf("VideoConference: 初始化: %d 个参与者。\n",controller.getCount());
try {
// 等待所有参与者
System.out.println("等待所有参与者: " + Thread.currentThread().getName());
controller.await();
// 开始会议
System.out.printf("VideoConference: 所有参与者都已到齐。\n");
System.out.printf("VideoConference: 让我们开始...\n");
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}

该类实现了视频会议的控制器,使用CountDownLatch来控制所有参与者的到达。类中定义了一个私有的CountDownLatch对象controller,用于计数参与者的到达。类中的构造函数初始化了CountDownLatch,

而arrive方法用于每个参与者加入视频会议时调用,以减少CountDownLatch的内部计数器。run方法是VideoConference类的主要方法,等待所有参与者到达,然后开始会议。

Participant类

/**
* 这个类实现了 VideoConference 中的参与者
*
*/
public class Participant implements Runnable {

/**
* 参与者将要参加的 VideoConference
*/

private Videoconference conference;

/**
* 参与者的名称,仅用于日志记录
*/
private String name;

/**
* 类的构造方法。初始化其属性
* @param conference 参与者将要参加的 VideoConference
* @param name 参与者的名称
*/
public Participant(Videoconference conference, String name) {
this.conference=conference;
this.name=name;
}

/**
* 参与者的核心方法。等待一段随机时间后加入 VideoConference
*/
@Override
public void run() {
Long duration=(long)(Math.random()*10);
try {
TimeUnit.SECONDS.sleep(duration);
} catch (InterruptedException e) {
e.printStackTrace();
}
conference.arrive(name);

}
}

这个类有两个属性:一个是参与的 Videoconference 类对象,另一个是这个参与者的名称。它有一个构造函数,接受 Videoconference 对象和名称作为参数,并使用它们来初始化对象的属性。在它的 run() 方法中,参与者会等待一个随机的时间,然后加入 VideoConference,通过调用 VideoConference 类中的 arrive() 方法实现。这个方法会向控制器对象发出信号,表示有一个参与者已经加入了 VideoConference。

Main类

/**
* 示例的主类。创建、初始化和执行所有示例所需的对象。
*/
public class Main {

/**
* 示例的主方法
* @param args
*/
public static void main(String[] args) {

// 创建一个具有5个参与者的VideoConference
Videoconference conference = new Videoconference(5);

// 创建一个线程来运行VideoConference并启动它
Thread threadConference = new Thread(conference);
threadConference.start();

// 创建十个参与者,为每个参与者创建一个线程并启动它们
for (int i = 0; i < 5; i++) {
Participant p = new Participant(conference, "Participant " + i);
Thread t = new Thread(p);
t.start();
}
}
}

在主方法中,首先创建了一个具有5个参与者的Videoconference对象,并使用它创建了一个新线程Thread(名为threadConference),并将其作为Videoconference对象的参数来启动会议。然后,通过循环创建了5个Participant对象,每个对象都有一个不同的名称("Participant 0"、"Participant 1" 等等)。对于每个参与者,都创建了一个新线程,并使用Participant对象作为参数来启动该线程。

其他说明

CountDownLatch与Thread.join()有何不同?

答:Thread.join()是在主线程中调用的,它只能等待被调用的线程结束了才会通知主线程,而CountDownLatch则不同,它的countDown()方法可以在线程执行的任意时刻调用,灵活性更大。

作者简介

鑫茂,深圳,Java开发工程师。

喜读思维方法、哲学心理学以及历史等方面的书,偶尔写些文字。

希望通过文章,结识更多同道中人。