{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "view-in-github" }, "source": [ "# 2D TV Denoising\n", "\n", "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "-v3IqZG92UTb" }, "source": [ "## Import libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "L56jXJk2NckF" }, "outputs": [], "source": [ "import numpy as np\n", "import cv2\n", "from matplotlib import pyplot as plt\n", "import scipy.signal as signal\n", "from scipy.signal import convolve2d\n", "import scipy.fft as fft\n", "import urllib.request\n", "from skimage.metrics import peak_signal_noise_ratio as PSNR\n", "from IPython.display import Image, HTML\n", "from matplotlib.animation import FuncAnimation\n", "import time" ] }, { "cell_type": "markdown", "metadata": { "id": "EjhU0Iub2YAF" }, "source": [ "## Import image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 489 }, "id": "-h4yjINvOi_-", "outputId": "6db609e2-80fa-47d8-9d5e-006a6e30c93c" }, "outputs": [], "source": [ "# Reading image (grayscale)\n", "url = \"https://i.stack.imgur.com/kP0u2.png\"\n", "with urllib.request.urlopen(url) as url_response:\n", " img_array = np.asarray(bytearray(url_response.read()), dtype=np.uint8)\n", " img = cv2.imdecode(img_array, cv2.IMREAD_GRAYSCALE)\n", "# img is a 3-dimensional numpy array (third number indicates channel)\n", "# Converting to (0,1)\n", "x = img.astype(float) / 255.0\n", "print(type(img))\n", "print(img.shape)\n", "plt.imshow(x, cmap=\"gray\")" ] }, { "cell_type": "markdown", "metadata": { "id": "uHwSwKqc2fcC" }, "source": [ "## Define conv and fft functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cYohlgEgfHt7" }, "outputs": [], "source": [ "# Define some of the operators that we need...\n", "def conv2d_fft(x, h):\n", " p0 = x.shape[0] - h.shape[0]\n", " p1 = x.shape[1] - h.shape[1]\n", " h_pad = np.pad(h, ((0, p0), (0, p1)))\n", " Fh = fft.fft2(h_pad)\n", " Fx = fft.fft2(x)\n", " return np.real(fft.ifft2(Fx * Fh))\n", "\n", "\n", "def conv2dT_fft(x, h):\n", " p0 = x.shape[0] - h.shape[0]\n", " p1 = x.shape[1] - h.shape[1]\n", " h_pad = np.pad(h, ((0, p0), (0, p1)))\n", " Fh = fft.fft2(h_pad)\n", " Fx = fft.fft2(x)\n", " return np.real(fft.ifft2(Fx * np.conj(Fh)))" ] }, { "cell_type": "markdown", "metadata": { "id": "hnlTVaW62nE3" }, "source": [ "## Noise function" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "V_hPb30HlU1C", "tags": [] }, "outputs": [], "source": [ "def awgn(img, n):\n", " \"\"\"Generating Gaussian Noise\n", " with 0 mean and standard deviation n\n", " choose n between 0,1 for normalized image\"\"\"\n", "\n", " noise = np.random.randn(*img.shape) * n\n", " # Add the noise to the input image\n", " noisy_image = img + noise\n", "\n", " return noisy_image" ] }, { "cell_type": "markdown", "metadata": { "id": "mskYtr7p2paz" }, "source": [ "## Add noise to the image" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 443 }, "id": "OV8CXqRclrmY", "outputId": "9b3d1572-d8bd-44c4-8955-5cf61d54791b" }, "outputs": [], "source": [ "\"\"\"\n", "y1 = img + n\n", "\"\"\"\n", "# Add noise to the image\n", "y1 = awgn(x, 0.2)\n", "\n", "fig = plt.figure(figsize=(15, 15))\n", "plt.subplot(121)\n", "plt.imshow(x, cmap=\"gray\", clim=[0, 1])\n", "plt.title(\"image x\")\n", "plt.subplot(122)\n", "plt.imshow(y1, cmap=\"gray\", clim=[0, 1])\n", "plt.title(\"Noisy image y = img + n\")\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "_YxvRnTZ231n" }, "source": [ "## Gradiant operator" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 302 }, "id": "4eq9mIIu2M9D", "outputId": "2041581a-e4fd-4c87-adb6-bc56990e996d" }, "outputs": [], "source": [ "# define gradient operators\n", "\"\"\"\n", "I am testing it with y1 for now\n", "\"\"\"\n", "dh = np.array([[1, -1], [0, 0]]) # horizontal gradient filter\n", "dv = np.array([[1, 0], [-1, 0]]) # vertical gradient filter\n", "\n", "Dh = lambda x: conv2d_fft(x, dh)\n", "Dv = lambda x: conv2d_fft(x, dv)\n", "\n", "DhT = lambda x: conv2dT_fft(x, dh)\n", "DvT = lambda x: conv2dT_fft(x, dv)\n", "\n", "# plot the image x and the gradient images Dh x and Dv x\n", "fig = plt.figure(figsize=(15, 15))\n", "plt.subplot(131)\n", "plt.imshow(x, cmap=\"gray\", clim=[0, 1])\n", "plt.title(\"image x\")\n", "plt.subplot(132)\n", "plt.imshow(np.abs(Dh(x)), cmap=\"gray\", clim=[0, 1])\n", "plt.title(r\"$|D_hx|$\")\n", "plt.subplot(133)\n", "plt.imshow(np.abs(Dv(x)), cmap=\"gray\", clim=[0, 1])\n", "plt.title(r\"$|D_vx|$\")\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "sn78eU3X9y4M" }, "source": [ "## TV-denoising Solver\n", "Write TV-denoising formulation with explanation here later" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yLpGSm_c3fsR" }, "outputs": [], "source": [ "def TV_denoising(y, lamb=2, rho=1e1, maxiter=200, return_history=False):\n", " \"\"\"\n", " TV-denoising solver to solve\n", " minimize 0.5 |x-y|_2^2 + lambda|Dx|_1\n", " \"\"\"\n", " # define the soft-thresholding function\n", " \"\"\"\n", " In the TVD case we have : \n", " Vector (v): Dx_(k+1) + u_(k)\n", " Threshold (t): lamb/rho\n", " \"\"\"\n", " start = time.time()\n", " soft_thresh = lambda v, t: np.maximum(np.abs(v) - t, 0.0) * np.sign(v)\n", "\n", " # DDT\n", " \"\"\"\n", " DDT = please check notes for fourier transform format\n", " \"\"\"\n", " # Calculating the difference between sizes x and d for padding purpose\n", " p0 = x.shape[0] - dh.shape[0]\n", " p1 = x.shape[1] - dh.shape[1]\n", " dh_pad = np.pad(dh, ((0, p0), (0, p1)))\n", "\n", " p0 = x.shape[0] - dv.shape[0]\n", " p1 = x.shape[1] - dv.shape[1]\n", " dv_pad = np.pad(dv, ((0, p0), (0, p1)))\n", "\n", " # Refer to Parisima's notes for computing DDT using FFT\n", " DDT = np.abs(fft.fft2(dh_pad)) ** 2 + np.abs(fft.fft2(dv_pad)) ** 2\n", "\n", " # -----------------------------\n", " # initilize iteration variables\n", " zh = np.zeros_like(y)\n", " zv = np.zeros_like(y)\n", " uh = np.zeros_like(zh)\n", " uv = np.zeros_like(zv)\n", " x_hat = np.zeros_like(y)\n", " # For computing error\n", " J = np.zeros(maxiter)\n", " history = []\n", "\n", " for k in range(maxiter):\n", " # solve the L2-L2 problem (update x)\n", " rhs = y + rho * ((DhT(zh) + DvT(zv)) - DhT(uh) - DvT(uv))\n", " F_rhs = fft.fft2(rhs)\n", " x_hat = np.real(fft.ifft2(F_rhs / (rho * DDT + 1)))\n", "\n", " # solve the TV problem (update z)\n", " zh = soft_thresh(Dh(x_hat) + uh, lamb / rho)\n", " zv = soft_thresh(Dv(x_hat) + uv, lamb / rho)\n", "\n", " # update u\n", " dual_h = Dh(x_hat) - zh\n", " dual_v = Dv(x_hat) - zv\n", " uh = uh + dual_h\n", " uv = uv + dual_v\n", "\n", " # compute the error\n", " J[k] = (dual_h**2).sum() + (dual_v**2).sum()\n", " history.append(x_hat.copy())\n", " end = time.time()\n", " if return_history:\n", " return x_hat, J, end - start, history\n", " return x_hat, J, end - start" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 564 }, "id": "rxgu1lPqbgMp", "outputId": "686ccf28-03aa-4167-c91b-330920756682" }, "outputs": [], "source": [ "x_hat, J, duration, history = TV_denoising(\n", " y1, lamb=0.2, rho=2, maxiter=50, return_history=True\n", ")\n", "\n", "fig = plt.figure()\n", "fig.set_size_inches(5, 7.5)\n", "ax = fig.add_subplot(321)\n", "ax.imshow(x, cmap=\"gray\", clim=[0, 1])\n", "plt.title(\"original image\")\n", "ax2 = fig.add_subplot(322)\n", "ax2.imshow(y1, cmap=\"gray\", clim=[0, 1])\n", "plt.title(\"noisy image\")\n", "ax3 = fig.add_subplot(323)\n", "ax3.imshow(x_hat, cmap=\"gray\", clim=[0, 1])\n", "plt.title(\"recovered image\")\n", "ax4 = fig.add_subplot(324)\n", "ax4.semilogy(range(len(J)), J, \"b-\", lw=2)\n", "plt.title(\"Convergence\")\n", "plt.xlabel(\"iteration (k)\")\n", "plt.tight_layout()\n", "plt.show()\n", "print(f\"Time taken = {duration}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## TVD using L2 regularizer animated" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "fig, axs = plt.subplots(1, 2, figsize=(10, 4), dpi=80)\n", "axis_img = axs[0].imshow(y1, cmap=\"gray\", clim=[0.0, 1.0])\n", "\n", "(line,) = axs[1].semilogy(range(len(J)), J)\n", "axs[1].set_xlabel(\"Iteration\")\n", "axs[1].set_ylabel(\"Convergence\")\n", "axs[1].set_title(\"Convergence vs Iteration\")\n", "plt.close()\n", "\n", "\n", "def animate(i):\n", " axs[0].set_title(f\"Recovered image at iteration {i}\")\n", " axis_img.set(cmap=\"gray\", clim=[0.0, 1.0])\n", " axis_img.set_data(history[i])\n", " line.set_data(np.array(range(i + 1)), J[: (i + 1)])\n", " return axis_img, line\n", "\n", "\n", "animation = FuncAnimation(fig, animate, frames=50, interval=120, blit=True)\n", "# let animation load\n", "time.sleep(1)\n", "plt.show();" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [] }, "outputs": [], "source": [ "display(HTML(f'
{animation.to_html5_video()}
'))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "colab": { "include_colab_link": true, "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 4 }