最近想试试写个多线程下载的程序,顺便记下来吧~

怎么开始

多线程下载,是通过在请求header中添加Range进行部分下载来实现的。那么我们要怎么确定Range的值呢?首先我们当然要得到下载文件的大小,像资源地址发送head请求,就可以从响应头中的Content-Length中得到文件的大小了,因此,如果响应头中不存在Content-Length字段,那我们就没办法进行多线程下载。

下面的代码中我使用了jodd的http库,相关信息可以查看[jodd-http]

代码

获取真实下载链接

在下载之前,我们首先要获取到文件的真实下载链接,这个我们通过向目标url发送head请求就可以得到:

private String getRealUrl(String url) {
    String tmpUrl = url;
    if (StringUtil.isBlank(url)) {
        return "";
    }
    HttpResponse response;
    int statusCode;
    while (true) {
        response = HttpRequest.head(url).send();
        statusCode = response.statusCode();
        if (statusCode == HttpStatus.HTTP_MOVED_PERMANENTLY || statusCode == HttpStatus.HTTP_MOVED_TEMPORARY) {
            // 当statusCode为301或302时,从响应头中获取真实链接(Location)
            url = response.header("Location");
        } else {
            break;
        }
    }
    if (statusCode == HttpStatus.HTTP_OK) {
        return url;
    } else {
        System.out.printf("获取真实链接失败,StatusCode: %d", statusCode);
        return tmpUrl;
    }
}

解析资源信息

要进行多线程下载文件,我们需要得到文件大小等信息

private void init() throws Exception {
    if (StringUtil.isBlank(this.url)) {
        throw new Exception("url不存在");
    }
    HttpResponse response = HttpRequest.head(url).send();
    // 获取文件大小
    String contentLength = response.header("Content-Length");
    if (StringUtil.isBlank(contentLength)) {
        System.out.println("无法获取到文件大小,只能进行单线程下载");
        this.threadNum = 1;
        this.isMulti = false;
        this.contentLength = 0;
    } else {
        this.contentLength = Long.parseLong(contentLength);
        if (this.threadNum > 1) {
            this.isMulti = true;
        }
    }

    // 获取文件名
    // 文件名可以从响应头的Content-Disposition获取url中获得,这里我们优先考虑前者
    // 当然,有的下载链接可能获取不到文件名,这就需要自定义规则或者手动输入了,这里我直接命名为时间戳
    String contentDisposition = response.header("Content-Disposition");
    // Content-Disposition中获取不到文件名
    if (StringUtil.isBlank(contentDisposition) || !contentDisposition.contains("filename=")) {
        int index = this.url.lastIndexOf("/");
        if (index != -1) {
            this.filename = this.url.substring(index + 1);
        } else {
            this.filename = String.valueOf(System.currentTimeMillis());
        }
    } else {
        int index = contentDisposition.indexOf("filename=");
        // 对取得的文件名信息进行转码,此处是针对百度云的文件名(ISO8859-1编码)进行转码,
        // 其他规则以后遇到再添加
        this.filename = new String(contentDisposition.substring(index + 10, contentDisposition.length() - 1).getBytes("ISO8859-1"), "utf-8");
    }
    // 这里我保存了各个部分的名称,方便后面进行下载和合并
    if (this.threadNum == 1) {
        this.partNames[0] = this.location + "/" + this.filename;
    } else {
        for (int i = 0; i < this.partNames.length; i++) {
            partNames[i] = this.location + "/" + this.filename + i;
        }
    }
    System.out.println(this);
}

单线程下载

在无法获取到Content-Length的时候我们就只能进行单线程下载了,以下是单线程下载的实现:

private void singleDownload() throws IOException {
    System.out.println("开始下载");
    HttpResponse response = HttpRequest.get(this.url).send();
    ByteBuffer byteBuffer = ByteBuffer.wrap(response.bodyBytes());
    File file = new File(location + "/" + filename);
    FileOutputStream outputStream = new FileOutputStream(file);
    FileChannel channel = outputStream.getChannel();
    channel.write(byteBuffer);
    outputStream.close();
    channel.close();
    System.out.println("下载完成");
}

合并文件

进行多线程下载,我采用的是分块下载然后再合并的方式,因此,各部分下载完成之后需要对文件进行合并。

private void merge() throws IOException {
    System.out.println("merging ...");
    if (this.isMulti) {
        String filepath = this.location + "/" + this.filename;
        System.out.println("文件保存路径:" + filepath);
        try (FileChannel writeChannel = new FileOutputStream(new File(this.location + "/" + this.filename)).getChannel()) {
            for (String partName : partNames) {
                File tmp = new File(partName);
                FileInputStream inputStream = new FileInputStream(tmp);
                try (FileChannel readChannel = inputStream.getChannel()) {
                    ByteBuffer byteBuffer = ByteBuffer.allocate(1024);
                    while (readChannel.read(byteBuffer) != -1) {
                        byteBuffer.flip();
                        writeChannel.write(byteBuffer);
                        byteBuffer.clear();
                    }
                }
                // 合并完成后删除临时文件
                tmp.deleteOnExit();
            }
        }
    }
    System.out.println("finished");
}

多线程下载

private void multiDownload() throws IOException, InterruptedException, ExecutionException {
    ThreadFactory threadFactory = Executors.defaultThreadFactory();
    List<Future> futures = new ArrayList<>((int) Math.ceil(this.threadNum / 0.75));
    // 使用线程池
    ThreadPoolExecutor executor = new ThreadPoolExecutor(this.threadNum, this.threadNum, 500, TimeUnit.MILLISECONDS, new LinkedBlockingDeque<>(this.threadNum), threadFactory);
    long partSize = contentLength / threadNum;
    for (int i = 0; i < this.threadNum; i++) {
        String partName = location + "/" + filename + i;
        // 获取文件片段的Range
        long start = i * partSize;
        long end;
        if (i == this.threadNum - 1) {
            end = this.contentLength;
        } else {
            end = (i + 1) * partSize - 1;
        }
        // 下载的实现
        Future future = executor.submit(() -> {
            System.out.println("开始下载:" + partName);
            HttpResponse response = HttpRequest.get(url)
                    .header("Range", "bytes=" + start + "-" + end)
                    .send();
            if (response.statusCode() == 206) {
                ByteBuffer buffer = ByteBuffer.wrap(response.bodyBytes());
                File file = new File(partName);
                try (FileOutputStream outputStream = new FileOutputStream(file);
                     FileChannel channel = outputStream.getChannel()
                ) {
                    channel.write(buffer);
                } catch (IOException e) {
                    System.out.println("文件下载失败");
                    e.printStackTrace();
                }
            }
            System.out.println("下载完成:" + partName);
        });
        futures.add(future);
        this.partNames[i] = partName;
    }
    for (Future future : futures) {
        future.get();
    }
    executor.shutdown();
    System.out.println("isShutDown: " + executor.isShutdown());
    System.out.println("下载完成");
    merge();
}

以上就是进行多线程下载的代码,写的有点乱还请大家见谅。如果有什么不足也请大家指出一起进步 :)

下面给出完整代码:

package com.loong;

import jodd.http.HttpRequest;
import jodd.http.HttpResponse;
import jodd.http.HttpStatus;
import jodd.util.StringUtil;

import java.io.*;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.*;

public class Downloader {
    private String url;
    private boolean isMulti;
    private int threadNum;
    private long contentLength;
    private String filename;
    private String[] partNames;
    private String location;

    public Downloader(String url, String location) {
        this.url = getRealUrl(url);
        this.location = location;
        this.threadNum = 3;
        this.partNames = new String[this.threadNum];
    }

    private Downloader(String url, String location, int threadNum) {
        this.url = getRealUrl(url);
        this.location = location;
        if (threadNum < 1) {
            this.threadNum = 1;
        } else {
            this.threadNum = threadNum;
        }
        this.partNames = new String[this.threadNum];
    }

    private void singleDownload() throws IOException {
        System.out.println("开始下载");
        HttpResponse response = HttpRequest.get(this.url).send();
        ByteBuffer byteBuffer = ByteBuffer.wrap(response.bodyBytes());
        File file = new File(location + "/" + filename);
        FileOutputStream outputStream = new FileOutputStream(file);
        FileChannel channel = outputStream.getChannel();
        channel.write(byteBuffer);
        outputStream.close();
        channel.close();
        System.out.println("下载完成");
    }

    private void multiDownload() throws IOException, InterruptedException, ExecutionException {
        ThreadFactory threadFactory = Executors.defaultThreadFactory();
        List<Future> futures = new ArrayList<>((int) Math.ceil(this.threadNum / 0.75));
        // 使用线程池
        ThreadPoolExecutor executor = new ThreadPoolExecutor(this.threadNum, this.threadNum, 500, TimeUnit.MILLISECONDS, new LinkedBlockingDeque<>(this.threadNum), threadFactory);
        long partSize = contentLength / threadNum;
        for (int i = 0; i < this.threadNum; i++) {
            String partName = location + "/" + filename + i;
            // 获取文件片段的Range
            long start = i * partSize;
            long end;
            if (i == this.threadNum - 1) {
                end = this.contentLength;
            } else {
                end = (i + 1) * partSize - 1;
            }
            // 下载的实现
            Future future = executor.submit(() -> {
                System.out.println("开始下载:" + partName);
                HttpResponse response = HttpRequest.get(url)
                        .header("Range", "bytes=" + start + "-" + end)
                        .send();
                if (response.statusCode() == 206) {
                    ByteBuffer buffer = ByteBuffer.wrap(response.bodyBytes());
                    File file = new File(partName);
                    try (FileOutputStream outputStream = new FileOutputStream(file);
                         FileChannel channel = outputStream.getChannel()
                    ) {
                        channel.write(buffer);
                    } catch (IOException e) {
                        System.out.println("文件下载失败");
                        e.printStackTrace();
                    }
                }
                System.out.println("下载完成:" + partName);
            });
            futures.add(future);
            this.partNames[i] = partName;
        }
        for (Future future : futures) {
            future.get();
        }
        executor.shutdown();
        System.out.println("isShutDown: " + executor.isShutdown());
        System.out.println("下载完成");
        merge();
    }

private void merge() throws IOException {
    System.out.println("merging ...");
    if (this.isMulti) {
        String filepath = this.location + "/" + this.filename;
        System.out.println("文件保存路径:" + filepath);
        try (FileChannel writeChannel = new FileOutputStream(new File(this.location + "/" + this.filename)).getChannel()) {
            for (String partName : partNames) {
                File tmp = new File(partName);
                FileInputStream inputStream = new FileInputStream(tmp);
                try (FileChannel readChannel = inputStream.getChannel()) {
                    ByteBuffer byteBuffer = ByteBuffer.allocate(1024);
                    while (readChannel.read(byteBuffer) != -1) {
                        byteBuffer.flip();
                        writeChannel.write(byteBuffer);
                        byteBuffer.clear();
                    }
                }
                // 合并完成后删除临时文件
                tmp.deleteOnExit();
            }
        }
    }
    System.out.println("finished");
}

    /**
     * 获取文件名、文件大小等
     */
    private void init() throws Exception {
        if (StringUtil.isBlank(this.url)) {
            throw new Exception("url不存在");
        }
        HttpResponse response = HttpRequest.head(url).send();
        // 获取文件大小
        String contentLength = response.header("Content-Length");
        if (StringUtil.isBlank(contentLength)) {
            System.out.println("无法获取到文件大小,只能进行单线程下载");
            this.threadNum = 1;
            this.isMulti = false;
            this.contentLength = 0;
        } else {
            this.contentLength = Long.parseLong(contentLength);
            if (this.threadNum > 1) {
                this.isMulti = true;
            }
        }

        // 获取文件名
        // 文件名可以从响应头的Content-Disposition获取url中获得,这里我们优先考虑前者
        // 当然,有的下载链接可能获取不到文件名,这就需要自定义规则或者手动输入了,这里我直接命名为时间戳
        String contentDisposition = response.header("Content-Disposition");
        // Content-Disposition中获取不到文件名
        if (StringUtil.isBlank(contentDisposition) || !contentDisposition.contains("filename=")) {
            int index = this.url.lastIndexOf("/");
            if (index != -1) {
                this.filename = this.url.substring(index + 1);
            } else {
                this.filename = String.valueOf(System.currentTimeMillis());
            }
        } else {
            int index = contentDisposition.indexOf("filename=");
            // 对取得的文件名信息进行转码,此处是针对百度云的链接进行转码(ISO8859-1B编码),其他规则以后遇到再添加
            this.filename = new String(contentDisposition.substring(index + 10, contentDisposition.length() - 1).getBytes("ISO8859-1"), "utf-8");
        }
        if (this.threadNum == 1) {
            this.partNames[0] = this.location + "/" + this.filename;
        } else {
            for (int i = 0; i < this.partNames.length; i++) {
                partNames[i] = this.location + "/" + this.filename + i;
            }
        }
        System.out.println(this);
    }

    public void run() throws Exception {
        init();
        long start = System.currentTimeMillis();
        if (isMulti) {
            multiDownload();
        } else {
            singleDownload();
        }
        long end = System.currentTimeMillis();
        System.out.println("速度:" + this.contentLength / (end - start) + "kb/s");
    }

    @Override
    public String toString() {
        return String.format("url: %s\nisMulti: %s\nthreadNum: %d\ncontentLength: %d\nfilename: %s\npartNames: %s\nlocation: %s",
                url, isMulti, threadNum, contentLength, filename, Arrays.toString(partNames), location);
    }

    private String getRealUrl(String url) {
        String tmpUrl = url;
        if (StringUtil.isBlank(url)) {
            return "";
        }
        HttpResponse response;
        int statusCode;
        while (true) {
            response = HttpRequest.head(url).send();
            statusCode = response.statusCode();
            if (statusCode == HttpStatus.HTTP_MOVED_PERMANENTLY || statusCode == HttpStatus.HTTP_MOVED_TEMPORARY) {
                url = response.header("Location");
            } else {
                break;
            }
        }
        if (statusCode == HttpStatus.HTTP_OK) {
            return url;
        } else {
            System.out.printf("获取真实链接失败,StatusCode: %d", statusCode);
            return tmpUrl;
        }
    }

    public static void main(String[] args) throws Exception {
        Downloader downloader = new Downloader(
                "https://dldir1.qq.com/qqfile/qq/QQ8.9.6/22404/QQ8.9.6.exe",
                "/home/loong/Downloads",
                8);
        downloader.run();
    }
}