如何有效地将三通道图像的每个像素映射到一个通道?
问题内容:
我正在编写一个python程序来预处理图像,以用作语义分割任务的标签。原始图像具有三个通道,其中代表每个像素的三个值的矢量表示该像素的类别标签。例如,[0,0,0]的像素可以是1类,[0,0,255]可以是2类,依此类推。
我需要将这些图像转换为单通道图像,像素值从0开始并依次增加以表示每个类。本质上,我需要将旧图像中的[0,0,0]转换为新图像中的0,将[0,0,255]转换为1,以此类推。
图像分辨率很高,宽度和高度超过2000像素。我需要对数百张图像执行此操作。我目前使用的方法包括遍历每个像素并将3维值替换为相应的标量值。
filename="file.png"
label_list = [[0,0,0], [0,0,255]] # for example. there are more classes like this
image = imread(filename)
new_image = np.empty((image.shape[0], image.shape[1]))
for i in range(image.shape[0]):
for j in range(image.shape[1]):
for k, label in enumerate(label_list):
if np.array_equal(image[i][j], label):
new_image[i][j] = k
break
imsave("newname.png", new_image)
问题是上述程序效率很低,并且每个图像需要花费几分钟才能运行。这太多了,无法处理我的所有图像,因此我需要对其进行改进。
首先,我认为通过转换label_list
为numpy数组并使用,np.where
可以删除最里面的循环。但是,我不确定如何在np.where
二维数组中查找一维数组以及它是否会有所改善。
从这个线程开始,我试图定义一个函数并将其直接应用于图像。但是,我需要将每个3维标签映射到一个标量。字典不能包含列表作为键。会有更好的方法来这样做,这会有所帮助吗?
有没有办法提高效率,或者有更好的方法来完成上述程序的工作?
谢谢。
问题答案:
方法1
这是views
和np.searchsorted
-
# https://stackoverflow.com/a/45313353/ @Divakar
def view1D(a, b): # a, b are arrays
a = np.ascontiguousarray(a)
b = np.ascontiguousarray(b)
void_dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
return a.view(void_dt).ravel(), b.view(void_dt).ravel()
# Trace back a 2D array back to given labels
def labelrows(a2D, label_list):
# Reduce array and labels to 1D
a1D,b1D = view1D(a2D, label_list)
# Use searchsorted to trace back label indices
sidx = b1D.argsort()
return sidx[np.searchsorted(b1D, a1D, sorter=sidx)]
因此,要将其用于3D
图像阵列,我们需要将高度和宽度合并为一个维度,并使颜色通道保持原样暗淡,并使用标签功能。
方法#2
针对具有[0,255]
范围的图像元素进行了调整,我们可以利用矩阵乘法来降低维数,从而进一步提高性能,就像这样-
def labelpixels(img3D, label_list):
# scale array
s = 256**np.arange(img.shape[-1])
# Reduce image and labels to 1D
img1D = img.reshape(-1,img.shape[-1]).dot(s)
label1D = np.dot(label_list, s)
# Use searchsorted to trace back label indices
sidx = label1D.argsort()
return sidx[np.searchsorted(label1D, img1D, sorter=sidx)]
样本运行如何扩展图像大小并验证-
In [194]: label_list = [[0,255,255], [0,0,0], [0,0,255], [255, 0, 255]]
In [195]: idx = [2,0,3,1,0,3,1,2] # We need to retrieve this back
In [196]: img = np.asarray(label_list)[idx].reshape(2,4,3)
In [197]: img
Out[197]:
array([[[ 0, 0, 255],
[ 0, 255, 255],
[255, 0, 255],
[ 0, 0, 0]],
[[ 0, 255, 255],
[255, 0, 255],
[ 0, 0, 0],
[ 0, 0, 255]]])
In [198]: labelrows(img.reshape(-1,img.shape[-1]), label_list)
Out[198]: array([2, 0, 3, 1, 0, 3, 1, 2])
In [217]: labelpixels(img, label_list)
Out[217]: array([2, 0, 3, 1, 0, 3, 1, 2])
最后,输出需要重新调整为2D
-
In [222]: labelpixels(img, label_list).reshape(img.shape[:-1])
Out[222]:
array([[2, 0, 3, 1],
[0, 3, 1, 2]])