import numpy as np
seed=42
np.random.seed(seed)
from matplotlib import pyplot as plt
from skimage.filters import threshold_mean
[docs]def reshape(data):
"""Reshape flattened 1D array to 2D image
:param data: flattened image array
:type data: np.ndarray of shape (n,)
:return: 2D image array
:rtype: np.ndarray of shape (sqrt(n), sqrt(n))
"""
dim = int(np.sqrt(len(data)))
data = np.reshape(data, (dim, dim))
return data
[docs]def plot(data, test, predicted, figsize=(5, 5), savefig=False):
"""Plot training images, corrupted inputs and predictions side by side
:param data: training images; atleast 3
:type data: list of np.ndarray with len >=3
:param test: corrupted (test) versions of images; atleast 3
:type test: list of np.ndarray with len >=3
:param predicted: predictions for images in test; atleast 3
:type predicted: list of np.ndarray with len >=3
:param figsize: size for matplotlib figure; defaults to (5, 5)
:type figsize: tuple
:param savefig: boolean value to save matplotlib figure; defaults to False
:type savefig: bool
:return: None
"""
data = [reshape(d) for d in data]
test = [reshape(d) for d in test]
predicted = [reshape(d) for d in predicted]
fig, axarr = plt.subplots(len(data), 3, figsize=figsize)
for i in range(len(test)):
if i==0:
axarr[i, 0].set_title('Training Image')
axarr[i, 1].set_title("Noisy Input Image")
axarr[i, 2].set_title('Prediction')
axarr[i, 0].imshow(data[i])
axarr[i, 0].axis('off')
axarr[i, 1].imshow(test[i])
axarr[i, 1].axis('off')
axarr[i, 2].imshow(predicted[i])
axarr[i, 2].axis('off')
plt.tight_layout()
if savefig:
plt.savefig("result.png")
plt.show()
[docs]def preprocessing(img):
"""Perform thresholding to convert grayscale image to binary & flatten image
:param img: grayscale image array
:type img: np.ndarray of shape (n, n)
:return: flattened binary image array
:rtype: np.ndarray of shape (n^2,)
"""
w, h = img.shape
thresh = threshold_mean(img)
binary = img > thresh
shift = 2*(binary*1)-1
flatten = np.reshape(shift, (w*h))
return flatten