深入分析 ThreadLocal

概述

ThreadLocal 提供了线程本地的实例。它与普通变量的区别在于,每个使用该变量的线程都会初始化一个完全独立的实例副本。ThreadLocal 变量通常被 private static 修饰。当一个线程结束时,它所使用的所有 ThreadLocal 实例副本都可被回收。

每个访问 ThreadLocal 变量的线程都有自己的一个实例副本。一个可能的方案是 ThreadLocal 维护一个 Map,键是 Thread,值是它在该 Thread 内的实例。线程通过该 ThreadLocal 的 get() 方法获取实例时,只需要以线程为键,从 Map 中找出对应的实例即可。

这种方式存在的问题是,当增加线程与减少线程均需要写入 Map,在多线程条件下需要保证该 Map 线程安全。虽然 ConcurrentHashMap 能够保证线程安全,但总归是借助锁来实现。

那么如果让 Map 由 Thread 维护,从而使得每个 Thread 只访问自己的 Map,那就不存在多线程写的问题,也就不需要锁。

该方案虽然没有锁的问题,但是由于每个线程访问某 ThreadLocal 变量后,都会在自己的 Map 内维护该 ThreadLocal 变量与具体实例的映射,如果不删除这些引用(映射),则这些 ThreadLocal 不能被回收,可能会造成内存泄漏。

Map 由 ThreadLocal 类的静态内部类 ThreadLocalMap 提供。该类的实例维护某个 ThreadLocal 与具体实例的映射。与 HashMap 不同的是,ThreadLocalMap 的 Entry 的键的引用是弱引用,而值的引用是强引用。

使用弱引用可以保证当没有强引用指向 ThreadLocal 变量时,它可被回收,从而避免 ThreadLocal 实例不能被回收而造成的内存泄漏的问题。

然而,这样仍然存在内存泄漏问题。当 Entry 的键被置为 null 时,Entry 的值仍然是强引用,其所指向的实例不能被回收。除此之外,Entry 本身也不能被回收。

针对该问题,ThreadLocalMap 的 set() 方法中,通过 replaceStaleEntry() 方法将所有键为 null 的 Entry 的值设置为 null,从而使得该值可被回收。另外,会在 rehash() 方法中通过 expungeStaleEntry() 方法将键和值为 null 的 Entry 设置为 null ,从而使得该 Entry 可被回收。通过这种方式,ThreadLocal 可防止内存泄漏。

ThreadLocal 源码分析

ThreadLocal 源码中有几个比较重要的方法:get()set()remove()initialValue()setInitialValue()

get()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
// ThreadLocal.java
public T get() {
// 获取当前线程
Thread t = Thread.currentThread();
// 线程首先获取自身的 ThreadLocalMap
// ThreadLocalMap 是位于 ThreadLocal 类的静态内部类
ThreadLocalMap map = getMap(t);
if (map != null) {
// 如果 map 不是null,获取在 ThreadLocalMap 中 ThreadlLocal 对象作为 key 获取对应的值
ThreadLocalMap.Entry e = map.getEntry(this);
//  Entry 不为 null,则 Entry 中的值即为所访问的本线程对应的值
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
// 返回该 ThreadLocal 变量在该线程中对应值的初始值
return setInitialValue();
}

首先在当前线程中获取 ThreadLocalMap 对象(Map 中 key 是 ThreadLocal 对象,value 则是设置的值)。如果 Map 存在,则获取当前 ThreadLocal 对象的值,并返回。如果 Map 不存在,则从setInitialValue() 方法中获取。

setInitialValue()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
// ThreadLocal.java
private T setInitialValue() {
// initialValue() 默认返回值为 null
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}

// `initialValue()` 方法在实际应用中被重写
protected T initialValue() {   
return null;
}

// 创建 ThreadLocalMap
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}

set()

1
2
3
4
5
6
7
8
9
10
11
12
13
//ThreadLocal.java
public void set(T value) {
// 获取当前线程
Thread t = Thread.currentThread();
// 获取该线程的 ThreadLocalMap
ThreadLocalMap map = getMap(t);
if (map != null)
// 调用 ThreadLocalMap.set() 方法
map.set(this, value);
else
// 创建 ThreadLocalMap
createMap(t, value);
}

set() 方法将 ThreadLocalMap 中该 ThreadLocal 对应的值设置为指定值。

remove()

1
2
3
4
5
6
7
// ThreadLocal.java
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
// 调用 ThreadLocalMap.remove() 方法
m.remove(this);
}

ThreadLocalMap

可以发现,ThreadLocal 中的 set()remove() 方法都是借助 ThreadLocalMap 的 set()remove() 方法实现,下面具体看一下这些方法源码。

1
2
3
4
5
6
7
8
9
10
11
12
13
// ThreadLocal.ThreadLocalMap
private static final int INITIAL_CAPACITY = 16;
private int size = 0;
private int threshold; // Default to 0
private Entry[] table;

static class Entry extends WeakReference<ThreadLocal<?>> {
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}

ThreadLocalMap 是 ThreadLocal 中的一个静态内部类,默认容量为 16,Map 的键是 ThreadLocal 对象,值是 Object 对象。ThreadLocalMap 维护着一个哈希表,即 Entry 数组。

值得注意的是,ThreadLocalMap 的每个 Entry 的键是一个弱引用。

set()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
// ThreadLocal.ThreadLocalMap
private void set(ThreadLocal<?> key, Object value) {

Entry[] tab = table;
int len = tab.length;
// 根据 ThreadLocal 的哈希值得到对应的下标
int i = key.threadLocalHashCode & (len-1);
// 线性探测
// 首先通过下标找对应的 Entry 对象,若不存在则创建一个新的 Entry 对象
// 若存在,但 key 冲突或者 key 是 null,则将下标加一(加一后如果小于数组长度则使用该值,否则使用 0),
// 再次尝试获取对应的 Entry,如果不为 null,则在循环中继续判断 key 是否重复或者 k 是否是 null
// Entry 不为 null
for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
// key 相同则覆盖 value
if (k == key) {
e.value = value;
return;
}
// key 为 null ,将 value 变为 null
if (k == null) {
//用新元素替换陈旧的元素
replaceStaleEntry(key, value, i);
return;
}
}
// Entry 为 null,则创建新的 Entry
tab[i] = new Entry(key, value);

int sz = ++size;

// cleanSomeSlots 清理脏 Entry
// 如果没有脏 Entry 需要清理并且数组中的元素大于阈值,则进行 rehash
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}

set() 方法对脏 Entry(stale entry)做如下处理:

  1. 如果当前 table[i]!=null ,说明哈希冲突就需要向后环形查找,若在查找过程中遇到脏 Entry 就通过 replaceStaleEntry() 方法进行处理;
  2. 如果当前 table[i]==null ,说明新的 Entry 可以直接插入,但是插入后会调用 cleanSomeSlots() 方法检测并清除脏 Entry。
replaceStaleEntry()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;

int slotToExpunge = staleSlot;

// 向前寻找最远的脏 Entry,并将其索引赋值给 slotToExpunge
// 停止条件是 Entry 为 null
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
if (e.get() == null)
slotToExpunge = i;

for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();

if (k == key) {
//向后查找过程中发现 key 相同的 Entry 就覆盖并且和脏 Entry 进行交换
// 保证 Entry 放在其索引所在位置
e.value = value;

tab[i] = tab[staleSlot];
tab[staleSlot] = e;
// 如果在向前查找过程中未发现脏 Entry,那么就以当前位置作为 cleanSomeSlot 的起点
if (slotToExpunge == staleSlot)
slotToExpunge = i;
// 搜索脏 Entry 并进行清理
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}

// 如果向前未搜索到脏 Entry,则在查找过程遇到脏 Entry,后面就以此时这个位置
//作为起点执行 cleanSomeSlots
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}

// 如果在查找过程中没有找到可以覆盖的 Entry,则将新的 Entry 插入在脏 Entry 中
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);

if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

replaceStaleEntry() 方法的作用就是在进行普通的 set() 过程中,同时寻找连续的 Entry 不为 null 的区间,并进行脏 Entry 的清理。

expungeStaleEntry()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;

// 将 Entry 和 Entry.value 置为 null
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;

// Rehash 直到遇到 Entry.key 为 null
Entry e;
int i;
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
// Entry.key 为 null
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
// Rehash 过程
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;

// 线性探测
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}

expungeStaleEntry() 方法会清理掉当前脏 Entry,并且继续向后搜索,若再次遇到脏 Entry 则继续将其清理,持续清理动作直到 Entry 为 null 时退出。该方法执行结果是当前脏 Entry(staleSlot)到返回的 i 位中间所有的 Entry 都不是脏 Entry。

cleanSomeSlots()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
i = nextIndex(i, len);
Entry e = tab[i];
if (e != null && e.get() == null) {
n = len;
removed = true;
i = expungeStaleEntry(i);
}
} while ( (n >>>= 1) != 0);
return removed;
}

注意参数 n 的作用,参数 n 主要用于控制扫描次数。在扫描过程中,如果一直没有遇到脏 Entry,就扫描log(n) 次,之所以是 log(n) 是因为 n >>>= 1,每次 n 右移一位相当于 n 除以 2。如果在扫描过程中遇到脏 Entry ,n 变为当前哈希表的长度,再扫描 log(n) 次,从而增大了搜索的范围。

rehash()
1
2
3
4
5
6
private void rehash() {
expungeStaleEntries();

if (size >= threshold - threshold / 4)
resize();
}

首先清理脏 Entry,清理后如果 size >= threshold - threshold / 4 成立,则执行 resize()

resize()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;

for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null; // Help the GC
} else {
int h = k.threadLocalHashCode & (newLen - 1);
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}

setThreshold(newLen);
size = count;
table = newTab;
}

private void setThreshold(int len) {   
threshold = len * 2 / 3;
}

ThreadLocalMap 扩容分为两个步骤:

  1. 当长度达到了容量的 2/3,就会执行 rehash() 方法,该方法会清理无用的数据;
  2. 如果清理完后,长度仍大于等于阀值的 3/4,则执行 resize() 方法,做真正的扩容。
remove()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear();
expungeStaleEntry(i);
return;
}
}
}

remove() 方法中当遇到脏 Entry 时,也会调用 expungeStaleEntry() 清理脏 Entry。

Netty 中的 FastThreadLocal

Netty 中 FastThreadLocal 用来代替 ThreadLocal 存放线程本地变量,从 FastThreadLocalThread 类型的线程中访问本地变量时,比直接使用 ThreadLocal 会有更好的性能。注意,FastThreadLocalThread 和 FastThreadLocal 一起使用才能有更好的性能。

和 ThreadLocal 实现方式类似,FastThreadLocalThread 中有一个 InternalThreadLocalMap 类型的字段 threadLocalMap,这样一个线程对应一个 InternalThreadLocalMap 实例,该线程下所有的线程本地变量都会放 threadLocalMap 中的数组 indexedVariables(Object 数组) 中。

FastThreadLocal 用 Object 数组来替代了 ThreadLocal 的 Entry 数组,需要注意的是 Object[0] 存储一个Set<FastThreadLocal<?>>集合。

下面依次看下几个重要的类 FastThreadLocal、InternalThreadLocalMap、UnpaddedInternalThreadLocalMap、FastThreadLocalThread 的源码。

FastThreadLocal

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// FastThreadLocal.java

public class FastThreadLocal<V> {

// variablesToRemoveIndex 指定用来存放 FastThreadLocal 实例的集合 variablesToRemove
// 在 indexedVariables 数组中的位置
private static final int variablesToRemoveIndex = InternalThreadLocalMap.nextVariableIndex();

// 全局唯一 ID
private final int index;

public FastThreadLocal() {
// 每个 FastThreadLocal 实例对应一个唯一 index
index = InternalThreadLocalMap.nextVariableIndex();
}

// ...
}

从 FastThreadLocal 的构造函数可以看出,FastThreadLocal 初始化时得到一个 index 变量,其值通过InternalThreadLocalMap.nextVariableIndex() 获取。

注意,variablesToRemoveIndex 指定用来存放 FastThreadLocal 实例的集合 variablesToRemove 在 indexedVariables 数组中的位置。集合 variablesToRemove 一般是数组第一个元素,即 Onject[0] 。

InternalThreadLocalMap

相比 Thread 中使用 ThreadLocal.ThreadLocalMap 存放 ThreadLocal 资源,FastThreadLocalThread 使用 InternalThreadLocalMap 存放 FastThreadLocal 资源。

InternalThreadLocalMap 中使用数组 indexedVariables 来存放 FastThreadLocal 变量,数组 indexedVariables 定义在 InternalThreadLocalMap 的父类 UnpaddedInternalThreadLocalMap 中。

构造函数在初始化时,会开辟一个 32 元素的空间,并填充 UNSET。由于 FastThreadLocal 的 index 是递增的,FastThreadLocal 在数组中不一定是连续存放的,可能中间会有 UNSET。

通常 CPU 的缓存行一般是 64 或 128 字节,为了防止 InternalThreadLocalMap 的不同实例被加载到同一个缓存行,需要多余填充一些字段,使得每个实例的大小超出缓存行的大小。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
public final class InternalThreadLocalMap extends UnpaddedInternalThreadLocalMap {

private static final int DEFAULT_ARRAY_LIST_INITIAL_CAPACITY = 8;
private static final int STRING_BUILDER_INITIAL_SIZE;
private static final int STRING_BUILDER_MAX_SIZE;
private static final int HANDLER_SHARABLE_CACHE_INITIAL_CAPACITY = 4;
private static final int INDEXED_VARIABLE_TABLE_INITIAL_SIZE = 32;

// 占位符
public static final Object UNSET = new Object();

private BitSet cleanerFlags;

// 用于字节填充,防止 InternalThreadLocalMap 的不同实例被加载到同一个缓存行
public long rp1, rp2, rp3, rp4, rp5, rp6, rp7, rp8, rp9;

private InternalThreadLocalMap() {
super(newIndexedVariableTable());
}

private static Object[] newIndexedVariableTable() {
// 初始容量
Object[] array = new Object[INDEXED_VARIABLE_TABLE_INITIAL_SIZE];
// 填充占位符
Arrays.fill(array, UNSET);
return array;
}

//...

}

UnpaddedInternalThreadLocalMap

InternalThreadLocalMap 继承于 UnpaddedInternalThreadLocalMap。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class UnpaddedInternalThreadLocalMap {

// 使用普通线程时,使用 ThreadLocal 存放当前线程的 InternalThreadLocalMap 实例
static final ThreadLocal<InternalThreadLocalMap> slowThreadLocalMap = new ThreadLocal<InternalThreadLocalMap>();
// 原子类
static final AtomicInteger nextIndex = new AtomicInteger();

// 主要用于 InternalThreadLocalMap 中存放线程本地变量
Object[] indexedVariables;

// ...

UnpaddedInternalThreadLocalMap(Object[] indexedVariables) {
this.indexedVariables = indexedVariables;
}
}

FastThreadLocalThread

FastThreadLocal 类中定义了许多方法,如 get()set()remove(),以及线程结束执行清理的 removeAll() 方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
public class FastThreadLocalThread extends Thread {
// This will be set to true if we have a chance to wrap the Runnable.
private final boolean cleanupFastThreadLocals;

// InternalThreadLocalMap 变量
private InternalThreadLocalMap threadLocalMap;

public FastThreadLocalThread() {
cleanupFastThreadLocals = false;
}

public FastThreadLocalThread(Runnable target) {
super(FastThreadLocalRunnable.wrap(target));
cleanupFastThreadLocals = true;
}

// ...
}

FastThreadLocal.get()

1
2
3
4
5
6
7
8
9
10
public final V get() {
InternalThreadLocalMap threadLocalMap = InternalThreadLocalMap.get();
// 用本 FastThreadLocal 实例的 index 去 indexedVariables 数组中取数据
Object v = threadLocalMap.indexedVariable(index);
if (v != InternalThreadLocalMap.UNSET) {
return (V) v;
}
// 没有,则初始化并返回
return initialize(threadLocalMap);
}

FastThreadLocal.set()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34

// 为当前线程设值
public final void set(V value) {
if (value != InternalThreadLocalMap.UNSET) {
InternalThreadLocalMap threadLocalMap = InternalThreadLocalMap.get();
setKnownNotUnset(threadLocalMap, value);
} else {
remove();
}
}

private void setKnownNotUnset(InternalThreadLocalMap threadLocalMap, V value) {
if (threadLocalMap.setIndexedVariable(index, value)) {
// 添加到 Set
addToVariablesToRemove(threadLocalMap, this);
}
}

private static void addToVariablesToRemove(InternalThreadLocalMap threadLocalMap, FastThreadLocal<?> variable) {
//从 variablesToRemoveIndex 下标处获取 variablesToRemove 集合
Object v = threadLocalMap.indexedVariable(variablesToRemoveIndex);
Set<FastThreadLocal<?>> variablesToRemove;
if (v == InternalThreadLocalMap.UNSET || v == null) {
// 本线程第一次添加 FastThreadLocal 实例,创建 Set 存放 FastThreadLocal 实例
variablesToRemove = Collections.newSetFromMap(new IdentityHashMap<FastThreadLocal<?>, Boolean>());
// 在数组 indexVariables 指定下标 index 处存放数据
threadLocalMap.setIndexedVariable(variablesToRemoveIndex, variablesToRemove);
} else {
// 非第一次添加
variablesToRemove = (Set<FastThreadLocal<?>>) v;
}
// 加入 FastThreadLocal 实例
variablesToRemove.add(variable);
}

FastThreadLocal.remove()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
public final void remove(InternalThreadLocalMap threadLocalMap) {
if (threadLocalMap == null) {
return;
}

// 设置数组 indexVariables 指定下标 index 处的数据为占位符 UNSET,并返回原数据
Object v = threadLocalMap.removeIndexedVariable(index);
// 从 Set 中删除自身实例
removeFromVariablesToRemove(threadLocalMap, this);

if (v != InternalThreadLocalMap.UNSET) {
try {
// 空方法,供子类调用
onRemoval((V) v);
} catch (Exception e) {
PlatformDependent.throwException(e);
}
}
}

内存泄漏问题的解决

通过装饰器设计模式,FastThreadLocalRunnable 类将一个普通的 Runnable 对象进行装饰,之后在调用 FastThreadLocalRunnable.run() 的时候,实际上会调用真实对象(即普通的 Runnable 对象)的 run() 方法,执行完成之后,会执行 FastThreadLocal.removeAll() 方法,进行对当前线程的回收操作,删除当前线程上的 InternalThreadLocalMap 本身及内部数据,从而解决内存泄漏问题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23

final class FastThreadLocalRunnable implements Runnable {
private final Runnable runnable;

private FastThreadLocalRunnable(Runnable runnable) {
this.runnable = ObjectUtil.checkNotNull(runnable, "runnable");
}

@Override
public void run() {
try {
runnable.run();
} finally {
// 清理数据,防止内存泄漏
FastThreadLocal.removeAll();
}
}

// 装饰 Runnable 对象
static Runnable wrap(Runnable runnable) {
return runnable instanceof FastThreadLocalRunnable ? runnable : new FastThreadLocalRunnable(runnable);
}
}

参考

  1. Java进阶(七)正确理解Thread Local的原理与适用场景
  2. 一篇文章,从源码深入详解ThreadLocal内存泄漏问题
  3. Netty 之线程本地变量 FastThreadLocal
  4. Netty源码分析– FastThreadLocal分析(十)