最近想试试写个多线程下载的程序,顺便记下来吧~
怎么开始
多线程下载,是通过在请求header中添加Range进行部分下载来实现的。那么我们要怎么确定Range的值呢?首先我们当然要得到下载文件的大小,像资源地址发送head请求,就可以从响应头中的Content-Length
中得到文件的大小了,因此,如果响应头中不存在Content-Length
字段,那我们就没办法进行多线程下载。
下面的代码中我使用了jodd的http库,相关信息可以查看[jodd-http]
代码
获取真实下载链接
在下载之前,我们首先要获取到文件的真实下载链接,这个我们通过向目标url发送head请求就可以得到:
private String getRealUrl(String url) {
String tmpUrl = url;
if (StringUtil.isBlank(url)) {
return "";
}
HttpResponse response;
int statusCode;
while (true) {
response = HttpRequest.head(url).send();
statusCode = response.statusCode();
if (statusCode == HttpStatus.HTTP_MOVED_PERMANENTLY || statusCode == HttpStatus.HTTP_MOVED_TEMPORARY) {
// 当statusCode为301或302时,从响应头中获取真实链接(Location)
url = response.header("Location");
} else {
break;
}
}
if (statusCode == HttpStatus.HTTP_OK) {
return url;
} else {
System.out.printf("获取真实链接失败,StatusCode: %d", statusCode);
return tmpUrl;
}
}
解析资源信息
要进行多线程下载文件,我们需要得到文件大小等信息
private void init() throws Exception {
if (StringUtil.isBlank(this.url)) {
throw new Exception("url不存在");
}
HttpResponse response = HttpRequest.head(url).send();
// 获取文件大小
String contentLength = response.header("Content-Length");
if (StringUtil.isBlank(contentLength)) {
System.out.println("无法获取到文件大小,只能进行单线程下载");
this.threadNum = 1;
this.isMulti = false;
this.contentLength = 0;
} else {
this.contentLength = Long.parseLong(contentLength);
if (this.threadNum > 1) {
this.isMulti = true;
}
}
// 获取文件名
// 文件名可以从响应头的Content-Disposition获取url中获得,这里我们优先考虑前者
// 当然,有的下载链接可能获取不到文件名,这就需要自定义规则或者手动输入了,这里我直接命名为时间戳
String contentDisposition = response.header("Content-Disposition");
// Content-Disposition中获取不到文件名
if (StringUtil.isBlank(contentDisposition) || !contentDisposition.contains("filename=")) {
int index = this.url.lastIndexOf("/");
if (index != -1) {
this.filename = this.url.substring(index + 1);
} else {
this.filename = String.valueOf(System.currentTimeMillis());
}
} else {
int index = contentDisposition.indexOf("filename=");
// 对取得的文件名信息进行转码,此处是针对百度云的文件名(ISO8859-1编码)进行转码,
// 其他规则以后遇到再添加
this.filename = new String(contentDisposition.substring(index + 10, contentDisposition.length() - 1).getBytes("ISO8859-1"), "utf-8");
}
// 这里我保存了各个部分的名称,方便后面进行下载和合并
if (this.threadNum == 1) {
this.partNames[0] = this.location + "/" + this.filename;
} else {
for (int i = 0; i < this.partNames.length; i++) {
partNames[i] = this.location + "/" + this.filename + i;
}
}
System.out.println(this);
}
单线程下载
在无法获取到Content-Length
的时候我们就只能进行单线程下载了,以下是单线程下载的实现:
private void singleDownload() throws IOException {
System.out.println("开始下载");
HttpResponse response = HttpRequest.get(this.url).send();
ByteBuffer byteBuffer = ByteBuffer.wrap(response.bodyBytes());
File file = new File(location + "/" + filename);
FileOutputStream outputStream = new FileOutputStream(file);
FileChannel channel = outputStream.getChannel();
channel.write(byteBuffer);
outputStream.close();
channel.close();
System.out.println("下载完成");
}
合并文件
进行多线程下载,我采用的是分块下载然后再合并的方式,因此,各部分下载完成之后需要对文件进行合并。
private void merge() throws IOException {
System.out.println("merging ...");
if (this.isMulti) {
String filepath = this.location + "/" + this.filename;
System.out.println("文件保存路径:" + filepath);
try (FileChannel writeChannel = new FileOutputStream(new File(this.location + "/" + this.filename)).getChannel()) {
for (String partName : partNames) {
File tmp = new File(partName);
FileInputStream inputStream = new FileInputStream(tmp);
try (FileChannel readChannel = inputStream.getChannel()) {
ByteBuffer byteBuffer = ByteBuffer.allocate(1024);
while (readChannel.read(byteBuffer) != -1) {
byteBuffer.flip();
writeChannel.write(byteBuffer);
byteBuffer.clear();
}
}
// 合并完成后删除临时文件
tmp.deleteOnExit();
}
}
}
System.out.println("finished");
}
多线程下载
private void multiDownload() throws IOException, InterruptedException, ExecutionException {
ThreadFactory threadFactory = Executors.defaultThreadFactory();
List<Future> futures = new ArrayList<>((int) Math.ceil(this.threadNum / 0.75));
// 使用线程池
ThreadPoolExecutor executor = new ThreadPoolExecutor(this.threadNum, this.threadNum, 500, TimeUnit.MILLISECONDS, new LinkedBlockingDeque<>(this.threadNum), threadFactory);
long partSize = contentLength / threadNum;
for (int i = 0; i < this.threadNum; i++) {
String partName = location + "/" + filename + i;
// 获取文件片段的Range
long start = i * partSize;
long end;
if (i == this.threadNum - 1) {
end = this.contentLength;
} else {
end = (i + 1) * partSize - 1;
}
// 下载的实现
Future future = executor.submit(() -> {
System.out.println("开始下载:" + partName);
HttpResponse response = HttpRequest.get(url)
.header("Range", "bytes=" + start + "-" + end)
.send();
if (response.statusCode() == 206) {
ByteBuffer buffer = ByteBuffer.wrap(response.bodyBytes());
File file = new File(partName);
try (FileOutputStream outputStream = new FileOutputStream(file);
FileChannel channel = outputStream.getChannel()
) {
channel.write(buffer);
} catch (IOException e) {
System.out.println("文件下载失败");
e.printStackTrace();
}
}
System.out.println("下载完成:" + partName);
});
futures.add(future);
this.partNames[i] = partName;
}
for (Future future : futures) {
future.get();
}
executor.shutdown();
System.out.println("isShutDown: " + executor.isShutdown());
System.out.println("下载完成");
merge();
}
以上就是进行多线程下载的代码,写的有点乱还请大家见谅。如果有什么不足也请大家指出一起进步 :)
下面给出完整代码:
package com.loong;
import jodd.http.HttpRequest;
import jodd.http.HttpResponse;
import jodd.http.HttpStatus;
import jodd.util.StringUtil;
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.*;
public class Downloader {
private String url;
private boolean isMulti;
private int threadNum;
private long contentLength;
private String filename;
private String[] partNames;
private String location;
public Downloader(String url, String location) {
this.url = getRealUrl(url);
this.location = location;
this.threadNum = 3;
this.partNames = new String[this.threadNum];
}
private Downloader(String url, String location, int threadNum) {
this.url = getRealUrl(url);
this.location = location;
if (threadNum < 1) {
this.threadNum = 1;
} else {
this.threadNum = threadNum;
}
this.partNames = new String[this.threadNum];
}
private void singleDownload() throws IOException {
System.out.println("开始下载");
HttpResponse response = HttpRequest.get(this.url).send();
ByteBuffer byteBuffer = ByteBuffer.wrap(response.bodyBytes());
File file = new File(location + "/" + filename);
FileOutputStream outputStream = new FileOutputStream(file);
FileChannel channel = outputStream.getChannel();
channel.write(byteBuffer);
outputStream.close();
channel.close();
System.out.println("下载完成");
}
private void multiDownload() throws IOException, InterruptedException, ExecutionException {
ThreadFactory threadFactory = Executors.defaultThreadFactory();
List<Future> futures = new ArrayList<>((int) Math.ceil(this.threadNum / 0.75));
// 使用线程池
ThreadPoolExecutor executor = new ThreadPoolExecutor(this.threadNum, this.threadNum, 500, TimeUnit.MILLISECONDS, new LinkedBlockingDeque<>(this.threadNum), threadFactory);
long partSize = contentLength / threadNum;
for (int i = 0; i < this.threadNum; i++) {
String partName = location + "/" + filename + i;
// 获取文件片段的Range
long start = i * partSize;
long end;
if (i == this.threadNum - 1) {
end = this.contentLength;
} else {
end = (i + 1) * partSize - 1;
}
// 下载的实现
Future future = executor.submit(() -> {
System.out.println("开始下载:" + partName);
HttpResponse response = HttpRequest.get(url)
.header("Range", "bytes=" + start + "-" + end)
.send();
if (response.statusCode() == 206) {
ByteBuffer buffer = ByteBuffer.wrap(response.bodyBytes());
File file = new File(partName);
try (FileOutputStream outputStream = new FileOutputStream(file);
FileChannel channel = outputStream.getChannel()
) {
channel.write(buffer);
} catch (IOException e) {
System.out.println("文件下载失败");
e.printStackTrace();
}
}
System.out.println("下载完成:" + partName);
});
futures.add(future);
this.partNames[i] = partName;
}
for (Future future : futures) {
future.get();
}
executor.shutdown();
System.out.println("isShutDown: " + executor.isShutdown());
System.out.println("下载完成");
merge();
}
private void merge() throws IOException {
System.out.println("merging ...");
if (this.isMulti) {
String filepath = this.location + "/" + this.filename;
System.out.println("文件保存路径:" + filepath);
try (FileChannel writeChannel = new FileOutputStream(new File(this.location + "/" + this.filename)).getChannel()) {
for (String partName : partNames) {
File tmp = new File(partName);
FileInputStream inputStream = new FileInputStream(tmp);
try (FileChannel readChannel = inputStream.getChannel()) {
ByteBuffer byteBuffer = ByteBuffer.allocate(1024);
while (readChannel.read(byteBuffer) != -1) {
byteBuffer.flip();
writeChannel.write(byteBuffer);
byteBuffer.clear();
}
}
// 合并完成后删除临时文件
tmp.deleteOnExit();
}
}
}
System.out.println("finished");
}
/**
* 获取文件名、文件大小等
*/
private void init() throws Exception {
if (StringUtil.isBlank(this.url)) {
throw new Exception("url不存在");
}
HttpResponse response = HttpRequest.head(url).send();
// 获取文件大小
String contentLength = response.header("Content-Length");
if (StringUtil.isBlank(contentLength)) {
System.out.println("无法获取到文件大小,只能进行单线程下载");
this.threadNum = 1;
this.isMulti = false;
this.contentLength = 0;
} else {
this.contentLength = Long.parseLong(contentLength);
if (this.threadNum > 1) {
this.isMulti = true;
}
}
// 获取文件名
// 文件名可以从响应头的Content-Disposition获取url中获得,这里我们优先考虑前者
// 当然,有的下载链接可能获取不到文件名,这就需要自定义规则或者手动输入了,这里我直接命名为时间戳
String contentDisposition = response.header("Content-Disposition");
// Content-Disposition中获取不到文件名
if (StringUtil.isBlank(contentDisposition) || !contentDisposition.contains("filename=")) {
int index = this.url.lastIndexOf("/");
if (index != -1) {
this.filename = this.url.substring(index + 1);
} else {
this.filename = String.valueOf(System.currentTimeMillis());
}
} else {
int index = contentDisposition.indexOf("filename=");
// 对取得的文件名信息进行转码,此处是针对百度云的链接进行转码(ISO8859-1B编码),其他规则以后遇到再添加
this.filename = new String(contentDisposition.substring(index + 10, contentDisposition.length() - 1).getBytes("ISO8859-1"), "utf-8");
}
if (this.threadNum == 1) {
this.partNames[0] = this.location + "/" + this.filename;
} else {
for (int i = 0; i < this.partNames.length; i++) {
partNames[i] = this.location + "/" + this.filename + i;
}
}
System.out.println(this);
}
public void run() throws Exception {
init();
long start = System.currentTimeMillis();
if (isMulti) {
multiDownload();
} else {
singleDownload();
}
long end = System.currentTimeMillis();
System.out.println("速度:" + this.contentLength / (end - start) + "kb/s");
}
@Override
public String toString() {
return String.format("url: %s\nisMulti: %s\nthreadNum: %d\ncontentLength: %d\nfilename: %s\npartNames: %s\nlocation: %s",
url, isMulti, threadNum, contentLength, filename, Arrays.toString(partNames), location);
}
private String getRealUrl(String url) {
String tmpUrl = url;
if (StringUtil.isBlank(url)) {
return "";
}
HttpResponse response;
int statusCode;
while (true) {
response = HttpRequest.head(url).send();
statusCode = response.statusCode();
if (statusCode == HttpStatus.HTTP_MOVED_PERMANENTLY || statusCode == HttpStatus.HTTP_MOVED_TEMPORARY) {
url = response.header("Location");
} else {
break;
}
}
if (statusCode == HttpStatus.HTTP_OK) {
return url;
} else {
System.out.printf("获取真实链接失败,StatusCode: %d", statusCode);
return tmpUrl;
}
}
public static void main(String[] args) throws Exception {
Downloader downloader = new Downloader(
"https://dldir1.qq.com/qqfile/qq/QQ8.9.6/22404/QQ8.9.6.exe",
"/home/loong/Downloads",
8);
downloader.run();
}
}