CyclicBarrier的使用与源码分析

CyclicBarrier俗称栅栏,它可以让多个线程之间互相等待,直到所有线程都到达同一个同步点,然后再继续一起执行。

CyclicBarrier默认的构造方法是CyclicBarrier(int parties),其参数parties表示屏障拦截的线程数量,每个线程都会调用await方法告诉CyclicBarrier我已经到达了屏障,然后当前线程被阻塞,当所有线程都到达了屏障,所有线程都被唤醒然后继续往下执行。

CyclicBarrier的常用方法

方法名

说明

CyclicBarrier(int parties)

创建一个给定parties数量的CyclicBarrier,parties参数表示屏障拦截的线程数量

CyclicBarrier(int parties, Runnable barrierAction)

当所有线程都到达了屏障后执行给定的屏障动作,由最后一个进入屏障的线程执行

int await()

等待所有线程到达屏蔽后执行

int await(long timeout, TimeUnit unit)

等待所有线程到达屏蔽后执行或超时

int getNumberWaiting()

返回正在等待的线程个数

int getParties()

返回屏障拦截的线程数量,也就是构造方法传入的参数

boolean isBroken()

查询这个障碍是否处于中断状态

void reset()

将屏障重置为初始状态,这样这个栅栏又能被重复使用

CyclicBarrier栅栏的使用

场景:三个朋友约好到公园集合,一起出去逛街。

package com.morris.concurrent.tool.cyclicbarrier.api;

import lombok.extern.slf4j.Slf4j;

import java.util.Random;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.TimeUnit;

/**
 * 演示CyclicBarrier栅栏的使用
 * 场景:三个朋友约好到公园集合,一起出去逛街
 */
@Slf4j
public class CyclicBarrierDemo {

    private static CyclicBarrier cyclicBarrier = new CyclicBarrier(3, () -> {
        log.info("all is arrive");
    });

    public static void main(String[] args) {
        new Thread(CyclicBarrierDemo::go, "ZhangSan").start();
        new Thread(CyclicBarrierDemo::go, "LiSi").start();
        new Thread(CyclicBarrierDemo::go, "WangWu").start();
        Thread.yield();
    }

    private static void go() {
        try {
            log.info("start go to park");
            TimeUnit.SECONDS.sleep(new Random(System.currentTimeMillis()).nextInt(30));
            log.info("arrive at park");
            log.info("当前正在等待的人数:" + cyclicBarrier.getNumberWaiting());
            cyclicBarrier.await();
            log.info("start go shopping...");
        } catch (InterruptedException | BrokenBarrierException e) {
            e.printStackTrace();
        }
    }

}

运行结果如下:

2020-09-24 17:04:09,074  INFO [ZhangSan] (CyclicBarrierDemo.java:30) - start go to park
2020-09-24 17:04:09,074  INFO [LiSi] (CyclicBarrierDemo.java:30) - start go to park
2020-09-24 17:04:09,074  INFO [WangWu] (CyclicBarrierDemo.java:30) - start go to park
2020-09-24 17:04:34,077  INFO [ZhangSan] (CyclicBarrierDemo.java:32) - arrive at park
2020-09-24 17:04:34,077  INFO [ZhangSan] (CyclicBarrierDemo.java:33) - 当前正在等待的人数:0
2020-09-24 17:04:34,077  INFO [WangWu] (CyclicBarrierDemo.java:32) - arrive at park
2020-09-24 17:04:34,079  INFO [WangWu] (CyclicBarrierDemo.java:33) - 当前正在等待的人数:1
2020-09-24 17:04:34,077  INFO [LiSi] (CyclicBarrierDemo.java:32) - arrive at park
2020-09-24 17:04:34,079  INFO [LiSi] (CyclicBarrierDemo.java:33) - 当前正在等待的人数:2
2020-09-24 17:04:34,079  INFO [LiSi] (CyclicBarrierDemo.java:18) - all is arrive
2020-09-24 17:04:34,079  INFO [LiSi] (CyclicBarrierDemo.java:35) - start go shopping...
2020-09-24 17:04:34,079  INFO [ZhangSan] (CyclicBarrierDemo.java:35) - start go shopping...
2020-09-24 17:04:34,080  INFO [WangWu] (CyclicBarrierDemo.java:35) - start go shopping...

CountDownLatch与CyclicBarrier的区别:

  • CountDownLatch:一个或者多个线程,等待其他多个线程完成某件事情之后才能执行;调用countDown()方法计数减一,计算为0时释放所有等待的线程,无法重置。
  • CyclicBarrier:多个线程互相等待,直到到达同一个同步点,再继续一起执行。计数达到指定值时释放所有等待线程,计数置为0重新开始,可重复利用。

源码分析

CyclicBarrier底层使用了ReentrantLock。

数据结构

private static class Generation {
    boolean broken = false;
}

private final ReentrantLock lock = new ReentrantLock(); // 可重入锁

private final Condition trip = lock.newCondition(); // 条件队列

private final int parties; // 屏障数,也就是参与的线程数

private final Runnable barrierCommand; // 到达屏障由最后一个进入屏障的线程执行的操作

private Generation generation = new Generation(); // barrier每一次使用都代表了一个generation实例

private int count; // 正在等待进入屏障的线程数量

await()

public int await() throws InterruptedException, BrokenBarrierException {
    try {
        return dowait(false, 0L);
    } catch (TimeoutException toe) {
        throw new Error(toe); // cannot happen
    }
}

private int dowait(boolean timed, long nanos)
    throws InterruptedException, BrokenBarrierException,
           TimeoutException {
    final ReentrantLock lock = this.lock;
    lock.lock(); // 持有独占锁
    try {
        final Generation g = generation;

        if (g.broken) // 屏障被破坏,抛出异常
            throw new BrokenBarrierException();

        if (Thread.interrupted()) { // 线程被中断
            breakBarrier(); // 唤醒所有的线程
            throw new InterruptedException(); // 抛出中断异常
        }

        int index = --count; // 减少正在等待进入屏障的线程数量
        if (index == 0) {  // 所有线程已到达屏障
            boolean ranAction = false;
            try {
                final Runnable command = barrierCommand;
                if (command != null)
                    command.run(); // 最后一个线程执行
                ranAction = true;
                nextGeneration(); // 唤醒所有等待的线程,重置generation
                return 0;
            } finally {
                if (!ranAction) // barrierCommand执行失败
                    breakBarrier(); // 损坏当前屏障
            }
        }

        // 无限循环
        for (;;) {
            try {
                if (!timed) // 没有设置等待时间
                    // 等待
                    trip.await(); 
                else if (nanos > 0L) // 设置了等待时间,并且等待时间大于0
                    // 等待指定时长
                    nanos = trip.awaitNanos(nanos);
            } catch (InterruptedException ie) { 
                if (g == generation && ! g.broken) { // 等于当前代并且屏障没有被损坏
                    // 损坏当前屏障
                    breakBarrier();
                    // 抛出异常
                    throw ie;
                } else { // 不等于当前代或者屏障被损坏
                    // 中断当前线程
                    Thread.currentThread().interrupt();
                }
            }

            if (g.broken) // 屏障被损坏,抛出异常
                throw new BrokenBarrierException();

            if (g != generation) // 不等于当前代
                // 返回索引
                return index;

            if (timed && nanos <= 0L) { // 设置了等待时间,并且等待时间小于0
                // 损坏屏障
                breakBarrier();
                // 抛出异常
                throw new TimeoutException();
            }
        }
    } finally {
        // 释放锁
        lock.unlock();
    }
}

// 唤醒等待队列中所有的线程
private void breakBarrier() {
    generation.broken = true;
    count = parties;
    trip.signalAll(); // 唤醒
}

自定义CyclicBarrier11

package com.morris.concurrent.tool.cyclicbarrier.my;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

/**
 * 使用wait-notify实现CyclicBarrier
 */
public class WaitNotifyCyclicBarrier {

    private final int parties; // 屏障数

    private volatile int numberWaiting; // 当前等待的线程数

    private List<Thread> waitingThreads;

    private boolean broken; // 是否被打断

    private Runnable runnable;

    public WaitNotifyCyclicBarrier(int parties) {
        this(parties, null);
    }

    public WaitNotifyCyclicBarrier(int parties, Runnable runnable) {
        this.parties = parties;
        waitingThreads = new ArrayList<>(parties - 1);
        this.runnable = runnable;
    }

    public void await() throws InterruptedException, BrokenBarrierException {
        try {
            doAwait(false, 0L);
        } catch (TimeoutException e) {
            throw new Error(); // not happen
        }
    }

    public void await(long timeout, TimeUnit unit) throws InterruptedException, BrokenBarrierException, TimeoutException {
        doAwait(true, unit.toMillis(timeout));
    }

    private synchronized void doAwait(boolean timed, long mills) throws InterruptedException, BrokenBarrierException, TimeoutException {
        long beginTime = System.currentTimeMillis();
        long waitTime = 0; // 记录已经等了多久
        ++numberWaiting;
        while (numberWaiting < parties) {

            if (broken) {
                // 这里有可能第一次进来就broken
                throw new BrokenBarrierException();
            }

            if (Thread.interrupted()) {
                breakBarrier(); // 这里有可能第一次进来就中断了
                throw new InterruptedException();
            }

            try {
                if (timed) { // 超时处理

                    if (waitTime > mills) {
                        breakBarrier();
                        throw new TimeoutException();
                    }

                    waitingThreads.add(Thread.currentThread());
                    this.wait(mills - waitTime);
                    // 超时醒来再试一次
                    waitTime = System.currentTimeMillis() - beginTime;
                } else {
                    waitingThreads.add(Thread.currentThread());
                    this.wait();

                }
            } catch (InterruptedException e) {
                if (broken) {
                    // 其他线程被中断,导致本线程被broken
                    throw new BrokenBarrierException();
                } else {
                    // 只要有一个线程被中断了,所有等待的线程都会被唤醒,然后中断,要么都往下执行,要么都中断,看做一个整体
                    breakBarrier();
                    throw new InterruptedException();
                }
            }
        }

        boolean ranAction = false;
        try {
            if (runnable != null) {
                runnable.run();
            }
            ranAction = true;
        } finally {
            if (!ranAction) {
                breakBarrier();
            }
        }

        this.notifyAll();
    }

    private void breakBarrier() {
        waitingThreads.forEach(Thread::interrupt);
        numberWaiting = 0;
        broken = true;
    }

    public synchronized void reset() {
        if (0 != numberWaiting) {
            waitingThreads.forEach(Thread::interrupt);
            numberWaiting = 0;
        }
        broken = false;
    }

    public int getParties() {
        return parties;
    }

    public int getNumberWaiting() {
        return numberWaiting;
    }

    public boolean isBroken() {
        return broken;
    }
}