上一篇中老吕介绍了ThreadLocal线程数据绑定的原理,今天聊聊父子线程之间如何继承ThreadLocal上维护的数据。

        开发过程中异步执行任务有两种情况,第一种情况是 主线程 通过 new Thread()的方式产生了一个子线程,然后把 task 交给子线程去执行;第二种情况是主线程将task提交到线程池去执行。不同的情况需要不同的方案解决。

第一种情况:

通过InheritableThreadLocal来代替ThreadLocal

先写个例子测试下:

public class Test1 {
    public static void main(String[] args) {
        InheritableThreadLocal<String> threadLocal = new InheritableThreadLocal<>();
        threadLocal.set("userId-1001");
        System.out.println("父线程set:userId-1001");
        Thread thread = new Thread(new Runnable() {
            @Override
            public void run() {
                System.out.println("子线程get:"+threadLocal.get());
            }
        });
        thread.start();
        System.out.println("over");
        while (true);
    }
}


测试结果正常实现预期目标:
父线程set:userId-1001
子线程get:userId-1001
over

InheritableThreadLocal是ThreadLocal的子类,它增加的功能就是可以继承父线程上绑定的数据,下面看下它的源码是如何做到的继承:

public class InheritableThreadLocal<T> extends ThreadLocal<T> {


//这个方法的意义是在从父到子复制数据过程中,如果想修改,可以覆盖这个方法,这里没有做任何修改,直接用的父线程的值
    protected T childValue(T parentValue) {
        return parentValue;
    }
//这个方法体现了ThreadLocal类和InheritableThreadLocal类的区别,数据存储的位置不同
//在InheritableThreadLocal中数据是存储在 Thread对象的 inheritableThreadLocals中
//而ThreadLocal中数据是存储在Thread对象的 threadLocals中
    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }
  //createMap和getMap是对应的,创建map时放到 Thread对象的 inheritableThreadLocals 中 
    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

java 子线程无法获取attributes 子线程获取父线程的threadlocal_System

我们再来看下在产生子线程过程中继承父线程的数据是如何实现的

public Thread(Runnable target) {
    init(null, target, "Thread-" + nextThreadNum(), 0);
}
private void init(ThreadGroup g, Runnable target, String name,
                  long stackSize) {
    init(g, target, name, stackSize, null, true);
}


private void init(ThreadGroup g, Runnable target, String name,
                  long stackSize, AccessControlContext acc,
                  boolean inheritThreadLocals) {
    。。。省略了无关代码
    
    //这就是继承的逻辑,inheritThreadLocals默认是true,并且父线程中inheritableThreadLocals不为null
    //复制过程在 ThreadLocal.createInheritedMap 方法中
    if (inheritThreadLocals && parent.inheritableThreadLocals != null)
        this.inheritableThreadLocals =ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
    
    。。。省略了无关代码
}


//new 了一个新的 ThreadLocalMap对象
static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
    return new ThreadLocalMap(parentMap);
}


//把key!=null 的Entry 复制过来,浅复制,key共用,value也是共用
private ThreadLocalMap(ThreadLocalMap parentMap) {
    Entry[] parentTable = parentMap.table;
    int len = parentTable.length;
    setThreshold(len);
    table = new Entry[len];


    for (int j = 0; j < len; j++) {
        Entry e = parentTable[j];
        if (e != null) {
            @SuppressWarnings("unchecked")
            ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
            if (key != null) {
                //这个childValue的意义就是你想修改数据时就覆盖,
                //而在InheritableThreadLocal中是原值返回,不做任何修改
                Object value = key.childValue(e.value);
                Entry c = new Entry(key, value);
                int h = key.threadLocalHashCode & (len - 1);
                while (table[h] != null)
                    h = nextIndex(h, len);
                table[h] = c;
                size++;
            }
        }
    }
}

可以看到创建了新的ThreadLocalMap对象,新的Entry,但是里面的key和value是复用的父线程中的对象。

至此InheritableThreadLocal类就讲清楚了,它适用的情况就是 子线程必须是主线程临时创建的。

第二种情况:针对线程池这种情况

方案1:通过阿里开源组件transmittable-thread-local解决

https://github.com/alibaba/transmittable-thread-local

方案2:手撕一个简单组件解决

思路:写个代理类,代理Runnable和Callable类

/**
 * 代理类:增加ThreadLocal数据传递功能
 */
class TaskProxy<V> implements Runnable, Callable {


    private Runnable runnable;
    private Callable<V> callable;


    public TaskProxy(Runnable runnable){
        this.runnable = runnable;
        storeThreadLocal();
    }
    public TaskProxy(Callable callable){
        this.callable = callable;
        storeThreadLocal();
    }


    @Override
    public void run() {
        restoreThreadLocal();
        this.runnable.run();
        clearThreadLocal();
    }


    @Override
    public Object call() throws Exception {
        restoreThreadLocal();
        V v = this.callable.call();
        clearThreadLocal();
        return v;
    }




    //------------------------绑定的数据-----------
    private String userId;
    private String traceId;
    private void storeThreadLocal() {
        this.userId = ThreadLocalUtil.getUserId();
        this.traceId = ThreadLocalUtil.getTraceId();
    }
    private void restoreThreadLocal() {
        ThreadLocalUtil.setUserId(userId);
        ThreadLocalUtil.setTraceId(traceId);
    }
    private void clearThreadLocal() {
        ThreadLocalUtil.removeUserId();
        ThreadLocalUtil.removeTraceId();
    }


}




/**
 * ThreadLocal工具类
 */
class ThreadLocalUtil{
    
    //可以使用一个自定义 上下文DTO 来存储数据,就不需要写多个ThreadLocal了
    private static ThreadLocal<String> userIdThreadLocal = new ThreadLocal<>();
    private static ThreadLocal<String> traceIdThreadLocal = new ThreadLocal<>();
    public static void setUserId(String userId){
        userIdThreadLocal.set(userId);
    }
    public static void setTraceId(String traceId){
        traceIdThreadLocal.set(traceId);
    }


    public static String getUserId(){
        return userIdThreadLocal.get();
    }
    public static String getTraceId(){
        return traceIdThreadLocal.get();
    }


    public static void removeUserId(){
        userIdThreadLocal.remove();
    }


    public static void removeTraceId(){
        traceIdThreadLocal.remove();
    }
}






//测试下
public class Test2 {
    static ExecutorService executorService = Executors.newFixedThreadPool(1);
    {
        executorService.submit(new Runnable() {
            @Override
            public void run() {
                System.out.println("预热,产生核心线程");
            }
        });
    }
    public static void main(String[] args) {
        //-----主线程绑定数据----
        ThreadLocalUtil.setUserId("1001");
        ThreadLocalUtil.setTraceId("o98iuj76yhe3");


        //复用核心线程,未使用代理
        executorService.submit(new Runnable() {
            @Override
            public void run() {
                System.out.println("未使用代理:"+ThreadLocalUtil.getUserId());
                System.out.println("未使用代理:"+ThreadLocalUtil.getTraceId());
            }
        });


        //复用核心线程,Runnable使用代理
        executorService.submit((Runnable) new TaskProxy(new Runnable() {
            @Override
            public void run() {
                System.out.println("使用代理Runnable:"+ThreadLocalUtil.getUserId());
                System.out.println("使用代理Runnable:"+ThreadLocalUtil.getTraceId());
            }
        }));


        //复用核心线程,Callable使用代理
        executorService.submit((Callable) new TaskProxy<String>(new Callable() {
            @Override
            public String call() throws Exception {
                System.out.println("使用代理Callable:"+ThreadLocalUtil.getUserId());
                System.out.println("使用代理Callable:"+ThreadLocalUtil.getTraceId());
                return "ok";
            }
        }));
        System.out.println("over");
        while (true);
    }
}


    //测试结果--使用了代理类的都能正常传递数据
未使用代理:null
未使用代理:null
使用代理Runnable:1001
使用代理Runnable:o98iuj76yhe3
使用代理Callable:1001
使用代理Callable:o98iuj76yhe3

总结

本文详述了InheritableThreadLocal的实现原理,

ThreadLocal和InheritableThreadLocal的区别,

以及如何解决不同情况下ThreadLocal的数据传递问题,

大家可以根据自己的需要选择不同的方案。