I want to convert a 3 channel RGB image to a index image with Python. It\'s used for handling the labels of training a deep net for semantic segmentation. By index image I m
Similar to what Armali and Mendrika proposed, I somehow had to tweak it a little bit to get it to work (maybe totally my fault). So I just wanted to share a snippet that works.
COLORS = np.array([
[0, 0, 0],
[0, 0, 255],
[255, 0, 0]
])
W = np.power(255, [0, 1, 2])
HASHES = np.sum(W * COLORS, axis=-1)
HASH2COLOR = {h : c for h, c in zip(HASHES, COLORS)}
HASH2IDX = {h: i for i, h in enumerate(HASHES)}
def rgb2index(segmentation_rgb):
"""
turn a 3 channel RGB color to 1 channel index color
"""
s_shape = segmentation_rgb.shape
s_hashes = np.sum(W * segmentation_rgb, axis=-1)
func = lambda x: HASH2IDX[int(x)]
segmentation_idx = np.apply_along_axis(func, 0, s_hashes.reshape((1, -1)))
segmentation_idx = segmentation_idx.reshape(s_shape[:2])
return segmentation_idx
segmentation = np.array([[0, 0, 0], [0, 0, 255], [255, 0, 0]] * 3).reshape((3, 3, 3))
rgb2index(segmentation)
Example plot
The code is also available here: https://github.com/theRealSuperMario/supermariopy/blob/dev/scripts/rgb2labels.py