CyclicBarrier是一个同步辅助类,允许一组线程互相等待,直到到达某个公共屏障点 (common barrier point)。因为该 barrier 在释放等待线程后可以重用,所以称它为循环 的 barrier。
简单示例:
public class CyclicBarrierLearn {
private static int SIZE = 5;
private static CyclicBarrier cb;
public static void main(String[] args) {
cb = new CyclicBarrier(SIZE);
// 新建5个任务
for(int i=0; i<SIZE; i++)
new InnerThread().start();
}
static class InnerThread extends Thread{
public void run() {
try {
System.out.println(Thread.currentThread().getName() + " wait for CyclicBarrier.");
// 将cb的参与者数量加1
cb.await();
// cb的参与者数量等于5时,才继续往后执行
System.out.println(Thread.currentThread().getName() + " continued.");
}catch(Exception e){
e.printStackTrace();
}
}
}
}
运行结果:当所有线程都准备好之后再执行
接下来我们看一下CyclicBarrier的实现机制。
首先在CyclicBarrier中没有设置volatile修饰的state锁标识位,而是定义了一个普通int变量count,但是对count的修改操作是通过ReentrantLock独占锁来实现的,和我们经常见到的volatile修改的state的功能类似。线程的阻塞与唤起是使用Condition完成的,参考博客并发编程--并发编程包Condition条件
wait函数
public int await() throws InterruptedException, BrokenBarrierException {
try {
return dowait(false, 0L);
} catch (TimeoutException toe) {
}
}
dowait中完成了对count线程个数值的修改,当前线程的阻塞,和当count值为0时对阻塞线程的唤起操作,这样所有的线程就都可以并发运行了。
首先CyclicBarrier在初始化时会初始化一个Count值,每次执行await方法时会对count进行减一操作,当count不为0时表示线程还没有达到count值,此时线程会调用Condition.await操作将当前线程阻塞,当count值等于0时,首先会执行当前线程,执行完当前线程时会在generation方法中调用Condition.signalAll函数完成对所有阻塞线程的唤起,这样就完成了线程屏障的作用。
private int dowait(boolean timed, long nanos)
throws InterruptedException, BrokenBarrierException,
TimeoutException {
final ReentrantLock lock = this.lock;
//使用独占锁
lock.lock();
try {
final Generation g = generation;
if (g.broken)
throw new BrokenBarrierException();
if (Thread.interrupted()) {
breakBarrier();
throw new InterruptedException();
}
//将count值进行减减操作,独占锁下完成
int index = --count;
if (index == 0) { // tripped
//当index等于0时,表明所有的线程都已经准备好,接下来可以执行线程了
boolean ranAction = false;
try {
final Runnable command = barrierCommand;
if (command != null)
//执行线程
command.run();
ranAction = true;
//唤起其他线程开始执行
nextGeneration();
return 0;
} finally {
if (!ranAction)
breakBarrier();
}
}
//当index不等于0时
for (;;) {
try {
if (!timed)
//将当前线程阻塞
trip.await();
else if (nanos > 0L)
nanos = trip.awaitNanos(nanos);
} catch (InterruptedException ie) {
if (g == generation && ! g.broken) {
breakBarrier();
throw ie;
} else {
Thread.currentThread().interrupt();
}
}
if (g.broken)
throw new BrokenBarrierException();
if (g != generation)
return index;
if (timed && nanos <= 0L) {
breakBarrier();
throw new TimeoutException();
}
}
} finally {
lock.unlock();
}
}
generation函数中的操作是调用Condition.signalAll函数唤起所有阻塞线程,并将count值恢复到初始值
private void nextGeneration() {
//唤起所有阻塞线程
trip.signalAll();
//将count值恢复到初始值
count = parties;
generation = new Generation();
}
CyclicBarrier源码:
public class CyclicBarrier {
private static class Generation {
boolean broken = false;
}
private final ReentrantLock lock = new ReentrantLock();
private final Condition trip = lock.newCondition();
/** The number of parties */
private final int parties;
/* The command to run when tripped */
private final Runnable barrierCommand;
/** The current generation */
private Generation generation = new Generation();
//设置count值,每次对count进行减减操作
private int count;
private void nextGeneration() {
//唤起所有阻塞线程,使用Condition的signalAll方法
trip.signalAll();
count = parties; //将count值恢复到初始值
generation = new Generation();
}
//将count重新设置为初始化值
private void breakBarrier() {
generation.broken = true;
count = parties;
trip.signalAll();
}
private int dowait(boolean timed, long nanos)
throws InterruptedException, BrokenBarrierException,
TimeoutException {
final ReentrantLock lock = this.lock;
//互斥锁进行操作
lock.lock();
try {
final Generation g = generation;
if (g.broken)
throw new BrokenBarrierException();
if (Thread.interrupted()) {
breakBarrier();
throw new InterruptedException();
}
int index = --count;
//当index等于0时唤起所以阻塞的线程
if (index == 0) { // tripped
boolean ranAction = false;
try {
final Runnable command = barrierCommand;
if (command != null)
command.run();
ranAction = true;
//唤起所有阻塞线程
nextGeneration();
return 0;
} finally {
if (!ranAction)
breakBarrier();
}
}
// loop until tripped, broken, interrupted, or timed out
//当index大于0时,将线程进行阻塞
for (;;) {
try {
if (!timed)
//使用Condition的await方法
trip.await();
else if (nanos > 0L)
nanos = trip.awaitNanos(nanos);
} catch (InterruptedException ie) {
if (g == generation && ! g.broken) {
breakBarrier();
throw ie;
} else {
// We're about to finish waiting even if we had not
// been interrupted, so this interrupt is deemed to
// "belong" to subsequent execution.
Thread.currentThread().interrupt();
}
}
if (g.broken)
throw new BrokenBarrierException();
if (g != generation)
return index;
if (timed && nanos <= 0L) {
breakBarrier();
throw new TimeoutException();
}
}
} finally {
lock.unlock();
}
}
public CyclicBarrier(int parties, Runnable barrierAction) {
if (parties <= 0) throw new IllegalArgumentException();
this.parties = parties;
this.count = parties;
this.barrierCommand = barrierAction;
}
public CyclicBarrier(int parties) {
this(parties, null);
}
public int getParties() {
return parties;
}
//阻塞线程
public int await() throws InterruptedException, BrokenBarrierException {
try {
return dowait(false, 0L);
} catch (TimeoutException toe) {
throw new Error(toe); // cannot happen
}
}
public int await(long timeout, TimeUnit unit)
throws InterruptedException,
BrokenBarrierException,
TimeoutException {
return dowait(true, unit.toNanos(timeout));
}
public boolean isBroken() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
return generation.broken;
} finally {
lock.unlock();
}
}
//重置,重置后count值为构造函数的初始化值,可以继续使用CyclicBarrier
public void reset() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
breakBarrier(); // break the current generation
nextGeneration(); // start a new generation
} finally {
lock.unlock();
}
}
//获取阻塞线程的数目
public int getNumberWaiting() {
final ReentrantLock lock = this.lock;
lock.lock();
try {
return parties - count;
} finally {
lock.unlock();
}
}
}