Source code for nabqr.visualization

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.lines import Line2D
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.cm import ScalarMappable


[docs] def visualize_results(y_hat, q_hat, ylabel): """Create a visualization of prediction intervals with actual values. Parameters ---------- y_hat : numpy.ndarray Actual observed values q_hat : numpy.ndarray Predicted quantiles for different probability levels ylabel : str Label for the y-axis Returns ------- None Saves the plot as 'TEST_NABQR_taqr_pi_plot.pdf' and displays it Notes ----- - Creates a filled plot showing prediction intervals using a blue gradient - Overlays actual values as a black line - Automatically adjusts x-axis date formatting """ y_hat = pd.Series(np.array(y_hat).flatten()) try: taqr_results_corrected_plot = pd.DataFrame(np.array(q_hat).T, index=y_hat.index) except: taqr_results_corrected_plot = pd.DataFrame(np.array(q_hat), index=y_hat.index) m = taqr_results_corrected_plot.shape[1] # Ensemble size # Define the color gradient from dark blue to light cyan colors = [ (173 / 255, 217 / 255, 229 / 255), (19 / 255, 25 / 255, 148 / 255), (173 / 255, 217 / 255, 229 / 255), ] cmap = LinearSegmentedColormap.from_list("blue_to_cyan", colors, N=100) norm = plt.Normalize(vmin=0, vmax=m - 2) # Normalize for the ensemble size sm = ScalarMappable(cmap=cmap, norm=norm) plt.figure(figsize=(6, 4)) for i in range(m - 1): color = sm.to_rgba(i) plt.fill_between( taqr_results_corrected_plot.index, taqr_results_corrected_plot.iloc[:, i], taqr_results_corrected_plot.iloc[:, i + 1], color=color, alpha=1, ) plt.plot(y_hat.index, y_hat, color="white", linewidth=3) # White outline plt.plot(y_hat, color="black", label="Actuals", linewidth=1) # Actual line plt.xlim(y_hat.index[0], y_hat.index[-1]) # Create legend elements line = Line2D([0], [0], color="black", lw=2, label="Actuals") contour = Line2D( [0], [0], color=sm.to_rgba(m // 2), lw=5, alpha=0.9, label="Prediction Interval" ) plt.legend(handles=[line, contour]) plt.xlabel("Time") plt.ylabel(ylabel) # Configure date formatting on x-axis locator = mdates.AutoDateLocator(minticks=6, maxticks=8) formatter = mdates.ConciseDateFormatter(locator) ax = plt.gca() ax.xaxis.set_major_locator(locator) ax.xaxis.set_major_formatter(formatter) plt.tight_layout() plt.savefig(f"{ylabel}_taqr_pi_plot.pdf") plt.show()