Last active
December 7, 2018 21:25
-
-
Save alexlee-gk/cbc9bfa6e5be51b53c622684cec0a3f3 to your computer and use it in GitHub Desktop.
SSIM TensorFlow implementation that matches scikit-image's compare_ssim
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
from tensorflow.python.util import nest | |
def _with_flat_batch(flat_batch_fn): | |
def fn(x, *args, **kwargs): | |
shape = tf.shape(x) | |
flat_batch_x = tf.reshape(x, tf.concat([[-1], shape[-3:]], axis=0)) | |
flat_batch_r = flat_batch_fn(flat_batch_x, *args, **kwargs) | |
r = nest.map_structure(lambda x: tf.reshape(x, tf.concat([shape[:-3], x.shape[1:]], axis=0)), | |
flat_batch_r) | |
return r | |
return fn | |
def structural_similarity(X, Y, K1=0.01, K2=0.03, win_size=7, | |
data_range=1.0, use_sample_covariance=True): | |
""" | |
Structural SIMilarity (SSIM) index between two images | |
Args: | |
X: A tensor of shape `[..., in_height, in_width, in_channels]`. | |
Y: A tensor of shape `[..., in_height, in_width, in_channels]`. | |
Returns: | |
The SSIM between images X and Y. | |
Reference: | |
https://github.com/scikit-image/scikit-image/blob/master/skimage/measure/_structural_similarity.py | |
Broadcasting is supported. | |
""" | |
X = tf.convert_to_tensor(X) | |
Y = tf.convert_to_tensor(Y) | |
ndim = 2 # number of spatial dimensions | |
nch = tf.shape(X)[-1] | |
filter_func = _with_flat_batch(tf.nn.depthwise_conv2d) | |
kernel = tf.cast(tf.fill([win_size, win_size, nch, 1], 1 / win_size ** 2), X.dtype) | |
filter_args = {'filter': kernel, 'strides': [1] * 4, 'padding': 'VALID'} | |
NP = win_size ** ndim | |
# filter has already normalized by NP | |
if use_sample_covariance: | |
cov_norm = NP / (NP - 1) # sample covariance | |
else: | |
cov_norm = 1.0 # population covariance to match Wang et. al. 2004 | |
# compute means | |
ux = filter_func(X, **filter_args) | |
uy = filter_func(Y, **filter_args) | |
# compute variances and covariances | |
uxx = filter_func(X * X, **filter_args) | |
uyy = filter_func(Y * Y, **filter_args) | |
uxy = filter_func(X * Y, **filter_args) | |
vx = cov_norm * (uxx - ux * ux) | |
vy = cov_norm * (uyy - uy * uy) | |
vxy = cov_norm * (uxy - ux * uy) | |
R = data_range | |
C1 = (K1 * R) ** 2 | |
C2 = (K2 * R) ** 2 | |
A1, A2, B1, B2 = ((2 * ux * uy + C1, | |
2 * vxy + C2, | |
ux ** 2 + uy ** 2 + C1, | |
vx + vy + C2)) | |
D = B1 * B2 | |
S = (A1 * A2) / D | |
ssim = tf.reduce_mean(S, axis=[-3, -2, -1]) | |
return ssim | |
def main(): | |
import numpy as np | |
from skimage.measure import compare_ssim | |
batch_size = 4 | |
image_shape = (64, 64, 3) | |
images0 = np.random.random((batch_size,) + image_shape) | |
images1 = np.random.random((batch_size,) + image_shape) | |
sess = tf.Session() | |
ssim_tf = tf.reduce_mean(structural_similarity(images0, images1)) | |
ssim_tf = sess.run(ssim_tf) | |
ssim_skimage = np.mean([compare_ssim(image0, image1, data_range=1.0, multichannel=True) | |
for image0, image1 in zip(images0, images1)]) | |
print(ssim_tf, ssim_skimage) | |
if __name__ == '__main__': | |
main() |
I agree with brunopop. ssim_tf is always 1. could you please explain the reason?
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello, I tried running your script as-is and the numbers don't match: ssim_tf is always 1.0. Are you sure you are flattening the batch correctly in _with_flat_batch?