简介

在日常的开发中,可能会遇到这样的场景:开启多个子线程执行一些耗时任务,然后在主线程汇总,在子线程执行的过程中,主线程保持阻塞状态直到子线程完成任务。

使用CountDownLatch类或者Thread类的join()方法都能实现这一点,下面通过例子来介绍这两种实现方法。

CountDownLatch的使用

一个小例子,等待所有玩家准备就绪,然后游戏才开始。

使用join方法实现:

public class Demo {
    public static void main(String[] args) throws InterruptedException {
        Runnable runnable = () -> {
            System.out.println(Thread.currentThread().getName() + ":准备就绪");
        };

        Thread thread1 = new Thread(runnable, "一号玩家");
        Thread thread2 = new Thread(runnable, "二号玩家");
        Thread thread3 = new Thread(runnable, "三号玩家");
        Thread thread4 = new Thread(runnable, "四号玩家");
        Thread thread5 = new Thread(runnable, "五号玩家");
        thread1.start();
        thread2.start();
        thread3.start();
        thread4.start();
        thread5.start();

	//主线程等待子线程执行完成再执行
        thread1.join();
        thread2.join();
        thread3.join();
        thread4.join();
        thread5.join();

        System.out.println("---游戏开始---");
    }
}

/*
 * 输出结果:
 * 二号玩家:准备就绪
 * 五号玩家:准备就绪
 * 四号玩家:准备就绪
 * 三号玩家:准备就绪
 * 一号玩家:准备就绪
 * ---游戏开始---
 */

使用CountDownLatch实现:

public class Demo {
    public static void main(String[] args) throws InterruptedException {
        //创建计数器初始值为5的CountDownLatch
        CountDownLatch countDownLatch = new CountDownLatch(5);

        Runnable runnable = () -> {
            try{
                System.out.println(Thread.currentThread().getName() + ":准备就绪");
            }catch (Exception ex){
                ex.printStackTrace();
            }finally {
                //计数器值减一
                countDownLatch.countDown();
            }
        };

        Thread thread1 = new Thread(runnable, "一号玩家");
        Thread thread2 = new Thread(runnable, "二号玩家");
        Thread thread3 = new Thread(runnable, "三号玩家");
        Thread thread4 = new Thread(runnable, "四号玩家");
        Thread thread5 = new Thread(runnable, "五号玩家");
        thread1.start();
        thread2.start();
        thread3.start();
        thread4.start();
        thread5.start();

        //等待计数器值为0
        countDownLatch.await();
        System.out.println("---游戏开始---");
    }
}

/*
 * 输出结果:
 * 四号玩家:准备就绪
 * 五号玩家:准备就绪
 * 一号玩家:准备就绪
 * 三号玩家:准备就绪
 * 二号玩家:准备就绪
 * ---游戏开始---
 */

CountDownLatch内部包含一个计数器,计数器的初始值为CountDownLatch构造函数传入的int类型的参数,countDown方法会递减计数器值,await方法会阻塞当前线程直到计数器值为0。

两种方式的区别:

当调用子线程的join方法时,会阻塞当前线程直到子线程结束。而CountDownLatch相对比较灵活,无需等到子线程结束,只要计数器值为0,await方法就会返回。

CountDownLatch源码

CountDownLatch源码:

public class CountDownLatch {
    /**
     * CountDownLatch的同步控制,使用AQS的状态值作为计数器值。
     */
    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;

        Sync(int count) {
            setState(count);
        }

        int getCount() {
            return getState();
        }

        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }

        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }

    private final Sync sync;

    /**
     * 构造函数,初始化计数器
     */
    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

    /**
     * 阻塞当前线程直到计数器值为0
     */
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    /**
     * 阻塞当前线程直到计数器值为0或者超时
     */
    public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

    /**
     * 递减计数器值,当计数器值为0时,释放所有等待的线程。
     */
    public void countDown() {
        sync.releaseShared(1);
    }

    /**
     * 返回当前计数器值
     */
    public long getCount() {
        return sync.getCount();
    }

    public String toString() {
        return super.toString() + "[Count = " + sync.getCount() + "]";
    }
}

通过源码可以看出,CountDownLatch内部是使用AQS实现的,它使用AQS的状态变量state作为计数器值,静态内部类Sync继承了AQS并实现了tryAcquireShared和tryReleaseShared方法。

接下来重点看下await()和countDown()的源码:

await()方法内部调用的是AQS的acquireSharedInterruptibly方法,会将当前线程放入AQS队列等待,直到计数值为0。

public final void acquireSharedInterruptibly(int arg) throws InterruptedException {
	//判断当前线程是否被中断,如果线程被中断则抛出异常
	if (Thread.interrupted())
		throw new InterruptedException();
	//判断计数器值是否为0,为0则直接返回,否则进AQS队列进行等待。
	if (tryAcquireShared(arg) < 0)
		doAcquireSharedInterruptibly(arg);
}

//CountDownLatch中Sync的tryAcquireShared方法实现,直接判断计数器值是否为0。
protected int tryAcquireShared(int acquires) {
	return (getState() == 0) ? 1 : -1;
}

countDown()方法内部调用的是AQS的releaseShared方法,每次调用都会递减计数值,直到计数值为0则调用AQS释放资源的方法。

public final boolean releaseShared(int arg) {
	if (tryReleaseShared(arg)) {
		//释放资源
		doReleaseShared();
		return true;
	}
	return false;
}

//CountDownLatch中Sync的tryReleaseShared方法实现
protected boolean tryReleaseShared(int releases) {
	for (;;) {
		int c = getState();
		//计数值为0直接返回
		if (c == 0)
			return false;
		//设置递减后的计数值
		int nextc = c-1;
		if (compareAndSetState(c, nextc))
			return nextc == 0;
	}
}