MyBatis-Plus多数据源dynamic-datasource解决多线程情境下数据源切换失效问题

前言:项目中使用MyBatis-Plus多数据源dynamic-datasource,完成多数据源的切换;但是在并发场景下,我们会发现线程会一直访问默认数据源(配置的Master数据),并没有访问我们在上一步切换后的数据源,之前切换的数据源失效了;显然多数据源对于并发的处理并不友好,那么我们怎么解决这个问题呢。

本文是在springboot项目已集成dynamic-datasource 基础上延伸的问题,项目集成多数据源可以参考:Idea+maven+spring-cloud项目搭建系列–13 整合MyBatis-Plus多数据源dynamic-datasource

1 问题产生的原因:

问题的产生来源于多数据源com.baomidou.dynamic.datasource.toolkit 包下DynamicDataSourceContextHolder 类的问题,当我们打开这个类,会发现,存储当前线程的数据源使用了 ThreadLocal:

package com.baomidou.dynamic.datasource.toolkit;

import java.util.ArrayDeque;
import java.util.Deque;
import org.springframework.core.NamedThreadLocal;
import org.springframework.util.StringUtils;

public final class DynamicDataSourceContextHolder {
	// 线程数据源的存储
    private static final ThreadLocal<Deque<String>> LOOKUP_KEY_HOLDER = new NamedThreadLocal<Deque<String>>("dynamic-datasource") {
        protected Deque<String> initialValue() {
            return new ArrayDeque();
        }
    };

    private DynamicDataSourceContextHolder() {
    }

    public static String peek() {
    	//  访问数据库时 从队列中peek 出来数据源
        return (String)((Deque)LOOKUP_KEY_HOLDER.get()).peek();
    }
	// 放入要切换的数据源
    public static String push(String ds) {
        String dataSourceStr = StringUtils.isEmpty(ds) ? "" : ds;
        ((Deque)LOOKUP_KEY_HOLDER.get()).push(dataSourceStr);
        return dataSourceStr;
    }
	// 从队列获取数据源
    public static void poll() {
        Deque<String> deque = (Deque)LOOKUP_KEY_HOLDER.get();
        deque.poll();
        if (deque.isEmpty()) {
            LOOKUP_KEY_HOLDER.remove();
        }

    }
	// 清除数据源
    public static void clear() {
        LOOKUP_KEY_HOLDER.remove();
    }
}

再来看下 NamedThreadLocal:

// 此处可以看到继承了 ThreadLocal 类
public class NamedThreadLocal<T> extends ThreadLocal<T> {

	private final String name;


	/**
	 * Create a new NamedThreadLocal with the given name.
	 * @param name a descriptive name for this ThreadLocal
	 */
	public NamedThreadLocal(String name) {
		Assert.hasText(name, "Name must not be empty");
		this.name = name;
	}

	@Override
	public String toString() {
		return this.name;
	}

}

简单概况下数据源的切换流程:
当我们进行数据源切换的时候,实际上是向当前线程所持有的LOOKUP_KEY_HOLDER 的ThreadLocal 对象放入数据源,这样在当前线程在进行数据库访问的时候,会得到当前的数据源,然后找到对应的jdbc 连接,完成数据的访问;
因为LOOKUP_KEY_HOLDER 对象是用ThreadLocal 修饰的,也就是说它是线程隔离的,所以当我们在切换完数据源之后,在子线程中维护的LOOKUP_KEY_HOLDER 是空的,再找不到数据源的情况下,就访问到了默认的数据源;

2 问题处理的思路:

既然是由于线程中保存数据源是每个线程隔离的,要想在并发的情形下仍然可以正常的数据源切换,要就需要打破其隔离性:
解决思路1:在开启线程执行任务时 ,先获取到父线程的数据源,然后在子线程内手动完成数据源的切换,保证子父线程数据源的一致性;
解决思路2:在项目中创建一个特殊的线程池,当有任务的执行时,进行拦截,获取父线程的数据源然后手动进行数据源的切换;
解决思路3:项目中覆盖DynamicDataSourceContextHolder 类修改LOOKUP_KEY_HOLDER 的对象,使得子线程在执行任务时,可以拿到父线程的数据源标识,这样也可以保证,子父线程访问数据源的一致性;改方法可以在不入侵原有业务代码的情况下,在业务开发者无感知的情况下,做到统一拦截并进行代理,完成父类数据源的传递;

3 问题解决的办法:

3.1 针对于解决思路1:
在执行线程任务时,进行手动的切换 demo:

// 获取当前父线程的数据源
String parentDb = "";
  new Thread(()->{
         // 切换数据源
        DynamicDataSourceContextHolder.push(parentDb ); 
       try {
       		// do some thing
		  }finally {
		  	// 最后移除数据源
		  	 DynamicDataSourceContextHolder.clear();
		  }    
  }).start();
// 在子线程执行任务时

3.2 针对于解决思路2:
创建一个线程池,当执行任务时,都使用改线程池:
线程配置类:TaskExecutionConfig

import org.springframework.context.annotation.*;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;

import java.util.concurrent.ThreadPoolExecutor;

@Configuration
public class TaskExecutionConfig {
    // cpu 核心数
    private static final int DEFAULT_THREADS = Math.max(1, Runtime.getRuntime().availableProcessors());


    @Primary
    @Bean(name = {"taskHolderExecutorProxy", "executor"})
    public TaskHolderExecutorProxy threadPoolTaskExecutor() {
        ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor();

        threadPoolTaskExecutor.setCorePoolSize(DEFAULT_THREADS);
        threadPoolTaskExecutor.setMaxPoolSize(DEFAULT_THREADS << 1);
        threadPoolTaskExecutor.setQueueCapacity(Integer.MAX_VALUE);
        threadPoolTaskExecutor.setKeepAliveSeconds(120);
        threadPoolTaskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.DiscardPolicy());
        threadPoolTaskExecutor.initialize();
        return new TaskHolderExecutorProxy(threadPoolTaskExecutor);
    }
}

线程执行任务时进行拦截进行数据源切换:TaskHolderExecutorProxy

import org.springframework.web.context.request.RequestAttributes;
import org.springframework.web.context.request.RequestContextHolder;

import java.util.concurrent.Executor;

/**
 * 使用多线程并行查询时,非主线程 尝试获取 用户上下文 (即httpServletRequest)时
 * 用户上下文为空,会导致 使用多线程查询的服务 无法使用多租户功能
 * 所以这个proxy在提交任务到线程池之前先保存线程的上下文,
 * 这样非主线程也能拿到主线程的用户上下文,从而使用多租户
 */

public class TaskHolderExecutorProxy implements Executor {

    /**
     * 被代理的线程池
     */
    private final Executor executor;

    public TaskHolderExecutorProxy(Executor executor) {
        this.executor = executor;
    }

    @Override
    public void execute(Runnable command) {
//		保存主线程的 用户上下文
       // RequestAttributes requestAttributes = RequestContextHolder.currentRequestAttributes();
		// 获取当前父线程的数据源
		String parentDb = "";
        executor.execute(() -> {

//			为线程池 设置用户上下文
          //  RequestContextHolder.setRequestAttributes(requestAttributes);
            // 切换数据源
            DynamicDataSourceContextHolder.push(parentDb ); 
            try {
                command.run();
            } finally {
//				清理线程池线程的上下文
//                RequestContextHolder.resetRequestAttributes();
 				// 最后移除数据源
		  		 DynamicDataSourceContextHolder.clear();
            }
        });

    }
}

3.3 针对于解决思路3:重写DynamicDataSourceContextHolder 类覆盖掉MyBatis-Plus 原有的类,并进行代理,在子线程任务执行之前放入父线程的数据源标识,并在子线程任务执行结束之后移除改数据源标识:
3.3.1 首先需要引入一个阿里的jar ,让其可以帮助我们将父线程ThreadLocal 修饰的常量,可以继承到子线程中:

  <!-- https://mvnrepository.com/artifact/com.alibaba/transmittable-thread-local -->
   <dependency>
        <groupId>com.alibaba</groupId>
        <artifactId>transmittable-thread-local</artifactId>
        <version>2.12.1</version>
    </dependency>

3.3.2 重写 DynamicDataSourceContextHolder 类:
我们需要在项目中创建一个路径和MyBatis-Plus 下 DynamicDataSourceContextHolder 类 路径相同,类名相同的DynamicDataSourceContextHolder 类:
在这里插入图片描述
DynamicDataSourceContextHolder 中我们重新定义LOOKUP_KEY_HOLDER

package com.baomidou.dynamic.datasource.toolkit;

import org.springframework.util.StringUtils;

import java.util.ArrayDeque;
import java.util.Deque;

public class DynamicDataSourceContextHolder {
    private static final ThreadLocal<Deque<String>> LOOKUP_KEY_HOLDER = new ChildThreadTreadLocal<Deque<String>>("dynamic-datasource") {
        protected Deque<String> initialValue() {
            return new ArrayDeque();
        }
    };

    private DynamicDataSourceContextHolder() {
    }

    public static String peek() {
        return (String)((Deque)LOOKUP_KEY_HOLDER.get()).peek();
    }

    public static String push(String ds) {
        String dataSourceStr = StringUtils.isEmpty(ds) ? "" : ds;
        ((Deque)LOOKUP_KEY_HOLDER.get()).push(dataSourceStr);
        return dataSourceStr;
    }

    public static void poll() {
        Deque<String> deque = (Deque)LOOKUP_KEY_HOLDER.get();
        deque.poll();
        if (deque.isEmpty()) {
            LOOKUP_KEY_HOLDER.remove();
        }

    }

    public static void clear() {
        LOOKUP_KEY_HOLDER.remove();
    }
}

相同包路径下定义ChildThreadTreadLocal类:在该类中我们继承TransmittableThreadLocal 类帮我进行父子线程数据的传递

package com.baomidou.dynamic.datasource.toolkit;

import com.alibaba.ttl.TransmittableThreadLocal;
import org.springframework.util.Assert;

public class ChildThreadTreadLocal<T> extends TransmittableThreadLocal {
    private final String name;

    public ChildThreadTreadLocal(String name) {
        Assert.hasText(name, "Name must not be empty");
        this.name = name;
    }

    public String toString() {
        return this.name;
    }
}

3.3.3 对项目中所以线程任务的执行增加代理
在需要代理的项目跟路径下放入之前pom 下载到maven 仓库的transmittable-thread-local-2.12.1.jar 包
新建buildlocal 文件包,并放入transmittable-thread-local-2.12.1.jar 包:
在这里插入图片描述
3.3.4 项目启动的jvm 参数增加代理:
在这里插入图片描述
-javaagent:xxxx/buildlocal/transmittable-thread-local-2.12.1.jar

3.3.5 对于线上部署docker 时 ,在doker 容器启动时增加代理:
在这里插入图片描述
4 总结:

  • 针对方法1和方法2:都需要侵入代码进行数据源的切换和移除;
  • 针对方法3 因为重新了DynamicDataSourceContextHolder 并且对数据源对象LOOKUP_KEY_HOLDER 使用TransmittableThreadLocal 进行修饰,当启动项目是使用-javaagent:完成代理后,每次在子线程进行任务执行时子线程都可以获取到父线程中的数据源,从而保证了子父线程数据源的一致性,并且该方法不需要入侵原有的业务代码;

5 扩展:
在项目开启-javaagent:xxxx/buildlocal/transmittable-thread-local-2.12.1.jar 线程的代理后,测试ThreadLocal 数据的可见性:


import com.alibaba.ttl.TransmittableThreadLocal;
import com.cric.zhongjian.common.datasource.Master;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;

import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

@RestController
public class ThreadTestController {
    private final static ThreadLocal<String> threadLocal1 = new ThreadLocal<>();
    private final static ThreadLocal<String> threadLocal2 = new InheritableThreadLocal<>();
    private final static ThreadLocal<String> threadLocal3= new TransmittableThreadLocal<>();

    @Master
    @GetMapping("threadlocal")
    public void testThread() throws InterruptedException {
        List<String > a = Arrays.asList("100,200".split(","));
        threadLocal1.set("x");
        threadLocal2.set("y");
        threadLocal3.set("z");

        new Thread(()->{
            System.out.println( Thread.currentThread().getId()+":"+ Thread.currentThread().getName());
            System.out.println("ExecutorServicex threadLocal1.get() = " + threadLocal1.get());
            System.out.println("ExecutorServicex threadLocal2.get() = " + threadLocal2.get());
            System.out.println("ExecutorServicex threadLocal3.get() = " + threadLocal3.get());
            System.out.println("================== ");
        }).start();
        Thread.sleep(1000);

        threadLocal1.set("a");
        threadLocal2.set("b");
        threadLocal3.set("c");
        ExecutorService fixedThreadPool = Executors.newFixedThreadPool(1);
        fixedThreadPool.submit(()->{
            System.out.println( Thread.currentThread().getId()+":"+ Thread.currentThread().getName());
            System.out.println("ExecutorService1 threadLocal1.get() = " + threadLocal1.get());
            System.out.println("ExecutorService1 threadLocal2.get() = " + threadLocal2.get());
            System.out.println("ExecutorService1 threadLocal3.get() = " + threadLocal3.get());
            System.out.println("================== ");
        });
        Thread.sleep(1000);
        threadLocal1.set("1");
        threadLocal2.set("2");
        threadLocal3.set("3");
        fixedThreadPool.submit(()->{
            System.out.println( Thread.currentThread().getId()+":"+ Thread.currentThread().getName());
            System.out.println("ExecutorService2 threadLocal1.get() = " + threadLocal1.get());
            System.out.println("ExecutorService2 threadLocal2.get() = " + threadLocal2.get());
            System.out.println("ExecutorService2 threadLocal3.get() = " + threadLocal3.get());
            System.out.println("================== ");
        });

        Thread.sleep(1000);
        threadLocal1.set("aa");
        threadLocal2.set("bb");
        threadLocal3.set("cc");
        a.parallelStream().forEach(e->{
            System.out.println(Thread.currentThread().getName()+":parallelStream threadLocal1.get() = " + threadLocal1.get());
            System.out.println(Thread.currentThread().getName()+":parallelStream threadLocal2.get() = " + threadLocal2.get());
            System.out.println(Thread.currentThread().getName()+":parallelStream threadLocal3.get() = " + threadLocal3.get());
            System.out.println("================== ");

        });


    }
}

测试结果:

160:Thread-30
ExecutorServicex threadLocal1.get() = null
ExecutorServicex threadLocal2.get() = y
ExecutorServicex threadLocal3.get() = z
================== 
161:pool-9-thread-1
ExecutorService1 threadLocal1.get() = null
ExecutorService1 threadLocal2.get() = b
ExecutorService1 threadLocal3.get() = c
================== 
161:pool-9-thread-1
ExecutorService2 threadLocal1.get() = null
ExecutorService2 threadLocal2.get() = b
ExecutorService2 threadLocal3.get() = 3
================== 
http-nio-9201-exec-2:parallelStream threadLocal1.get() = aa
http-nio-9201-exec-2:parallelStream threadLocal2.get() = bb
http-nio-9201-exec-2:parallelStream threadLocal3.get() = cc
================== 
ForkJoinPool.commonPool-worker-3:parallelStream threadLocal1.get() = null
ForkJoinPool.commonPool-worker-3:parallelStream threadLocal2.get() = bb
ForkJoinPool.commonPool-worker-3:parallelStream threadLocal3.get() = cc
================== 

可以看到当使用TransmittableThreadLocal 修饰后,在项目中进行子线程任务的执行时,子线程都可以拿到父线程的ThreadLocal 数据;

6 参考:
6.1 TransmittableThreadLocal的使用及原理解析;
6.2 springboot springmvc 拦截线程池线程执行业务逻辑;