{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "view-in-github" }, "source": [ "# 1D TV Denosing using Condat Algorithm\n", "\"Open" ] }, { "cell_type": "markdown", "metadata": { "id": "-PrX6BKgHLq-" }, "source": [ "[Reference](https://lcondat.github.io/software.html)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-IgmQxgD299X" }, "outputs": [], "source": [ "import numpy as np\n", "import math\n", "import matplotlib.pyplot as plt\n", "from scipy.sparse import csr_matrix, spdiags, diags, csc_matrix\n", "from scipy.sparse.linalg import spsolve\n", "from scipy.fftpack import fft, ifft\n", "import time\n", "import cv2\n", "import urllib.request\n", "from skimage.metrics import peak_signal_noise_ratio as PSNR" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "cfoFSI0l3IVr" }, "outputs": [], "source": [ "def TV_Condat_v2(y, lam):\n", " start = time.time()\n", " N = len(y)\n", " if N <= 1:\n", " return y\n", " x = np.zeros_like(y)\n", " cost = np.zeros(N)\n", " indstart_low = np.zeros(N, dtype=int)\n", " indstart_up = np.zeros(N, dtype=int)\n", " j_low = j_up = jseg = indjseg = 0\n", " indstart_low[0] = indstart_up[0] = 0\n", " x_low_first = y[0] - lam\n", " x_up_first = y[0] + lam\n", " x_low_curr = x_low_first\n", " x_up_curr = x_up_first\n", " cost[0] = 0.5 * np.sum(np.abs(x - y) ** 2) + lam * np.sum(np.abs(np.diff(x)))\n", " for i in range(1, N - 1):\n", " if y[i] >= x_low_curr:\n", " if y[i] <= x_up_curr:\n", " x_up_curr = x_up_curr + (y[i] - x_up_curr) / (i - indstart_up[j_up] + 1)\n", " x[indjseg] = x_up_first\n", " while j_up > jseg and x_up_curr <= x[indstart_up[j_up - 1]]:\n", " j_up -= 1\n", " x_up_curr = x[indstart_up[j_up]] + (\n", " x_up_curr - x[indstart_up[j_up]]\n", " ) * ((i - indstart_up[j_up + 1] + 1) / (i - indstart_up[j_up] + 1))\n", " if j_up == jseg:\n", " while x_up_curr <= x_low_first and jseg < j_low:\n", " jseg += 1\n", " x[indjseg : indstart_low[jseg] - 1] = x_low_first\n", " x_up_curr = x_up_curr + (x_up_curr - x_low_first) * (\n", " (indstart_low[jseg] - indjseg)\n", " / (i - indstart_low[jseg] + 1)\n", " )\n", " indjseg = indstart_low[jseg]\n", " x_low_first = x[indjseg]\n", " x_up_first = x_up_curr\n", " j_up = jseg\n", " indstart_up[jseg] = indjseg\n", " else:\n", " x[indstart_up[j_up]] = x_up_curr\n", " else:\n", " j_up += 1\n", " indstart_up[j_up] = i\n", " x[i] = y[i]\n", " x_up_curr = x[i]\n", " x_low_curr = x_low_curr + (y[i] - x_low_curr) / (\n", " i - indstart_low[j_low] + 1\n", " )\n", " x[indjseg] = x_low_first\n", " while j_low > jseg and x_low_curr >= x[indstart_low[j_low - 1]]:\n", " j_low -= 1\n", " x_low_curr = x[indstart_low[j_low]] + (\n", " x_low_curr - x[indstart_low[j_low]]\n", " ) * ((i - indstart_low[j_low + 1] + 1) / (i - indstart_low[j_low] + 1))\n", " if j_low == jseg:\n", " while x_low_curr >= x_up_first and jseg < j_up:\n", " jseg += 1\n", " x[indjseg : indstart_up[jseg] - 1] = x_up_first\n", " x_low_curr = x_low_curr + (x_low_curr - x_up_first) * (\n", " (indstart_up[jseg] - indjseg) / (i - indstart_up[jseg] + 1)\n", " )\n", " indjseg = indstart_up[jseg]\n", " x_up_first = x[indjseg]\n", " x_low_first = x_low_curr\n", " j_low = jseg\n", " indstart_low[jseg] = indjseg\n", " if indjseg == i:\n", " x_low_first = x_up_first - 2 * lam\n", " else:\n", " x[indstart_low[j_low]] = x_low_curr\n", " else:\n", " j_low = j_low + 1\n", " indstart_low[j_low] = i\n", " x[i] = y[i]\n", " x_low_curr = x[i]\n", "\n", " # fusion of x_up to keep it nondecreasing\n", " x_up_curr = x_up_curr + (y[i] - x_up_curr) / (i - indstart_up[j_up] + 1)\n", " x[indjseg] = x_up_first\n", "\n", " while j_up > jseg and x_up_curr <= x[indstart_up[j_up - 1]]:\n", " j_up = j_up - 1\n", " x_up_curr = x[indstart_up[j_up]] + (\n", " x_up_curr - x[indstart_up[j_up]]\n", " ) * ((i - indstart_up[j_up + 1] + 1) / (i - indstart_up[j_up] + 1))\n", " if j_up == jseg:\n", " # a jump in x downwards is possible\n", " while x_up_curr <= x_low_first and jseg < j_low:\n", " # validation of segments of x_low in x\n", " jseg += 1\n", " x[indjseg : indstart_low[jseg] - 1] = x_low_first\n", " x_up_curr = x_up_curr + (x_up_curr - x_low_first) * (\n", " (indstart_low[jseg] - indjseg) / (i - indstart_low[jseg] + 1)\n", " )\n", " indjseg = indstart_low[jseg]\n", " x_low_first = x[indjseg]\n", " x_up_first = x_up_curr\n", " j_up = jseg\n", " indstart_up[jseg] = indjseg\n", " if indjseg == i:\n", " # this part is not mandatory, it is a kind of reset to increase numerical robustness.\n", " x_up_first = x_low_first + 2 * lam\n", " else:\n", " x[indstart_up[j_up]] = x_up_curr\n", " cost[i] = 0.5 * np.sum(np.abs(x - y) ** 2) + lam * np.sum(np.abs(np.diff(x)))\n", " i = N - 1\n", " if y[i] + lam <= x_low_curr:\n", " # the segments of x_low are validated\n", " while jseg < j_low:\n", " jseg += 1\n", " x[indjseg : indstart_low[jseg] - 1] = x_low_first\n", " indjseg = indstart_low[jseg]\n", " x_low_first = x[indjseg]\n", " x[indjseg : i - 1] = x_low_first\n", " x[i] = y[i] + lam\n", " elif y[i] - lam >= x_up_curr:\n", " while jseg < j_up:\n", " jseg += 1\n", " x[indjseg : indstart_up[jseg] - 1] = x_up_first\n", " indjseg = indstart_up[jseg]\n", " x_up_first = x[indjseg]\n", "\n", " x[indjseg : i - 1] = x_up_first\n", " x[i] = y[i] - lam\n", " else:\n", " x_low_curr = x_low_curr + (y[i] + lam - x_low_curr) / (\n", " i - indstart_low[j_low] + 1\n", " )\n", " x[indjseg] = x_low_first\n", " while j_low > jseg and x_low_curr >= x[indstart_low[j_low - 1]]:\n", " j_low -= 1\n", " x_low_curr = x[indstart_low[j_low]] + (\n", " x_low_curr - x[indstart_low[j_low]]\n", " ) * ((i - indstart_low[j_low + 1] + 1) / (i - indstart_low[j_low] + 1))\n", " if j_low == jseg:\n", " if x_up_first >= x_low_curr:\n", " x[indjseg:i] = x_low_curr\n", " else:\n", " x_up_curr = x_up_curr + (y[i] - lam - x_up_curr) / (\n", " i - indstart_up[j_up] + 1\n", " )\n", " x[indjseg] = x_up_first\n", " while j_up > jseg and x_up_curr <= x[indstart_up[j_up - 1]]:\n", " j_up = j_up - 1\n", " x_up_curr = x[indstart_up[j_up]] + (\n", " x_up_curr - x[indstart_up[j_up]]\n", " ) * ((i - indstart_up[j_up + 1] + 1) / (i - indstart_up[j_up] + 1))\n", " x[indstart_up[j_up] : i] = x_up_curr\n", " while jseg < j_up:\n", " jseg = jseg + 1\n", " x[indjseg : indstart_up[jseg] - 1] = x_up_first\n", " indjseg = indstart_up[jseg]\n", " x_up_first = x[indjseg]\n", " else:\n", " x[indstart_low[j_low] : i] = x_low_curr\n", " while jseg < j_low:\n", " jseg = jseg + 1\n", " x[indjseg : indstart_low[jseg] - 1] = x_low_first\n", " indjseg = indstart_low[jseg]\n", " x_low_first = x[indjseg]\n", " cost[N - 1] = 0.5 * np.sum(np.abs(x - y) ** 2) + lam * np.sum(np.abs(np.diff(x)))\n", " end = time.time()\n", " return x, cost, end - start" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NChnpW6KEUap" }, "outputs": [], "source": [ "s = np.loadtxt(\n", " \"https://eeweb.engineering.nyu.edu/iselesni/lecture_notes/TVDmm/TVD_software/blocks.txt\"\n", ")\n", "y = np.loadtxt(\n", " \"https://eeweb.engineering.nyu.edu/iselesni/lecture_notes/TVDmm/TVD_software/blocks_noisy.txt\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aeKe7OFAEczE" }, "outputs": [], "source": [ "N = 256\n", "# N : signal length\n", "sigma = 0.5; # sigma : standard deviation of noise" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 485 }, "id": "Tk2KddeCEe0H", "outputId": "582d5302-d441-4f7d-b15d-feba7d12b495" }, "outputs": [], "source": [ "fig, axs = plt.subplots(1, 2, figsize=(15, 5))\n", "axs[0].plot(s)\n", "axs[0].set_title(\"original signal\")\n", "axs[1].plot(y)\n", "axs[1].set_title(\"noisy signal\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 298 }, "id": "JXTz5dVGEjKT", "outputId": "b994ff00-5ea1-41f9-a210-35c59b259b43" }, "outputs": [], "source": [ "lam = 3.0\n", "x_condat, cost_condat, time_taken = TV_Condat_v2(y, lam)\n", "fig, axs = plt.subplots(1, 4, figsize=(30, 5))\n", "axs[0].plot(s)\n", "axs[0].set_title(\"original signal\")\n", "axs[1].plot(y)\n", "axs[1].set_title(\"noisy signal\")\n", "axs[2].plot(x_condat)\n", "axs[2].set_title(\"recovered signal\")\n", "axs[3].plot(cost_condat)\n", "axs[3].set_title(\"cost graph\")\n", "print(f\"Time taken = {time_taken}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 447 }, "id": "vUKUtpz2rsiO", "outputId": "3cd15fc4-36f2-478a-fc68-f7b56985499e" }, "outputs": [], "source": [ "plt.plot(x_condat)\n", "# plt.axis('off')" ] } ], "metadata": { "colab": { "include_colab_link": true, "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" } }, "nbformat": 4, "nbformat_minor": 4 }