我们把时钟拨到 11 年前,2007 年,在第 34 届 SIGGRAPH 2007 数字图形学年会上,以色列的两位教授 Shai Avidan 和 Ariel Shamir 展示了一种新的缩放裁剪图像方法,他们称之为 Seam Carving for Content-Aware Image Resizing,也就是我们后来所说的“接缝剪裁”(Seam Carving)算法。
这个算法能实现什么效果呢?
这项技术能计算出图像上的“关键部分”和“不重要区域”,从而使得随意改变一个图像的高宽比但“不会让图像内容变得扭曲变形”。
算法还能判断出图片里哪部分重要哪部分不重要?
简单的说,利用这个技术我们可以在缩放时固定图片中特定区域的大小,或者可以在缩小时让特定的区块被周围图像缝合消除,并且因为“seam carving”的缝补算法,你可以让图片缩放后仍然维持整体的完整性。
举实际应用的例子来说,利用 Seam Carving 算法我们可以将原本窄镜头的夕阳照片,修改成广角镜头的夕阳照片,且照片中心的太阳不会因为图片拉宽而变形;或者我们可以将原本中间隔著距离的两人合照,修改成靠在一起的合照,且图片也不会因为修改变形。这是一个很有趣也让人觉得很厉害的技术,是你从没有玩过的船新版本切图工具。
接缝剪裁算法这种很新颖的技术,能让我们在没有损失图像中重要内容的情况下裁切图像。因此它又常被称为“内容感知”裁剪或图像重定向。
到底这种算法有多奇妙?我们看下面这个图:
注意画面中的各个元素的间距
使用接缝剪裁算法,我们可以把它变成这样:
重要元素没有丢,没有变形
可以看到,图片中的大部分重要内容比如小船都完整的保存了下来。算法移除了一些岩石以及湖水(所以我们看到图中的小船离得更近了)。这就是接缝剪裁算法的神奇之处,它能在调整图像大小本身的同时,也能保留图像中最重要最突出的内容。如果我们在切图时,既想获得合适的图像大小,也想保留图像的完整内容,使用传统的切图方法几乎无法做到。而使用接缝剪裁算法就能实现二者兼得。
关于算法的核心原理,在原论文中解释的非常清楚了,网上也有很多解析文章,这里不再赘述。在本文我(作者Karthik Karanth——译者注)就以上面所举的例子为素材,重点讲讲如何用Python基本实现接缝剪裁算法。
算法论文地址:
工作过程概览
在接缝裁剪(seam carving)算法中,缝隙(seam)就是指从左到右或从上到下的连续像素,它们横向或纵向穿过整个图像。
因此,为了执行缝隙拼接,我们需要两个重要的输入:1.原始图片:我们想要调整大小的图片。
2.能量图(energy map): 我们从原始图像导出的能量图。
能量图应该代表图像的最显著的区域。通常,我们使用梯度幅度,熵图或显著图表示。
算法工作过程如下所示:
为每个像素分配一个能量值
找到能量值最小的像素的八连通路径
删除路径中的所有像素
重复前面1-3步,直到删除的行/列数量达到理想状态
在本文,我们会假设只想裁切图像的宽度,也就是只删除列。但是同样的方法也能用于删除行。
下面是我们需要导入的环境依赖:
import sys
import numpy as np
from imageio import imread, imwrite
from scipy.ndimage.filters import convolve
# tqdm并非必需,但能为提供很美观的进度条,方便我们查看进度
from tqdm import trange
能量图
第一步是为每个像素计算出一个能量值。原作者在论文中定义了很多不同的能量函数供我们使用,我们使用最基本的那个:
那么这到底是啥意思呢?I 指图像,那么这个方程告诉我们的是,对于图像中的每个像素,每个通道,我们执行如下操作:找到X轴中的偏导数
找到Y轴中的偏导数
将它们的绝对值相加。
这就会成为该像素的能量值。那么问题来了“怎么计算图像中的导数?”计算图像的导数有很多种方法,我们这里使用 sobel 滤波器。这是一种卷积内核,可以在图像的每个通道上运行。这里是图像两个不同方向上的滤波器:
我们从直觉上可以认为,第一个滤波器会用其顶部值的差将每个像素替换为其在底部的值。第二个滤波器会用其左边值和右边值的差替换每个像素。这样就能捕捉 3X3 区域像素的整体趋势。实际上,这种方法和边缘检测算法高度相关。
计算能量图就比较简单了:
def calc_energy(img):
filter_du = np.array([
[1.0, 2.0, 1.0],
[0.0, 0.0, 0.0],
[-1.0, -2.0, -1.0],
])
# 这会将它从2D滤波转换为3D滤波器
# 为每个通道:R,G,B复制相同的滤波器
filter_du = np.stack([filter_du] * 3, axis=2)
filter_dv = np.array([
[1.0, 0.0, -1.0],
[2.0, 0.0, -2.0],
[1.0, 0.0, -1.0],
])
# 这会将它从2D滤波转换为3D滤波器
# 为每个通道:R,G,B复制相同的滤波器
filter_dv = np.stack([filter_dv] * 3, axis=2)
img = img.astype('float32')
convolved = np.absolute(convolve(img, filter_du)) + np.absolute(convolve(img, filter_dv))
# 我们将红,绿,蓝通道中的能量相加
energy_map = convolved.sum(axis=2)
return energy_map
我们将能量图可视化:
很明显,具有最小变分的区域,比如天空、静止的水域,都有非常低的能量(较暗的区域)。在我们运行接缝裁剪算法时,被移除的线条会在紧密关联图像中这些区域的同时,试图保存图像中具有高能量的部分(较亮的区域)。
找到能量值最小的缝隙
我们的下一个目标是找到从图像顶部到底部之间具有最小能量值的路径。这条线必须是八连通的线:意味着线条上的每个像素必须在边缘或拐角处彼此相连。比如,下图的红线就是我们要找的缝隙:
那么我们是怎么发现这条线的?很明显(明显??),这个问题可以很好的转化为动态规划概念!
我们创建一个称为 M 的 2D 数组,存储该像素上可见的最小能量值。如果你不熟悉动态规划,这里大概就是说 M[i,j] 会在图像中这个点包含最小能量,同时考虑图像顶部到底部之间所有可能经过这个点的缝隙。所以,需要从图像顶部遍历至图像底部的最小能量值会出现在 M 的最后一行。我们需要从这里回溯,找到在该缝隙中出现的像素列,因此我们会使用这些值和 2D 数组,调用 backtrack。
def minimum_seam(img):
r, c, _ = img.shape
energy_map = calc_energy(img)
M = energy_map.copy()
backtrack = np.zeros_like(M, dtype=np.int)
for i in range(1, r):
for j in range(0, c):
# 处理图像的左侧边缘,确保我们不会索引-1
if j == 0:
idx = np.argmin(M[i - 1, j:j + 2])
backtrack[i, j] = idx + j
min_energy = M[i - 1, idx + j]
else:
idx = np.argmin(M[i - 1, j - 1:j + 2])
backtrack[i, j] = idx + j - 1
min_energy = M[i - 1, idx + j - 1]
M[i, j] += min_energy
return M, backtrack
从具有最小能量值的缝隙中删除像素
然后我们移除具有最小能量值的缝隙,返回一个新图像:
def carve_column(img):
r, c, _ = img.shape
M, backtrack = minimum_seam(img)
# 创建一个(r,c)矩阵,填充值为True
# 后面会从值为False的图像中移除所有像素
mask = np.ones((r, c), dtype=np.bool)
# 找到M的最后一行中的最小元素的位置
j = np.argmin(M[-1])
for i in reversed(range(r)):
# 标记出需要删除的像素
mask[i, j] = False
j = backtrack[i, j]
# 因为图像有3个通道,我们将蒙版转换为3D
mask = np.stack([mask] * 3, axis=2)
# 删除蒙版中所有标记为False的像素,
# 将其大小重新调整为新图像的维度
img = img[mask].reshape((r, c - 1, 3))
return img
在每一列重复此项操作
到了这里我们已经打好了所有的地基!现在,我们反复运行 carve_column 函数,直到删除了理想数量的列。我们创建一个 crop_c 函数,它会将图像和一个比例因子作为输入。如果图像维度为(300,600),我们想把它缩减为(150,600),我们需要输入 0.5 作为参数 scale_c 的值。
def crop_c(img, scale_c):
r, c, _ = img.shape
new_c = int(scale_c * c)
for i in trange(c - new_c): # use range if you don't want to use tqdm
img = carve_column(img)
return img
汇总信息
我们可以添加一个主函数,从如下命令行调用该函数:
def main():
scale = float(sys.argv[1])
in_filename = sys.argv[2]
out_filename = sys.argv[3]
img = imread(in_filename)
out = crop_c(img, scale)
imwrite(out_filename, out)
if __name__ == '__main__':
main()
然后用如下代码运行:
python carver.py 0.5 image.jpg cropped.jpg
现在,cropped.jpg 应该包含如下一张图:
这样我们就用 Python 实现了接缝剪裁算法!
那么行呢?
很简单,只需旋转一下图像,运行 crop_c 就 ok 了!
def crop_r(img, scale_r):
img = np.rot90(img, 1, (0, 1))
img = crop_c(img, scale_r)
img = np.rot90(img, 3, (0, 1))
return img
将如下内容添加至主函数,现在我们也能剪裁行了!
def main():
if len(sys.argv) != 5:
print('usage: carver.py ', file=sys.stderr)
sys.exit(1)
which_axis = sys.argv[1]
scale = float(sys.argv[2])
in_filename = sys.argv[3]
out_filename = sys.argv[4]
img = imread(in_filename)
if which_axis == 'r':
out = crop_r(img, scale)
elif which_axis == 'c':
out = crop_c(img, scale)
else:
print('usage: carver.py ', file=sys.stderr)
sys.exit(1)
imwrite(out_filename, out)
以如下代码运行:
python carver.py r 0.5 image2.jpg cropped.jpg
这时我们就能将下面这张正方形照片:
转换为广角镜头的矩形照片,而且完整的保留了原图的重要内容:
效果拔群
结语
希望本文能帮助你更好的理解接缝裁剪算法,以及用 Python 实现它。我现在正在研究怎么改进这种算法,让它运行的更快一些。一个比较简单的改动会是利用计算出的图像中的同一缝隙,去除图像的多个缝隙。我自己试验了几次,发现这样能使算法运行的更快,每次迭代时去除的缝隙数量越多,算法就越快,不过图像质量会有明显的损失。另一个优化方式是在 GPU 上计算能量图。
以下是完整的程序:
#!/usr/bin/env python
"""Usage: python carver.py Copyright 2018 Karthik Karanth, MIT License"""
import sys
from tqdm import trange
import numpy as np
from imageio import imread, imwrite
from scipy.ndimage.filters import convolve
def calc_energy(img):
filter_du = np.array([
[1.0, 2.0, 1.0],
[0.0, 0.0, 0.0],
[-1.0, -2.0, -1.0],
])
# 这会将它从2D滤波转换为3D滤波器
# 为每个通道:R,G,B复制相同的滤波器
filter_du = np.stack([filter_du] * 3, axis=2)
filter_dv = np.array([
[1.0, 0.0, -1.0],
[2.0, 0.0, -2.0],
[1.0, 0.0, -1.0],
])
# 这会将它从2D滤波转换为3D滤波器
# 为每个通道:R,G,B复制相同的滤波器
filter_dv = np.stack([filter_dv] * 3, axis=2)
img = img.astype('float32')
convolved = np.absolute(convolve(img, filter_du)) + np.absolute(convolve(img, filter_dv))
# 我们计算红,绿,蓝通道中的能量值之和
energy_map = convolved.sum(axis=2)
return energy_map
def crop_c(img, scale_c):
r, c, _ = img.shape
new_c = int(scale_c * c)
for i in trange(c - new_c):
img = carve_column(img)
return img
def crop_r(img, scale_r):
img = np.rot90(img, 1, (0, 1))
img = crop_c(img, scale_r)
img = np.rot90(img, 3, (0, 1))
return img
def carve_column(img):
r, c, _ = img.shape
M, backtrack = minimum_seam(img)
mask = np.ones((r, c), dtype=np.bool)
j = np.argmin(M[-1])
for i in reversed(range(r)):
mask[i, j] = False
j = backtrack[i, j]
mask = np.stack([mask] * 3, axis=2)
img = img[mask].reshape((r, c - 1, 3))
return img
def minimum_seam(img):
r, c, _ = img.shape
energy_map = calc_energy(img)
M = energy_map.copy()
backtrack = np.zeros_like(M, dtype=np.int)
for i in range(1, r):
for j in range(0, c):
# 处理图像的左侧边缘,确保我们不会索引-1
if j == 0:
idx = np.argmin(M[i-1, j:j + 2])
backtrack[i, j] = idx + j
min_energy = M[i-1, idx + j]
else:
idx = np.argmin(M[i - 1, j - 1:j + 2])
backtrack[i, j] = idx + j - 1
min_energy = M[i - 1, idx + j - 1]
M[i, j] += min_energy
return M, backtrack
def main():
if len(sys.argv) != 5:
print('usage: carver.py ', file=sys.stderr)
sys.exit(1)
which_axis = sys.argv[1]
scale = float(sys.argv[2])
in_filename = sys.argv[3]
out_filename = sys.argv[4]
img = imread(in_filename)
if which_axis == 'r':
out = crop_r(img, scale)
elif which_axis == 'c':
out = crop_c(img, scale)
else:
print('usage: carver.py ', file=sys.stderr)
sys.exit(1)
imwrite(out_filename, out)
if __name__ == '__main__':
main()