MongoTemplate很好用,但是基于xml注册为Bean时只能绑定在一个database上。

遇到需要支撑多个database的项目或动态切换database的项目就非常难受了。

解决的思路是把MongoTemplate放在Map中缓存起来,由于MongoTemplate内部实现了连接池,所以不用再关心池的概念。

把管理容器的类声明为Spring的组件,这样一来就可以通过@Value引入properties文件中的属性

使用LocalThread来确保本地线程的安全,避免多线程并发调用时导致的结果不一致。

import com.mongodb.*;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.mongodb.MongoDbFactory;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.SimpleMongoDbFactory;
import org.springframework.data.mongodb.core.convert.DefaultDbRefResolver;
import org.springframework.data.mongodb.core.convert.DefaultMongoTypeMapper;
import org.springframework.data.mongodb.core.convert.MappingMongoConverter;
import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.Update;
import org.springframework.stereotype.Repository;

import java.net.UnknownHostException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @author ParanoidCAT
 * @since JDK 1.8
 */
@Repository(value = "mongoRepository")
public class MongoRepository {
    @Value("${mongo.host}")
    private String host;
    @Value("${mongo.port}")
    private Integer port;
    @Value("${mongo.username}")
    private String username;
    @Value("${mongo.password}")
    private String password;
    @Value("${mongo.database}")
    private String database;
    @Value("${mongo.connectionsPerHost}")
    private Integer connectionsPerHost;
    @Value("${mongo.threadsAllowedToBlockForConnectionMultiplier}")
    private Integer threadsAllowedToBlockForConnectionMultiplier;
    @Value("${mongo.connectTimeout}")
    private Integer connectTimeout;
    @Value("${mongo.maxWaitTime}")
    private Integer maxWaitTime;
    @Value("${mongo.socketTimeout}")
    private Integer socketTimeout;
    @Value("${mongo.socketKeepAlive}")
    private Boolean socketKeepAlive;
    private ThreadLocal<MongoTemplate> threadLocal = new ThreadLocal<>();
    private static final Map<String, MongoTemplate> MONGO_TEMPLATE_CACHE = new ConcurrentHashMap<>(16);

    private void changeDatabase(String databaseName) {
        if (Optional.ofNullable(threadLocal.get()).map(MongoTemplate::getDb).map(DB::getName).orElse(database).equals(databaseName)) {
            return;
        }
        if (MONGO_TEMPLATE_CACHE.containsKey(databaseName)) {
            threadLocal.remove();
            threadLocal.set(MONGO_TEMPLATE_CACHE.get(databaseName));
            return;
        }
        synchronized (MONGO_TEMPLATE_CACHE) {
            if (MONGO_TEMPLATE_CACHE.containsKey(databaseName)) {
                changeDatabase(databaseName);
            } else {
                threadLocal.remove();
                try {
                    threadLocal.set(createMongoTemplate(databaseName));
                    MONGO_TEMPLATE_CACHE.putIfAbsent(databaseName, threadLocal.get());
                } catch (Exception e) {
                    // TODO 输出日志
                    System.out.println(e.toString());
                }
            }
        }
    }

    private MongoTemplate createMongoTemplate(String databaseName) throws UnknownHostException {
        MongoClient mongoClient = new MongoClient(
                Collections.singletonList(new ServerAddress(host, port)),
                Collections.singletonList(MongoCredential.createCredential(username, database, password.toCharArray())),
                new MongoClientOptions
                        .Builder()
                        .connectionsPerHost(connectionsPerHost)
                        .threadsAllowedToBlockForConnectionMultiplier(threadsAllowedToBlockForConnectionMultiplier)
                        .connectTimeout(connectTimeout)
                        .maxWaitTime(maxWaitTime)
                        .socketTimeout(socketTimeout)
                        .socketKeepAlive(socketKeepAlive)
                        .cursorFinalizerEnabled(true)
                        .build()
        );
        MongoDbFactory mongoDbFactory = new SimpleMongoDbFactory(mongoClient, databaseName);
        MappingMongoConverter mappingMongoConverter = new MappingMongoConverter(new DefaultDbRefResolver(mongoDbFactory), new MongoMappingContext());
        mappingMongoConverter.setTypeMapper(new DefaultMongoTypeMapper(null));
        return new MongoTemplate(mongoDbFactory, mappingMongoConverter);
    }

    /**
     * 插入一条记录
     *
     * @param databaseName 数据库名
     * @param t            实例
     * @param <T>          实例所属的类
     */
    public <T> void insert(String databaseName, T t) {
        changeDatabase(databaseName);
        threadLocal.get().insert(t);
    }

    /**
     * 插入一条记录
     *
     * @param databaseName   数据库名
     * @param collectionName 集合名
     * @param t              实例
     * @param <T>            实例所属的类
     */
    public <T> void insert(String databaseName, String collectionName, T t) {
        changeDatabase(databaseName);
        threadLocal.get().insert(t, collectionName);
    }

    /**
     * 插入多条记录
     *
     * @param databaseName 数据库名
     * @param tClass       实例的class
     * @param tList        实例
     * @param <T>          实例所属的类
     */
    public <T> void insertAll(String databaseName, Class<T> tClass, List<T> tList) {
        changeDatabase(databaseName);
        threadLocal.get().insert(tList, tClass);
    }

    /**
     * 插入多条记录
     *
     * @param databaseName   数据库名
     * @param collectionName 集合名
     * @param tList          实例
     * @param <T>            实例所属的类
     */
    public <T> void insertAll(String databaseName, String collectionName, List<T> tList) {
        changeDatabase(databaseName);
        threadLocal.get().insert(tList, collectionName);
    }

    /**
     * 移除一条或多条记录
     *
     * @param databaseName 数据库名
     * @param tClass       实例的class
     * @param query        查询条件
     * @param <T>          实例所属的类
     * @return
     */
    public <T> long remove(String databaseName, Class<T> tClass, Query query) {
        changeDatabase(databaseName);
        return threadLocal.get().remove(query, tClass).getN();
    }

    /**
     * 移除一条或多条记录
     *
     * @param databaseName   数据库名
     * @param collectionName 集合名
     * @param query          查询条件
     * @return 受影响的记录条数
     */
    public long remove(String databaseName, String collectionName, Query query) {
        changeDatabase(databaseName);
        return threadLocal.get().remove(query, collectionName).getN();
    }

    /**
     * 更新多条记录
     *
     * @param databaseName 数据库名
     * @param tClass       实例的class
     * @param query        查询条件
     * @param update       更新内容
     * @param <T>          实例所属的类
     * @return 受影响的记录条数
     */
    public <T> long updateMulti(String databaseName, Class<T> tClass, Query query, Update update) {
        changeDatabase(databaseName);
        return threadLocal.get().updateMulti(query, update, tClass).getN();
    }

    /**
     * 更新多条记录
     *
     * @param databaseName   数据库名
     * @param collectionName 集合名
     * @param query          查询条件
     * @param update         更新内容
     * @return 受影响的记录条数
     */
    public long updateMulti(String databaseName, String collectionName, Query query, Update update) {
        changeDatabase(databaseName);
        return threadLocal.get().updateMulti(query, update, collectionName).getN();
    }

    /**
     * 查询多条记录
     *
     * @param databaseName 数据库名
     * @param tClass       实例的class
     * @param query        查询条件
     * @param <T>          实例所属的类
     * @return 实例
     */
    public <T> List<T> find(String databaseName, Class<T> tClass, Query query) {
        changeDatabase(databaseName);
        return threadLocal.get().find(query, tClass);
    }

    /**
     * 查询多条记录
     *
     * @param databaseName   数据库名
     * @param collectionName 集合名
     * @param query          查询条件
     * @param tClass         实例的class
     * @param <T>            实例所属的类
     * @return 实例
     */
    public <T> List<T> find(String databaseName, String collectionName, Query query, Class<T> tClass) {
        changeDatabase(databaseName);
        return threadLocal.get().find(query, tClass, collectionName);
    }

    /**
     * 查询第一条记录
     *
     * @param databaseName 数据库名
     * @param tClass       实例的class
     * @param query        查询条件
     * @param <T>          实例所属的类
     * @return 实例
     */
    public <T> T findOne(String databaseName, Class<T> tClass, Query query) {
        changeDatabase(databaseName);
        return threadLocal.get().findOne(query, tClass);
    }

    /**
     * 查询第一条记录
     *
     * @param databaseName   数据库名
     * @param tClass         实例的class
     * @param collectionName 集合名
     * @param query          查询条件
     * @param <T>            实例所属的类
     * @return 实例
     */
    public <T> T findOne(String databaseName, Class<T> tClass, String collectionName, Query query) {
        changeDatabase(databaseName);
        return threadLocal.get().findOne(query, tClass, collectionName);
    }
}

测试类:

40个线程同时并发,测试多线程调用是否安全

这里我在五个database中放入了五个名称为"test"的collection,每个collection里放了一个{"name":"数据库名"}的Document用于测试

import com.mdruby.repository.MongoRepository;
import com.mongodb.BasicDBObject;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.ResponseBody;

import java.util.Optional;
import java.util.concurrent.CountDownLatch;

/**
 * @author ParanoidCAT
 * @since JDK 1.8
 */
@Controller(value = "testController")
@RequestMapping(value = {"test"})
public class TestController {
    @Autowired
    private MongoRepository mongoRepository;
    private static final CountDownLatch COUNT_DOWN_LATCH = new CountDownLatch(40);

    @RequestMapping(value = {"/run/{databaseName}"}, method = {RequestMethod.GET}, produces = {"application/json;charset=utf-8"})
    @ResponseBody
    public void run(@PathVariable String databaseName) throws InterruptedException {
        // 线程计数+1
        COUNT_DOWN_LATCH.countDown();
        // 线程没到40个就等等
        COUNT_DOWN_LATCH.await();
        // 线程如果到了40个就一起放行,每个线程执行150次query
        for (int i = 0; i < 150; i++) {
            BasicDBObject basicDBObject = mongoRepository.findOne(databaseName, BasicDBObject.class, "test", new Query());
            if (!databaseName.equals(Optional.ofNullable(basicDBObject).map(basicDBObject1 -> basicDBObject1.getString("name")).orElse("testString"))) {
                System.out.println(Thread.currentThread().getName() + ": " + databaseName + " - " + basicDBObject);
            }
        }
    }
}