十余年品牌的成都網(wǎng)站建設(shè)公司,上1000家企業(yè)網(wǎng)站設(shè)計經(jīng)驗.價格合理,可準(zhǔn)確把握網(wǎng)頁設(shè)計訴求.提供定制網(wǎng)站建設(shè)、購物商城網(wǎng)站建設(shè)、微信小程序定制開發(fā)、響應(yīng)式網(wǎng)站建設(shè)等服務(wù),我們設(shè)計的作品屢獲殊榮,是您值得信賴的專業(yè)網(wǎng)絡(luò)公司。
(手機(jī)橫屏看源碼更方便)
注:java源碼分析部分如無特殊說明均基于 java8 版本。
注:本文基于ForkJoinPool分治線程池類。
隨著在硬件上多核處理器的發(fā)展和廣泛使用,并發(fā)編程成為程序員必須掌握的一門技術(shù),在面試中也經(jīng)??疾槊嬖囌卟l(fā)相關(guān)的知識。
今天,我們就來看一道面試題:
如何充分利用多核CPU,計算很大數(shù)組中所有整數(shù)的和?
單線程相加?
我們最容易想到就是單線程相加,一個for循環(huán)搞定。
線程池相加?
如果進(jìn)一步優(yōu)化,我們會自然而然地想到使用線程池來分段相加,最后再把每個段的結(jié)果相加。
其它?
Yes,就是我們今天的主角——ForkJoinPool,但是它要怎么實現(xiàn)呢?似乎沒怎么用過哈^^
OK,剖析完了,我們直接來看三種實現(xiàn),不墨跡,直接上菜。
/**
* 計算1億個整數(shù)的和
*/
public class ForkJoinPoolTest01 {
public static void main(String[] args) throws ExecutionException, InterruptedException {
// 構(gòu)造數(shù)據(jù)
int length = 100000000;
long[] arr = new long[length];
for (int i = 0; i < length; i++) {
arr[i] = ThreadLocalRandom.current().nextInt(Integer.MAX_VALUE);
}
// 單線程
singleThreadSum(arr);
// ThreadPoolExecutor線程池
multiThreadSum(arr);
// ForkJoinPool線程池
forkJoinSum(arr);
}
private static void singleThreadSum(long[] arr) {
long start = System.currentTimeMillis();
long sum = 0;
for (int i = 0; i < arr.length; i++) {
// 模擬耗時,本文由公從號“彤哥讀源碼”原創(chuàng)
sum += (arr[i]/3*3/3*3/3*3/3*3/3*3);
}
System.out.println("sum: " + sum);
System.out.println("single thread elapse: " + (System.currentTimeMillis() - start));
}
private static void multiThreadSum(long[] arr) throws ExecutionException, InterruptedException {
long start = System.currentTimeMillis();
int count = 8;
ExecutorService threadPool = Executors.newFixedThreadPool(count);
List> list = new ArrayList<>();
for (int i = 0; i < count; i++) {
int num = i;
// 分段提交任務(wù)
Future future = threadPool.submit(() -> {
long sum = 0;
for (int j = arr.length / count * num; j < (arr.length / count * (num + 1)); j++) {
try {
// 模擬耗時
sum += (arr[j]/3*3/3*3/3*3/3*3/3*3);
} catch (Exception e) {
e.printStackTrace();
}
}
return sum;
});
list.add(future);
}
// 每個段結(jié)果相加
long sum = 0;
for (Future future : list) {
sum += future.get();
}
System.out.println("sum: " + sum);
System.out.println("multi thread elapse: " + (System.currentTimeMillis() - start));
}
private static void forkJoinSum(long[] arr) throws ExecutionException, InterruptedException {
long start = System.currentTimeMillis();
ForkJoinPool forkJoinPool = ForkJoinPool.commonPool();
// 提交任務(wù)
ForkJoinTask forkJoinTask = forkJoinPool.submit(new SumTask(arr, 0, arr.length));
// 獲取結(jié)果
Long sum = forkJoinTask.get();
forkJoinPool.shutdown();
System.out.println("sum: " + sum);
System.out.println("fork join elapse: " + (System.currentTimeMillis() - start));
}
private static class SumTask extends RecursiveTask {
private long[] arr;
private int from;
private int to;
public SumTask(long[] arr, int from, int to) {
this.arr = arr;
this.from = from;
this.to = to;
}
@Override
protected Long compute() {
// 小于1000的時候直接相加,可靈活調(diào)整
if (to - from <= 1000) {
long sum = 0;
for (int i = from; i < to; i++) {
// 模擬耗時
sum += (arr[i]/3*3/3*3/3*3/3*3/3*3);
}
return sum;
}
// 分成兩段任務(wù),本文由公從號“彤哥讀源碼”原創(chuàng)
int middle = (from + to) / 2;
SumTask left = new SumTask(arr, from, middle);
SumTask right = new SumTask(arr, middle, to);
// 提交左邊的任務(wù)
left.fork();
// 右邊的任務(wù)直接利用當(dāng)前線程計算,節(jié)約開銷
Long rightResult = right.compute();
// 等待左邊計算完畢
Long leftResult = left.join();
// 返回結(jié)果
return leftResult + rightResult;
}
}
}
彤哥偷偷地告訴你,實際上計算1億個整數(shù)相加,單線程是最快的,我的電腦大概是100ms左右,使用線程池反而會變慢。
所以,為了演示ForkJoinPool的牛逼之處,我把每個數(shù)都/3*3/3*3/3*3/3*3/3*3
了一頓操作,用來模擬計算耗時。
來看結(jié)果:
sum: 107352457433800662
single thread elapse: 789
sum: 107352457433800662
multi thread elapse: 228
sum: 107352457433800662
fork join elapse: 189
可以看到,F(xiàn)orkJoinPool相對普通線程池還是有很大提升的。
問題:普通線程池能否實現(xiàn)ForkJoinPool這種計算方式呢,即大任務(wù)拆中任務(wù),中任務(wù)拆小任務(wù),最后再匯總?
你可以試試看(-?_-?)
OK,下面我們正式進(jìn)入ForkJoinPool的解析。
基本思想
把一個規(guī)模大的問題劃分為規(guī)模較小的子問題,然后分而治之,最后合并子問題的解得到原問題的解。
步驟
(1)分割原問題:
(2)求解子問題:
(3)合并子問題的解為原問題的解。
在分治法中,子問題一般是相互獨立的,因此,經(jīng)常通過遞歸調(diào)用算法來求解子問題。
典型應(yīng)用場景
(1)二分搜索
(2)大整數(shù)乘法
(3)Strassen矩陣乘法
(4)棋盤覆蓋
(5)歸并排序
(6)快速排序
(7)線性時間選擇
(8)漢諾塔
ForkJoinPool是 java 7 中新增的線程池類,它的繼承體系如下:
ForkJoinPool和ThreadPoolExecutor都是繼承自AbstractExecutorService抽象類,所以它和ThreadPoolExecutor的使用幾乎沒有多少區(qū)別,除了任務(wù)變成了ForkJoinTask以外。
這里又運(yùn)用到了一種很重要的設(shè)計原則——開閉原則——對修改關(guān)閉,對擴(kuò)展開放。
可見整個線程池體系一開始的接口設(shè)計就很好,新增一個線程池類,不會對原有的代碼造成干擾,還能利用原有的特性。
fork()
fork()方法類似于線程的Thread.start()方法,但是它不是真的啟動一個線程,而是將任務(wù)放入到工作隊列中。
join()
join()方法類似于線程的Thread.join()方法,但是它不是簡單地阻塞線程,而是利用工作線程運(yùn)行其它任務(wù)。當(dāng)一個工作線程中調(diào)用了join()方法,它將處理其它任務(wù),直到注意到目標(biāo)子任務(wù)已經(jīng)完成了。
RecursiveAction
無返回值任務(wù)。
RecursiveTask
有返回值任務(wù)。
CountedCompleter
無返回值任務(wù),完成任務(wù)后可以觸發(fā)回調(diào)。
ForkJoinPool內(nèi)部使用的是“工作竊取”算法實現(xiàn)的。
(1)每個工作線程都有自己的工作隊列WorkQueue;
(2)這是一個雙端隊列,它是線程私有的;
(3)ForkJoinTask中fork的子任務(wù),將放入運(yùn)行該任務(wù)的工作線程的隊頭,工作線程將以LIFO的順序來處理工作隊列中的任務(wù);
(4)為了最大化地利用CPU,空閑的線程將從其它線程的隊列中“竊取”任務(wù)來執(zhí)行;
(5)從工作隊列的尾部竊取任務(wù),以減少競爭;
(6)雙端隊列的操作:push()/pop()僅在其所有者工作線程中調(diào)用,poll()是由其它線程竊取任務(wù)時調(diào)用的;
(7)當(dāng)只剩下最后一個任務(wù)時,還是會存在競爭,是通過CAS來實現(xiàn)的;
(1)最適合的是計算密集型任務(wù),本文由公從號“彤哥讀源碼”原創(chuàng);
(2)在需要阻塞工作線程時,可以使用ManagedBlocker;
(3)不應(yīng)該在RecursiveTask
(1)ForkJoinPool特別適合于“分而治之”算法的實現(xiàn);
(2)ForkJoinPool和ThreadPoolExecutor是互補(bǔ)的,不是誰替代誰的關(guān)系,二者適用的場景不同;
(3)ForkJoinTask有兩個核心方法——fork()和join(),有三個重要子類——RecursiveAction、RecursiveTask和CountedCompleter;
(4)ForkjoinPool內(nèi)部基于“工作竊取”算法實現(xiàn);
(5)每個線程有自己的工作隊列,它是一個雙端隊列,自己從隊列頭存取任務(wù),其它線程從尾部竊取任務(wù);
(6)ForkJoinPool最適合于計算密集型任務(wù),但也可以使用ManagedBlocker以便用于阻塞型任務(wù);
(7)RecursiveTask內(nèi)部可以少調(diào)用一次fork(),利用當(dāng)前線程處理,這是一種技巧;
ManagedBlocker怎么使用?
答:ManagedBlocker相當(dāng)于明確告訴ForkJoinPool框架要阻塞了,F(xiàn)orkJoinPool就會啟另一個線程來運(yùn)行任務(wù),以最大化地利用CPU。
請看下面的例子,自己琢磨哈^^。
/**
* 斐波那契數(shù)列
* 一個數(shù)是它前面兩個數(shù)之和
* 1,1,2,3,5,8,13,21
*/
public class Fibonacci {
public static void main(String[] args) {
long time = System.currentTimeMillis();
Fibonacci fib = new Fibonacci();
int result = fib.f(1_000).bitCount();
time = System.currentTimeMillis() - time;
System.out.println("result,本文由公從號“彤哥讀源碼”原創(chuàng) = " + result);
System.out.println("test1_000() time = " + time);
}
public BigInteger f(int n) {
Map cache = new ConcurrentHashMap<>();
cache.put(0, BigInteger.ZERO);
cache.put(1, BigInteger.ONE);
return f(n, cache);
}
private final BigInteger RESERVED = BigInteger.valueOf(-1000);
public BigInteger f(int n, Map cache) {
BigInteger result = cache.putIfAbsent(n, RESERVED);
if (result == null) {
int half = (n + 1) / 2;
RecursiveTask f0_task = new RecursiveTask() {
@Override
protected BigInteger compute() {
return f(half - 1, cache);
}
};
f0_task.fork();
BigInteger f1 = f(half, cache);
BigInteger f0 = f0_task.join();
long time = n > 10_000 ? System.currentTimeMillis() : 0;
try {
if (n % 2 == 1) {
result = f0.multiply(f0).add(f1.multiply(f1));
} else {
result = f0.shiftLeft(1).add(f1).multiply(f1);
}
synchronized (RESERVED) {
cache.put(n, result);
RESERVED.notifyAll();
}
} finally {
time = n > 10_000 ? System.currentTimeMillis() - time : 0;
if (time > 50)
System.out.printf("f(%d) took %d%n", n, time);
}
} else if (result == RESERVED) {
try {
ReservedFibonacciBlocker blocker = new ReservedFibonacciBlocker(n, cache);
ForkJoinPool.managedBlock(blocker);
result = blocker.result;
} catch (InterruptedException e) {
throw new CancellationException("interrupted");
}
}
return result;
// return f(n - 1).add(f(n - 2));
}
private class ReservedFibonacciBlocker implements ForkJoinPool.ManagedBlocker {
private BigInteger result;
private final int n;
private final Map cache;
public ReservedFibonacciBlocker(int n, Map cache) {
this.n = n;
this.cache = cache;
}
@Override
public boolean block() throws InterruptedException {
synchronized (RESERVED) {
while (!isReleasable()) {
RESERVED.wait();
}
}
return true;
}
@Override
public boolean isReleasable() {
return (result = cache.get(n)) != RESERVED;
}
}
}