同步组件CountDownLatch源码解析

共 6942字,需浏览 14分钟

 ·

2023-06-25 22:37

走过路过不要错过

点击蓝字关注我们

 

CountDownLatch概述

日常开发中,经常会遇到类似场景:主线程开启多个子线程执行任务,需要等待所有子线程执行完毕后再进行汇总。

在同步组件CountDownLatch出现之前,我们可以使用join方法来完成,简单实现如下:

public class JoinTest {
public static void main(String[] args) throws InterruptedException {
Thread A = new Thread(() -> {
try {
Thread.sleep(1000);
System.out.println("A finish!");
} catch (InterruptedException e) {
e.printStackTrace();
}
});
Thread B = new Thread(() -> {
try {
Thread.sleep(1000);
System.out.println("B finish!");

} catch (InterruptedException e) {
e.printStackTrace();
}
});
System.out.println("main thread wait ..");
A.start();
B.start();
A.join(); // 等待A执行结束
B.join(); // 等待B执行结束
System.out.println("all thread finish !");
}
}


但使用join方法并不是很灵活,并不能很好地满足某些场景的需要,而CountDownLatch则能够很好地代替它,并且相比之下,提供了更多灵活的特性:

CountDownLatch相比join方法对线程同步有更灵活的控制,原因如下:

  1. 调用子线程的join()方法后,该线程会一直被阻塞直到子线程运行完毕,而CountDownLatch使用计数器来允许子线程运行完毕或者运行中递减计数,await方法返回不一定必须等待线程结束。

  2. 使用线程池管理线程时,添加Runnable到线程池,没有办法再调用线程的join方法了。

使用案例与基本思路

public class TestCountDownLatch {

public static volatile CountDownLatch countDownLatch = new CountDownLatch(2);

public static void main (String[] args) throws InterruptedException {
ExecutorService executorService = Executors.newFixedThreadPool(2);
executorService.submit(() -> {
try {
Thread.sleep(1000);
System.out.println("A finish!");

} catch (InterruptedException e) {
e.printStackTrace();
} finally {
countDownLatch.countDown();
}
});
executorService.submit(() -> {
try {
Thread.sleep(1000);
System.out.println("B finish!");

} catch (InterruptedException e) {
e.printStackTrace();
} finally {
countDownLatch.countDown();
}
});
System.out.println("main thread wait ..");
countDownLatch.await();
System.out.println("all thread finish !");
executorService.shutdown();
}
}
// 结果
main thread wait ..
B finish!
A finish!
all thread finish !


  • 构建CountDownLatch实例,构造参数传参为2,内部计数初始值为2。

  • 主线程构建线程池,提交两个任务,接着调用countDownLatch.await()陷入阻塞。

  • 子线程执行完毕之后调用countDownLatch.countDown(),内部计数器减1。

  • 所有子线程执行完毕之后,计数为0,此时主线程的await方法返回。

类图与基本结构

public class CountDownLatch {
/**
* Synchronization control For CountDownLatch.
* Uses AQS state to represent count.
*/

private static final class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 4982264981922014374L;

Sync(int count) {
setState(count);
}
//...
}

private final Sync sync;

public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}

public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}

public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}

public void countDown() {
sync.releaseShared(1);
}

public long getCount() {
return sync.getCount();
}

public String toString() {
return super.toString() + "[Count = " + sync.getCount() + "]";
}
}


CountDownLatch基于AQS实现,内部维护一个Sync变量,继承了AQS。

在AQS中,最重要的就是state状态的表示,在CountDownLatch中使用state表示计数器的值,在初始化的时候,为state赋值。

几个同步方法实现比较简单,如果你不熟悉AQS,推荐你瞅一眼前置文章:

接下来我们简单看一看实现,主要学习两个方法:await()和countdown()。

void await()

当线程调用CountDownLatch的await方法后,线程会被阻塞,除非发生下面两种情况:

  1. 内部计数器值为0,getState() == 0

  2. 被其他线程中断,抛出异常,也就是currThread.interrupt()

    // CountDownLatch.java
public void await() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
// AQS.java
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
// 如果线程中断, 则抛出异常
if (Thread.interrupted())
throw new InterruptedException();
// 由子类实现,这里再Sync中实现,计数器为0就可以返回,否则进入AQS队列等待
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
// Sync
// 计数器为0 返回1, 否则返回-1
private static final class Sync extends AbstractQueuedSynchronizer {
protected int tryAcquireShared(int acquires) {
return (getState() == 0) ? 1 : -1;
}
}


boolean await(long timeout, TimeUnit unit)

当线程调用CountDownLatch的await方法后,线程会被阻塞,除非发生下面三种情况:

  1. 内部计数器值为0,getState() == 0,返回true。

  2. 被其他线程中断,抛出异常,也就是currThread.interrupt()

  3. 设置的timeout时间到了,超时返回false。

    // CountDownLatch.java
public boolean await(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
// AQS.java
public final boolean tryAcquireSharedNanos(int arg, long nanosTimeout)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();
return tryAcquireShared(arg) >= 0 ||
doAcquireSharedNanos(arg, nanosTimeout);
}


void countDown()

调用该方法,内部计数值减1,递减后如果计数器值为0,唤醒所有因调用await方法而被阻塞的线程,否则跳过。

    // CountDownLatch.java
public void countDown() {
sync.releaseShared(1);
}
// AQS.java
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) {
doReleaseShared();
return true;
}
return false;
}
// Sync
private static final class Sync extends AbstractQueuedSynchronizer {
protected boolean tryReleaseShared(int releases) {
// 循环进行CAS操作
for (;;) {
int c = getState();
// 一旦为0,就返回false
if (c == 0)
return false;
int nextc = c-1;
// CAS尝试将state-1,只有这一步CAS成功且将state变成0的线程才会返回true
if (compareAndSetState(c, nextc))
return nextc == 0;
}
}
}


总结

  • CountDownLatch相比于join方法更加灵活且方便地实现线程间同步,体现在以下几点:

    • 调用子线程的join()方法后,该线程会一直被阻塞直到子线程运行完毕,而CountDownLatch使用计数器来允许子线程运行完毕或者运行中递减计数,await方法返回不一定必须等待线程结束。

    • 使用线程池管理线程时,添加Runnable到线程池,没有办法再调用线程的join方法了。

  • CountDownLatch使用state表示内部计数器的值,初始化传入count。

  • 线程调用countdown方法将会原子性地递减AQS的state值,线程调用await方法后将会置入AQS阻塞队列中,直到计数器为0,或被打断,或超时等才会返回,计数器为0时,当前线程还需要唤醒由于await()被阻塞的线程。




想进大厂的小伙伴请注意,

大厂面试的套路很神奇,

早做准备对大家更有好处,

埋头刷题效率低,

看面经会更有效率!

小编准备了一份大厂常问面经汇总集

剩下的就不会给大家一展出来了,以上资料按照一下操作即可获得


——将文章进行转发评论关注公众号【Java烤猪皮】,关注后继续后台回复领取口令“ 666 ”即可免费领文章取中所提供的资料。




往期精品推荐



腾讯、阿里、滴滴后台试题汇集总结 — (含答案)

面试:史上最全多线程序面试题!

最新阿里内推Java后端试题

JVM难学?那是因为你没有真正看完整这篇文章


结束


关注作者微信公众号 — 《JAVA烤猪皮》


了解了更多java后端架构知识以及最新面试宝典



看完本文记得给作者点赞+在看哦~~~大家的支持,是作者来源不断出文的动力~

浏览 14
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报