From 931330eb4a8a19cb791d2b2606196e81c886ab38 Mon Sep 17 00:00:00 2001
From: Stefano Serafin <serafin@jet01.img.univie.ac.at>
Date: Fri, 24 Jan 2025 10:08:42 +0100
Subject: [PATCH] fixed plot_CBL_identifiability to deal with randomly sorted
 observations

---
 graphics.py | 49 ++++++++++++++++++++++++++++++++++++++++++++-----
 1 file changed, 44 insertions(+), 5 deletions(-)

diff --git a/graphics.py b/graphics.py
index 7afdc5b..d30a0b8 100644
--- a/graphics.py
+++ b/graphics.py
@@ -5,6 +5,8 @@ from matplotlib import pyplot as p
 import matplotlib.colors as colors
 from ENDA import experiment
 
+is_sorted = lambda a: np.all(a[:-1] <= a[1:])
+
 def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
     #https://stackoverflow.com/questions/18926031/how-to-extract-a-subset-of-a-colormap-as-a-new-colormap-in-matplotlib
     new_cmap = colors.LinearSegmentedColormap.from_list(
@@ -27,7 +29,10 @@ def plot_CBL_assimilation(exp,figtitle,which_cycle=0,ax=None):
     o = exp.observations[which_cycle]
     b = exp.da.backgrounds[which_cycle,:,:]
     a = exp.da.analyses[which_cycle,:,:]
-    ocoords = exp.obs_coordinates
+    if exp.obs_coordinates.ndim == 1:
+        ocoords = exp.obs_coordinates
+    elif exp.obs_coordinates.ndim == 2:
+        ocoords = exp.obs_coordinates[which_cycle]
     nz = exp.nr.nz
     zt = exp.nr.zt 
 
@@ -251,10 +256,37 @@ def plot_CBL_identifiability(cbl,sigma_o,ax=None):
     npar = cbl.nr.parameter_number
     ntimes = cbl.da.ncycles
     times = cbl.da.time
-    zobs = cbl.da.obs_coordinates
+
+    # Expand times array to 2D (matching the number of observations)
+    times = times[:,None]+np.zeros((ntimes,nobs))
+
+    # Read model equivalents and state vector elements
     theta = cbl.da.model_equivalents  # ntim*nobs*nens
     pfac = cbl.da.backgrounds[:,-npar,:] # ntim*nx*nens
     pfac = cbl.nr.parameter_transform[0](pfac,kind='inv')
+
+    # Read observation coordinates
+    if cbl.da.obs_coordinates.ndim == 1:
+        zobs = cbl.da.obs_coordinates
+        zobs = zobs[None,:]+np.zeros(times.shape)
+    elif cbl.da.obs_coordinates.ndim == 2:
+        zobs = cbl.da.obs_coordinates
+
+    # At each time, check if the coordinate array is sorted
+    # if not, resort both coordinates and model equivalents
+    for k in range(ntimes):
+        if is_sorted(zobs[k,:]):
+            pass
+        else:
+            indices = np.argsort(zobs[k,:])
+            zobs[k,:] = zobs[k,indices]
+            theta[k,:] = theta[k,indices]
+            cbl.da.innov[k,:] = cbl.da.innov[k,indices] # ntim*nobs
+            cbl.da.cov_yy[k,:] = cbl.da.cov_yy[k,indices] # ntim*nobs
+            cbl.da.cov_pp[k,:] = cbl.da.cov_pp[k,indices] # ntim*nobs(*npar)
+            cbl.da.cov_py[k,:] = cbl.da.cov_py[k,indices] # ntim*nobs(*npar)
+            cbl.da.increments[k,:] = cbl.da.increments[k,indices] # ntim*nobs(*npar)
+
     if hasattr(cbl.da,'cov_py'):
         sigma2_yb = cbl.da.cov_yy # ntim*nobs
         sigma2_p = np.squeeze(cbl.da.cov_pp) # ntim*nobs(*npar)
@@ -273,6 +305,10 @@ def plot_CBL_identifiability(cbl,sigma_o,ax=None):
     # For plotting
     contours = np.linspace(cbl.nr.theta_0,cbl.nr.theta_0+cbl.nr.gamma*zmax,ncont)
 
+    # Before plotting, transpose the 2D arrays of times and coordinates
+    times = times.T
+    zobs = zobs.T
+
     # Make plots
     if ax is None:
         pass
@@ -381,7 +417,7 @@ def plot_p(p_factors,theta_profiles,zt,figtitle,ax=None):
             ax.plot(theta_profiles[i],zt,label='$p=%4.1f$'%p_factors[i])
         return ax
 
-def plot_spread(cbl,ax=None):
+def plot_spread(cbl,plot='spread',ax=None):
 
     # Read relevant dimensions
     times = cbl.history['0000']['time']
@@ -396,7 +432,10 @@ def plot_spread(cbl,ax=None):
         theta[k,:,:]= cbl.history['%04u'%k]['theta']
 
     # Compute spread
-    spread = theta.std(axis=0)
+    if plot=='spread':
+        bkgd = theta.std(axis=0)
+    elif plot=='mean':
+        bkgd = theta.mean(axis=0)
 
     # Plot
     zmax = 2000
@@ -404,7 +443,7 @@ def plot_spread(cbl,ax=None):
 
     if ax is not None:
         # Make plots
-        c = ax.pcolormesh(times/3600.,zt,spread)
+        c = ax.pcolormesh(times/3600.,zt,bkgd)
         ax.contour(times/3600.,zt,theta.mean(axis=0),
                     np.linspace(cbl.theta_0,cbl.theta_0+cbl.gamma*zmax,ncont),
                     colors='white',
-- 
GitLab