1. 换脸流程
采用三维重建的方式重建出参考图的3Dshape,获得相应的颜色空间(identity);重建视频中人脸3Dshape,提取相应的vertices(shape);结合参考图的颜色空间以及目标图的vertices,渲染出更换了identity的face。
离线服务提取待换脸视频中人脸图片的3D定点信息,存放于redis中,由于顶点信息至少需要float32精度存放,导致把顶点信息以float32存放于视频中会,一个value会非常的大,故把视频拆分成了很多段存放于redis中,在取redis中信息时,用多线程的方式取出,然后进行换脸服务,比单个串行服务速度快很多。
2. 涉及的技术
人脸三维重建,图像渲染,图像补全,边缘检测,人分割
人脸三维重建:网络采用PRNet
3. 存在的问题及解决办法
抖动:视频中对人脸更换后出现抖动,通过对人脸检测框进行平滑处理可以有效降低抖动程度,确定抖动由人脸检测精度低造成,目前采用face++人脸检测接口进行人脸检测
边缘伪影:由换脸mask造成对边缘出现伪影,通过设置模板mask以及人脸分割,精确得到换脸mak
眼镜:采用图像补全技术,用边缘检测方法获取mask,根据得到等mask对图像进行补全
眼睛转动及嘴巴张闭:由于图像重建后对眼睛和嘴巴是固定对,在换脸mask上去掉相应区域保留视频中人眼和嘴巴
参考图和原图存在色差:通过颜色校正,调整图像亮度
4. 接口说明
版本: 1.0
描述:传入base64编码的二进制图片数据和视频名,把检测到的人脸通过3D人脸重建替换视频中人脸,根据换脸后的视频帧视频
请求方式:post
请求链接:xxxxxxxxxx:9775/ai/v1/FaceSwap
图片要求:
图片格式:JPG(JPEG),PNG
图片像素尺寸:最小 200*200 像素,最大 4096*4096像素
5. 整体架构方案
6. 接口设计
接口请求参数:
参数名 | 必选 | 类型 | 说明 |
requestId | 是 | String | 用于区分每一次请求的唯一的字符串id |
| 是 | String | 图片的base64值 |
videoName | 是 | String | 视频名称 |
token | 是 | String | 服务鉴权标识,AI组统一分配 |
userId | 否 | String | 用户id |
接口返回结果示例:
{
"code": 0,
"msg": "success",
"data": {
"requestId": "100022" ,
"faceSwapRes": True,
"timeUsed": "30.11962342262268"
}
}
接口返回参数说明:
参数名 | 类型 | 说明 |
参数名 | 类型 | 说明 |
| String | 用户请求唯一表示 |
| String | 换脸服务返回结果,成功True,失败False |
| Int | 整个请求所花费的时间,单位为毫秒 |
接口状态码code:
状态码 | 状态说明 |
0 | 成功 |
2 | 未检测出人脸 |
3 | 鉴权失败 |
4 | 参数无效 |
5 | 图片尺寸不符合超出范围 |
6 | 请求异常 |
7. 代码如下:
# encoding:utf-8
from meinheld import server
from flask import Flask, request
from skimage.io import imread, imsave
from concurrent.futures import ThreadPoolExecutor, wait, ALL_COMPLETED, FIRST_COMPLETED, as_completed
import logging
from logging.handlers import TimedRotatingFileHandler
import json
import base64
import hashlib
from threading import Thread, Lock
from PIL import Image
from io import BytesIO
from conf import config
import os
from api import PRN
from glass_judge import *
# from utils.render import render_texture,render_texture_v1
# from utils.estimate_pose import rotate_pos
import cv2
import redis
# from face_segmentation.face_segment import FaceSegment
from face_segmentation.face_segment import FaceSegmentFCN
from mesh.render import render_colors
from faceDetect.face_detection import FaceDetector, FaceTracker
from face_align import FaceAligner_v1
from Pluralistic.FaceEdit import FaceEditor, CropLayer
app = Flask(__name__)
def setLog():
log_fmt = '%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s'
formatter = logging.Formatter(log_fmt)
fh = TimedRotatingFileHandler(
filename="log/run_faceswap_server" + str(time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())) + ".log",
when="H", interval=1,
backupCount=72)
fh.setFormatter(formatter)
logging.basicConfig(level=logging.INFO)
log = logging.getLogger()
log.addHandler(fh)
setLog()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
model_path = "./checkpoint/snapshot/checkpoint_epoch_1000.pth.tar"
# faceLandMark = FaceLandmarkModel(model_path)
# segmentor = FaceSegment('./face_segmentation/checkpoints/model.pt')
fcn_segmentor = FaceSegmentFCN('./face_segmentation/weights/Keras_FCN8s_face_seg_YuvalNirkin.h5')
MODEL_PATH = './faceDetect/model_new.pb'
face_detector = FaceDetector(MODEL_PATH, gpu_memory_fraction=0.25, visible_device_list='0')
face_aligner = FaceAligner_v1()
# cv2.dnn_registerLayer('Crop', CropLayer)
prn = PRN(is_dlib=True)
editor = FaceEditor()
executor = ThreadPoolExecutor(config.threadPoolSize)
# 创建链接到redis数据库的对象
pool = redis.ConnectionPool(host=config.redisHost, port=config.redisPort, password=config.redisPassword,
max_connections=config.maxConnections)
redisDb = redis.Redis(connection_pool=pool)
lock = Lock()
swap_threads = []
frame_dict_list = dict()
all_task = list()
imageList = [""]*5000
frame_count_all = 0
fps = 25
w = 255
h = 255
def colorTransfer(src, dst, mask=None):
if mask is None:
h, w, c = dst.shape
x = np.array(np.arange(w))
y = np.array(np.arange(h))
X, Y = np.meshgrid(x, y)
X = np.reshape(X, (w * h,))
Y = np.reshape(Y, (w * h,))
maskIndices = (X, Y)
else:
# indeksy nie czarnych pikseli maski
maskIndices = np.where(mask != 0)
transferredDst = np.copy(dst)
# src[maskIndices[0], maskIndices[1]] zwraca piksele w nie czarnym obszarze maski
maskedSrc = src[maskIndices[0], maskIndices[1]].astype(np.int32)
maskedDst = dst[maskIndices[0], maskIndices[1]].astype(np.int32)
meanSrc = np.mean(maskedSrc, axis=0)
meanDst = np.mean(maskedDst, axis=0)
maskedDst = maskedDst - meanDst
maskedDst = maskedDst + meanSrc
maskedDst = np.clip(maskedDst, 0, 255)
transferredDst[maskIndices[0], maskIndices[1]] = maskedDst
return transferredDst
def swapThread(alpha, new_colors, frame_key, frame_val, videoPath):
start_time = time.time()
logging.info(f"frame_key is: {str(frame_key)}")
global fps
if frame_key == "fps":
fps = frame_val.get("fps")
frame_count = frame_key.split(":")[0]
frame_val = eval(frame_val)
vertices = frame_val.get("vertices")
logging.info("vertices")
fps = int(float(frame_val.get("fps")))
new_mask = cv2.imread(videoPath + str(frame_count) + "_new_mask.jpg")
new_mask = cv2.cvtColor(new_mask, cv2.COLOR_BGR2GRAY)
new_mask = np.where(new_mask < 1, 0, 1)
# image = base64.b64decode(image)
# img = plt.imread(BytesIO(image), "jpg")
# image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
time_image = time.time()
image = cv2.imread(videoPath + str(frame_count) + ".jpg")
if image is None or not image.data or len(image) < 1:
return False
global h
global w
[h, w, _] = image.shape
im_size = (w, h)
vertices = np.fromstring(vertices, dtype=np.float32)
# vertices = np.fromstring(vertices, dtype=np.float16)
vertices = vertices.astype(np.float32).copy()
vertices = vertices.reshape((43867, -1)) # (43867,3)
new_image = render_colors(vertices, prn.triangles, new_colors, h, w) #3D人脸融合
new_image = (255 * new_image).astype(np.uint8)
# 去掉嘴部mask,目的保留视频中人脸嘴部,使得嘴部可以张开漏出牙齿
# 根据视频中人脸颜色,校正渲染出的人脸的颜色
# new_image = correct_colours(image, new_image, landmark[:,:2])
new_image = colorTransfer(image, new_image, new_mask)
# 合并渲染出的人脸和视频中的人脸
swap_image = image * (1 - new_mask[:, :, np.newaxis]) + \
new_image * alpha * new_mask[:, :, np.newaxis] + \
image * (1 - alpha) * new_mask[:, :, np.newaxis]
# 得到泊松缝合中心位置
r = cv2.boundingRect((new_mask * 255).astype(np.uint8))
center = ((r[0] + np.round(r[2] / 2), r[1] + np.round(r[3] / 2)))
center = tuple(map(int, center))
if image is None or not image.data or len(image) < 1:
return False
output = cv2.seamlessClone(swap_image.astype(np.uint8), image,
(new_mask * 255).astype(np.uint8), center, cv2.NORMAL_CLONE)
out = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
if out is None or not out.data or len(out) < 1:
return False
logging.info(f"swap merge face cost time is: {str(time.time() - start_time)}")
# print(f"swap merge face cost time is: {str(time.time() - start_time)}")
time1 = time.time()
ret, buf = cv2.imencode(".jpg", out)
out_base64 = base64.b64encode(buf)
lock.acquire()
global imageList
imageList[int(float(frame_count))] = out_base64
lock.release()
return True
def faceSwapRun(alpha, new_colors, frame_dict, videoPath, imageList):
fps = 25
w = 255
h = 255
for frame_key, frame_val in frame_dict.items():
start_time = time.time()
logging.info(f"frame_key is: {str(frame_key)}")
# global fps
if frame_key == "fps":
fps = frame_val.get("fps")
continue
frame_count = frame_key.split(":")[0]
frame_val = eval(frame_val)
vertices = frame_val.get("vertices")
logging.info("vertices")
fps = int(float(frame_val.get("fps")))
new_mask = cv2.imread(videoPath + str(frame_count) + "_new_mask.jpg")
new_mask = cv2.cvtColor(new_mask, cv2.COLOR_BGR2GRAY)
new_mask = np.where(new_mask < 1, 0, 1)
time_image = time.time()
image = cv2.imread(videoPath + str(frame_count) + ".jpg")
if image is None or not image.data or len(image) < 1:
continue
[h, w, _] = image.shape
im_size = (w, h)
vertices = np.fromstring(vertices, dtype=np.float32)
# vertices = np.fromstring(vertices, dtype=np.float16)
vertices = vertices.astype(np.float32).copy()
vertices = vertices.reshape((43867, -1)) # (43867,3)
new_image = render_colors(vertices, prn.triangles, new_colors, h, w) # 从这开始 结合
new_image = (255 * new_image).astype(np.uint8)
# 去掉嘴部mask,目的保留视频中人脸嘴部,使得嘴部可以张开漏出牙齿
# 根据视频中人脸颜色,校正渲染出的人脸的颜色
# new_image = correct_colours(image, new_image, landmark[:,:2])
new_image = colorTransfer(image, new_image, new_mask)
print(new_image.shape)
print(image.shape)
# 合并渲染出的人脸和视频中的人脸
swap_image = image * (1 - new_mask[:, :, np.newaxis]) + \
new_image * alpha * new_mask[:, :, np.newaxis] + \
image * (1 - alpha) * new_mask[:, :, np.newaxis]
# 得到泊松缝合中心位置
r = cv2.boundingRect((new_mask * 255).astype(np.uint8))
center = ((r[0] + np.round(r[2] / 2), r[1] + np.round(r[3] / 2)))
center = tuple(map(int, center))
if image is None or not image.data or len(image) < 1:
continue
output = cv2.seamlessClone(swap_image.astype(np.uint8), image,
(new_mask * 255).astype(np.uint8), center, cv2.NORMAL_CLONE)
out = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
if out is None or not out.data or len(out) < 1:
continue
logging.info(f"swap merge face cost time is: {str(time.time() - start_time)}")
# print(f"swap merge face cost time is: {str(time.time() - start_time)}")
time1 = time.time()
ret, buf = cv2.imencode(".jpg", out)
out_base64 = base64.b64encode(buf)
print(f"encode base64 cost time is: {str(time.time() - time1)}")
print("frame_count is : ", frame_count)
# global imageList
imageList[int(float(frame_count))] = out_base64
return imageList, fps, w, h
def get_redis(video_key, redisDb, i, alpha, new_colors, videoPath):
logging.info("key is: " + video_key)
frame_dict = eval(redisDb.get(video_key))
global frame_count_all
frame_count_all += len(frame_dict)
for frame_key, frame_val in frame_dict.items():
swapThread(alpha, new_colors, frame_key, frame_val, videoPath)
return True
def get_redis1(video_key, redisDb, i, alpha, new_colors, videoPath):
logging.info("key is: " + video_key)
frame_dict = eval(redisDb.get(video_key))
global frame_count_all
frame_count_all += len(frame_dict)
global frame_dict_list
# frame_dict_list.append(frame_dict)
frame_dict_list.update(frame_dict)
return True
def faceSwap(ref_image, video_id, prn, videoPath):
try:
begin_time = time.time()
# 人脸加权比例
alpha = 0.8
# read referance image and get the color space
# ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB)
ref_image = face_aligner.aligner(ref_image)
h, w, _ = ref_image.shape
boxes, _ = face_detector(ref_image)
ref_pos = prn.process(ref_image, image_info=boxes[0])
# ref_pos = prn.process(ref_image)
logging.info("faceDetector and prn.process cost time: " + str(time.time() - begin_time))
ref_image = ref_image / 255.
ref_texture = cv2.remap(ref_image, ref_pos[:, :, :2].astype(np.float32), None, interpolation=cv2.INTER_NEAREST,
borderMode=cv2.BORDER_CONSTANT, borderValue=(0))
new_colors = prn.get_colors_from_texture(ref_texture) # 获取重建出来的ref_texture上的点的颜色值
logging.info("to remap get colors cost time: " + str(time.time() - begin_time))
# 获取脸部mask颜色值
redis_time = time.time()
global all_task
global frame_count_all
frame_count_all = 0
for i in range(120):
video_key = video_id + "-" + str(i + 1)
if redisDb.exists(video_key):
# redis_thread = executor.submit(get_redis1, video_key, redisDb, str(i + 1), alpha, new_colors, videoPath)
redis_thread = executor.submit(get_redis, video_key, redisDb, str(i + 1), alpha, new_colors, videoPath)
all_task.append(redis_thread)
# executor.shutdown(wait=True)
# wait(all_task, return_when=ALL_COMPLETED)
for future in as_completed(all_task):
data = future.result()
logging.info(f"in main: get page {str(data)}s success")
# frame_dict = eval(redisDb.get(video_id))
logging.info("get redis val cost time: " + str(time.time() - redis_time))
print("get redis val cost time: " + str(time.time() - redis_time))
logging.info("frame_dict_list len is: " + str(len(frame_dict_list)))
# 提取关键点
logging.info("threads swap face cost time: " + str(time.time() - begin_time))
# imageList, fps, w, h = faceSwapRun(alpha, new_colors, frame_dict, videoPath, imageList)
# return imageList, fps, w, h
# global swap_threads
# for frame_key, frame_val in frame_dict_list.items():
# # swapThread(alpha, new_colors, frame_key, frame_val, videoPath, imageList)
#
# thread = Thread(target=swapThread, args=(alpha, new_colors, frame_key, frame_val, videoPath))
# swap_threads.append(thread)
# thread.start()
# for t in swap_threads:
# t.join()
return True
except Exception as ex:
logging.exception(ex)
return False
@app.route('/ai/v1/FaceSwap', methods=['POST'])
def faceSwapMethod():
try:
start_time = time.time()
resParm = request.data
# 转字符串
resParm = str(resParm, encoding="utf-8")
resParm = eval(resParm)
requestId = resParm.get('requestId')
# 服务鉴权
token = resParm.get('token')
if not token:
res = {'code': 3, 'msg': 'token fail'}
logging.error("code: 3 msg: token fail ")
return json.dumps(res)
videoId = resParm.get("videoName")
if videoId is None or videoId.strip() == '':
res = {'code': 7, 'msg': 'videoName is null'}
logging.error("code: 3 msg: videoName is null")
# 按照debase64进行处理
modelImg_base64 = resParm.get("inputImage")
if not modelImg_base64:
res = {'code': 4, 'msg': ' picture param invalid'}
logging.error("code: 4 msg: picture param invalid")
return json.dumps(res)
modelImg_data_1 = None
if is_has_glass(modelImg_base64):
modelImg = base64.b64decode(modelImg_base64)
modelImg_data = np.fromstring(modelImg, np.uint8)
modelImg_data_1 = cv2.imdecode(modelImg_data, cv2.IMREAD_COLOR)
image = cv2.cvtColor(modelImg_data_1, cv2.COLOR_BGR2RGB)
res = editor.removeglasses(image)
modelImg_data_1 = res[0]
img = cv2.cvtColor(modelImg_data_1, cv2.COLOR_BGR2RGB)
cv2.imwrite("glass_img.jpg", img)
else:
modelImg = base64.b64decode(modelImg_base64)
# recv_time = time.time()
# logging.info(f"recv image cost time: {str(recv_time - start_time)}")
modelImg_data = np.fromstring(modelImg, np.uint8)
modelImg_data_1 = cv2.imdecode(modelImg_data, cv2.IMREAD_COLOR)
# cv2.imwrite("modelImg.jpg", modelImg_data_1)
# 判定图片尺寸
if modelImg_data_1.shape[0] > config.size or modelImg_data_1.shape[1] > config.size:
res = {'code': 5, 'msg': ' picture size invalid'}
logging.error("code: 5 msg: picture size invalid")
return json.dumps(res)
logging.info(f"modelImg_data_1 shape: {str(modelImg_data_1.shape)} size: {str(modelImg_data_1.size)}")
time_predict = time.time()
# cv2.imwrite("upload_ref.jpg", modelImg_data_1)
modelImg_data_1 = cv2.cvtColor(modelImg_data_1, cv2.COLOR_BGR2RGB)
swapRes = gen_swap_face(modelImg_data_1, videoId, prn)
logging.info(f"face swap cost Time is: {str(time.time() - time_predict)} ")
for t in swap_threads:
t.join()
timeUsed = time.time() - start_time
data = {'requestId': requestId, 'faceSwapRes': str(swapRes), 'timeUsed': str(timeUsed)}
res = {'code': 0, 'msg': 'success', 'data': data}
logging.info(f"code:0 msg:success face swap cost Time is: {str(timeUsed)} ")
return json.dumps(res)
except Exception as e:
logging.exception(e)
res = {'code': 6, 'msg': 'request exception'}
return json.dumps(res)
def gen_swap_face(modelImg_data_1, videoId, prn):
try:
videoName = os.path.basename(videoId)
videoName = videoName.split(".")[0]
refImgMd = hashlib.md5(modelImg_data_1).hexdigest()
videoPath = './img_video/' + videoName + "/"
save_res_path = './img_video/' + videoName + "/" + refImgMd + "/"
if not os.path.exists(save_res_path):
os.makedirs(save_res_path)
time_predict = time.time()
# imageList, fps, w, h = faceSwap(modelImg_data_1, videoId, prn, videoPath)
swapRes = faceSwap(modelImg_data_1, videoId, prn, videoPath)
if not swapRes:
return False
print(f"face swap Method cost Time is: {str(time.time() - time_predict)} ")
global imageList
global fps
global w
global h
global frame_count_all
im_size = (w, h)
out = None
logging.info(f"imageList len is: {str(len(imageList))}")
if len(imageList) < 1:
return False
start_time = time.time()
# for image in imageList:
for i in range(frame_count_all):
image = imageList[i]
if image is None or len(image) < 1:
continue
image = base64.b64decode(image)
img = plt.imread(BytesIO(image), "jpg")
image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if out is None:
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(save_res_path + videoName + "-" + refImgMd + ".mp4", fourcc, fps, im_size, True)
out.write(image)
# logging.info(f"imageList len is: {str(len(imageList))}")
# logging.info(f"img_size is: {str(im_size)}")
# print(str(i) + "index frame_count_all ", frame_count_all)
logging.info("image List to merge face video cost: " + str(time.time() - start_time))
return True
except Exception as x:
logging.exception(x)
return False
def save_video_face(videoName):
cap = cv2.VideoCapture(videoName)
fps = cap.get(cv2.CAP_PROP_FPS)
im_size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
out = None
frameId = 0
while True:
ret, frame = cap.read()
if not ret:
break
if frameId < 350:
frameId += 1
continue
if out is None:
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter("test_swap1_200f.mp4", fourcc, fps, im_size, True)
out.write(frame)
frameId += 1
if __name__ == "__main__":
logging.info('Starting the server...')
server.listen(("0.0.0.0", 9775))
server.run(app)
# app.run(host='0.0.0.0', port=18885, threaded=True)