一键抠图3:Android实现人像抠图 (Portrait Matting)

目录

一键抠图3:Android实现人像抠图 (Portrait Matting)

1. 前言

2. 抠图算法

3. 模型Android部署

(1) 将Pytorch模型转换ONNX模型

(2) 将ONNX模型转换为TNN模型

(3) Android端上部署模型

(4) Android测试效果 

(5) 运行APP闪退:dlopen failed: library "libomp.so" not found

4.Android项目源码下载

5.人像抠图C++版本

6.人像抠图Python版本


1. 前言

这是一键抠图项目系列之《Android实现人像抠图 (Portrait Matting)》;本篇主要分享将Python训练后的matting模型部署到Android平台。我们将开发一个简易的、可实时运行的人像抠图Android Demo。Android版本人像抠图模型推理支持CPU和GPU加速,在GPU(OpenCL)加速下,可以达到头发细致级别的人像抠图效果,为了方便后续模型工程化和Android平台部署,项目提供高精度版本人像抠图和轻量化快速版人像抠图,并开发了Python/C++/Android多个版本;

一键抠图3:Android实现人像抠图 (Portrait Matting)_matting

先展示一下Android版本一键抠图效果:

模型选择

原图

高精度人像抠图

视频抠图

一键抠图3:Android实现人像抠图 (Portrait Matting)_matting_02

一键抠图3:Android实现人像抠图 (Portrait Matting)_matting_03

一键抠图3:Android实现人像抠图 (Portrait Matting)_matting_04

一键抠图3:Android实现人像抠图 (Portrait Matting)_matting_05



更多项目《一键抠图》系列文章请参考:

  1. 一键抠图1:Python实现人像抠图 (Portrait Matting)
  2. 一键抠图2:C/C++实现人像抠图 (Portrait Matting)
  3. 一键抠图3:Android实现人像抠图 (Portrait Matting)

一键抠图3:Android实现人像抠图 (Portrait Matting)_人像抠图_06


2. 抠图算法

基于深度学习的Matting分为两大类:

  • 一种是基于辅助信息输入。即除了原图和标注图像外,还需要输入其他的信息辅助预测。最常见的辅助信息是Trimap,即将图片划分为前景,背景及过度区域三部分。另外也有以背景或交互点作为辅助信息。
  • 一种是不依赖任何辅助信息,直接对Alpha进行预测。如本博客复现的MODNet

第一种方法,需要加入辅助信息,而辅助信息一般较难获取,这也限制其应用,为了提升Matting的应用性,针对Portrait Matting领域MODNet摒弃了辅助信息,直接实现Alpha预测,实现了实时Matting,极大提升了基于深度学习Matting的应用价值。

更多抠图算法(Matting),请参考我的一篇博客《图像抠图Image Matting算法调研》:

可能,有小伙伴搞不清楚分割(segmentation)和抠图(matting)有什么区别,我这里简单说明一下:

  •  分割(segmentation):从深度学习的角度来说,分割本质是像素级别的分类任务,其损失函数最简单的莫过于是交叉熵CrossEntropyLoss(当然也可以是Focal Loss,IOU Loss,Dice Loss等);对于前景和背景分割任务,输出Mask的每个像素要么是0,要么是1。如果拿去直接做图像融合,就很不自然,Mask边界很生硬,这时就需要使用抠图算法了
  •  抠图(matting): 而抠图本质是一种回归任务,其损失函数可以是MSE Loss,L1 Loss,L2 Loss等,对于前景和背景抠图任务,输出Mask的每个像素是0~1之间的连续值,可看作是对图像透明通道(Alpha)的回归预测。可以用公式表示为C = αF + (1-α)B ,其中α(不透明度)、F(前景色)和B(背景色),alpha是[0, 1]之间的连续值,可以理解为像素属于前景的概率。在人像分割任务中,alpha只能取0或1,而抠图任务中,alpha可取[0, 1]之间的连续值,
  • 本质上就是一句话:分割是分类任务,而抠图是回归任务。

3. 模型Android部署

目前CNN模型有多种部署方式,可以采用TNN,MNN,NCNN,以及TensorRT等部署工具,鄙人采用TNN进行Android端上部署。部署流程可分为四步:训练模型->将模型转换ONNX模型->将ONNX模型转换为TNN模型->Android端上部署TNN模型。

(1) 将Pytorch模型转换ONNX模型

训练好模型后,你需要先将Pytorch模型转换为ONNX模型,并使用onnx-simplifier简化网络结构,Python版本的已经提供了ONNX转换脚本,终端输入命令如下:

# 导出ONNX模型
python export.py --model_type "modnet" --model_file "work_space/modnet_416/model/best_model.pth"

GitHub: https://github.com/daquexian/onnx-simplifier
Install:  pip3 install onnx-simplifier 

(2) 将ONNX模型转换为TNN模型

目前CNN模型有多种部署方式,可以采用TNN,MNN,NCNN,以及TensorRT等部署工具,鄙人采用TNN进行Android端上部署

TNN转换工具:

一键抠图3:Android实现人像抠图 (Portrait Matting)_android_07


转换成功后,会生成两个文件(*.tnnproto和*.tnnmodel) ,下载下来后面会用到

(3) Android端上部署模型

项目Android部署框架采用TNN,支持多线程CPU和GPU加速推理,在普通手机上可以实时处理。项目Android源码,核心算法均采用C++实现,上层通过JNI接口调用。

如果你想在这个Android Demo部署你自己训练的模型,你可将训练好的Pytorch模型转换ONNX ,再转换成TNN模型,然后把TNN模型代替你模型即可。 

  • 这是项目Android源码JNI接口 ,Java部分

matting接口:实现基本的人像构图Matting功能
fusion接口:实现人像构图Matting,并与背景图进行融合
mattingFusion接口:人像构图Matting,并与背景图进行融合(会返回mask)

package com.cv.tnn.model;
 
import android.graphics.Bitmap;
 
public class Detector {
 
    static {
        System.loadLibrary("tnn_wrapper");
    }
 
 
    /***
     * 初始化检测模型
     * @param proto: TNN *.tnnproto文件文件名(含后缀名)
     * @param model: TNN *.tnnmodel文件文件名(含后缀名)
     * @param root:模型文件的根目录,放在assets文件夹下
     * @param model_type:模型类型
     * @param num_thread:开启线程数
     * @param useGPU:是否使用GPU
     */
    public static native void init(String proto, String model, String root, int model_type, int num_thread, boolean useGPU);
 
    /***
     * 缩放图片
     * @param bitmap
     * @param resize_width
     * @param resize_height
     * @return
     */
    public static Bitmap resizeBitmap(Bitmap bitmap, int resize_width, int resize_height) {
        int width = bitmap.getWidth();
        int height = bitmap.getHeight();
        if (resize_width <= 0 && resize_height <= 0) {
            return bitmap;
        } else if (resize_height <= 0) {
            resize_height = height * resize_width / width;
        } else if (resize_width <= 0) {
            resize_width = width * resize_height / height;
        }
        Bitmap dst = Bitmap.createScaledBitmap(bitmap, resize_width, resize_height, false);
        return dst;
    }
 
 
    /***
     * 人像构图Matting
     * @param bitmap 输入图像(bitmap),ARGB_8888格式
     * @param mask   输出Mask图像(bitmap),ARGB_8888格式,调用前需要createBitmap初始化大小,如
     *               Bitmap.createBitmap(Width, Height, Bitmap.Config.ARGB_8888);
     * @return
     */
    public static native void matting(Bitmap bitmap, Bitmap mask);
 
 
    /***
     * 人像构图Matting,并与背景图进行融合
     * @param bitmap 输入图像(bitmap),ARGB_8888格式
     * @param bgmap  输入背景图像(bitmap),ARGB_8888格式,可任意大小的图像
     * @param fusion 输出与背景融合后图像,调用前需要createBitmap初始化大小,ARGB_8888格式
     */
    public static native void fusion(Bitmap bitmap, Bitmap bgmap, Bitmap fusion);
 
    /***
     * 人像构图Matting,并与背景图进行融合
     * @param bitmap 输入图像(bitmap),ARGB_8888格式
     * @param bgmap  输入背景图像(bitmap),ARGB_8888格式,可任意大小的图像
     * @param fusion 输出与背景融合后图像,调用前需要createBitmap初始化大小,ARGB_8888格式
     * @param mask   输出Mask图像(bitmap),调用前需要createBitmap初始化大小,ARGB_8888格式
     * @return
     */
    public static native void mattingFusion(Bitmap bitmap, Bitmap bgmap, Bitmap fusion, Bitmap mask);
 
 
}
  • 这是Android项目源码JNI接口 ,C++部分
#include <jni.h>
#include <string>
#include <fstream>
#include "src/segment.h"
#include "src/object_detection.h"
#include "src/Types.h"
#include "debug.h"
#include "android_utils.h"
#include "opencv2/opencv.hpp"

using namespace dm;
using namespace vision;

static Segment *segment = nullptr;
static ObjectDetection *detector = nullptr;


JNIEXPORT jint JNI_OnLoad(JavaVM *vm, void *reserved) {
    return JNI_VERSION_1_6;
}

JNIEXPORT void JNI_OnUnload(JavaVM *vm, void *reserved) {

}


extern "C"
JNIEXPORT void JNICALL
Java_com_cv_tnn_model_Detector_init(JNIEnv *env,
                                    jclass clazz,
                                    jstring proto,
                                    jstring model,
                                    jstring root,
                                    jint model_type,
                                    jint num_thread,
                                    jboolean use_gpu) {
    if (segment != nullptr) {
        delete segment;
        segment = nullptr;
    }
    std::string parent = env->GetStringUTFChars(root, 0);
    std::string proto_file = parent + env->GetStringUTFChars(proto, 0);
    std::string model_file = parent + env->GetStringUTFChars(model, 0);
    DeviceType device = use_gpu ? GPU : CPU;
    LOGW("parent     : %s", parent.c_str());
    LOGW("useGPU     : %d", use_gpu);
    LOGW("device_type: %d", device);
    LOGW("model_type : %d", model_type);
    LOGW("num_thread : %d", num_thread);
    SegmentParam model_param = SEG_MODEL_TYPE[model_type];//模型参数
    segment = new Segment(model_file,
                          proto_file,
                          model_param,
                          num_thread,
                          device);

}


extern "C"
JNIEXPORT void JNICALL
Java_com_cv_tnn_model_Detector_matting(JNIEnv *env, jclass clazz, jobject bitmap,
                                       jobject out_mask) {
    cv::Mat image;//bgr
    cv::Mat bg;//bgr
    BitmapToMatrix(env, bitmap, image);
    cv::Mat mask;
    cv::Mat fusion;
    // 检测人像分割
    segment->detect(image, mask);
    // 返回Mask
    MatrixToBitmap(env, mask, out_mask);
}



extern "C"
JNIEXPORT void JNICALL
Java_com_cv_tnn_model_Detector_fusion(JNIEnv *env, jclass clazz,
                                      jobject bitmap,
                                      jobject bgmap,
                                      jobject out_fusion) {
    cv::Mat image;//bgr
    cv::Mat bg;//bgr
    BitmapToMatrix(env, bitmap, image);
    BitmapToMatrix(env, bgmap, bg);
    cv::Mat mask;
    cv::Mat fusion;
    // 检测人像分割
    segment->detect(image, mask);
    // 将matte与背景bg进行融合fusion
    image_fusion(image, mask, fusion, bg);
    // 融合fusion图像
    MatrixToBitmap(env, fusion, out_fusion);
}



extern "C"
JNIEXPORT void JNICALL
Java_com_cv_tnn_model_Detector_mattingFusion(JNIEnv *env, jclass clazz,
                                             jobject bitmap,
                                             jobject bgmap,
                                             jobject out_fusion,
                                             jobject out_mask) {
    cv::Mat image;//bgr
    cv::Mat bg;//bgr
    BitmapToMatrix(env, bitmap, image);
    BitmapToMatrix(env, bgmap, bg);
    cv::Mat mask;
    cv::Mat fusion;
    // 检测人像分割
    segment->detect(image, mask);
    // 将matte与背景bg进行融合fusion
    image_fusion(image, mask, fusion, bg);
    // 融合fusion图像
    MatrixToBitmap(env, fusion, out_fusion);
    MatrixToBitmap(env, mask, out_mask);
}



extern "C"
JNIEXPORT jobjectArray JNICALL
Java_com_cv_tnn_model_Detector_detect(JNIEnv *env, jclass clazz, jobject bitmap,
                                      jfloat score_thresh, jfloat iou_thresh) {
    cv::Mat bgr;
    BitmapToMatrix(env, bitmap, bgr);
    int src_h = bgr.rows;
    int src_w = bgr.cols;
    // 检测区域为整张图片的大小
    FrameInfo resultInfo;
    // 开始检测
    if (detector != nullptr) {
        detector->detect(bgr, &resultInfo, score_thresh, iou_thresh);
    } else {
        ObjectInfo objectInfo;
        objectInfo.x1 = 0;
        objectInfo.y1 = 0;
        objectInfo.x2 = 84;
        objectInfo.y2 = 84;
        objectInfo.label = 0;
        resultInfo.info.push_back(objectInfo);
    }

    int nums = resultInfo.info.size();
    LOGW("object nums: %d\n", nums);

    auto BoxInfo = env->FindClass("com/cv/tnn/model/FrameInfo");
    auto init_id = env->GetMethodID(BoxInfo, "<init>", "()V");
    auto box_id = env->GetMethodID(BoxInfo, "addBox", "(FFFFIF)V");
    auto ky_id = env->GetMethodID(BoxInfo, "addKeyPoint", "(FFF)V");
    jobjectArray ret = env->NewObjectArray(resultInfo.info.size(), BoxInfo, nullptr);
    for (int i = 0; i < nums; ++i) {
        auto info = resultInfo.info[i];
        env->PushLocalFrame(1);
        //jobject obj = env->AllocObject(BoxInfo);
        jobject obj = env->NewObject(BoxInfo, init_id);
        // set bbox
        //LOGW("rect:[%f,%f,%f,%f] label:%d,score:%f \n", info.rect.x,info.rect.y, info.rect.w, info.rect.h, 0, 1.0f);
        env->CallVoidMethod(obj, box_id, info.x1, info.y1, info.x2 - info.x1, info.y2 - info.y1,
                            info.label, info.score);
        // set keypoint
        for (const auto &kps : info.landmarks) {
            //LOGW("point:[%f,%f] score:%f \n", lm.point.x, lm.point.y, lm.score);
            env->CallVoidMethod(obj, ky_id, (float) kps.x, (float) kps.y, 1.0f);
        }
        obj = env->PopLocalFrame(obj);
        env->SetObjectArrayElement(ret, i, obj);
    }
    return ret;
}

(4) Android测试效果 

实际使用中,建议你:

  • 背景越单一,抠图的效果越好,背景越复杂,抠图效果越差;建议你实际使用中,找一比较单一的背景,如墙面,天空等
  • 上半身抠图的效果越好,下半身或者全身抠图效果较差;本质上这是数据的问题,因为训练数据70%都是只有上半身的
  • 白种人抠图的效果越好,黑人和黄种人抠图效果较差;这也是数据的问题,因为训练数据大部分都是隔壁的老外

下图是高精度版本人像抠图和快速人像构图的测试效果,相对而言,高精度版本人像抠图可以精细到发丝级别的抠图效果;而快速人像构图目前仅能实现基本的抠图效果:

原图

 Mask图像

 融合图像

一键抠图3:Android实现人像抠图 (Portrait Matting)_android_08


一键抠图3:Android实现人像抠图 (Portrait Matting)_android_09


一键抠图3:Android实现人像抠图 (Portrait Matting)_matting_10


一键抠图3:Android实现人像抠图 (Portrait Matting)_人像抠图_11

一键抠图3:Android实现人像抠图 (Portrait Matting)_android_12

一键抠图3:Android实现人像抠图 (Portrait Matting)_人像抠图_13

(5) 运行APP闪退:dlopen failed: library "libomp.so" not found

参考解决方法:
解决dlopen failed: library “libomp.so“ not found_PKing666666的博客

 Android SDK和NDK相关版本信息,请参考: 

一键抠图3:Android实现人像抠图 (Portrait Matting)_一键抠图_14

 

一键抠图3:Android实现人像抠图 (Portrait Matting)_人像抠图_15


4.Android项目源码下载

 Android Demo APP下载地址:

Android项目源码下载地址:一键抠图Portrait Matting人像抠图 (C++和Android源码)

整套Android项目源码内容包含:

  1.  Android版本人像抠图算法,支持CPU和GPU
  2. 提供高精度版本人像抠图,可以达到精细到发丝级别的抠图效果(Android GPU 150ms,  CPU 500ms左右)
  3. 提供轻量化快速版人像抠图,满足基本的人像抠图效果,可以在Android达到实时的抠图效果(Android GPU 60ms,  CPU 140ms左右)
  4. Android Demo支持图片,视频,摄像头测试
  5. 所有依赖库都已经配置好,可直接build运行,若运行出现闪退,请参考dlopen failed: library “libomp.so“ not found 解决。

5.人像抠图C++版本

一键抠图2:C/C++实现人像抠图 (Portrait Matting)


6.人像抠图Python版本

一键抠图1:Python实现人像抠图 (Portrait Matting)