1D TV Denosing using Condat Algorithm#

Open In Colab

Reference

import numpy as np
import math
import matplotlib.pyplot as plt
from scipy.sparse import csr_matrix, spdiags, diags, csc_matrix
from scipy.sparse.linalg import spsolve
from scipy.fftpack import fft, ifft
import time
import cv2
import urllib.request
from skimage.metrics import peak_signal_noise_ratio as PSNR
def TV_Condat_v2(y, lam):
    start = time.time()
    N = len(y)
    if N <= 1:
        return y
    x = np.zeros_like(y)
    cost = np.zeros(N)
    indstart_low = np.zeros(N, dtype=int)
    indstart_up = np.zeros(N, dtype=int)
    j_low = j_up = jseg = indjseg = 0
    indstart_low[0] = indstart_up[0] = 0
    x_low_first = y[0] - lam
    x_up_first = y[0] + lam
    x_low_curr = x_low_first
    x_up_curr = x_up_first
    cost[0] = 0.5 * np.sum(np.abs(x - y) ** 2) + lam * np.sum(np.abs(np.diff(x)))
    for i in range(1, N - 1):
        if y[i] >= x_low_curr:
            if y[i] <= x_up_curr:
                x_up_curr = x_up_curr + (y[i] - x_up_curr) / (i - indstart_up[j_up] + 1)
                x[indjseg] = x_up_first
                while j_up > jseg and x_up_curr <= x[indstart_up[j_up - 1]]:
                    j_up -= 1
                    x_up_curr = x[indstart_up[j_up]] + (
                        x_up_curr - x[indstart_up[j_up]]
                    ) * ((i - indstart_up[j_up + 1] + 1) / (i - indstart_up[j_up] + 1))
                if j_up == jseg:
                    while x_up_curr <= x_low_first and jseg < j_low:
                        jseg += 1
                        x[indjseg : indstart_low[jseg] - 1] = x_low_first
                        x_up_curr = x_up_curr + (x_up_curr - x_low_first) * (
                            (indstart_low[jseg] - indjseg)
                            / (i - indstart_low[jseg] + 1)
                        )
                        indjseg = indstart_low[jseg]
                        x_low_first = x[indjseg]
                    x_up_first = x_up_curr
                    j_up = jseg
                    indstart_up[jseg] = indjseg
                else:
                    x[indstart_up[j_up]] = x_up_curr
            else:
                j_up += 1
                indstart_up[j_up] = i
                x[i] = y[i]
                x_up_curr = x[i]
            x_low_curr = x_low_curr + (y[i] - x_low_curr) / (
                i - indstart_low[j_low] + 1
            )
            x[indjseg] = x_low_first
            while j_low > jseg and x_low_curr >= x[indstart_low[j_low - 1]]:
                j_low -= 1
                x_low_curr = x[indstart_low[j_low]] + (
                    x_low_curr - x[indstart_low[j_low]]
                ) * ((i - indstart_low[j_low + 1] + 1) / (i - indstart_low[j_low] + 1))
            if j_low == jseg:
                while x_low_curr >= x_up_first and jseg < j_up:
                    jseg += 1
                    x[indjseg : indstart_up[jseg] - 1] = x_up_first
                    x_low_curr = x_low_curr + (x_low_curr - x_up_first) * (
                        (indstart_up[jseg] - indjseg) / (i - indstart_up[jseg] + 1)
                    )
                    indjseg = indstart_up[jseg]
                    x_up_first = x[indjseg]
                x_low_first = x_low_curr
                j_low = jseg
                indstart_low[jseg] = indjseg
                if indjseg == i:
                    x_low_first = x_up_first - 2 * lam
            else:
                x[indstart_low[j_low]] = x_low_curr
        else:
            j_low = j_low + 1
            indstart_low[j_low] = i
            x[i] = y[i]
            x_low_curr = x[i]

            # fusion of x_up to keep it nondecreasing
            x_up_curr = x_up_curr + (y[i] - x_up_curr) / (i - indstart_up[j_up] + 1)
            x[indjseg] = x_up_first

            while j_up > jseg and x_up_curr <= x[indstart_up[j_up - 1]]:
                j_up = j_up - 1
                x_up_curr = x[indstart_up[j_up]] + (
                    x_up_curr - x[indstart_up[j_up]]
                ) * ((i - indstart_up[j_up + 1] + 1) / (i - indstart_up[j_up] + 1))
            if j_up == jseg:
                # a jump in x downwards is possible
                while x_up_curr <= x_low_first and jseg < j_low:
                    # validation of segments of x_low in x
                    jseg += 1
                    x[indjseg : indstart_low[jseg] - 1] = x_low_first
                    x_up_curr = x_up_curr + (x_up_curr - x_low_first) * (
                        (indstart_low[jseg] - indjseg) / (i - indstart_low[jseg] + 1)
                    )
                    indjseg = indstart_low[jseg]
                    x_low_first = x[indjseg]
                x_up_first = x_up_curr
                j_up = jseg
                indstart_up[jseg] = indjseg
                if indjseg == i:
                    # this part is not mandatory, it is a kind of reset to increase numerical robustness.
                    x_up_first = x_low_first + 2 * lam
            else:
                x[indstart_up[j_up]] = x_up_curr
        cost[i] = 0.5 * np.sum(np.abs(x - y) ** 2) + lam * np.sum(np.abs(np.diff(x)))
    i = N - 1
    if y[i] + lam <= x_low_curr:
        # the segments of x_low are validated
        while jseg < j_low:
            jseg += 1
            x[indjseg : indstart_low[jseg] - 1] = x_low_first
            indjseg = indstart_low[jseg]
            x_low_first = x[indjseg]
        x[indjseg : i - 1] = x_low_first
        x[i] = y[i] + lam
    elif y[i] - lam >= x_up_curr:
        while jseg < j_up:
            jseg += 1
            x[indjseg : indstart_up[jseg] - 1] = x_up_first
            indjseg = indstart_up[jseg]
            x_up_first = x[indjseg]

        x[indjseg : i - 1] = x_up_first
        x[i] = y[i] - lam
    else:
        x_low_curr = x_low_curr + (y[i] + lam - x_low_curr) / (
            i - indstart_low[j_low] + 1
        )
        x[indjseg] = x_low_first
        while j_low > jseg and x_low_curr >= x[indstart_low[j_low - 1]]:
            j_low -= 1
            x_low_curr = x[indstart_low[j_low]] + (
                x_low_curr - x[indstart_low[j_low]]
            ) * ((i - indstart_low[j_low + 1] + 1) / (i - indstart_low[j_low] + 1))
        if j_low == jseg:
            if x_up_first >= x_low_curr:
                x[indjseg:i] = x_low_curr
            else:
                x_up_curr = x_up_curr + (y[i] - lam - x_up_curr) / (
                    i - indstart_up[j_up] + 1
                )
                x[indjseg] = x_up_first
                while j_up > jseg and x_up_curr <= x[indstart_up[j_up - 1]]:
                    j_up = j_up - 1
                    x_up_curr = x[indstart_up[j_up]] + (
                        x_up_curr - x[indstart_up[j_up]]
                    ) * ((i - indstart_up[j_up + 1] + 1) / (i - indstart_up[j_up] + 1))
                x[indstart_up[j_up] : i] = x_up_curr
                while jseg < j_up:
                    jseg = jseg + 1
                    x[indjseg : indstart_up[jseg] - 1] = x_up_first
                    indjseg = indstart_up[jseg]
                    x_up_first = x[indjseg]
        else:
            x[indstart_low[j_low] : i] = x_low_curr
            while jseg < j_low:
                jseg = jseg + 1
                x[indjseg : indstart_low[jseg] - 1] = x_low_first
                indjseg = indstart_low[jseg]
                x_low_first = x[indjseg]
    cost[N - 1] = 0.5 * np.sum(np.abs(x - y) ** 2) + lam * np.sum(np.abs(np.diff(x)))
    end = time.time()
    return x, cost, end - start
s = np.loadtxt(
    "https://eeweb.engineering.nyu.edu/iselesni/lecture_notes/TVDmm/TVD_software/blocks.txt"
)
y = np.loadtxt(
    "https://eeweb.engineering.nyu.edu/iselesni/lecture_notes/TVDmm/TVD_software/blocks_noisy.txt"
)
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
Cell In[3], line 1
----> 1 s = np.loadtxt(
      2     "https://eeweb.engineering.nyu.edu/iselesni/lecture_notes/TVDmm/TVD_software/blocks.txt"
      3 )
      4 y = np.loadtxt(
      5     "https://eeweb.engineering.nyu.edu/iselesni/lecture_notes/TVDmm/TVD_software/blocks_noisy.txt"
      6 )

File /usr/share/miniconda/envs/L96M2lines/lib/python3.9/site-packages/numpy/lib/npyio.py:1338, in loadtxt(fname, dtype, comments, delimiter, converters, skiprows, usecols, unpack, ndmin, encoding, max_rows, quotechar, like)
   1335 if isinstance(delimiter, bytes):
   1336     delimiter = delimiter.decode('latin1')
-> 1338 arr = _read(fname, dtype=dtype, comment=comment, delimiter=delimiter,
   1339             converters=converters, skiplines=skiprows, usecols=usecols,
   1340             unpack=unpack, ndmin=ndmin, encoding=encoding,
   1341             max_rows=max_rows, quote=quotechar)
   1343 return arr

File /usr/share/miniconda/envs/L96M2lines/lib/python3.9/site-packages/numpy/lib/npyio.py:975, in _read(fname, delimiter, comment, quote, imaginary_unit, usecols, skiplines, max_rows, converters, ndmin, unpack, dtype, encoding)
    973     fname = os.fspath(fname)
    974 if isinstance(fname, str):
--> 975     fh = np.lib._datasource.open(fname, 'rt', encoding=encoding)
    976     if encoding is None:
    977         encoding = getattr(fh, 'encoding', 'latin1')

File /usr/share/miniconda/envs/L96M2lines/lib/python3.9/site-packages/numpy/lib/_datasource.py:193, in open(path, mode, destpath, encoding, newline)
    156 """
    157 Open `path` with `mode` and return the file object.
    158 
   (...)
    189 
    190 """
    192 ds = DataSource(destpath)
--> 193 return ds.open(path, mode, encoding=encoding, newline=newline)

File /usr/share/miniconda/envs/L96M2lines/lib/python3.9/site-packages/numpy/lib/_datasource.py:533, in DataSource.open(self, path, mode, encoding, newline)
    530     return _file_openers[ext](found, mode=mode,
    531                               encoding=encoding, newline=newline)
    532 else:
--> 533     raise FileNotFoundError(f"{path} not found.")

FileNotFoundError: https://eeweb.engineering.nyu.edu/iselesni/lecture_notes/TVDmm/TVD_software/blocks.txt not found.
N = 256
# N : signal length
sigma = 0.5;  # sigma : standard deviation of noise
fig, axs = plt.subplots(1, 2, figsize=(15, 5))
axs[0].plot(s)
axs[0].set_title("original signal")
axs[1].plot(y)
axs[1].set_title("noisy signal")
lam = 3.0
x_condat, cost_condat, time_taken = TV_Condat_v2(y, lam)
fig, axs = plt.subplots(1, 4, figsize=(30, 5))
axs[0].plot(s)
axs[0].set_title("original signal")
axs[1].plot(y)
axs[1].set_title("noisy signal")
axs[2].plot(x_condat)
axs[2].set_title("recovered signal")
axs[3].plot(cost_condat)
axs[3].set_title("cost graph")
print(f"Time taken = {time_taken}")
plt.plot(x_condat)
# plt.axis('off')