{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"# 2D TV Denoising\n",
"\n",
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-v3IqZG92UTb"
},
"source": [
"## Import libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "L56jXJk2NckF"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import cv2\n",
"from matplotlib import pyplot as plt\n",
"import scipy.signal as signal\n",
"from scipy.signal import convolve2d\n",
"import scipy.fft as fft\n",
"import urllib.request\n",
"from skimage.metrics import peak_signal_noise_ratio as PSNR\n",
"from IPython.display import Image, HTML\n",
"from matplotlib.animation import FuncAnimation\n",
"import time"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EjhU0Iub2YAF"
},
"source": [
"## Import image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 489
},
"id": "-h4yjINvOi_-",
"outputId": "6db609e2-80fa-47d8-9d5e-006a6e30c93c"
},
"outputs": [],
"source": [
"# Reading image (grayscale)\n",
"url = \"https://i.stack.imgur.com/kP0u2.png\"\n",
"with urllib.request.urlopen(url) as url_response:\n",
" img_array = np.asarray(bytearray(url_response.read()), dtype=np.uint8)\n",
" img = cv2.imdecode(img_array, cv2.IMREAD_GRAYSCALE)\n",
"# img is a 3-dimensional numpy array (third number indicates channel)\n",
"# Converting to (0,1)\n",
"x = img.astype(float) / 255.0\n",
"print(type(img))\n",
"print(img.shape)\n",
"plt.imshow(x, cmap=\"gray\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "uHwSwKqc2fcC"
},
"source": [
"## Define conv and fft functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cYohlgEgfHt7"
},
"outputs": [],
"source": [
"# Define some of the operators that we need...\n",
"def conv2d_fft(x, h):\n",
" p0 = x.shape[0] - h.shape[0]\n",
" p1 = x.shape[1] - h.shape[1]\n",
" h_pad = np.pad(h, ((0, p0), (0, p1)))\n",
" Fh = fft.fft2(h_pad)\n",
" Fx = fft.fft2(x)\n",
" return np.real(fft.ifft2(Fx * Fh))\n",
"\n",
"\n",
"def conv2dT_fft(x, h):\n",
" p0 = x.shape[0] - h.shape[0]\n",
" p1 = x.shape[1] - h.shape[1]\n",
" h_pad = np.pad(h, ((0, p0), (0, p1)))\n",
" Fh = fft.fft2(h_pad)\n",
" Fx = fft.fft2(x)\n",
" return np.real(fft.ifft2(Fx * np.conj(Fh)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hnlTVaW62nE3"
},
"source": [
"## Noise function"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "V_hPb30HlU1C",
"tags": []
},
"outputs": [],
"source": [
"def awgn(img, n):\n",
" \"\"\"Generating Gaussian Noise\n",
" with 0 mean and standard deviation n\n",
" choose n between 0,1 for normalized image\"\"\"\n",
"\n",
" noise = np.random.randn(*img.shape) * n\n",
" # Add the noise to the input image\n",
" noisy_image = img + noise\n",
"\n",
" return noisy_image"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mskYtr7p2paz"
},
"source": [
"## Add noise to the image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 443
},
"id": "OV8CXqRclrmY",
"outputId": "9b3d1572-d8bd-44c4-8955-5cf61d54791b"
},
"outputs": [],
"source": [
"\"\"\"\n",
"y1 = img + n\n",
"\"\"\"\n",
"# Add noise to the image\n",
"y1 = awgn(x, 0.2)\n",
"\n",
"fig = plt.figure(figsize=(15, 15))\n",
"plt.subplot(121)\n",
"plt.imshow(x, cmap=\"gray\", clim=[0, 1])\n",
"plt.title(\"image x\")\n",
"plt.subplot(122)\n",
"plt.imshow(y1, cmap=\"gray\", clim=[0, 1])\n",
"plt.title(\"Noisy image y = img + n\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_YxvRnTZ231n"
},
"source": [
"## Gradiant operator"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 302
},
"id": "4eq9mIIu2M9D",
"outputId": "2041581a-e4fd-4c87-adb6-bc56990e996d"
},
"outputs": [],
"source": [
"# define gradient operators\n",
"\"\"\"\n",
"I am testing it with y1 for now\n",
"\"\"\"\n",
"dh = np.array([[1, -1], [0, 0]]) # horizontal gradient filter\n",
"dv = np.array([[1, 0], [-1, 0]]) # vertical gradient filter\n",
"\n",
"Dh = lambda x: conv2d_fft(x, dh)\n",
"Dv = lambda x: conv2d_fft(x, dv)\n",
"\n",
"DhT = lambda x: conv2dT_fft(x, dh)\n",
"DvT = lambda x: conv2dT_fft(x, dv)\n",
"\n",
"# plot the image x and the gradient images Dh x and Dv x\n",
"fig = plt.figure(figsize=(15, 15))\n",
"plt.subplot(131)\n",
"plt.imshow(x, cmap=\"gray\", clim=[0, 1])\n",
"plt.title(\"image x\")\n",
"plt.subplot(132)\n",
"plt.imshow(np.abs(Dh(x)), cmap=\"gray\", clim=[0, 1])\n",
"plt.title(r\"$|D_hx|$\")\n",
"plt.subplot(133)\n",
"plt.imshow(np.abs(Dv(x)), cmap=\"gray\", clim=[0, 1])\n",
"plt.title(r\"$|D_vx|$\")\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sn78eU3X9y4M"
},
"source": [
"## TV-denoising Solver\n",
"Write TV-denoising formulation with explanation here later"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yLpGSm_c3fsR"
},
"outputs": [],
"source": [
"def TV_denoising(y, lamb=2, rho=1e1, maxiter=200, return_history=False):\n",
" \"\"\"\n",
" TV-denoising solver to solve\n",
" minimize 0.5 |x-y|_2^2 + lambda|Dx|_1\n",
" \"\"\"\n",
" # define the soft-thresholding function\n",
" \"\"\"\n",
" In the TVD case we have : \n",
" Vector (v): Dx_(k+1) + u_(k)\n",
" Threshold (t): lamb/rho\n",
" \"\"\"\n",
" start = time.time()\n",
" soft_thresh = lambda v, t: np.maximum(np.abs(v) - t, 0.0) * np.sign(v)\n",
"\n",
" # DDT\n",
" \"\"\"\n",
" DDT = please check notes for fourier transform format\n",
" \"\"\"\n",
" # Calculating the difference between sizes x and d for padding purpose\n",
" p0 = x.shape[0] - dh.shape[0]\n",
" p1 = x.shape[1] - dh.shape[1]\n",
" dh_pad = np.pad(dh, ((0, p0), (0, p1)))\n",
"\n",
" p0 = x.shape[0] - dv.shape[0]\n",
" p1 = x.shape[1] - dv.shape[1]\n",
" dv_pad = np.pad(dv, ((0, p0), (0, p1)))\n",
"\n",
" # Refer to Parisima's notes for computing DDT using FFT\n",
" DDT = np.abs(fft.fft2(dh_pad)) ** 2 + np.abs(fft.fft2(dv_pad)) ** 2\n",
"\n",
" # -----------------------------\n",
" # initilize iteration variables\n",
" zh = np.zeros_like(y)\n",
" zv = np.zeros_like(y)\n",
" uh = np.zeros_like(zh)\n",
" uv = np.zeros_like(zv)\n",
" x_hat = np.zeros_like(y)\n",
" # For computing error\n",
" J = np.zeros(maxiter)\n",
" history = []\n",
"\n",
" for k in range(maxiter):\n",
" # solve the L2-L2 problem (update x)\n",
" rhs = y + rho * ((DhT(zh) + DvT(zv)) - DhT(uh) - DvT(uv))\n",
" F_rhs = fft.fft2(rhs)\n",
" x_hat = np.real(fft.ifft2(F_rhs / (rho * DDT + 1)))\n",
"\n",
" # solve the TV problem (update z)\n",
" zh = soft_thresh(Dh(x_hat) + uh, lamb / rho)\n",
" zv = soft_thresh(Dv(x_hat) + uv, lamb / rho)\n",
"\n",
" # update u\n",
" dual_h = Dh(x_hat) - zh\n",
" dual_v = Dv(x_hat) - zv\n",
" uh = uh + dual_h\n",
" uv = uv + dual_v\n",
"\n",
" # compute the error\n",
" J[k] = (dual_h**2).sum() + (dual_v**2).sum()\n",
" history.append(x_hat.copy())\n",
" end = time.time()\n",
" if return_history:\n",
" return x_hat, J, end - start, history\n",
" return x_hat, J, end - start"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 564
},
"id": "rxgu1lPqbgMp",
"outputId": "686ccf28-03aa-4167-c91b-330920756682"
},
"outputs": [],
"source": [
"x_hat, J, duration, history = TV_denoising(\n",
" y1, lamb=0.2, rho=2, maxiter=50, return_history=True\n",
")\n",
"\n",
"fig = plt.figure()\n",
"fig.set_size_inches(5, 7.5)\n",
"ax = fig.add_subplot(321)\n",
"ax.imshow(x, cmap=\"gray\", clim=[0, 1])\n",
"plt.title(\"original image\")\n",
"ax2 = fig.add_subplot(322)\n",
"ax2.imshow(y1, cmap=\"gray\", clim=[0, 1])\n",
"plt.title(\"noisy image\")\n",
"ax3 = fig.add_subplot(323)\n",
"ax3.imshow(x_hat, cmap=\"gray\", clim=[0, 1])\n",
"plt.title(\"recovered image\")\n",
"ax4 = fig.add_subplot(324)\n",
"ax4.semilogy(range(len(J)), J, \"b-\", lw=2)\n",
"plt.title(\"Convergence\")\n",
"plt.xlabel(\"iteration (k)\")\n",
"plt.tight_layout()\n",
"plt.show()\n",
"print(f\"Time taken = {duration}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## TVD using L2 regularizer animated"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"fig, axs = plt.subplots(1, 2, figsize=(10, 4), dpi=80)\n",
"axis_img = axs[0].imshow(y1, cmap=\"gray\", clim=[0.0, 1.0])\n",
"\n",
"(line,) = axs[1].semilogy(range(len(J)), J)\n",
"axs[1].set_xlabel(\"Iteration\")\n",
"axs[1].set_ylabel(\"Convergence\")\n",
"axs[1].set_title(\"Convergence vs Iteration\")\n",
"plt.close()\n",
"\n",
"\n",
"def animate(i):\n",
" axs[0].set_title(f\"Recovered image at iteration {i}\")\n",
" axis_img.set(cmap=\"gray\", clim=[0.0, 1.0])\n",
" axis_img.set_data(history[i])\n",
" line.set_data(np.array(range(i + 1)), J[: (i + 1)])\n",
" return axis_img, line\n",
"\n",
"\n",
"animation = FuncAnimation(fig, animate, frames=50, interval=120, blit=True)\n",
"# let animation load\n",
"time.sleep(1)\n",
"plt.show();"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"display(HTML(f'