 
import numpy as np
from sklearn.decomposition import SparseCoder
from sklearn.utils.fixes import np_version, parse_version
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties


def ricker_function(resolution, center, width):
    """离散二次采样Ricker小波,也称为墨西哥帽小波，
       一种小波函数，常用于小波变换"""
    x = np.linspace(0, resolution - 1, resolution)
    x = (
        (2 / (np.sqrt(3 * width) * np.pi ** 0.25))
        * (1 - (x - center) ** 2 / width ** 2)
        * np.exp(-((x - center) ** 2) / (2 * width ** 2))
    )
    return x


def ricker_matrix(width, resolution, n_components):
    """Ricker小波(墨西哥帽小波)的字典"""
    centers = np.linspace(0, resolution - 1, n_components)
    D = np.empty((n_components, resolution))
    for i, center in enumerate(centers):
        D[i] = ricker_function(resolution, center, width)
    D /= np.sqrt(np.sum(D ** 2, axis=1))[:, np.newaxis]
    return D


resolution = 1024
subsampling = 3  # 二次抽样因子
width = 100
n_components = resolution // subsampling

# 计算小波字典
D_fixed = ricker_matrix(width=width, resolution=resolution, n_components=n_components)
D_multi = np.r_[
    tuple(
        ricker_matrix(width=w, resolution=resolution, n_components=n_components // 5)
        for w in (10, 50, 100, 500, 1000)
    )
]

# 生成一个信号
y = np.linspace(0, resolution - 1, resolution)
first_quarter = y < resolution / 4
y[first_quarter] = 3.0
y[np.logical_not(first_quarter)] = -1.0

# 列出不同的稀疏编码方法，格式：
# （标题(title)，转换算法(transform_algorithm)，转换alpha(transform_alpha),
#   转换中非零系数个数(transform_n_nozero_coefs), 颜色(color)）
estimators = [
    ("正交匹配追踪法", "omp", None, 15, "navy"),
    ("最小角Lasso回归", "lasso_lars", 2, None, "turquoise"),
]

# 考虑未来兼容性
# Avoid FutureWarning about default value change when numpy >= 1.14
lstsq_rcond = None if np_version >= parse_version("1.14") else -1

fig = plt.figure(figsize=(13, 6))
fig.canvas.manager.set_window_title("稀疏编码SparseCoder")  # Matplotlib >= 3.4
#fig.canvas.set_window_title("稀疏编码SparseCoder")  # Matplotlib < 3.4
font = FontProperties(fname="C:\\Windows\\Fonts\\SimHei.ttf")  # , size=16

for subplot, (D, title) in enumerate(
    zip( (D_fixed, D_multi), ("固定宽度", "多种宽度")) ):

    plt.subplot(1, 2, subplot + 1)
    plt.title("基于%s字典的稀疏编码" % title, font=font)
    plt.plot(y, lw=2, linestyle="--", label="原始信号")
    # 实施小波近似
    for title, algo, alpha, n_nonzero, color in estimators:
        coder = SparseCoder(
            dictionary=D,
            transform_n_nonzero_coefs=n_nonzero,
            transform_alpha=alpha,
            transform_algorithm=algo,
        )
        x = coder.transform(y.reshape(1, -1))
        density = len(np.flatnonzero(x))
        x = np.ravel(np.dot(x, D))
        squared_error = np.sum((y - x) ** 2)
        plt.plot( x, color=color, lw=2,
            label="%s: %s个非零系数,\n%.2f误差" % (title, density, squared_error),
        )

    # 阈值方式去偏差
    coder = SparseCoder(
        dictionary=D, transform_algorithm="threshold", transform_alpha=20
    )
    x = coder.transform(y.reshape(1, -1))
    _, idx = np.where(x != 0)
    x[0, idx], _, _, _ = np.linalg.lstsq(D[idx, :].T, y, rcond=lstsq_rcond)
    x = np.ravel(np.dot(x, D))
    squared_error = np.sum((y - x) ** 2)
    plt.plot(x, color="darkorange", lw=2,
        label="阈值法(去偏差):\n%d个非零系数, %.2f误差" % (len(idx), squared_error),
    )
    plt.axis("tight")
    plt.legend(shadow=False, loc="best", prop=font)

plt.subplots_adjust(0.04, 0.07, 0.97, 0.90, 0.09, 0.2)
plt.show()
 
