1D TV Denosing using Condat Algorithm#
import numpy as np
import math
import matplotlib.pyplot as plt
from scipy.sparse import csr_matrix, spdiags, diags, csc_matrix
from scipy.sparse.linalg import spsolve
from scipy.fftpack import fft, ifft
import time
import cv2
import urllib.request
from skimage.metrics import peak_signal_noise_ratio as PSNR
def TV_Condat_v2(y, lam):
start = time.time()
N = len(y)
if N <= 1:
return y
x = np.zeros_like(y)
cost = np.zeros(N)
indstart_low = np.zeros(N, dtype=int)
indstart_up = np.zeros(N, dtype=int)
j_low = j_up = jseg = indjseg = 0
indstart_low[0] = indstart_up[0] = 0
x_low_first = y[0] - lam
x_up_first = y[0] + lam
x_low_curr = x_low_first
x_up_curr = x_up_first
cost[0] = 0.5 * np.sum(np.abs(x - y) ** 2) + lam * np.sum(np.abs(np.diff(x)))
for i in range(1, N - 1):
if y[i] >= x_low_curr:
if y[i] <= x_up_curr:
x_up_curr = x_up_curr + (y[i] - x_up_curr) / (i - indstart_up[j_up] + 1)
x[indjseg] = x_up_first
while j_up > jseg and x_up_curr <= x[indstart_up[j_up - 1]]:
j_up -= 1
x_up_curr = x[indstart_up[j_up]] + (
x_up_curr - x[indstart_up[j_up]]
) * ((i - indstart_up[j_up + 1] + 1) / (i - indstart_up[j_up] + 1))
if j_up == jseg:
while x_up_curr <= x_low_first and jseg < j_low:
jseg += 1
x[indjseg : indstart_low[jseg] - 1] = x_low_first
x_up_curr = x_up_curr + (x_up_curr - x_low_first) * (
(indstart_low[jseg] - indjseg)
/ (i - indstart_low[jseg] + 1)
)
indjseg = indstart_low[jseg]
x_low_first = x[indjseg]
x_up_first = x_up_curr
j_up = jseg
indstart_up[jseg] = indjseg
else:
x[indstart_up[j_up]] = x_up_curr
else:
j_up += 1
indstart_up[j_up] = i
x[i] = y[i]
x_up_curr = x[i]
x_low_curr = x_low_curr + (y[i] - x_low_curr) / (
i - indstart_low[j_low] + 1
)
x[indjseg] = x_low_first
while j_low > jseg and x_low_curr >= x[indstart_low[j_low - 1]]:
j_low -= 1
x_low_curr = x[indstart_low[j_low]] + (
x_low_curr - x[indstart_low[j_low]]
) * ((i - indstart_low[j_low + 1] + 1) / (i - indstart_low[j_low] + 1))
if j_low == jseg:
while x_low_curr >= x_up_first and jseg < j_up:
jseg += 1
x[indjseg : indstart_up[jseg] - 1] = x_up_first
x_low_curr = x_low_curr + (x_low_curr - x_up_first) * (
(indstart_up[jseg] - indjseg) / (i - indstart_up[jseg] + 1)
)
indjseg = indstart_up[jseg]
x_up_first = x[indjseg]
x_low_first = x_low_curr
j_low = jseg
indstart_low[jseg] = indjseg
if indjseg == i:
x_low_first = x_up_first - 2 * lam
else:
x[indstart_low[j_low]] = x_low_curr
else:
j_low = j_low + 1
indstart_low[j_low] = i
x[i] = y[i]
x_low_curr = x[i]
# fusion of x_up to keep it nondecreasing
x_up_curr = x_up_curr + (y[i] - x_up_curr) / (i - indstart_up[j_up] + 1)
x[indjseg] = x_up_first
while j_up > jseg and x_up_curr <= x[indstart_up[j_up - 1]]:
j_up = j_up - 1
x_up_curr = x[indstart_up[j_up]] + (
x_up_curr - x[indstart_up[j_up]]
) * ((i - indstart_up[j_up + 1] + 1) / (i - indstart_up[j_up] + 1))
if j_up == jseg:
# a jump in x downwards is possible
while x_up_curr <= x_low_first and jseg < j_low:
# validation of segments of x_low in x
jseg += 1
x[indjseg : indstart_low[jseg] - 1] = x_low_first
x_up_curr = x_up_curr + (x_up_curr - x_low_first) * (
(indstart_low[jseg] - indjseg) / (i - indstart_low[jseg] + 1)
)
indjseg = indstart_low[jseg]
x_low_first = x[indjseg]
x_up_first = x_up_curr
j_up = jseg
indstart_up[jseg] = indjseg
if indjseg == i:
# this part is not mandatory, it is a kind of reset to increase numerical robustness.
x_up_first = x_low_first + 2 * lam
else:
x[indstart_up[j_up]] = x_up_curr
cost[i] = 0.5 * np.sum(np.abs(x - y) ** 2) + lam * np.sum(np.abs(np.diff(x)))
i = N - 1
if y[i] + lam <= x_low_curr:
# the segments of x_low are validated
while jseg < j_low:
jseg += 1
x[indjseg : indstart_low[jseg] - 1] = x_low_first
indjseg = indstart_low[jseg]
x_low_first = x[indjseg]
x[indjseg : i - 1] = x_low_first
x[i] = y[i] + lam
elif y[i] - lam >= x_up_curr:
while jseg < j_up:
jseg += 1
x[indjseg : indstart_up[jseg] - 1] = x_up_first
indjseg = indstart_up[jseg]
x_up_first = x[indjseg]
x[indjseg : i - 1] = x_up_first
x[i] = y[i] - lam
else:
x_low_curr = x_low_curr + (y[i] + lam - x_low_curr) / (
i - indstart_low[j_low] + 1
)
x[indjseg] = x_low_first
while j_low > jseg and x_low_curr >= x[indstart_low[j_low - 1]]:
j_low -= 1
x_low_curr = x[indstart_low[j_low]] + (
x_low_curr - x[indstart_low[j_low]]
) * ((i - indstart_low[j_low + 1] + 1) / (i - indstart_low[j_low] + 1))
if j_low == jseg:
if x_up_first >= x_low_curr:
x[indjseg:i] = x_low_curr
else:
x_up_curr = x_up_curr + (y[i] - lam - x_up_curr) / (
i - indstart_up[j_up] + 1
)
x[indjseg] = x_up_first
while j_up > jseg and x_up_curr <= x[indstart_up[j_up - 1]]:
j_up = j_up - 1
x_up_curr = x[indstart_up[j_up]] + (
x_up_curr - x[indstart_up[j_up]]
) * ((i - indstart_up[j_up + 1] + 1) / (i - indstart_up[j_up] + 1))
x[indstart_up[j_up] : i] = x_up_curr
while jseg < j_up:
jseg = jseg + 1
x[indjseg : indstart_up[jseg] - 1] = x_up_first
indjseg = indstart_up[jseg]
x_up_first = x[indjseg]
else:
x[indstart_low[j_low] : i] = x_low_curr
while jseg < j_low:
jseg = jseg + 1
x[indjseg : indstart_low[jseg] - 1] = x_low_first
indjseg = indstart_low[jseg]
x_low_first = x[indjseg]
cost[N - 1] = 0.5 * np.sum(np.abs(x - y) ** 2) + lam * np.sum(np.abs(np.diff(x)))
end = time.time()
return x, cost, end - start
s = np.loadtxt(
"https://eeweb.engineering.nyu.edu/iselesni/lecture_notes/TVDmm/TVD_software/blocks.txt"
)
y = np.loadtxt(
"https://eeweb.engineering.nyu.edu/iselesni/lecture_notes/TVDmm/TVD_software/blocks_noisy.txt"
)
---------------------------------------------------------------------------
FileNotFoundError Traceback (most recent call last)
Cell In[3], line 1
----> 1 s = np.loadtxt(
2 "https://eeweb.engineering.nyu.edu/iselesni/lecture_notes/TVDmm/TVD_software/blocks.txt"
3 )
4 y = np.loadtxt(
5 "https://eeweb.engineering.nyu.edu/iselesni/lecture_notes/TVDmm/TVD_software/blocks_noisy.txt"
6 )
File /usr/share/miniconda/envs/L96M2lines/lib/python3.9/site-packages/numpy/lib/npyio.py:1338, in loadtxt(fname, dtype, comments, delimiter, converters, skiprows, usecols, unpack, ndmin, encoding, max_rows, quotechar, like)
1335 if isinstance(delimiter, bytes):
1336 delimiter = delimiter.decode('latin1')
-> 1338 arr = _read(fname, dtype=dtype, comment=comment, delimiter=delimiter,
1339 converters=converters, skiplines=skiprows, usecols=usecols,
1340 unpack=unpack, ndmin=ndmin, encoding=encoding,
1341 max_rows=max_rows, quote=quotechar)
1343 return arr
File /usr/share/miniconda/envs/L96M2lines/lib/python3.9/site-packages/numpy/lib/npyio.py:975, in _read(fname, delimiter, comment, quote, imaginary_unit, usecols, skiplines, max_rows, converters, ndmin, unpack, dtype, encoding)
973 fname = os.fspath(fname)
974 if isinstance(fname, str):
--> 975 fh = np.lib._datasource.open(fname, 'rt', encoding=encoding)
976 if encoding is None:
977 encoding = getattr(fh, 'encoding', 'latin1')
File /usr/share/miniconda/envs/L96M2lines/lib/python3.9/site-packages/numpy/lib/_datasource.py:193, in open(path, mode, destpath, encoding, newline)
156 """
157 Open `path` with `mode` and return the file object.
158
(...)
189
190 """
192 ds = DataSource(destpath)
--> 193 return ds.open(path, mode, encoding=encoding, newline=newline)
File /usr/share/miniconda/envs/L96M2lines/lib/python3.9/site-packages/numpy/lib/_datasource.py:533, in DataSource.open(self, path, mode, encoding, newline)
530 return _file_openers[ext](found, mode=mode,
531 encoding=encoding, newline=newline)
532 else:
--> 533 raise FileNotFoundError(f"{path} not found.")
FileNotFoundError: https://eeweb.engineering.nyu.edu/iselesni/lecture_notes/TVDmm/TVD_software/blocks.txt not found.
N = 256
# N : signal length
sigma = 0.5; # sigma : standard deviation of noise
fig, axs = plt.subplots(1, 2, figsize=(15, 5))
axs[0].plot(s)
axs[0].set_title("original signal")
axs[1].plot(y)
axs[1].set_title("noisy signal")
lam = 3.0
x_condat, cost_condat, time_taken = TV_Condat_v2(y, lam)
fig, axs = plt.subplots(1, 4, figsize=(30, 5))
axs[0].plot(s)
axs[0].set_title("original signal")
axs[1].plot(y)
axs[1].set_title("noisy signal")
axs[2].plot(x_condat)
axs[2].set_title("recovered signal")
axs[3].plot(cost_condat)
axs[3].set_title("cost graph")
print(f"Time taken = {time_taken}")
plt.plot(x_condat)
# plt.axis('off')