{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "view-in-github" }, "source": [ "# Compressed Sensing for MRI\n", "\n", "\"Open" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Import libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L56jXJk2NckF", "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "import cv2\n", "from matplotlib import pyplot as plt\n", "import scipy.signal as signal\n", "from scipy.signal import convolve2d\n", "import scipy.fft as fft\n", "import urllib.request\n", "from skimage.metrics import peak_signal_noise_ratio as PSNR\n", "import time\n", "import nibabel as nib\n", "\n", "np.random.seed(30)" ] }, { "cell_type": "markdown", "metadata": { "id": "EjhU0Iub2YAF" }, "source": [ "## Import image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "2E2rA4u3zP4y", "tags": [] }, "outputs": [], "source": [ "file_path = \"data/0219191_mystudy-0219-1114_anat_ses-01_T1w_20190219111436_5.nii.gz\"\n", "# file_path=\"/content/data/dicom_examples/nii/0219191_mystudy-0219-1114_anat_ses-01_scout_20190219111436_2_i00001.nii.gz\"\n", "\n", "t1_img = nib.load(file_path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xuKAVYVHzXl0", "tags": [] }, "outputs": [], "source": [ "t1_data = t1_img.get_fdata()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wco86fxqzZ8w", "tags": [] }, "outputs": [], "source": [ "x_slice = t1_data[9, :, :]\n", "y_slice = t1_data[:, 19, :]\n", "z_slice = t1_data[:, :, 2]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "id": "ISLUZAkxzcqO", "outputId": "59a6b0da-1131-4fa1-cf56-b30c29eec327", "tags": [] }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "%matplotlib inline\n", "\n", "slices = [x_slice, y_slice, z_slice]\n", "\n", "fig, axes = plt.subplots(1, len(slices))\n", "for i, slice in enumerate(slices):\n", " axes[i].imshow(slice, cmap=\"gray\", origin=\"lower\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 182 }, "id": "vnMMImDjCYlF", "outputId": "20ef1540-3372-43ed-d115-df820c29675e", "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "# Assuming x_slice, y_slice, and z_slice are the input images\n", "\n", "slices = [x_slice, y_slice, z_slice]\n", "\n", "# Find the maximum dimension among all images\n", "max_dim = max(\n", " [slice.shape[0] for slice in slices] + [slice.shape[1] for slice in slices]\n", ")\n", "\n", "# Create a new list to store the padded images\n", "padded_slices = []\n", "\n", "# Pad zeros to make each image square\n", "for slice in slices:\n", " height_diff = max_dim - slice.shape[0]\n", " width_diff = max_dim - slice.shape[1]\n", " pad_top = height_diff // 2\n", " pad_bottom = height_diff - pad_top\n", " pad_left = width_diff // 2\n", " pad_right = width_diff - pad_left\n", " padded_slice = np.pad(\n", " slice, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=\"constant\"\n", " )\n", " padded_slices.append(padded_slice)\n", "\n", "# Plot the padded images\n", "fig, axes = plt.subplots(1, len(padded_slices))\n", "for i, slice in enumerate(padded_slices):\n", " axes[i].imshow(slice, cmap=\"gray\", origin=\"lower\")\n", "for ax in axes.ravel():\n", " ax.axis(\"off\")" ] }, { "cell_type": "markdown", "metadata": { "id": "uHwSwKqc2fcC" }, "source": [ "## Define conv and fft functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cGin1TEJ3Rn3", "tags": [] }, "outputs": [], "source": [ "def to_fourier_domain(x):\n", " return fft.fftshift(fft.fft2(fft.ifftshift(x)))\n", "\n", "\n", "def to_image_domain(x):\n", " return fft.ifftshift(fft.ifft2(fft.fftshift(x)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cYohlgEgfHt7", "tags": [] }, "outputs": [], "source": [ "# Define some of the operators that we need...\n", "def conv2d_fft(x, h):\n", " p0 = x.shape[0] - h.shape[0]\n", " p1 = x.shape[1] - h.shape[1]\n", " h_pad = np.pad(h, ((p0 // 2, p0 // 2), (p1 // 2, p1 // 2)))\n", " Fh = to_fourier_domain(h_pad)\n", " Fx = to_fourier_domain(x)\n", " return to_image_domain(Fx * Fh)\n", "\n", "\n", "def conv2dT_fft(x, h):\n", " p0 = x.shape[0] - h.shape[0]\n", " p1 = x.shape[1] - h.shape[1]\n", " h_pad = np.pad(h, ((p0 // 2, p0 // 2), (p1 // 2, p1 // 2)))\n", " Fh = to_fourier_domain(h_pad)\n", " Fx = to_fourier_domain(x)\n", " return to_image_domain(Fx * np.conj(Fh))" ] }, { "cell_type": "markdown", "metadata": { "id": "vjOZUuyRwi9R" }, "source": [ "## Create the binary mask\n", "According to FastMRI paper" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OWIB1JOglWr8", "tags": [] }, "outputs": [], "source": [ "def generate_random_mask(img_shape, center_percentage=8):\n", " mask = np.zeros(img_shape)\n", " center_columns = int(img_shape[1] * center_percentage / 100)\n", " center_start = int(img_shape[1] / 2) - int(center_columns / 2)\n", " center_end = center_start + center_columns\n", "\n", " for col in range(center_start, center_end):\n", " mask[:, col] = 1\n", "\n", " num_random_cols = int(0.25 * img_shape[1])\n", " random_cols = np.random.choice(\n", " [i for i in range(img_shape[1]) if i < center_start or i >= center_end],\n", " num_random_cols,\n", " replace=False,\n", " )\n", "\n", " for col in random_cols:\n", " mask[:, col] = 1\n", "\n", " return mask.astype(int)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 468 }, "id": "q7XV56QPxSgw", "outputId": "0cf395fb-03e1-43e6-dd8e-aa6c7d1c915f", "tags": [] }, "outputs": [], "source": [ "rmask = generate_random_mask(\n", " x_slice.shape\n", ") # generate a 100x100 mask with 30% center columns\n", "# plot mask and mask_fft together\n", "fig, axs = plt.subplots(1, 1, figsize=(10, 5))\n", "axs.imshow(rmask, cmap=\"gray\")\n", "axs.set_title(\"Mask\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "K7Pg_gnmmcZC", "tags": [] }, "outputs": [], "source": [ "# assuming img_shape to be square, shape is of type (x,x) -> (512,512)\n", "def generate_equidistant_mask(img_shape, acceleration_factor=4, center_percentage=25):\n", " mask = np.zeros(img_shape, dtype=np.float32)\n", " mask[::acceleration_factor] = 1.0\n", "\n", " # Set center_percentage% of the center lines to all ones\n", " fraction = 100 // center_percentage\n", " center_lines = img_shape[0] // fraction\n", " start = (img_shape[0] - center_lines) // 2\n", " end = start + center_lines\n", " mask[start:end] = 1.0\n", " return mask.T" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 468 }, "id": "poO-pPmZh0Uu", "outputId": "e906965f-d19c-49f7-ab00-90195dc4ebd3", "tags": [] }, "outputs": [], "source": [ "eq_mask = generate_equidistant_mask(\n", " x_slice.shape, acceleration_factor=4, center_percentage=8\n", ")\n", "# plot mask and mask_fft together\n", "fig, axs = plt.subplots(1, 1, figsize=(10, 5))\n", "axs.imshow(eq_mask, cmap=\"gray\")\n", "axs.set_title(\"Mask\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "R5QR8YaVztl-", "tags": [] }, "outputs": [], "source": [ "def get_log_spectrum(fft_data):\n", " return 20 * np.log(np.abs(fft_data) + 1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aOUURjP_z-uY", "tags": [] }, "outputs": [], "source": [ "# Input image for processing\n", "x = padded_slices[2]\n", "# Equidistant Mask\n", "eq_mask = generate_equidistant_mask(\n", " x.T.shape, acceleration_factor=4, center_percentage=8\n", ")\n", "# Random Mask\n", "rmask = generate_random_mask(x.shape)\n", "\n", "# Computing ploting parameters\n", "\n", "\n", "def param(x, mask):\n", " spectrum_image = to_fourier_domain(x.T)\n", " masked_spectrum = spectrum_image * mask\n", " zero_filled_ifft = to_image_domain(masked_spectrum * mask)\n", " return spectrum_image, masked_spectrum, zero_filled_ifft\n", "\n", "\n", "# Computing parameters for random mask and equidistant mask\n", "spectrum_image, masked_spectrum, zero_filled_ifft = param(x, eq_mask)\n", "spectrum_image_r, masked_spectrum_r, zero_filled_ifft_r = param(x, rmask)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 367 }, "id": "9oS7cwMOpGQE", "outputId": "a9509127-f61f-4b9b-c604-c0323f4bc427", "tags": [] }, "outputs": [], "source": [ "fig, axs = plt.subplots(2, 5, figsize=(15, 6))\n", "axs[0, 0].imshow(x.T, cmap=\"gray\")\n", "axs[0, 0].set_title(\"Original MRI\")\n", "axs[0, 1].imshow(eq_mask, cmap=\"gray\")\n", "axs[0, 1].set_title(\"Equidistant 4 fold Mask\")\n", "axs[0, 2].imshow(get_log_spectrum(spectrum_image), cmap=\"gray\")\n", "axs[0, 2].set_title(\"log spectrumFx\")\n", "axs[0, 3].imshow(get_log_spectrum(masked_spectrum), cmap=\"gray\")\n", "axs[0, 3].set_title(\"log spectrum MFx\")\n", "axs[0, 4].imshow(np.abs(zero_filled_ifft), cmap=\"gray\")\n", "axs[0, 4].set_title(\"zero filled ifft\")\n", "axs[1, 0].imshow(x.T, cmap=\"gray\")\n", "axs[1, 0].set_title(\"Original MRI\")\n", "axs[1, 1].imshow(rmask, cmap=\"gray\")\n", "axs[1, 1].set_title(\"Random 4-fold Mask\")\n", "axs[1, 2].imshow(get_log_spectrum(spectrum_image_r), cmap=\"gray\")\n", "axs[1, 2].set_title(\"log spectrumFx\")\n", "axs[1, 3].imshow(get_log_spectrum(masked_spectrum_r), cmap=\"gray\")\n", "axs[1, 3].set_title(\"log spectrum MFx\")\n", "axs[1, 4].imshow(np.abs(zero_filled_ifft_r), cmap=\"gray\")\n", "axs[1, 4].set_title(\"zero filled ifft\")\n", "for ax in axs.ravel():\n", " ax.axis(\"off\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "_YxvRnTZ231n" }, "source": [ "## Gradiant operator" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 302 }, "id": "4eq9mIIu2M9D", "outputId": "654ae10b-e446-44ab-da72-7607dd78b324", "tags": [] }, "outputs": [], "source": [ "# define gradient operators\n", "\"\"\"\n", "I am testing it with y1 for now\n", "\"\"\"\n", "dh = np.array([[1, -1], [0, 0]]) # horizontal gradient filter\n", "dv = np.array([[1, 0], [-1, 0]]) # vertical gradient filter\n", "\n", "Dh = lambda x: conv2d_fft(x, dh)\n", "Dv = lambda x: conv2d_fft(x, dv)\n", "\n", "DhT = lambda x: conv2dT_fft(x, dh)\n", "DvT = lambda x: conv2dT_fft(x, dv)\n", "\n", "# plot the image x and the gradient images Dh x and Dv x\n", "fig = plt.figure(figsize=(15, 15))\n", "plt.subplot(131)\n", "plt.imshow(x.T, cmap=\"gray\")\n", "plt.title(\"image x\")\n", "plt.subplot(132)\n", "plt.imshow(np.abs(Dh(x.T)), cmap=\"gray\")\n", "plt.title(r\"$|D_hx|$\")\n", "plt.subplot(133)\n", "plt.imshow(np.abs(Dv(x.T)), cmap=\"gray\")\n", "plt.title(r\"$|D_vx|$\")\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "sn78eU3X9y4M" }, "source": [ "## TV Primal Dual\n", "TV Compress Sensing MRI using Primal Dual" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3VStzu5ZL8a6", "tags": [] }, "outputs": [], "source": [ "def TV_MRI(MFx, M, lamb=2, maxiter=50, tol=1e-4):\n", " \"\"\"\n", " TV-Compress Senssing MRI solver to solve\n", " minimize 0.5 |z-Hx|_2^2 + lambda|Dx|_1\n", " z : reconstructed img in each update\n", " x : original image\n", " Hx: masked image\n", " \"\"\"\n", " start = time.time()\n", " # define the soft-thresholding function\n", " soft_thresh = lambda v, t: np.maximum(np.abs(v) - t, 0.0) * np.sign(v)\n", "\n", " # set step-sizes at maximum: τσL² < 1\n", " # note: PDS seems sensitive to these (given finite iterations at least...)\n", " L = np.sqrt(8) # Spectral norm of D\n", " tao = 0.99 / L\n", " sigma = 0.99 / L\n", "\n", " # Defining Prox functions\n", " \"\"\"\n", " z^k+1 = prox_tg(z^k-tD^Ty^k) \n", " y^k+1 = prox_sigmaf*(y^k+sigmaDx^k+1)\n", " \"\"\"\n", " # Proximal Gradient Descent on x (primal)\n", " prox_D = lambda v, z: to_image_domain(\n", " (to_fourier_domain(v) + tao * z) / (1 + tao * M)\n", " )\n", " # Proximal Gradient Ascent on z (dual)\n", " # prox_A = lambda v: (v - sigma*soft_thresh(v/sigma, lamb/sigma))\n", " prox_A = lambda v: np.clip(v, -lamb, lamb)\n", "\n", " # -----------------------------\n", " # initilize iteration variables\n", " z_hat = np.zeros_like(MFx)\n", " yh_hat = np.zeros_like(MFx)\n", " yv_hat = np.zeros_like(MFx)\n", " # For computing error\n", " J = np.zeros(maxiter)\n", "\n", " # Iterations\n", " k = 0\n", " while k < maxiter:\n", " # Update x - ProxD\n", " z_old = z_hat\n", " z_hat = prox_D(z_hat - tao * (DhT(yh_hat) + DvT(yv_hat)), MFx)\n", " # Update y - ProxA\n", " yh_hat = prox_A(yh_hat + sigma * (Dh(2 * z_hat - z_old)))\n", " yv_hat = prox_A(yv_hat + sigma * (Dv(2 * z_hat - z_old)))\n", " # compute the convergence\n", "\n", " # dual_h = Dh(z_hat) - yh_hat\n", " # dual_v = Dv(z_hat) - yv_hat\n", " # J[k] = (dual_h**2).sum()+(dual_v**2).sum()\n", " J[k] = np.abs(z_hat - z_old).sum()\n", " if J[k] < tol:\n", " break\n", " k = k + 1\n", "\n", " end = time.time()\n", " return z_hat, J, end - start" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 564 }, "id": "PlzBQCI48jyU", "outputId": "462e992c-2c3e-4fcb-816f-5c1d21c6995c", "tags": [] }, "outputs": [], "source": [ "x_t, j, duration1 = TV_MRI(masked_spectrum, eq_mask, lamb=5, maxiter=500, tol=1e-2)\n", "x_t_ifft = to_image_domain(x_t)\n", "fig = plt.figure()\n", "fig.set_size_inches(5, 7.5)\n", "ax = fig.add_subplot(321)\n", "ax.imshow(x.T, cmap=\"gray\")\n", "plt.title(\"original image\")\n", "ax2 = fig.add_subplot(322)\n", "ax2.imshow(np.abs(zero_filled_ifft), cmap=\"gray\")\n", "plt.title(\"Corrupted image\")\n", "ax3 = fig.add_subplot(323)\n", "ax3.imshow(np.abs(x_t), cmap=\"gray\")\n", "plt.title(\"recovered image\")\n", "ax4 = fig.add_subplot(324)\n", "ax4.semilogy(range(len(j)), j, \"b-\", lw=2)\n", "plt.title(\"Residual\")\n", "plt.xlabel(\"iteration (k)\")\n", "plt.tight_layout()\n", "plt.show()\n", "print(f\"Time taken = {duration1}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "7OfGF0Xt70ET", "outputId": "dcacf6d7-ed26-4155-ad56-c88239e19ffc", "tags": [] }, "outputs": [], "source": [ "PSNR(x.T, np.abs(zero_filled_ifft), data_range=np.max(np.abs(x)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fOX39_YP5BQ2", "outputId": "48724a24-5c70-4c8b-f2a9-53233515d254", "tags": [] }, "outputs": [], "source": [ "PSNR(x.T, np.abs(x_t), data_range=np.max(np.abs(x)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 564 }, "id": "gSIZUDkyjzP6", "outputId": "5178fbb9-432b-4965-f523-60c94abc7406", "tags": [] }, "outputs": [], "source": [ "x_tr, jr, duration1r = TV_MRI(masked_spectrum_r, rmask, lamb=5, maxiter=500, tol=1e-2)\n", "\n", "fig = plt.figure()\n", "fig.set_size_inches(5, 7.5)\n", "ax = fig.add_subplot(321)\n", "ax.imshow(x.T, cmap=\"gray\")\n", "plt.title(\"original image\")\n", "ax2 = fig.add_subplot(322)\n", "ax2.imshow(np.abs(zero_filled_ifft_r), cmap=\"gray\")\n", "plt.title(\"Corrupted image\")\n", "ax3 = fig.add_subplot(323)\n", "ax3.imshow(np.abs(x_tr), cmap=\"gray\")\n", "plt.title(\"recovered image\")\n", "ax4 = fig.add_subplot(324)\n", "ax4.semilogy(range(len(jr)), jr, \"b-\", lw=2)\n", "plt.title(\"Residual\")\n", "plt.xlabel(\"iteration (k)\")\n", "plt.tight_layout()\n", "plt.show()\n", "print(f\"Time taken = {duration1r}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "83Ba6_NykFlp", "outputId": "52e3e9f7-0aab-4bc1-e04e-1dad4b0663e8", "tags": [] }, "outputs": [], "source": [ "PSNR(x.T, np.abs(zero_filled_ifft_r), data_range=np.max(np.abs(x)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "dHlNC7QBkJvC", "outputId": "b0f3e414-0e0e-400a-ceec-86f526b33b16", "tags": [] }, "outputs": [], "source": [ "PSNR(x.T, np.abs(x_tr), data_range=np.max(np.abs(x)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 398 }, "id": "pboWXUd4LmVj", "outputId": "c00f0f06-10ca-4144-a51a-fc39d7d3ac3c", "tags": [] }, "outputs": [], "source": [ "fig, axs = plt.subplots(1, 1, figsize=(15, 6))\n", "axs.semilogy(range(len(jr)), jr, \"b-\", lw=2, label=\"Random Mask\")\n", "axs.semilogy(range(len(j)), j, \"r+\", lw=2, label=\"Equidistant Mask\")\n", "plt.title(\"Residual\")\n", "plt.legend()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "colab": { "include_colab_link": true, "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 4 }