Compressed Sensing for MRI#
Import libraries#
import numpy as np
import cv2
from matplotlib import pyplot as plt
import scipy.signal as signal
from scipy.signal import convolve2d
import scipy.fft as fft
import urllib.request
from skimage.metrics import peak_signal_noise_ratio as PSNR
import time
import nibabel as nib
np.random.seed(30)
Import image#
file_path = "data/0219191_mystudy-0219-1114_anat_ses-01_T1w_20190219111436_5.nii.gz"
# file_path="/content/data/dicom_examples/nii/0219191_mystudy-0219-1114_anat_ses-01_scout_20190219111436_2_i00001.nii.gz"
t1_img = nib.load(file_path)
t1_data = t1_img.get_fdata()
x_slice = t1_data[9, :, :]
y_slice = t1_data[:, 19, :]
z_slice = t1_data[:, :, 2]
import matplotlib.pyplot as plt
%matplotlib inline
slices = [x_slice, y_slice, z_slice]
fig, axes = plt.subplots(1, len(slices))
for i, slice in enumerate(slices):
axes[i].imshow(slice, cmap="gray", origin="lower")
data:image/s3,"s3://crabby-images/2bcfc/2bcfc81341dde4895cabff729b1f1caf50511a86" alt="../_images/1651e2fe46a86abaaad177a6fd71170600a68f571803a305166d1c8a37958dba.png"
import numpy as np
import matplotlib.pyplot as plt
# Assuming x_slice, y_slice, and z_slice are the input images
slices = [x_slice, y_slice, z_slice]
# Find the maximum dimension among all images
max_dim = max(
[slice.shape[0] for slice in slices] + [slice.shape[1] for slice in slices]
)
# Create a new list to store the padded images
padded_slices = []
# Pad zeros to make each image square
for slice in slices:
height_diff = max_dim - slice.shape[0]
width_diff = max_dim - slice.shape[1]
pad_top = height_diff // 2
pad_bottom = height_diff - pad_top
pad_left = width_diff // 2
pad_right = width_diff - pad_left
padded_slice = np.pad(
slice, ((pad_top, pad_bottom), (pad_left, pad_right)), mode="constant"
)
padded_slices.append(padded_slice)
# Plot the padded images
fig, axes = plt.subplots(1, len(padded_slices))
for i, slice in enumerate(padded_slices):
axes[i].imshow(slice, cmap="gray", origin="lower")
for ax in axes.ravel():
ax.axis("off")
data:image/s3,"s3://crabby-images/ee368/ee368d90b9df67bcb7355610c817acce1c37ffaa" alt="../_images/b15f1a1b93f85baf962d0e862c1666c21fbac79a99d59148a8910777b796e2fd.png"
Define conv and fft functions#
def to_fourier_domain(x):
return fft.fftshift(fft.fft2(fft.ifftshift(x)))
def to_image_domain(x):
return fft.ifftshift(fft.ifft2(fft.fftshift(x)))
# Define some of the operators that we need...
def conv2d_fft(x, h):
p0 = x.shape[0] - h.shape[0]
p1 = x.shape[1] - h.shape[1]
h_pad = np.pad(h, ((p0 // 2, p0 // 2), (p1 // 2, p1 // 2)))
Fh = to_fourier_domain(h_pad)
Fx = to_fourier_domain(x)
return to_image_domain(Fx * Fh)
def conv2dT_fft(x, h):
p0 = x.shape[0] - h.shape[0]
p1 = x.shape[1] - h.shape[1]
h_pad = np.pad(h, ((p0 // 2, p0 // 2), (p1 // 2, p1 // 2)))
Fh = to_fourier_domain(h_pad)
Fx = to_fourier_domain(x)
return to_image_domain(Fx * np.conj(Fh))
Create the binary mask#
According to FastMRI paper
def generate_random_mask(img_shape, center_percentage=8):
mask = np.zeros(img_shape)
center_columns = int(img_shape[1] * center_percentage / 100)
center_start = int(img_shape[1] / 2) - int(center_columns / 2)
center_end = center_start + center_columns
for col in range(center_start, center_end):
mask[:, col] = 1
num_random_cols = int(0.25 * img_shape[1])
random_cols = np.random.choice(
[i for i in range(img_shape[1]) if i < center_start or i >= center_end],
num_random_cols,
replace=False,
)
for col in random_cols:
mask[:, col] = 1
return mask.astype(int)
rmask = generate_random_mask(
x_slice.shape
) # generate a 100x100 mask with 30% center columns
# plot mask and mask_fft together
fig, axs = plt.subplots(1, 1, figsize=(10, 5))
axs.imshow(rmask, cmap="gray")
axs.set_title("Mask")
plt.show()
data:image/s3,"s3://crabby-images/999fe/999fee707171b27e970752e3b993b3e02eb08696" alt="../_images/b630439180f9953ec5e89ff8497f42cb75b8777480ae9221ed1c103f267c751a.png"
# assuming img_shape to be square, shape is of type (x,x) -> (512,512)
def generate_equidistant_mask(img_shape, acceleration_factor=4, center_percentage=25):
mask = np.zeros(img_shape, dtype=np.float32)
mask[::acceleration_factor] = 1.0
# Set center_percentage% of the center lines to all ones
fraction = 100 // center_percentage
center_lines = img_shape[0] // fraction
start = (img_shape[0] - center_lines) // 2
end = start + center_lines
mask[start:end] = 1.0
return mask.T
eq_mask = generate_equidistant_mask(
x_slice.shape, acceleration_factor=4, center_percentage=8
)
# plot mask and mask_fft together
fig, axs = plt.subplots(1, 1, figsize=(10, 5))
axs.imshow(eq_mask, cmap="gray")
axs.set_title("Mask")
plt.show()
data:image/s3,"s3://crabby-images/201b3/201b32ad24bed1f78cb9d837304b04198ec064ac" alt="../_images/ea4e6ad0dfc2127c4fd5cc004f306a7e0c36dd60901bb12ccf0f11d7d046814a.png"
def get_log_spectrum(fft_data):
return 20 * np.log(np.abs(fft_data) + 1)
# Input image for processing
x = padded_slices[2]
# Equidistant Mask
eq_mask = generate_equidistant_mask(
x.T.shape, acceleration_factor=4, center_percentage=8
)
# Random Mask
rmask = generate_random_mask(x.shape)
# Computing ploting parameters
def param(x, mask):
spectrum_image = to_fourier_domain(x.T)
masked_spectrum = spectrum_image * mask
zero_filled_ifft = to_image_domain(masked_spectrum * mask)
return spectrum_image, masked_spectrum, zero_filled_ifft
# Computing parameters for random mask and equidistant mask
spectrum_image, masked_spectrum, zero_filled_ifft = param(x, eq_mask)
spectrum_image_r, masked_spectrum_r, zero_filled_ifft_r = param(x, rmask)
fig, axs = plt.subplots(2, 5, figsize=(15, 6))
axs[0, 0].imshow(x.T, cmap="gray")
axs[0, 0].set_title("Original MRI")
axs[0, 1].imshow(eq_mask, cmap="gray")
axs[0, 1].set_title("Equidistant 4 fold Mask")
axs[0, 2].imshow(get_log_spectrum(spectrum_image), cmap="gray")
axs[0, 2].set_title("log spectrumFx")
axs[0, 3].imshow(get_log_spectrum(masked_spectrum), cmap="gray")
axs[0, 3].set_title("log spectrum MFx")
axs[0, 4].imshow(np.abs(zero_filled_ifft), cmap="gray")
axs[0, 4].set_title("zero filled ifft")
axs[1, 0].imshow(x.T, cmap="gray")
axs[1, 0].set_title("Original MRI")
axs[1, 1].imshow(rmask, cmap="gray")
axs[1, 1].set_title("Random 4-fold Mask")
axs[1, 2].imshow(get_log_spectrum(spectrum_image_r), cmap="gray")
axs[1, 2].set_title("log spectrumFx")
axs[1, 3].imshow(get_log_spectrum(masked_spectrum_r), cmap="gray")
axs[1, 3].set_title("log spectrum MFx")
axs[1, 4].imshow(np.abs(zero_filled_ifft_r), cmap="gray")
axs[1, 4].set_title("zero filled ifft")
for ax in axs.ravel():
ax.axis("off")
plt.show()
data:image/s3,"s3://crabby-images/d47eb/d47ebbb7c0cb19246062b2200094c430eb4211e9" alt="../_images/ee6a77467e460beb8b6ebcdc7056c55c75ef804a381ccd178de6cffb1385eaab.png"
Gradiant operator#
# define gradient operators
"""
I am testing it with y1 for now
"""
dh = np.array([[1, -1], [0, 0]]) # horizontal gradient filter
dv = np.array([[1, 0], [-1, 0]]) # vertical gradient filter
Dh = lambda x: conv2d_fft(x, dh)
Dv = lambda x: conv2d_fft(x, dv)
DhT = lambda x: conv2dT_fft(x, dh)
DvT = lambda x: conv2dT_fft(x, dv)
# plot the image x and the gradient images Dh x and Dv x
fig = plt.figure(figsize=(15, 15))
plt.subplot(131)
plt.imshow(x.T, cmap="gray")
plt.title("image x")
plt.subplot(132)
plt.imshow(np.abs(Dh(x.T)), cmap="gray")
plt.title(r"$|D_hx|$")
plt.subplot(133)
plt.imshow(np.abs(Dv(x.T)), cmap="gray")
plt.title(r"$|D_vx|$")
plt.tight_layout()
plt.show()
data:image/s3,"s3://crabby-images/3b0f0/3b0f02866b198022fcdb20fa9354fa532f4c012a" alt="../_images/e3c6f587185ed44713024a389fa50271b7ea45a5c8a85d83fd1b94fb3ef783fe.png"
TV Primal Dual#
TV Compress Sensing MRI using Primal Dual
def TV_MRI(MFx, M, lamb=2, maxiter=50, tol=1e-4):
"""
TV-Compress Senssing MRI solver to solve
minimize 0.5 |z-Hx|_2^2 + lambda|Dx|_1
z : reconstructed img in each update
x : original image
Hx: masked image
"""
start = time.time()
# define the soft-thresholding function
soft_thresh = lambda v, t: np.maximum(np.abs(v) - t, 0.0) * np.sign(v)
# set step-sizes at maximum: τσL² < 1
# note: PDS seems sensitive to these (given finite iterations at least...)
L = np.sqrt(8) # Spectral norm of D
tao = 0.99 / L
sigma = 0.99 / L
# Defining Prox functions
"""
z^k+1 = prox_tg(z^k-tD^Ty^k)
y^k+1 = prox_sigmaf*(y^k+sigmaDx^k+1)
"""
# Proximal Gradient Descent on x (primal)
prox_D = lambda v, z: to_image_domain(
(to_fourier_domain(v) + tao * z) / (1 + tao * M)
)
# Proximal Gradient Ascent on z (dual)
# prox_A = lambda v: (v - sigma*soft_thresh(v/sigma, lamb/sigma))
prox_A = lambda v: np.clip(v, -lamb, lamb)
# -----------------------------
# initilize iteration variables
z_hat = np.zeros_like(MFx)
yh_hat = np.zeros_like(MFx)
yv_hat = np.zeros_like(MFx)
# For computing error
J = np.zeros(maxiter)
# Iterations
k = 0
while k < maxiter:
# Update x - ProxD
z_old = z_hat
z_hat = prox_D(z_hat - tao * (DhT(yh_hat) + DvT(yv_hat)), MFx)
# Update y - ProxA
yh_hat = prox_A(yh_hat + sigma * (Dh(2 * z_hat - z_old)))
yv_hat = prox_A(yv_hat + sigma * (Dv(2 * z_hat - z_old)))
# compute the convergence
# dual_h = Dh(z_hat) - yh_hat
# dual_v = Dv(z_hat) - yv_hat
# J[k] = (dual_h**2).sum()+(dual_v**2).sum()
J[k] = np.abs(z_hat - z_old).sum()
if J[k] < tol:
break
k = k + 1
end = time.time()
return z_hat, J, end - start
x_t, j, duration1 = TV_MRI(masked_spectrum, eq_mask, lamb=5, maxiter=500, tol=1e-2)
x_t_ifft = to_image_domain(x_t)
fig = plt.figure()
fig.set_size_inches(5, 7.5)
ax = fig.add_subplot(321)
ax.imshow(x.T, cmap="gray")
plt.title("original image")
ax2 = fig.add_subplot(322)
ax2.imshow(np.abs(zero_filled_ifft), cmap="gray")
plt.title("Corrupted image")
ax3 = fig.add_subplot(323)
ax3.imshow(np.abs(x_t), cmap="gray")
plt.title("recovered image")
ax4 = fig.add_subplot(324)
ax4.semilogy(range(len(j)), j, "b-", lw=2)
plt.title("Residual")
plt.xlabel("iteration (k)")
plt.tight_layout()
plt.show()
print(f"Time taken = {duration1}")
data:image/s3,"s3://crabby-images/cb6ad/cb6addff6b39274e145b19f8640ee7093a5d64da" alt="../_images/bf874c580efab43ed5f882400e6a29460521afcacd32585f56a9149eb3c2d7f0.png"
Time taken = 7.693760633468628
PSNR(x.T, np.abs(zero_filled_ifft), data_range=np.max(np.abs(x)))
28.910346296010466
PSNR(x.T, np.abs(x_t), data_range=np.max(np.abs(x)))
31.11096519230359
x_tr, jr, duration1r = TV_MRI(masked_spectrum_r, rmask, lamb=5, maxiter=500, tol=1e-2)
fig = plt.figure()
fig.set_size_inches(5, 7.5)
ax = fig.add_subplot(321)
ax.imshow(x.T, cmap="gray")
plt.title("original image")
ax2 = fig.add_subplot(322)
ax2.imshow(np.abs(zero_filled_ifft_r), cmap="gray")
plt.title("Corrupted image")
ax3 = fig.add_subplot(323)
ax3.imshow(np.abs(x_tr), cmap="gray")
plt.title("recovered image")
ax4 = fig.add_subplot(324)
ax4.semilogy(range(len(jr)), jr, "b-", lw=2)
plt.title("Residual")
plt.xlabel("iteration (k)")
plt.tight_layout()
plt.show()
print(f"Time taken = {duration1r}")
data:image/s3,"s3://crabby-images/50a23/50a2376a531d567c6ff0250b82868b3b096f29d1" alt="../_images/cda146d2bf397fca286d2979e31ed1d6d68feffee9cb2ddb7c0a58ef4a8c0747.png"
Time taken = 7.645873546600342
PSNR(x.T, np.abs(zero_filled_ifft_r), data_range=np.max(np.abs(x)))
28.999801885609365
PSNR(x.T, np.abs(x_tr), data_range=np.max(np.abs(x)))
33.28850707391314
fig, axs = plt.subplots(1, 1, figsize=(15, 6))
axs.semilogy(range(len(jr)), jr, "b-", lw=2, label="Random Mask")
axs.semilogy(range(len(j)), j, "r+", lw=2, label="Equidistant Mask")
plt.title("Residual")
plt.legend()
<matplotlib.legend.Legend at 0x7f3062d97850>
data:image/s3,"s3://crabby-images/98ef5/98ef5830d8cfa8d280e4d5abbd9907a121a86317" alt="../_images/44f0ad5ecc2870c92f4b2702d66101ee6553083e2ae133a72027ac25b3787639.png"