Java并发编程学习——Java并发编程之美学习笔记十

Java并发保重线程同步器原理剖析

CountDownLatch 原理剖析

CountDownLatch 介绍

日常开发中可能我们可能遇到需要开启多个子线程去并行执行任务,并且 主线程需要等待所有子线程执行完毕后再进行汇总 的场景。我们可以使用 join() 方法(等待该子线程线程执行完毕),但是join()不灵活而且很多场景可能使用不了,所以JDK中提供了 CountDownLatch 这个类。我们来看一下 CountDownLatch 的使用案例。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
public class JoinCountDownLatch {

private static volatile CountDownLatch countDownLatch = new CountDownLatch(2);

public static void main(String[] args) throws InterruptedException {
Thread thread1 = new Thread(new Runnable() {
@Override
public void run() {
for (int i = 0; i < 1000; i++) {
System.out.println("child thread1 running!");
}
countDownLatch.countDown();
System.out.println("child thread1 over!");
}
});

Thread thread2 = new Thread(new Runnable() {
@Override
public void run() {
for (int i = 0; i < 1000; i++) {
System.out.println("child thread2 running!");
}
countDownLatch.countDown();
System.out.println("child thread2 over!");
}
});

thread1.start();
thread2.start();

System.out.println("wait all child thread over");

countDownLatch.await();

System.out.println("all child thread over");
}

}

运行结果:

CountDownLatch演示

我们可以看到 main 函数最后一条语句总是等待两个子线程运行结束才会运行。

当然我们还可以使用线程池的方式创建,以避免直接操作Thread。而且使用线程池来管理线程一般直接添加 Runnable 到线程池,这个时候我们就没有办法调用 join 方法了,所以说 CountDownLatch 比 join 更具有灵活性。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
public class JoinCountDownLatch {

private static volatile CountDownLatch countDownLatch = new CountDownLatch(2);

public static void main(String[] args) throws InterruptedException {

ExecutorService executorService = Executors.newFixedThreadPool(2);

executorService.submit(new Runnable() {
@Override
public void run() {
for (int i = 0; i < 1000; i++) {
System.out.println("child thread1 running!");
}
countDownLatch.countDown();
System.out.println("child thread1 over!");
}
});

executorService.submit(new Runnable() {
@Override
public void run() {
for (int i = 0; i < 1000; i++) {
System.out.println("child thread2 running!");
}
countDownLatch.countDown();
System.out.println("child thread2 over!");
}
});

System.out.println("wait all child thread over");

countDownLatch.await();

System.out.println("all child thread over");
}

}

CountDownLatch 实现原理探究

在学习 AQS 的时候提到过, AQS 是同步器的基本组成部分,而且其中 AQS 的 state 是用来表示 CountDownLatch 的计数器的。我们可以查看 CountDownLatch 的类图结构。

CountDownLatch演示

因为 Sync 是继承了 AQS 的,他实现了一些 AQS 的方法,所以可以说 CountDownLatch 是基于 AQS 实现的。

1
2
3
4
5
6
7
8
public CountDownLatch(int count) {
if (count < 0) throw new IllegalArgumentException("count < 0");
this.sync = new Sync(count);
}
// 设置 state
Sync(int count) {
setState(count);
}
  • void await()

    当线程调用 CountDownLatch 对象的 await() 方法后,当前线程会被阻塞(上面案例是主线程调用的await 所以主线程会被阻塞)当所有线程调用了 CountDownLatch 的 countDown 方法后,即计数器的值为0的时候,调用 await 方法的线程会返回,或者当其他线程调用了当前被阻塞线程的 interrupt() 方法中断了饿当前线程,当前线程就会抛出 InterruptedException 异常返回

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
public void await() throws InterruptedException {
// 调用的是sync方法 其实就是调用的 AQS
sync.acquireSharedInterruptibly(1);
}
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
if (Thread.interrupted())
throw new InterruptedException();、
// 调用 tryAcquireShared 判断 这里AQS没有实现 调用的是实现类Sync的tryAcquireShared
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
// Sync的 tryAcquireShared
protected int tryAcquireShared(int acquires) {
// 返回的state 不为0就返回 -1调用doAcquireSharedInterruptibly阻塞
// 为0不阻塞
return (getState() == 0) ? 1 : -1;
}
// 回顾一下 AQS 的阻塞 这是获取共享资源被阻塞
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
  • void countDown() 方法

    线程调用该方法后 计数器的值递减,如果递减后计数器为0则唤醒因为调用 await 方法而被阻塞的线程,否则什么都不做

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
public void countDown() {
sync.releaseShared(1);
}
// 这是 AQS 定义的
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) {
doReleaseShared();
return true;
}
return false;
}
// 这是Sync实现的
protected boolean tryReleaseShared(int releases) {
// Decrement count; signal when transition to zero
for (;;) {
// 获取 state
int c = getState();
// 入过为0则false 意思就是什么都不做
if (c == 0)
return false;
int nextc = c-1;
if (compareAndSetState(c, nextc))
// 如果CAS设置成功那么判断此时state是否为0 如果为0那么返回true
// 返回true代表要对阻塞线程进行唤醒
// 返回false代表什么都不做
return nextc == 0;
}
}
// AQS 中唤醒阻塞线程
private void doReleaseShared() {
for (;;) {
Node h = head;
if (h != null && h != tail) {
int ws = h.waitStatus;
if (ws == Node.SIGNAL) {
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue; // loop to recheck cases
unparkSuccessor(h);
}
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue; // loop on failed CAS
}
if (h == head) // loop if head changed
break;
}
}

CountDownLatch 小结

CountDownLatch 通过使用 AQS 实现,其中使用 AQS 的状态变量来存放计数器的值,当调用countDown方法的时候使state递减,调用await未得到满足的时候会 调用线程会被放入 AQS 阻塞队列中等待。 当其他线程调用 countDown方法并得到递减后的state为0的时候会调用 AQS 的 doReleaseShared 方法来激活由于调用 await() 方法而被阻塞的线程。

回环屏障 CyclicBarrier

对于 CountDownLatch 来说,线程同步后,等到计数器为0之后在调用 await 和 countDown 方法都会立即返回,也就是说 CountDownLatch 是一次性的。而 CyclicBarrier 会在所有子线程执行完毕后 重置 CyclicBarrier 的状态

CyclicBarrier 使用案例介绍

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
public class CyclicBarrierTest1 {

// 这里构造方法里可以添加任务 这个任务会在所有调用await方法的线程全部到达
// 屏障点(计数器为0)的时候调用
// 并且等到这个任务执行完毕 被阻塞的线程会被唤醒继续执行
private static CyclicBarrier cyclicBarrier = new CyclicBarrier(2,
() -> System.out.println(Thread.currentThread() + "task1 merge result"));

public static void main(String[] args) {
ExecutorService executorService = Executors.newFixedThreadPool(2);

executorService.submit(new Runnable() {
@Override
public void run() {
System.out.println(Thread.currentThread() + "task-1");

System.out.println(Thread.currentThread() + "enter in barrier");
try {
// 计数器会递减
cyclicBarrier.await();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (BrokenBarrierException e) {
e.printStackTrace();
}
System.out.println(Thread.currentThread() + "out barrier");
}
});

executorService.submit(new Runnable() {
@Override
public void run() {
System.out.println(Thread.currentThread() + "task-2");

System.out.println(Thread.currentThread() + "enter in barrier");
try {
cyclicBarrier.await();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (BrokenBarrierException e) {
e.printStackTrace();
}
// 被唤醒后继续执行
System.out.println(Thread.currentThread() + "out barrier");
}
});

executorService.shutdown();
}

}

运行结果:

CountDownLatch演示

我们再来看一个 重复使用 的例子。这里我们没执行完一个步骤就汇总一次

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
public class CyclicBarrierTest2 {
// 计数器初始化为2 并且定义了汇总任务
private static CyclicBarrier cyclicBarrier = new CyclicBarrier(2,
() -> System.out.println(Thread.currentThread() + "merge"));

public static void main(String[] args) {
ExecutorService executorService = Executors.newFixedThreadPool(2);

executorService.submit(new Runnable() {
@Override
public void run() {
System.out.println(Thread.currentThread() + "step1");
try {
cyclicBarrier.await();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (BrokenBarrierException e) {
e.printStackTrace();
}
System.out.println(Thread.currentThread() + "step2");
try {
cyclicBarrier.await();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (BrokenBarrierException e) {
e.printStackTrace();
}
System.out.println(Thread.currentThread() + "out barrier");
}
});

executorService.submit(new Runnable() {
@Override
public void run() {
System.out.println(Thread.currentThread() + "step1");
try {
cyclicBarrier.await();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (BrokenBarrierException e) {
e.printStackTrace();
}
System.out.println(Thread.currentThread() + "step2");
try {
cyclicBarrier.await();
} catch (InterruptedException e) {
e.printStackTrace();
} catch (BrokenBarrierException e) {
e.printStackTrace();
}
System.out.println(Thread.currentThread() + "out barrier");
}
});

executorService.shutdown();
}

}

运行结果:

CountDownLatch演示

CyclicBarrier 实现原理探究

我们首先看一下 CyclicBarrier 的类图。

CyclicBarrier

由此我们可以知道 CyclicBarrier 是通过 独占锁 来实现的。parties用来记录线程个数,这里表示多少个线程调用await方法后 所有线程才会冲破屏障。count 一开始等于 parties,count计数器变为0之后会将parties的值重新赋值给count,以达到重复利用的功能。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
public CyclicBarrier(int parties, Runnable barrierAction) {
if (parties <= 0) throw new IllegalArgumentException();
this.parties = parties;
// 将parties 赋值给count
this.count = parties;
this.barrierCommand = barrierAction;
}
public CyclicBarrier(int parties) {
this(parties, null);
}
// 独占锁
private final ReentrantLock lock = new ReentrantLock();
// 使用trip条件变量实现同步
private final Condition trip = lock.newCondition();
private final int parties;
private final Runnable barrierCommand;
private Generation generation = new Generation();
private int count;
// 里面的broken记录该屏障是否被打破
private static class Generation {
boolean broken = false;
}
  • int await() 方法

    当前线程调用该方法会被阻塞,知道满足下面条件之一才会返回:

    1. parties个线程调用了该方法,即到达屏障点
    2. 其他线程调用了当前线程的interrupt() 方法
    3. 与当前屏障点关联的 broken 标志被设置为 true 会抛出 BrokenbarrierException 然后返回。
1
2
3
4
5
6
7
public int await() throws InterruptedException, BrokenBarrierException {
try {
return dowait(false, 0L);
} catch (TimeoutException toe) {
throw new Error(toe); // cannot happen
}
}
  • int dowait() 方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
private int dowait(boolean timed, long nanos)
throws InterruptedException, BrokenBarrierException,
TimeoutException {
// 获取锁
final ReentrantLock lock = this.lock;
lock.lock();
try {
// 判断broken标志
final Generation g = generation;
if (g.broken)
throw new BrokenBarrierException();
// 是否被打断
if (Thread.interrupted()) {
// 被打断也要重置和唤醒
breakBarrier();
throw new InterruptedException();
}
// 将count递减
int index = --count;
// 如果执行后为0 那么执行屏障的任务
if (index == 0) { // tripped
boolean ranAction = false;
try {
final Runnable command = barrierCommand;
if (command != null)
command.run();
ranAction = true;
// 激活其他因调用await而被阻塞的线程 并重置cyclicBarrier
nextGeneration();
return 0;
} finally {
// 重置和唤醒
if (!ranAction)
breakBarrier();
}
}
// loop until tripped, broken, interrupted, or timed out
for (;;) {
try {
// 没有设置超时时间
if (!timed)
// 放入条件变量阻塞队列
trip.await();
else if (nanos > 0L)
nanos = trip.awaitNanos(nanos);
} catch (InterruptedException ie) {
if (g == generation && ! g.broken) {
breakBarrier();
throw ie;
} else {
// We're about to finish waiting even if we had not
// been interrupted, so this interrupt is deemed to
// "belong" to subsequent execution.
Thread.currentThread().interrupt();
}
}
if (g.broken)
throw new BrokenBarrierException();
if (g != generation)
return index;
if (timed && nanos <= 0L) {
breakBarrier();
throw new TimeoutException();
}
}
} finally {
lock.unlock();
}
}

// 唤醒阻塞队列并重置
private void nextGeneration() {
// 唤醒条件队列的所有阻塞线程
trip.signalAll();
// 重置
count = parties;
generation = new Generation();
}

private void breakBarrier() {
generation.broken = true;
count = parties;
trip.signalAll();
}

CyclicBarrier 小结

与 CountDownLatch 不同的是 CyclicBarrier 可以实现复用,并且特别适用分段任务有序执行的场景。CyclicBarrier适用独占锁来保证计数器的原子性更新,并使用条件队列来实现线程同步。

信号量 Semaphore 原理探究

Semaphore 也是 Java 中的一个同步器,和前面的 CountDownLatch 和 CyclicBarrier 不同的是 Semaphore内部的计数器是递增的,并且在初始化的时候可以指定一个初始值,但是 并不需要知道需要同步的线程个数, 而是在需要同步的地方调用 acquire 方法时指定需要同步的线程个数

Semaphore 案例介绍

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
public class SemaphoreTest {

// 生成 Semaphore 初始化计数器为0 因为是递增 所以只需要在需要
// 的时候指定 递增到多少 不需要一开始指定需要同步的线程个数
private static Semaphore semaphore = new Semaphore(0);

public static void main(String[] args) throws InterruptedException {

ExecutorService executorService = Executors.newFixedThreadPool(2);

executorService.submit(new Runnable() {
@Override
public void run() {
System.out.println(Thread.currentThread() + "over");
semaphore.release();
}
});

executorService.submit(new Runnable() {
@Override
public void run() {
System.out.println(Thread.currentThread() + "over");
semaphore.release();
}
});

// 需要时指定acquire
semaphore.acquire(2);

System.out.println("merge");

executorService.shutdown();

}

}

运行结果:

Semaphore

我们再来看一下 使用 Semaphore 实现复用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
public class SemaphoreTest2 {

private static Semaphore semaphore = new Semaphore(0);

public static void main(String[] args) throws InterruptedException {

ExecutorService executorService = Executors.newFixedThreadPool(2);

executorService.submit(new Runnable() {
@Override
public void run() {
System.out.println(Thread.currentThread() + "taskA over");
semaphore.release();
}
});

executorService.submit(new Runnable() {
@Override
public void run() {
System.out.println(Thread.currentThread() + "taskA over");
semaphore.release();
}
});

semaphore.acquire(2);

executorService.submit(new Runnable() {
@Override
public void run() {
System.out.println(Thread.currentThread() + "taskB over");
semaphore.release();
}
});

executorService.submit(new Runnable() {
@Override
public void run() {
System.out.println(Thread.currentThread() + "taskB over");
semaphore.release();
}
});

semaphore.acquire(2);

System.out.println("all task is over");


executorService.shutdown();

}

}

Semaphore

怎么能够复用呢? 其实是因为主线程调用 acquire 方法返回后 信号量会重新变成0。

Semaphore 实现原理探究

Semaphore

由类图可知,我们还是使用 AQS 实现的,并且还实现了获取信号量时是采用 公平策略 还是 非公平策略。

1
2
3
4
5
6
7
// permits是初始化的计数器值
public Semaphore(int permits) {
sync = new NonfairSync(permits);
}
public Semaphore(int permits, boolean fair) {
sync = fair ? new FairSync(permits) : new NonfairSync(permits);
}
  • void acquire() 方法

    当前线程调用该方法是 希望获取一个信号量资源。如果信号量个数大于0则当前信号量的计数会减一,然后该方法直接返回。否则如果当前信号量个数等于0,则当前线程会被放入 AQS 的阻塞队列。 当其他线程调用该线程的 interrupt() 方法中断了当前线程,当前线程会抛出 InterruptedException 然后返回。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
public void acquire() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
// AQS 实现的
public final void acquireSharedInterruptibly(int arg)
throws InterruptedException {
// 如果打断 抛出异常
if (Thread.interrupted())
throw new InterruptedException();
// 调用sync自己实现的tryAcquireShared 尝试获取共享资源
// 如果小于0 阻塞当前调用线程
if (tryAcquireShared(arg) < 0)
doAcquireSharedInterruptibly(arg);
}
// 我们首先看非公平实现
protected int tryAcquireShared(int acquires) {
return nonfairTryAcquireShared(acquires);
}
final int nonfairTryAcquireShared(int acquires) {
// 无限循环
for (;;) {
// 获取状态值
int available = getState();
// 剩余量
int remaining = available - acquires;
// 如果剩余量小于0 那么返回剩余量 此时为负值 那么会直接调用doAcquireSharedInterruptibly
// 不小于0 那么cas设置状态值 如果成功返回剩余量 设置不成功就一直循环
// 如果返回大于0返回剩余值
if (remaining < 0 ||
compareAndSetState(available, remaining))
return remaining;
}
}
// 我们看一下 AQS 阻塞 放入阻塞队列里
private void doAcquireSharedInterruptibly(int arg)
throws InterruptedException {
final Node node = addWaiter(Node.SHARED);
boolean failed = true;
try {
for (;;) {
final Node p = node.predecessor();
if (p == head) {
int r = tryAcquireShared(arg);
if (r >= 0) {
setHeadAndPropagate(node, r);
p.next = null; // help GC
failed = false;
return;
}
}
if (shouldParkAfterFailedAcquire(p, node) &&
parkAndCheckInterrupt())
throw new InterruptedException();
}
} finally {
if (failed)
cancelAcquire(node);
}
}
  • void acquire(int permits)
1
2
3
4
5
public void acquire(int permits) throws InterruptedException {
if (permits < 0) throw new IllegalArgumentException();
// 这里指定了个数
sync.acquireSharedInterruptibly(permits);
}
  • void release()

    该方法是把当前的 Semaphore 对象的信号量值增加1,如果当前线程又因为调用 acquire 方法被阻塞放入 AQS 阻塞队列中,则会 根据公平策略选择一个信号量个数能被满足的线程进行激活, 激活的线程会尝试获取刚增加的信号量。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
public void release() {
sync.releaseShared(1);
}
// AQS
public final boolean releaseShared(int arg) {
if (tryReleaseShared(arg)) {
doReleaseShared();
return true;
}
return false;
}
// Sync实现到的tryReleaseShared
protected final boolean tryReleaseShared(int releases) {
// 无限循环
for (;;) {
// 获取当前信号量
int current = getState();
// 获取加上release的信号量
int next = current + releases;
// 如果相加后小于current 说明发生了整型溢出则抛出异常
if (next < current) // overflow
throw new Error("Maximum permit count exceeded");
// 尝试CAS设置信号量 直到成功为止
// 成功后会调用doReleaseShared 可选择AQS阻塞队列符合要求的线程进行激活
if (compareAndSetState(current, next))
return true;
}
}
// AQS 进行激活
private void doReleaseShared() {
for (;;) {
Node h = head;
if (h != null && h != tail) {
int ws = h.waitStatus;
if (ws == Node.SIGNAL) {
if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
continue; // loop to recheck cases
unparkSuccessor(h);
}
else if (ws == 0 &&
!compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
continue; // loop on failed CAS
}
if (h == head) // loop if head changed
break;
}
}

Semaphore 小结

Semaphore 内部使用了一个递增的计数器,这样就可以不在初始化的时候指定需要同步的线程个数了。它通过 AQS 实现,并且在获取信号量时有公平和非公平策略选择。

-------------本文结束感谢阅读-------------