Python implementation of [Colorization using Optimazation]
Colorization using optimization
This is a python implementation of the paper: Colorization using Optimazation (Anat Levin, Dani Lischinski and Yair Weiss). The idea is that neighboring pixels in a photo should have similar color if their intensity levels are close. Thus it is possible to colorize a black-and-white photo with a little hints.
\[J \big( U \big) = \sum_r \Big( U\big(r\big) - \sum_{s\in N(r)} w_{rs} U(s) \Big)^2\]r: a pixel (x,y)
s: neighboring pixels of point r.
$ w_{rs} $ : weight between points r and s.
$ U \big( r \big) $ : chrominance channel “U” in YUV color space of pixel r
if 2 neighboring pixels have similar intensity (channel Y), then we assume they are more likely to have similar color (channel U & V). Otherwise, they should have less similar color. To represent this weight, we use this affinity function (equation 2 in the paper).
\[w_{rs} \propto \exp \Big( \frac{-\big( Y(r) - Y(s) \big)^2 }{2 \sigma_r^2} \Big)\]$ Y \big( r \big) $ : intensity value (channel Y) in YUV color space of pixel r.
# import packages
import numpy as np
import matplotlib.pyplot as plt
import colorsys
import scipy
import logging
from scipy.misc import imread
np.set_printoptions(precision=8, suppress=True)
# set the photo file path
path_pic = '/Users/larrysu/repos/mconda3/cv01/ex_1.bmp'
path_pic_marked = '/Users/larrysu/repos/mconda3/cv01/ex_1_marked.bmp'
# window width
wd_width = 1
pic_o_rgb = imread(path_pic)
pic_o = pic_o_rgb.astype(float)/255
pic_m_rgb = imread(path_pic_marked)
pic_m = pic_m_rgb.astype(float)/255
fig = plt.figure()
fig.add_subplot(1,2,1).set_title('Black & White')
imgplot = plt.imshow(pic_o)
fig.add_subplot(1,2,2).set_title('Color Hints')
imgplot = plt.imshow(pic_m)
plt.show();
With some human intervention, an image with “color hints” can be used to generate colors for all pixels on the photo. Before start, we need to prepare some useful functions:
# the window class, find the neighbor pixels around the center.
class WindowNeighbor:
def __init__(self, width, center, pic):
# center is a list of [row, col, Y_intensity]
self.center = [center[0], center[1], pic[center][0]]
self.width = width
self.neighbors = None
self.find_neighbors(pic)
self.mean = None
self.var = None
def find_neighbors(self, pic):
self.neighbors = []
ix_r_min = max(0, self.center[0] - self.width)
ix_r_max = min(pic.shape[0], self.center[0] + self.width + 1)
ix_c_min = max(0, self.center[1] - self.width)
ix_c_max = min(pic.shape[1], self.center[1] + self.width + 1)
for r in range(ix_r_min, ix_r_max):
for c in range(ix_c_min, ix_c_max):
if r == self.center[0] and c == self.center[1]:
continue
self.neighbors.append([r,c,pic[r,c,0]])
def __str__(self):
return 'windows c=(%d, %d, %f) size: %d' % (self.center[0], self.center[1], self.center[2], len(self.neighbors))
# affinity functions, calculate weights of pixels in a window by their intensity.
def affinity_a(w):
nbs = np.array(w.neighbors)
sY = nbs[:,2]
cY = w.center[2]
diff = sY - cY
sig = np.var(np.append(sY, cY))
if sig < 1e-6:
sig = 1e-6
wrs = np.exp(- np.power(diff,2) / (sig * 2.0))
wrs = - wrs / np.sum(wrs)
nbs[:,2] = wrs
return nbs
# translate (row,col) to/from sequential number
def to_seq(r, c, rows):
return c * rows + r
def fr_seq(seq, rows):
r = seq % rows
c = int((seq - r) / rows)
return (r, c)
# combine 3 channels of YUV to a RGB photo: n x n x 3 array
def yuv_channels_to_rgb(cY,cU,cV):
ansRGB = [colorsys.yiq_to_rgb(cY[i],cU[i],cV[i]) for i in range(len(ansY))]
ansRGB = np.array(ansRGB)
pic_ansRGB = np.zeros(pic_yuv.shape)
pic_ansRGB[:,:,0] = ansRGB[:,0].reshape(pic_rows, pic_cols, order='F')
pic_ansRGB[:,:,1] = ansRGB[:,1].reshape(pic_rows, pic_cols, order='F')
pic_ansRGB[:,:,2] = ansRGB[:,2].reshape(pic_rows, pic_cols, order='F')
return pic_ansRGB
def init_logger():
FORMAT = '%(asctime)-15s %(message)s'
logging.basicConfig(format=FORMAT, level=logging.DEBUG)
logger = logging.getLogger()
return logger
Prepare the Matrix: A
The matrix A holds all the weights between each pixel. Because only pixels in a window have weight with each other, so this is a sparse matrix of size n × n (n is the total number of pixels)
log = init_logger()
(pic_rows, pic_cols, _) = pic_o.shape
pic_size = pic_rows * pic_cols
channel_Y,_,_ = colorsys.rgb_to_yiq(pic_o[:,:,0],pic_o[:,:,1],pic_o[:,:,2])
_,channel_U,channel_V = colorsys.rgb_to_yiq(pic_m[:,:,0],pic_m[:,:,1],pic_m[:,:,2])
map_colored = (abs(channel_U) + abs(channel_V)) > 0.0001
pic_yuv = np.dstack((channel_Y, channel_U, channel_V))
weightData = []
num_pixel_bw = 0
# build the weight matrix for each window.
for c in range(pic_cols):
for r in range(pic_rows):
res = []
w = WindowNeighbor(wd_width, (r,c), pic_yuv)
if not map_colored[r,c]:
weights = affinity_a(w)
for e in weights:
weightData.append([w.center,(e[0],e[1]), e[2]])
weightData.append([w.center, (w.center[0],w.center[1]), 1.])
sp_idx_rc_data = [[to_seq(e[0][0], e[0][1], pic_rows), to_seq(e[1][0], e[1][1], pic_rows), e[2]] for e in weightData]
sp_idx_rc = np.array(sp_idx_rc_data, dtype=np.integer)[:,0:2]
sp_data = np.array(sp_idx_rc_data, dtype=np.float64)[:,2]
matA = scipy.sparse.csr_matrix((sp_data, (sp_idx_rc[:,0], sp_idx_rc[:,1])), shape=(pic_size, pic_size))
Vector b
using chrominance channels U,V to get the vector $ \vec{b} $, then we can solve the equation:
\[A \ \ \vec{x} = \vec{b}\]b_u = np.zeros(pic_size)
b_v = np.zeros(pic_size)
idx_colored = np.nonzero(map_colored.reshape(pic_size, order='F'))
pic_u_flat = pic_yuv[:,:,1].reshape(pic_size, order='F')
b_u[idx_colored] = pic_u_flat[idx_colored]
pic_v_flat = pic_yuv[:,:,2].reshape(pic_size, order='F')
b_v[idx_colored] = pic_v_flat[idx_colored]
Solve the optimazation problem
log.info('Optimizing Ax=b')
ansY = pic_yuv[:,:,0].reshape(pic_size, order='F')
ansU = scipy.sparse.linalg.spsolve(matA, b_u)
ansV = scipy.sparse.linalg.spsolve(matA, b_v)
pic_ans = yuv_channels_to_rgb(ansY,ansU,ansV)
log.info('Optimized Ax=b')
fig = plt.figure()
fig.add_subplot(1,2,1).set_title('Black & White')
imgplot = plt.imshow(pic_o_rgb)
fig.add_subplot(1,2,2).set_title('Colorized')
imgplot = plt.imshow(pic_ans)
plt.show();
2017-09-07 11:51:14,236 Optimizing Ax=b
2017-09-07 11:51:15,258 Optimized Ax=b
Iterative method of optimization : jacobi
We can also implement a simple iterative method to find an answer close to the solution. Here shows the result of 50, 100, 300 iterations, more iteration will generate better result.
\[A \ x = b \\ D \ x = b - R x \\ \to x^{k+1} = D^{-1} \big( b - R \ x \big)\]D: diagonal matrix of A.
R: A - D
# jacobi method for iterative optimization
def jacobi(A, b, x, n, verbose=False):
D = A.diagonal()
R = A - scipy.sparse.diags(D)
for i in range(n):
x = (b - R.dot(x)) / D
return x
ansY = pic_yuv[:,:,0].reshape(pic_size, order='F')
ansU050 = jacobi(matA, b_u, x=np.zeros(matA.shape[0]), n=50)
ansV050 = jacobi(matA, b_v, x=np.zeros(matA.shape[0]), n=50)
ansU100 = jacobi(matA, b_u, x=np.zeros(matA.shape[0]), n=100)
ansV100 = jacobi(matA, b_v, x=np.zeros(matA.shape[0]), n=100)
ansU300 = jacobi(matA, b_u, x=np.zeros(matA.shape[0]), n=300)
ansV300 = jacobi(matA, b_v, x=np.zeros(matA.shape[0]), n=300)
pic_ans050 = yuv_channels_to_rgb(ansY,ansU050,ansV050)
pic_ans100 = yuv_channels_to_rgb(ansY,ansU100,ansV100)
pic_ans300 = yuv_channels_to_rgb(ansY,ansU300,ansV300)
fig = plt.figure(figsize=(8, 6))
fig.add_subplot(2,2,1).set_title('Black & White')
imgplot = plt.imshow(pic_o_rgb)
fig.add_subplot(2,2,2).set_title('Loop 50')
imgplot = plt.imshow(pic_ans050)
fig.add_subplot(2,2,3).set_title('Loop 100')
imgplot = plt.imshow(pic_ans100)
fig.add_subplot(2,2,4).set_title('Loop 300')
imgplot = plt.imshow(pic_ans300)
plt.tight_layout()
plt.show();
Try other photos
path_pic = '/Users/larrysu/repos/mconda3/cv01/ex_2.bmp'
path_pic_marked = '/Users/larrysu/repos/mconda3/cv01/ex_2_marked.bmp'
# window width
wd_width = 1
pic_o_rgb = imread(path_pic, mode='RGB')
pic_o = pic_o_rgb.astype(float)/255
pic_m_rgb = imread(path_pic_marked)
pic_m = pic_m_rgb.astype(float)/255
# prepare matrix A
(pic_rows, pic_cols, _) = pic_o.shape
pic_size = pic_rows * pic_cols
channel_Y,_,_ = colorsys.rgb_to_yiq(pic_o[:,:,0],pic_o[:,:,1],pic_o[:,:,2])
_,channel_U,channel_V = colorsys.rgb_to_yiq(pic_m[:,:,0],pic_m[:,:,1],pic_m[:,:,2])
map_colored = (abs(channel_U) + abs(channel_V)) > 0.0001
pic_yuv = np.dstack((channel_Y, channel_U, channel_V))
weightData = []
num_pixel_bw = 0
# build the weight matrix for each window.
for c in range(pic_cols):
for r in range(pic_rows):
res = []
w = WindowNeighbor(wd_width, (r,c), pic_yuv)
if not map_colored[r,c]:
weights = affinity_a(w)
for e in weights:
weightData.append([w.center,(e[0],e[1]), e[2]])
weightData.append([w.center, (w.center[0],w.center[1]), 1.])
sp_idx_rc_data = [[to_seq(e[0][0], e[0][1], pic_rows), to_seq(e[1][0], e[1][1], pic_rows), e[2]] for e in weightData]
sp_idx_rc = np.array(sp_idx_rc_data, dtype=np.integer)[:,0:2]
sp_data = np.array(sp_idx_rc_data, dtype=np.float64)[:,2]
matA = scipy.sparse.csr_matrix((sp_data, (sp_idx_rc[:,0], sp_idx_rc[:,1])), shape=(pic_size, pic_size))
# prepare vector b
b_u = np.zeros(pic_size)
b_v = np.zeros(pic_size)
idx_colored = np.nonzero(map_colored.reshape(pic_size, order='F'))
pic_u_flat = pic_yuv[:,:,1].reshape(pic_size, order='F')
b_u[idx_colored] = pic_u_flat[idx_colored]
pic_v_flat = pic_yuv[:,:,2].reshape(pic_size, order='F')
b_v[idx_colored] = pic_v_flat[idx_colored]
# optimize the problem
ansY = pic_yuv[:,:,0].reshape(pic_size, order='F')
ansU = scipy.sparse.linalg.spsolve(matA, b_u)
ansV = scipy.sparse.linalg.spsolve(matA, b_v)
pic_ans = yuv_channels_to_rgb(ansY,ansU,ansV)
fig = plt.figure(figsize=(16, 13))
fig.add_subplot(2,2,1).set_title('Black & White')
imgplot = plt.imshow(pic_o_rgb)
fig.add_subplot(2,2,2).set_title('hints')
imgplot = plt.imshow(pic_m_rgb)
fig.add_subplot(2,2,3).set_title('Colorized')
imgplot = plt.imshow(pic_ans)
plt.show();