前情提要

在SpringBoot中使用多线程进行业务处理时,难免需要在子线程中使用到系统的上下文获取到其中的某些信息,例如从SecurityContextHolder中获取用户验证信息,或者是从RequestContextHolder中获取到请求信息,但是,在多线程中直接这么调用会报一个空指针异常,换句话说就是主线程中的这些上下文信息无法传递给子线程进行共享。

解决方案
使用ThreadLocal

ThreadLocal是线程独享的,可以在主线程中获取到对应的上下文信息然后设置到threadLocal中,最后在子线程中使用,这样,子线程中的ThreadLocal都是独享的,互不干扰。但是缺点也很明显,那就是每次设计多线程使用都需要提前获取之后设置到ThreadLocal,然后再由子线程使用,代码重复性比较高,这里就不做介绍了

使用TaskDecorator

TaskDecorator是一个接口,作为线程的一个装饰器,可以再原有线程的基础上对线程做相应的增强。具体实现就是在异步线程池的定义中加入一个实现了TaskDecorator的静态内部类,然后将其实例对象设置到线程池中,以后执行多线程时使用该线程池去执行即可

package com.we.applet.wifi.config;

import com.we.applet.wifi.common.Constant;
import com.we.applet.wifi.utils.LogUtil;
import org.springframework.aop.interceptor.AsyncExecutionAspectSupport;
import org.springframework.aop.interceptor.AsyncUncaughtExceptionHandler;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.task.TaskDecorator;
import org.springframework.scheduling.annotation.AsyncConfigurer;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.annotation.Resource;
import java.util.concurrent.Executor;
import java.util.concurrent.RejectedExecutionHandler;

/**
 * @Author: 张定辉
 * @CreateDate: 2022/12/12
 * @Description: 线程池配置
 */
@Configuration
public class AsyncConfig implements AsyncConfigurer {
    /**
     * 内部类,解决父子线程之间上下文(例如SecurityContext,ServletRequestAttributes等)不能共享问题
     */
    static class ContextTaskDecorator implements TaskDecorator {
        @Override
        public Runnable decorate(Runnable runnable) {
            //先拿到主线程的上下文对象
            SecurityContext securityContext = SecurityContextHolder.getContext();
            ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
            return()->{
                try{
                    //在子线程中重新设置进去
                    SecurityContextHolder.setContext(securityContext);
                    RequestContextHolder.setRequestAttributes(servletRequestAttributes,true);
                    runnable.run();
                }finally {
                    //使用完成之后清除子线程中的上下文,避免内存泄露
                    SecurityContextHolder.clearContext();
                    RequestContextHolder.resetRequestAttributes();
                }
            };
        }
    }


    @Resource
    private Constant constant;
    /**
     * 异步任务使用的线程池
     *
     * @return 线程池
     */
    @Bean(name = AsyncExecutionAspectSupport.DEFAULT_TASK_EXECUTOR_BEAN_NAME)
    @Override
    public Executor getAsyncExecutor() {
        return createThreadPool(constant.getDefaultThreadPoolCore(),
                                constant.getDefaultThreadPoolMax(),
                                constant.getDefaultThreadPoolKeep(),
                                constant.getDefaultThreadPoolQueue(),
                                constant.getDefaultThreadPoolPrefix());
    }

    private Executor createThreadPool(int core,int max,int keep,int queue,String prefix){
        ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
        executor.setCorePoolSize(core);
        executor.setMaxPoolSize(max);
        executor.setKeepAliveSeconds(keep);
        executor.setThreadNamePrefix(prefix);
        executor.setQueueCapacity(queue);
        executor.setRejectedExecutionHandler(customizeRejectedExecutionHandler());
        //设置线程装饰器,解决父子线程之间上下文无法传递问题
        executor.setTaskDecorator(new ContextTaskDecorator());
        executor.initialize();
        return executor;
    }
    /**
     * 线程运行时抛出异常处理逻辑
     */
    @Override
    public AsyncUncaughtExceptionHandler getAsyncUncaughtExceptionHandler() {
        return (throwable, method, objects) -> {
            LogUtil.error("====================捕获线程异常=================");
            LogUtil.error("错误信息:{}", throwable.getMessage());
            LogUtil.error("调用的方法:{}", method.getName());
            LogUtil.error("参数列表:{}", objects);
            LogUtil.error("===============================================");
        };
    }

    /**
     * 自定义线程池拒绝策略
     */
    public RejectedExecutionHandler customizeRejectedExecutionHandler() {
        return (r, executor) -> {
            if (!executor.isShutdown()) {
                try {
                    LogUtil.info("线程池已满,线程进入休眠");
                    //线程休眠5秒
                    Thread.sleep(5000);
                } catch (InterruptedException e) {
                    LogUtil.error(e.getMessage());
                }
                //再尝试入队
                executor.execute(r);
            }
        };
    }
}