MongoTemplate很好用,但是基于xml注册为Bean时只能绑定在一个database上。
遇到需要支撑多个database的项目或动态切换database的项目就非常难受了。
解决的思路是把MongoTemplate放在Map中缓存起来,由于MongoTemplate内部实现了连接池,所以不用再关心池的概念。
把管理容器的类声明为Spring的组件,这样一来就可以通过@Value引入properties文件中的属性
使用LocalThread来确保本地线程的安全,避免多线程并发调用时导致的结果不一致。
import com.mongodb.*;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.mongodb.MongoDbFactory;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.SimpleMongoDbFactory;
import org.springframework.data.mongodb.core.convert.DefaultDbRefResolver;
import org.springframework.data.mongodb.core.convert.DefaultMongoTypeMapper;
import org.springframework.data.mongodb.core.convert.MappingMongoConverter;
import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.Update;
import org.springframework.stereotype.Repository;
import java.net.UnknownHostException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
/**
* @author ParanoidCAT
* @since JDK 1.8
*/
@Repository(value = "mongoRepository")
public class MongoRepository {
@Value("${mongo.host}")
private String host;
@Value("${mongo.port}")
private Integer port;
@Value("${mongo.username}")
private String username;
@Value("${mongo.password}")
private String password;
@Value("${mongo.database}")
private String database;
@Value("${mongo.connectionsPerHost}")
private Integer connectionsPerHost;
@Value("${mongo.threadsAllowedToBlockForConnectionMultiplier}")
private Integer threadsAllowedToBlockForConnectionMultiplier;
@Value("${mongo.connectTimeout}")
private Integer connectTimeout;
@Value("${mongo.maxWaitTime}")
private Integer maxWaitTime;
@Value("${mongo.socketTimeout}")
private Integer socketTimeout;
@Value("${mongo.socketKeepAlive}")
private Boolean socketKeepAlive;
private ThreadLocal<MongoTemplate> threadLocal = new ThreadLocal<>();
private static final Map<String, MongoTemplate> MONGO_TEMPLATE_CACHE = new ConcurrentHashMap<>(16);
private void changeDatabase(String databaseName) {
if (Optional.ofNullable(threadLocal.get()).map(MongoTemplate::getDb).map(DB::getName).orElse(database).equals(databaseName)) {
return;
}
if (MONGO_TEMPLATE_CACHE.containsKey(databaseName)) {
threadLocal.remove();
threadLocal.set(MONGO_TEMPLATE_CACHE.get(databaseName));
return;
}
synchronized (MONGO_TEMPLATE_CACHE) {
if (MONGO_TEMPLATE_CACHE.containsKey(databaseName)) {
changeDatabase(databaseName);
} else {
threadLocal.remove();
try {
threadLocal.set(createMongoTemplate(databaseName));
MONGO_TEMPLATE_CACHE.putIfAbsent(databaseName, threadLocal.get());
} catch (Exception e) {
// TODO 输出日志
System.out.println(e.toString());
}
}
}
}
private MongoTemplate createMongoTemplate(String databaseName) throws UnknownHostException {
MongoClient mongoClient = new MongoClient(
Collections.singletonList(new ServerAddress(host, port)),
Collections.singletonList(MongoCredential.createCredential(username, database, password.toCharArray())),
new MongoClientOptions
.Builder()
.connectionsPerHost(connectionsPerHost)
.threadsAllowedToBlockForConnectionMultiplier(threadsAllowedToBlockForConnectionMultiplier)
.connectTimeout(connectTimeout)
.maxWaitTime(maxWaitTime)
.socketTimeout(socketTimeout)
.socketKeepAlive(socketKeepAlive)
.cursorFinalizerEnabled(true)
.build()
);
MongoDbFactory mongoDbFactory = new SimpleMongoDbFactory(mongoClient, databaseName);
MappingMongoConverter mappingMongoConverter = new MappingMongoConverter(new DefaultDbRefResolver(mongoDbFactory), new MongoMappingContext());
mappingMongoConverter.setTypeMapper(new DefaultMongoTypeMapper(null));
return new MongoTemplate(mongoDbFactory, mappingMongoConverter);
}
/**
* 插入一条记录
*
* @param databaseName 数据库名
* @param t 实例
* @param <T> 实例所属的类
*/
public <T> void insert(String databaseName, T t) {
changeDatabase(databaseName);
threadLocal.get().insert(t);
}
/**
* 插入一条记录
*
* @param databaseName 数据库名
* @param collectionName 集合名
* @param t 实例
* @param <T> 实例所属的类
*/
public <T> void insert(String databaseName, String collectionName, T t) {
changeDatabase(databaseName);
threadLocal.get().insert(t, collectionName);
}
/**
* 插入多条记录
*
* @param databaseName 数据库名
* @param tClass 实例的class
* @param tList 实例
* @param <T> 实例所属的类
*/
public <T> void insertAll(String databaseName, Class<T> tClass, List<T> tList) {
changeDatabase(databaseName);
threadLocal.get().insert(tList, tClass);
}
/**
* 插入多条记录
*
* @param databaseName 数据库名
* @param collectionName 集合名
* @param tList 实例
* @param <T> 实例所属的类
*/
public <T> void insertAll(String databaseName, String collectionName, List<T> tList) {
changeDatabase(databaseName);
threadLocal.get().insert(tList, collectionName);
}
/**
* 移除一条或多条记录
*
* @param databaseName 数据库名
* @param tClass 实例的class
* @param query 查询条件
* @param <T> 实例所属的类
* @return
*/
public <T> long remove(String databaseName, Class<T> tClass, Query query) {
changeDatabase(databaseName);
return threadLocal.get().remove(query, tClass).getN();
}
/**
* 移除一条或多条记录
*
* @param databaseName 数据库名
* @param collectionName 集合名
* @param query 查询条件
* @return 受影响的记录条数
*/
public long remove(String databaseName, String collectionName, Query query) {
changeDatabase(databaseName);
return threadLocal.get().remove(query, collectionName).getN();
}
/**
* 更新多条记录
*
* @param databaseName 数据库名
* @param tClass 实例的class
* @param query 查询条件
* @param update 更新内容
* @param <T> 实例所属的类
* @return 受影响的记录条数
*/
public <T> long updateMulti(String databaseName, Class<T> tClass, Query query, Update update) {
changeDatabase(databaseName);
return threadLocal.get().updateMulti(query, update, tClass).getN();
}
/**
* 更新多条记录
*
* @param databaseName 数据库名
* @param collectionName 集合名
* @param query 查询条件
* @param update 更新内容
* @return 受影响的记录条数
*/
public long updateMulti(String databaseName, String collectionName, Query query, Update update) {
changeDatabase(databaseName);
return threadLocal.get().updateMulti(query, update, collectionName).getN();
}
/**
* 查询多条记录
*
* @param databaseName 数据库名
* @param tClass 实例的class
* @param query 查询条件
* @param <T> 实例所属的类
* @return 实例
*/
public <T> List<T> find(String databaseName, Class<T> tClass, Query query) {
changeDatabase(databaseName);
return threadLocal.get().find(query, tClass);
}
/**
* 查询多条记录
*
* @param databaseName 数据库名
* @param collectionName 集合名
* @param query 查询条件
* @param tClass 实例的class
* @param <T> 实例所属的类
* @return 实例
*/
public <T> List<T> find(String databaseName, String collectionName, Query query, Class<T> tClass) {
changeDatabase(databaseName);
return threadLocal.get().find(query, tClass, collectionName);
}
/**
* 查询第一条记录
*
* @param databaseName 数据库名
* @param tClass 实例的class
* @param query 查询条件
* @param <T> 实例所属的类
* @return 实例
*/
public <T> T findOne(String databaseName, Class<T> tClass, Query query) {
changeDatabase(databaseName);
return threadLocal.get().findOne(query, tClass);
}
/**
* 查询第一条记录
*
* @param databaseName 数据库名
* @param tClass 实例的class
* @param collectionName 集合名
* @param query 查询条件
* @param <T> 实例所属的类
* @return 实例
*/
public <T> T findOne(String databaseName, Class<T> tClass, String collectionName, Query query) {
changeDatabase(databaseName);
return threadLocal.get().findOne(query, tClass, collectionName);
}
}
测试类:
40个线程同时并发,测试多线程调用是否安全
这里我在五个database中放入了五个名称为"test"的collection,每个collection里放了一个{"name":"数据库名"}的Document用于测试
import com.mdruby.repository.MongoRepository;
import com.mongodb.BasicDBObject;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.ResponseBody;
import java.util.Optional;
import java.util.concurrent.CountDownLatch;
/**
* @author ParanoidCAT
* @since JDK 1.8
*/
@Controller(value = "testController")
@RequestMapping(value = {"test"})
public class TestController {
@Autowired
private MongoRepository mongoRepository;
private static final CountDownLatch COUNT_DOWN_LATCH = new CountDownLatch(40);
@RequestMapping(value = {"/run/{databaseName}"}, method = {RequestMethod.GET}, produces = {"application/json;charset=utf-8"})
@ResponseBody
public void run(@PathVariable String databaseName) throws InterruptedException {
// 线程计数+1
COUNT_DOWN_LATCH.countDown();
// 线程没到40个就等等
COUNT_DOWN_LATCH.await();
// 线程如果到了40个就一起放行,每个线程执行150次query
for (int i = 0; i < 150; i++) {
BasicDBObject basicDBObject = mongoRepository.findOne(databaseName, BasicDBObject.class, "test", new Query());
if (!databaseName.equals(Optional.ofNullable(basicDBObject).map(basicDBObject1 -> basicDBObject1.getString("name")).orElse("testString"))) {
System.out.println(Thread.currentThread().getName() + ": " + databaseName + " - " + basicDBObject);
}
}
}
}