Java CountDownLatch 实现

CountDownLatch 也是基于 AQS(AbstractQueuedSynchronizer)实现的一种同步器,表示“所有线程都等待,直到锁打开才继续执行”的含义。CountDownLatch 的作用是:当一个线程需要另外一个或多个线程完成后,再开始执行。

CountDownLatch 可以用来实现很多场景,比如:

  1. 某个服务依赖于其他服务的启动才能启动
  2. 某个游戏,必须等所有就绪者都到达才能开始游戏
  3. 启动一组相关的线程
  4. 等待一组相关线程结束

AQS 实现类

CountDownLatch 提供了一个内部类 Sync 继承自 AQS

  1. CountDownLatch 可以让多个线程同时进入临界区,所以也是共享模式的 AQS
  2. 获取操作只是判断状态是否为0,即是否可以结束等待,进入临界区
  3. 释放操作是对状态递减1,所以叫 CountDown,类似报数的意思
private static final class Sync extends AbstractQueuedSynchronizer {

    Sync(int count) {
        setState(count);
    }

    int getCount() {
        return getState();
    }

    protected int tryAcquireShared(int acquires) {
        return (getState() == 0) ? 1 : -1;
    }

    protected boolean tryReleaseShared(int releases) {
        // Decrement count; signal when transition to zero
        for (;;) {
            int c = getState();
            if (c == 0)
                return false;
            int nextc = c-1;
            if (compareAndSetState(c, nextc))
                return nextc == 0;
        }
    }
}

CountDownLatch 维护了一个状态表示 Count 的总数,释放一次对这个总数减1直到为0,它的 tryXXX 方法传递的参数没有实际意义,只是为了适应接口。

主要方法

CountDownLatch 常用的方法:await()countDown()

public class CountDownLatch {
    private final Sync sync;

    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

    // 阻塞等待
    // 当count的值为0时,调用await()方法的线程便会逐个被唤醒
    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    // 限时等待
        public boolean await(long timeout, TimeUnit unit)
        throws InterruptedException {
            return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
        }

    // count-1
    public void countDown() {
            sync.releaseShared(1);
        }
}

AQS 原理

state 值是 AbstractQueuedSynchronizer 类中的一个 volatile 变量,state 为 0 时,唤醒等待线程

public abstract class AbstractQueuedSynchronizer
      extends AbstractOwnableSynchronizer
      implements java.io.Serializable {

      private volatile int state;

    public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        if (tryAcquireShared(arg) < 0)      
            // Sync 类实现方法 >0 表示没有线程需要等待 <0 表示有线程需要等待
            doAcquireSharedInterruptibly(arg);
    }

    private void doAcquireSharedInterruptibly(int arg)
          throws InterruptedException {
        // 加入等待队列
        final Node node = addWaiter(Node.SHARED);
        boolean failed = true;
        try {
            for (;;) {
                // 当一个线程(Node代表)进入等待队列后,获取此Node的prev节点 
                final Node p = node.predecessor();
                if (p == head) { // 如果获取到的 prev 是 head(队列中第一个等待线程)
                    // 查看是否还有线程需要等待
                    int r = tryAcquireShared(arg);
                    if (r >= 0) {
                        // 没有线程需要等待了 state==0,退出循环返回
                        setHeadAndPropagate(node, r);
                        p.next = null; // help GC
                        failed = false;
                        return;
                    }
                }

                                // LockSupport#park 阻塞线程
                if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt())
                    throw new InterruptedException();
            }
        } finally {
            if (failed)
                cancelAcquire(node);
        }
    }
}

Add a Comment

电子邮件地址不会被公开。 必填项已用*标注