一键抠图3:Android实现人像抠图 (Portrait Matting)
目录
一键抠图3:Android实现人像抠图 (Portrait Matting)
1. 前言
2. 抠图算法
3. 模型Android部署
(1) 将Pytorch模型转换ONNX模型
(2) 将ONNX模型转换为TNN模型
(3) Android端上部署模型
(4) Android测试效果
(5) 运行APP闪退:dlopen failed: library "libomp.so" not found
4.Android项目源码下载
5.人像抠图C++版本
6.人像抠图Python版本
1. 前言
这是一键抠图项目系列之《Android实现人像抠图 (Portrait Matting)》;本篇主要分享将Python训练后的matting模型部署到Android平台。我们将开发一个简易的、可实时运行的人像抠图Android Demo。Android版本人像抠图模型推理支持CPU和GPU加速,在GPU(OpenCL)加速下,可以达到头发细致级别的人像抠图效果,为了方便后续模型工程化和Android平台部署,项目提供高精度版本人像抠图和轻量化快速版人像抠图,并开发了Python/C++/Android多个版本;
先展示一下Android版本一键抠图效果:
模型选择 | 原图 | 高精度人像抠图 | 视频抠图 |
更多项目《一键抠图》系列文章请参考:
- 一键抠图1:Python实现人像抠图 (Portrait Matting)
- 一键抠图2:C/C++实现人像抠图 (Portrait Matting)
- 一键抠图3:Android实现人像抠图 (Portrait Matting)
2. 抠图算法
基于深度学习的Matting分为两大类:
- 一种是基于辅助信息输入。即除了原图和标注图像外,还需要输入其他的信息辅助预测。最常见的辅助信息是Trimap,即将图片划分为前景,背景及过度区域三部分。另外也有以背景或交互点作为辅助信息。
- 一种是不依赖任何辅助信息,直接对Alpha进行预测。如本博客复现的MODNet
第一种方法,需要加入辅助信息,而辅助信息一般较难获取,这也限制其应用,为了提升Matting的应用性,针对Portrait Matting领域MODNet摒弃了辅助信息,直接实现Alpha预测,实现了实时Matting,极大提升了基于深度学习Matting的应用价值。
更多抠图算法(Matting),请参考我的一篇博客《图像抠图Image Matting算法调研》:
可能,有小伙伴搞不清楚分割(segmentation)和抠图(matting)有什么区别,我这里简单说明一下:
- 分割(segmentation):从深度学习的角度来说,分割本质是像素级别的分类任务,其损失函数最简单的莫过于是交叉熵CrossEntropyLoss(当然也可以是Focal Loss,IOU Loss,Dice Loss等);对于前景和背景分割任务,输出Mask的每个像素要么是0,要么是1。如果拿去直接做图像融合,就很不自然,Mask边界很生硬,这时就需要使用抠图算法了
- 抠图(matting): 而抠图本质是一种回归任务,其损失函数可以是MSE Loss,L1 Loss,L2 Loss等,对于前景和背景抠图任务,输出Mask的每个像素是0~1之间的连续值,可看作是对图像透明通道(Alpha)的回归预测。可以用公式表示为C = αF + (1-α)B ,其中α(不透明度)、F(前景色)和B(背景色),alpha是[0, 1]之间的连续值,可以理解为像素属于前景的概率。在人像分割任务中,alpha只能取0或1,而抠图任务中,alpha可取[0, 1]之间的连续值,
- 本质上就是一句话:分割是分类任务,而抠图是回归任务。
3. 模型Android部署
目前CNN模型有多种部署方式,可以采用TNN,MNN,NCNN,以及TensorRT等部署工具,鄙人采用TNN进行Android端上部署。部署流程可分为四步:训练模型->将模型转换ONNX模型->将ONNX模型转换为TNN模型->Android端上部署TNN模型。
(1) 将Pytorch模型转换ONNX模型
训练好模型后,你需要先将Pytorch模型转换为ONNX模型,并使用onnx-simplifier简化网络结构,Python版本的已经提供了ONNX转换脚本,终端输入命令如下:
# 导出ONNX模型
python export.py --model_type "modnet" --model_file "work_space/modnet_416/model/best_model.pth"
GitHub: https://github.com/daquexian/onnx-simplifier
Install: pip3 install onnx-simplifier
(2) 将ONNX模型转换为TNN模型
目前CNN模型有多种部署方式,可以采用TNN,MNN,NCNN,以及TensorRT等部署工具,鄙人采用TNN进行Android端上部署
TNN转换工具:
- (1)将ONNX模型转换为TNN模型,请参考TNN官方说明:TNN/onnx2tnn.md at master · Tencent/TNN · GitHub
- (2)一键转换,懒人必备:一键转换 Caffe, ONNX, TensorFlow 到 NCNN, MNN, Tengine (可能存在版本问题,这个工具转换的TNN模型可能不兼容,建议还是自己build源码进行转换,2022年9约25日测试可用)
转换成功后,会生成两个文件(*.tnnproto和*.tnnmodel) ,下载下来后面会用到
(3) Android端上部署模型
项目Android部署框架采用TNN,支持多线程CPU和GPU加速推理,在普通手机上可以实时处理。项目Android源码,核心算法均采用C++实现,上层通过JNI接口调用。
如果你想在这个Android Demo部署你自己训练的模型,你可将训练好的Pytorch模型转换ONNX ,再转换成TNN模型,然后把TNN模型代替你模型即可。
- 这是项目Android源码JNI接口 ,Java部分
matting接口:实现基本的人像构图Matting功能
fusion接口:实现人像构图Matting,并与背景图进行融合
mattingFusion接口:人像构图Matting,并与背景图进行融合(会返回mask)
package com.cv.tnn.model;
import android.graphics.Bitmap;
public class Detector {
static {
System.loadLibrary("tnn_wrapper");
}
/***
* 初始化检测模型
* @param proto: TNN *.tnnproto文件文件名(含后缀名)
* @param model: TNN *.tnnmodel文件文件名(含后缀名)
* @param root:模型文件的根目录,放在assets文件夹下
* @param model_type:模型类型
* @param num_thread:开启线程数
* @param useGPU:是否使用GPU
*/
public static native void init(String proto, String model, String root, int model_type, int num_thread, boolean useGPU);
/***
* 缩放图片
* @param bitmap
* @param resize_width
* @param resize_height
* @return
*/
public static Bitmap resizeBitmap(Bitmap bitmap, int resize_width, int resize_height) {
int width = bitmap.getWidth();
int height = bitmap.getHeight();
if (resize_width <= 0 && resize_height <= 0) {
return bitmap;
} else if (resize_height <= 0) {
resize_height = height * resize_width / width;
} else if (resize_width <= 0) {
resize_width = width * resize_height / height;
}
Bitmap dst = Bitmap.createScaledBitmap(bitmap, resize_width, resize_height, false);
return dst;
}
/***
* 人像构图Matting
* @param bitmap 输入图像(bitmap),ARGB_8888格式
* @param mask 输出Mask图像(bitmap),ARGB_8888格式,调用前需要createBitmap初始化大小,如
* Bitmap.createBitmap(Width, Height, Bitmap.Config.ARGB_8888);
* @return
*/
public static native void matting(Bitmap bitmap, Bitmap mask);
/***
* 人像构图Matting,并与背景图进行融合
* @param bitmap 输入图像(bitmap),ARGB_8888格式
* @param bgmap 输入背景图像(bitmap),ARGB_8888格式,可任意大小的图像
* @param fusion 输出与背景融合后图像,调用前需要createBitmap初始化大小,ARGB_8888格式
*/
public static native void fusion(Bitmap bitmap, Bitmap bgmap, Bitmap fusion);
/***
* 人像构图Matting,并与背景图进行融合
* @param bitmap 输入图像(bitmap),ARGB_8888格式
* @param bgmap 输入背景图像(bitmap),ARGB_8888格式,可任意大小的图像
* @param fusion 输出与背景融合后图像,调用前需要createBitmap初始化大小,ARGB_8888格式
* @param mask 输出Mask图像(bitmap),调用前需要createBitmap初始化大小,ARGB_8888格式
* @return
*/
public static native void mattingFusion(Bitmap bitmap, Bitmap bgmap, Bitmap fusion, Bitmap mask);
}
- 这是Android项目源码JNI接口 ,C++部分
#include <jni.h>
#include <string>
#include <fstream>
#include "src/segment.h"
#include "src/object_detection.h"
#include "src/Types.h"
#include "debug.h"
#include "android_utils.h"
#include "opencv2/opencv.hpp"
using namespace dm;
using namespace vision;
static Segment *segment = nullptr;
static ObjectDetection *detector = nullptr;
JNIEXPORT jint JNI_OnLoad(JavaVM *vm, void *reserved) {
return JNI_VERSION_1_6;
}
JNIEXPORT void JNI_OnUnload(JavaVM *vm, void *reserved) {
}
extern "C"
JNIEXPORT void JNICALL
Java_com_cv_tnn_model_Detector_init(JNIEnv *env,
jclass clazz,
jstring proto,
jstring model,
jstring root,
jint model_type,
jint num_thread,
jboolean use_gpu) {
if (segment != nullptr) {
delete segment;
segment = nullptr;
}
std::string parent = env->GetStringUTFChars(root, 0);
std::string proto_file = parent + env->GetStringUTFChars(proto, 0);
std::string model_file = parent + env->GetStringUTFChars(model, 0);
DeviceType device = use_gpu ? GPU : CPU;
LOGW("parent : %s", parent.c_str());
LOGW("useGPU : %d", use_gpu);
LOGW("device_type: %d", device);
LOGW("model_type : %d", model_type);
LOGW("num_thread : %d", num_thread);
SegmentParam model_param = SEG_MODEL_TYPE[model_type];//模型参数
segment = new Segment(model_file,
proto_file,
model_param,
num_thread,
device);
}
extern "C"
JNIEXPORT void JNICALL
Java_com_cv_tnn_model_Detector_matting(JNIEnv *env, jclass clazz, jobject bitmap,
jobject out_mask) {
cv::Mat image;//bgr
cv::Mat bg;//bgr
BitmapToMatrix(env, bitmap, image);
cv::Mat mask;
cv::Mat fusion;
// 检测人像分割
segment->detect(image, mask);
// 返回Mask
MatrixToBitmap(env, mask, out_mask);
}
extern "C"
JNIEXPORT void JNICALL
Java_com_cv_tnn_model_Detector_fusion(JNIEnv *env, jclass clazz,
jobject bitmap,
jobject bgmap,
jobject out_fusion) {
cv::Mat image;//bgr
cv::Mat bg;//bgr
BitmapToMatrix(env, bitmap, image);
BitmapToMatrix(env, bgmap, bg);
cv::Mat mask;
cv::Mat fusion;
// 检测人像分割
segment->detect(image, mask);
// 将matte与背景bg进行融合fusion
image_fusion(image, mask, fusion, bg);
// 融合fusion图像
MatrixToBitmap(env, fusion, out_fusion);
}
extern "C"
JNIEXPORT void JNICALL
Java_com_cv_tnn_model_Detector_mattingFusion(JNIEnv *env, jclass clazz,
jobject bitmap,
jobject bgmap,
jobject out_fusion,
jobject out_mask) {
cv::Mat image;//bgr
cv::Mat bg;//bgr
BitmapToMatrix(env, bitmap, image);
BitmapToMatrix(env, bgmap, bg);
cv::Mat mask;
cv::Mat fusion;
// 检测人像分割
segment->detect(image, mask);
// 将matte与背景bg进行融合fusion
image_fusion(image, mask, fusion, bg);
// 融合fusion图像
MatrixToBitmap(env, fusion, out_fusion);
MatrixToBitmap(env, mask, out_mask);
}
extern "C"
JNIEXPORT jobjectArray JNICALL
Java_com_cv_tnn_model_Detector_detect(JNIEnv *env, jclass clazz, jobject bitmap,
jfloat score_thresh, jfloat iou_thresh) {
cv::Mat bgr;
BitmapToMatrix(env, bitmap, bgr);
int src_h = bgr.rows;
int src_w = bgr.cols;
// 检测区域为整张图片的大小
FrameInfo resultInfo;
// 开始检测
if (detector != nullptr) {
detector->detect(bgr, &resultInfo, score_thresh, iou_thresh);
} else {
ObjectInfo objectInfo;
objectInfo.x1 = 0;
objectInfo.y1 = 0;
objectInfo.x2 = 84;
objectInfo.y2 = 84;
objectInfo.label = 0;
resultInfo.info.push_back(objectInfo);
}
int nums = resultInfo.info.size();
LOGW("object nums: %d\n", nums);
auto BoxInfo = env->FindClass("com/cv/tnn/model/FrameInfo");
auto init_id = env->GetMethodID(BoxInfo, "<init>", "()V");
auto box_id = env->GetMethodID(BoxInfo, "addBox", "(FFFFIF)V");
auto ky_id = env->GetMethodID(BoxInfo, "addKeyPoint", "(FFF)V");
jobjectArray ret = env->NewObjectArray(resultInfo.info.size(), BoxInfo, nullptr);
for (int i = 0; i < nums; ++i) {
auto info = resultInfo.info[i];
env->PushLocalFrame(1);
//jobject obj = env->AllocObject(BoxInfo);
jobject obj = env->NewObject(BoxInfo, init_id);
// set bbox
//LOGW("rect:[%f,%f,%f,%f] label:%d,score:%f \n", info.rect.x,info.rect.y, info.rect.w, info.rect.h, 0, 1.0f);
env->CallVoidMethod(obj, box_id, info.x1, info.y1, info.x2 - info.x1, info.y2 - info.y1,
info.label, info.score);
// set keypoint
for (const auto &kps : info.landmarks) {
//LOGW("point:[%f,%f] score:%f \n", lm.point.x, lm.point.y, lm.score);
env->CallVoidMethod(obj, ky_id, (float) kps.x, (float) kps.y, 1.0f);
}
obj = env->PopLocalFrame(obj);
env->SetObjectArrayElement(ret, i, obj);
}
return ret;
}
(4) Android测试效果
实际使用中,建议你:
- 背景越单一,抠图的效果越好,背景越复杂,抠图效果越差;建议你实际使用中,找一比较单一的背景,如墙面,天空等
- 上半身抠图的效果越好,下半身或者全身抠图效果较差;本质上这是数据的问题,因为训练数据70%都是只有上半身的
- 白种人抠图的效果越好,黑人和黄种人抠图效果较差;这也是数据的问题,因为训练数据大部分都是隔壁的老外
下图是高精度版本人像抠图和快速人像构图的测试效果,相对而言,高精度版本人像抠图可以精细到发丝级别的抠图效果;而快速人像构图目前仅能实现基本的抠图效果:
原图 | Mask图像 | 融合图像 | |
高 精 度 人 像 抠 图 | |||
超 快 人 像 抠 图 |
(5) 运行APP闪退:dlopen failed: library "libomp.so" not found
参考解决方法:
解决dlopen failed: library “libomp.so“ not found_PKing666666的博客
Android SDK和NDK相关版本信息,请参考:
4.Android项目源码下载
Android Demo APP下载地址:
Android项目源码下载地址:一键抠图Portrait Matting人像抠图 (C++和Android源码)
整套Android项目源码内容包含:
- Android版本人像抠图算法,支持CPU和GPU
- 提供高精度版本人像抠图,可以达到精细到发丝级别的抠图效果(Android GPU 150ms, CPU 500ms左右)
- 提供轻量化快速版人像抠图,满足基本的人像抠图效果,可以在Android达到实时的抠图效果(Android GPU 60ms, CPU 140ms左右)
- Android Demo支持图片,视频,摄像头测试
- 所有依赖库都已经配置好,可直接build运行,若运行出现闪退,请参考dlopen failed: library “libomp.so“ not found 解决。
5.人像抠图C++版本
一键抠图2:C/C++实现人像抠图 (Portrait Matting)
6.人像抠图Python版本
一键抠图1:Python实现人像抠图 (Portrait Matting)