最近业务方有一个需求,需要一次导入超过100万数据到系统数据库。可能大家首先会想,这么大的数据,干嘛通过程序去实现导入,为什么不直接通过SQL导入到数据库。
大数据量报表导出请参考:Java实现大批量数据导入导出(100W以上) -(二)导出
一、为什么一定要在代码实现
说说为什么不能通过SQL直接导入到数据库,而是通过程序实现:
- 首先,这个导入功能开始提供页面导入,只是开始业务方保证的一次只有<3W的数据导入;
- 其次,业务方导入的内容需要做校验,比如门店号,商品号等是否系统存在,需要程序校验;
- 最后,业务方导入的都是编码,数据库中还要存入对应名称,方便后期查询,SQL导入也是无法实现的。
基于以上上三点,就无法直接通过SQL语句导入数据库。那就只能老老实实的想办法通过程序实现。
二、程序实现有以下技术难点
- 一次读取这么大的数据量,肯定会导致服务器内存溢出;
- 调用接口保存一次传输数据量太大,网络传输压力会很大;
- 最终通过SQL一次批量插入,对数据库压力也比较大,如果业务同时操作这个表数据,很容易造成死锁。
三、解决思路
根据列举的技术难点我的解决思路是:
- 既然一次读取整个导入文件,那就先将文件流上传到服务器磁盘,然后分批从磁盘读取(支持多线程读取),这样就防止内存溢出;
- 调用插入数据库接口也是根据分批读取的内容进行调用;
- 分批插入数据到数据库。
四、具体实现代码
- 流式上传文件到服务器磁盘
略,一般Java上传就可以实现,这里就不贴出。
- 多线程分批从磁盘读取
批量读取文件:
import org.slf4j.Logger;
2 import org.slf4j.LoggerFactory;
3
4 import java.io.File;
5 import java.io.FileNotFoundException;
6 import java.io.RandomAccessFile;
7 import java.nio.ByteBuffer;
8 import java.nio.channels.FileChannel;
9
10 /**
11 * 类功能描述:批量读取文件
12 *
13 * @author WangXueXing create at 19-3-14 下午6:47
14 * @version 1.0.0
15 */
16 public class BatchReadFile {
17 private final Logger LOGGER = LoggerFactory.getLogger(BatchReadFile.class);
18 /**
19 * 字符集UTF-8
20 */
21 public static final String CHARSET_UTF8 = "UTF-8";
22 /**
23 * 字符集GBK
24 */
25 public static final String CHARSET_GBK = "GBK";
26 /**
27 * 字符集gb2312
28 */
29 public static final String CHARSET_GB2312 = "gb2312";
30 /**
31 * 文件内容分割符-逗号
32 */
33 public static final String SEPARATOR_COMMA = ",";
34
35 private int bufSize = 1024;
36 // 换行符
37 private byte key = "\n".getBytes()[0];
38 // 当前行数
39 private long lineNum = 0;
40 // 文件编码,默认为gb2312
41 private String encode = CHARSET_GB2312;
42 // 具体业务逻辑监听器
43 private ReaderFileListener readerListener;
44
45 public void setEncode(String encode) {
46 this.encode = encode;
47 }
48
49 public void setReaderListener(ReaderFileListener readerListener) {
50 this.readerListener = readerListener;
51 }
52
53 /**
54 * 获取准确开始位置
55 * @param file
56 * @param position
57 * @return
58 * @throws Exception
59 */
60 public long getStartNum(File file, long position) throws Exception {
61 long startNum = position;
62 FileChannel fcin = new RandomAccessFile(file, "r").getChannel();
63 fcin.position(position);
64 try {
65 int cache = 1024;
66 ByteBuffer rBuffer = ByteBuffer.allocate(cache);
67 // 每次读取的内容
68 byte[] bs = new byte[cache];
69 // 缓存
70 byte[] tempBs = new byte[0];
71 while (fcin.read(rBuffer) != -1) {
72 int rSize = rBuffer.position();
73 rBuffer.rewind();
74 rBuffer.get(bs);
75 rBuffer.clear();
76 byte[] newStrByte = bs;
77 // 如果发现有上次未读完的缓存,则将它加到当前读取的内容前面
78 if (null != tempBs) {
79 int tL = tempBs.length;
80 newStrByte = new byte[rSize + tL];
81 System.arraycopy(tempBs, 0, newStrByte, 0, tL);
82 System.arraycopy(bs, 0, newStrByte, tL, rSize);
83 }
84 // 获取开始位置之后的第一个换行符
85 int endIndex = indexOf(newStrByte, 0);
86 if (endIndex != -1) {
87 return startNum + endIndex;
88 }
89 tempBs = substring(newStrByte, 0, newStrByte.length);
90 startNum += 1024;
91 }
92 } finally {
93 fcin.close();
94 }
95 return position;
96 }
97
98 /**
99 * 从设置的开始位置读取文件,一直到结束为止。如果 end设置为负数,刚读取到文件末尾
100 * @param fullPath
101 * @param start
102 * @param end
103 * @throws Exception
104 */
105 public void readFileByLine(String fullPath, long start, long end) throws Exception {
106 File fin = new File(fullPath);
107 if (!fin.exists()) {
108 throw new FileNotFoundException("没有找到文件:" + fullPath);
109 }
110 FileChannel fileChannel = new RandomAccessFile(fin, "r").getChannel();
111 fileChannel.position(start);
112 try {
113 ByteBuffer rBuffer = ByteBuffer.allocate(bufSize);
114 // 每次读取的内容
115 byte[] bs = new byte[bufSize];
116 // 缓存
117 byte[] tempBs = new byte[0];
118 String line;
119 // 当前读取文件位置
120 long nowCur = start;
121 while (fileChannel.read(rBuffer) != -1) {
122 int rSize = rBuffer.position();
123 rBuffer.rewind();
124 rBuffer.get(bs);
125 rBuffer.clear();
126 byte[] newStrByte;
127 //去掉表头
128 if(nowCur == start){
129 int firstLineIndex = indexOf(bs, 0);
130 int newByteLenth = bs.length-firstLineIndex-1;
131 newStrByte = new byte[newByteLenth];
132 System.arraycopy(bs, firstLineIndex+1, newStrByte, 0, newByteLenth);
133 } else {
134 newStrByte = bs;
135 }
136
137 // 如果发现有上次未读完的缓存,则将它加到当前读取的内容前面
138 if (null != tempBs && tempBs.length != 0) {
139 int tL = tempBs.length;
140 newStrByte = new byte[rSize + tL];
141 System.arraycopy(tempBs, 0, newStrByte, 0, tL);
142 System.arraycopy(bs, 0, newStrByte, tL, rSize);
143 }
144 // 是否已经读到最后一位
145 boolean isEnd = false;
146 nowCur += bufSize;
147 // 如果当前读取的位数已经比设置的结束位置大的时候,将读取的内容截取到设置的结束位置
148 if (end > 0 && nowCur > end) {
149 // 缓存长度 - 当前已经读取位数 - 最后位数
150 int l = newStrByte.length - (int) (nowCur - end);
151 newStrByte = substring(newStrByte, 0, l);
152 isEnd = true;
153 }
154 int fromIndex = 0;
155 int endIndex = 0;
156 // 每次读一行内容,以 key(默认为\n) 作为结束符
157 while ((endIndex = indexOf(newStrByte, fromIndex)) != -1) {
158 byte[] bLine = substring(newStrByte, fromIndex, endIndex);
159 line = new String(bLine, 0, bLine.length, encode);
160 lineNum++;
161 // 输出一行内容,处理方式由调用方提供
162 readerListener.outLine(line.trim(), lineNum, false);
163 fromIndex = endIndex + 1;
164 }
165 // 将未读取完成的内容放到缓存中
166 tempBs = substring(newStrByte, fromIndex, newStrByte.length);
167 if (isEnd) {
168 break;
169 }
170 }
171 // 将剩下的最后内容作为一行,输出,并指明这是最后一行
172 String lineStr = new String(tempBs, 0, tempBs.length, encode);
173 readerListener.outLine(lineStr.trim(), lineNum, true);
174 } finally {
175 fileChannel.close();
176 fin.deleteOnExit();
177 }
178 }
179
180 /**
181 * 查找一个byte[]从指定位置之后的一个换行符位置
182 *
183 * @param src
184 * @param fromIndex
185 * @return
186 * @throws Exception
187 */
188 private int indexOf(byte[] src, int fromIndex) throws Exception {
189 for (int i = fromIndex; i < src.length; i++) {
190 if (src[i] == key) {
191 return i;
192 }
193 }
194 return -1;
195 }
196
197 /**
198 * 从指定开始位置读取一个byte[]直到指定结束位置为止生成一个全新的byte[]
199 *
200 * @param src
201 * @param fromIndex
202 * @param endIndex
203 * @return
204 * @throws Exception
205 */
206 private byte[] substring(byte[] src, int fromIndex, int endIndex) throws Exception {
207 int size = endIndex - fromIndex;
208 byte[] ret = new byte[size];
209 System.arraycopy(src, fromIndex, ret, 0, size);
210 return ret;
211 }
212 }
复制代码
以上是关键代码:利用FileChannel与ByteBuffer从磁盘中分批读取数据
多线程调用批量读取:
复制代码
1 /**
2 * 类功能描述: 线程读取文件
3 *
4 * @author WangXueXing create at 19-3-14 下午6:51
5 * @version 1.0.0
6 */
7 public class ReadFileThread extends Thread {
8 private ReaderFileListener processDataListeners;
9 private String filePath;
10 private long start;
11 private long end;
12 private Thread preThread;
13
14 public ReadFileThread(ReaderFileListener processDataListeners,
15 long start,long end,
16 String file) {
17 this(processDataListeners, start, end, file, null);
18 }
19
20 public ReadFileThread(ReaderFileListener processDataListeners,
21 long start,long end,
22 String file,
23 Thread preThread) {
24 this.setName(this.getName()+"-ReadFileThread");
25 this.start = start;
26 this.end = end;
27 this.filePath = file;
28 this.processDataListeners = processDataListeners;
29 this.preThread = preThread;
30 }
31
32 @Override
33 public void run() {
34 BatchReadFile readFile = new BatchReadFile();
35 readFile.setReaderListener(processDataListeners);
36 readFile.setEncode(processDataListeners.getEncode());
37 try {
38 readFile.readFileByLine(filePath, start, end + 1);
39 if(this.preThread != null){
40 this.preThread.join();
41 }
42 } catch (Exception e) {
43 throw new RuntimeException(e);
44 }
45 }
46 }
复制代码
监听读取:
复制代码
1 import java.util.ArrayList;
2 import java.util.List;
3
4 /**
5 * 类功能描述:读文件监听父类
6 *
7 * @author WangXueXing create at 19-3-14 下午6:52
8 * @version 1.0.0
9 */
10 public abstract class ReaderFileListener<T> {
11 // 一次读取行数,默认为1000
12 private int readColNum = 1000;
13
14 /**
15 * 文件编码
16 */
17 private String encode;
18
19 /**
20 * 分批读取行列表
21 */
22 private List<String> rowList = new ArrayList<>();
23
24 /**
25 *其他参数
26 */
27 private T otherParams;
28
29 /**
30 * 每读取到一行数据,添加到缓存中
31 * @param lineStr 读取到的数据
32 * @param lineNum 行号
33 * @param over 是否读取完成
34 * @throws Exception
35 */
36 public void outLine(String lineStr, long lineNum, boolean over) throws Exception {
37 if(null != lineStr && !lineStr.trim().equals("")){
38 rowList.add(lineStr);
39 }
40
41 if (!over && (lineNum % readColNum == 0)) {
42 output(rowList);
43 rowList = new ArrayList<>();
44 } else if (over) {
45 output(rowList);
46 rowList = new ArrayList<>();
47 }
48 }
49
50 /**
51 * 批量输出
52 *
53 * @param stringList
54 * @throws Exception
55 */
56 public abstract void output(List<String> stringList) throws Exception;
57
58 /**
59 * 设置一次读取行数
60 * @param readColNum
61 */
62 protected void setReadColNum(int readColNum) {
63 this.readColNum = readColNum;
64 }
65
66 public String getEncode() {
67 return encode;
68 }
69
70 public void setEncode(String encode) {
71 this.encode = encode;
72 }
73
74 public T getOtherParams() {
75 return otherParams;
76 }
77
78 public void setOtherParams(T otherParams) {
79 this.otherParams = otherParams;
80 }
81
82 public List<String> getRowList() {
83 return rowList;
84 }
85
86 public void setRowList(List<String> rowList) {
87 this.rowList = rowList;
88 }
89 }
复制代码
实现监听读取并分批调用插入数据接口:
复制代码
1 import com.today.api.finance.ImportServiceClient;
2 import com.today.api.finance.request.ImportRequest;
3 import com.today.api.finance.response.ImportResponse;
4 import com.today.api.finance.service.ImportService;
5 import com.today.common.Constants;
6 import com.today.domain.StaffSimpInfo;
7 import com.today.util.EmailUtil;
8 import com.today.util.UserSessionHelper;
9 import com.today.util.readfile.ReadFile;
10 import com.today.util.readfile.ReadFileThread;
11 import com.today.util.readfile.ReaderFileListener;
12 import org.slf4j.Logger;
13 import org.slf4j.LoggerFactory;
14 import org.springframework.beans.factory.annotation.Value;
15 import org.springframework.stereotype.Service;
16 import org.springframework.util.StringUtils;
17
18 import java.io.File;
19 import java.io.FileInputStream;
20 import java.util.ArrayList;
21 import java.util.Arrays;
22 import java.util.List;
23 import java.util.concurrent.FutureTask;
24 import java.util.stream.Collectors;
25
26 /**
27 * 类功能描述:报表导入服务实现
28 *
29 * @author WangXueXing create at 19-3-19 下午1:43
30 * @version 1.0.0
31 */
32 @Service
33 public class ImportReportServiceImpl extends ReaderFileListener<ImportRequest> {
34 private final Logger LOGGER = LoggerFactory.getLogger(ImportReportServiceImpl.class);
35 @Value("${READ_COL_NUM_ONCE}")
36 private String readColNum;
37 @Value("${REPORT_IMPORT_RECEIVER}")
38 private String reportImportReceiver;
39 /**
40 * 财务报表导入接口
41 */
42 private ImportService service = new ImportServiceClient();
43
44 /**
45 * 读取文件内容
46 * @param file
47 */
48 public void readTxt(File file, ImportRequest importRequest) throws Exception {
49 this.setOtherParams(importRequest);
50 ReadFile readFile = new ReadFile();
51 try(FileInputStream fis = new FileInputStream(file)){
52 int available = fis.available();
53 long maxThreadNum = 3L;
54 // 线程粗略开始位置
55 long i = available / maxThreadNum;
56
57 this.setRowList(new ArrayList<>());
58 StaffSimpInfo staffSimpInfo = ((StaffSimpInfo)UserSessionHelper.getCurrentUserInfo().getData());
59 String finalReportReceiver = getEmail(staffSimpInfo.getEmail(), reportImportReceiver);
60 this.setReadColNum(Integer.parseInt(readColNum));
61 this.setEncode(ReadFile.CHARSET_GB2312);
62 //这里单独使用一个线程是为了当maxThreadNum大于1的时候,统一管理这些线程
63 new Thread(()->{
64 Thread preThread = null;
65 FutureTask futureTask = null ;
66 try {
67 for (long j = 0; j < maxThreadNum; j++) {
68 //计算精确开始位置
69 long startNum = j == 0 ? 0 : readFile.getStartNum(file, i * j);
70 long endNum = j + 1 < maxThreadNum ? readFile.getStartNum(file, i * (j + 1)) : -2L;
71
72 //具体监听实现
73 preThread = new ReadFileThread(this, startNum, endNum, file.getPath(), preThread);
74 futureTask = new FutureTask(preThread, new Object());
75 futureTask.run();
76 }
77 if(futureTask.get() != null) {
78 EmailUtil.sendEmail(EmailUtil.REPORT_IMPORT_EMAIL_PREFIX, finalReportReceiver, "导入报表成功", "导入报表成功" ); //todo 等文案
79 }
80 } catch (Exception e){
81 futureTask.cancel(true);
82 try {
83 EmailUtil.sendEmail(EmailUtil.REPORT_IMPORT_EMAIL_PREFIX, finalReportReceiver, "导入报表失败", e.getMessage());
84 } catch (Exception e1){
85 //ignore
86 LOGGER.error("发送邮件失败", e1);
87 }
88 LOGGER.error("导入报表类型:"+importRequest.getReportType()+"失败", e);
89 } finally {
90 futureTask.cancel(true);
91 }
92 }).start();
93 }
94 }
95
96 private String getEmail(String infoEmail, String reportImportReceiver){
97 if(StringUtils.isEmpty(infoEmail)){
98 return reportImportReceiver;
99 }
100 return infoEmail;
101 }
102
103 /**
104 * 每批次调用导入接口
105 * @param stringList
106 * @throws Exception
107 */
108 @Override
109 public void output(List<String> stringList) throws Exception {
110 ImportRequest importRequest = this.getOtherParams();
111 List<List<String>> dataList = stringList.stream()
112 .map(x->Arrays.asList(x.split(ReadFile.SEPARATOR_COMMA)).stream().map(String::trim).collect(Collectors.toList()))
113 .collect(Collectors.toList());
114 LOGGER.info("上传数据:{}", dataList);
115 importRequest.setDataList(dataList);
116 // LOGGER.info("request对象:{}",importRequest, "request增加请求字段:{}", importRequest.data);
117 ImportResponse importResponse = service.batchImport(importRequest);
118 LOGGER.info("===========SUCESS_CODE======="+importResponse.getCode());
119 //导入错误,输出错误信息
120 if(!Constants.SUCESS_CODE.equals(importResponse.getCode())){
121 LOGGER.error("导入报表类型:"+importRequest.getReportType()+"失败","返回码为:", importResponse.getCode() ,"返回信息:",importResponse.getMessage());
122 throw new RuntimeException("导入报表类型:"+importRequest.getReportType()+"失败"+"返回码为:"+ importResponse.getCode() +"返回信息:"+importResponse.getMessage());
123 }
124 // if(importResponse.data != null && importResponse.data.get().get("batchImportFlag")!=null) {
125 // LOGGER.info("eywa-service请求batchImportFlag不为空");
126 // }
127 importRequest.setData(importResponse.data);
128
129 }
130 }
复制代码
注意:
第53行代码:
long maxThreadNum = 3L;
就是设置分批读取磁盘文件的线程数,我设置为3,大家不要设置太大,不然多个线程读取到内存,也会造成服务器内存溢出。
以上所有批次的批量读取并调用插入接口都成功发送邮件通知给导入人,任何一个批次失败直接发送失败邮件。
数据库分批插入数据:
复制代码
1 /**
2 * 批量插入非联机第三方导入账单
3 * @param dataList
4 */
5 def insertNonOnlinePayment(dataList: List[NonOnlineSourceData]) : Unit = {
6 if (dataList.nonEmpty) {
7 CheckAccountDataSource.mysqlData.withConnection { conn =>
8 val sql =
9 s""" INSERT INTO t_pay_source_data
10 (store_code,
11 store_name,
12 source_date,
13 order_type,
14 trade_type,
15 third_party_payment_no,
16 business_type,
17 business_amount,
18 trade_time,
19 created_at,
20 updated_at)
21 VALUES (?,?,?,?,?,?,?,?,?,NOW(),NOW())"""
22
23 conn.setAutoCommit(false)
24 var stmt = conn.prepareStatement(sql)
25 var i = 0
26 dataList.foreach { x =>
27 stmt.setString(1, x.storeCode)
28 stmt.setString(2, x.storeName)
29 stmt.setString(3, x.sourceDate)
30 stmt.setInt(4, x.orderType)
31 stmt.setInt(5, x.tradeType)
32 stmt.setString(6, x.tradeNo)
33 stmt.setInt(7, x.businessType)
34 stmt.setBigDecimal(8, x.businessAmount.underlying())
35 stmt.setString(9, x.tradeTime.getOrElse(null))
36 stmt.addBatch()
37 if ((i % 5000 == 0) && (i != 0)) { //分批提交
38 stmt.executeBatch
39 conn.commit
40 conn.setAutoCommit(false)
41 stmt = conn.prepareStatement(sql)
42
43 }
44 i += 1
45 }
46 stmt.executeBatch()
47 conn.commit()
48 }
49 }
50 }