From 03401fd2cb36d0fdbb5e54fcf18f472c3bffbccc Mon Sep 17 00:00:00 2001
From: Stefano Serafin <serafin@jet02.img.univie.ac.at>
Date: Thu, 1 Aug 2024 16:48:40 +0200
Subject: [PATCH] added storage of covariances/innovations/increments at
 runtime; some clarifying comments on EAKF algorithm

---
 ENDA.py | 101 +++++++++++++++++++++++++++++++++++++++++++++++---------
 1 file changed, 85 insertions(+), 16 deletions(-)

diff --git a/ENDA.py b/ENDA.py
index d3c13a0..3fd5c23 100644
--- a/ENDA.py
+++ b/ENDA.py
@@ -90,7 +90,7 @@ def compute_distances(xcoords,ocoords):
     return dist
 
 def eakf(xb,o,var_o,xcoord,ocoord,dist,\
-                    localization_cutoff=np.infty):
+                    localization_cutoff=np.infty,return_covariances_increments_and_innovations=False):
 
     def increments_in_obs_space(y_p,o,var_o):
     
@@ -102,14 +102,21 @@ def eakf(xb,o,var_o,xcoord,ocoord,dist,\
         weight_o = 1./var_o
         weight_p = 1./var_p
     
-        # Compute ensemble mean, ensemble variance and weight of posterior
+        # Compute weight of posterior, then ensemble variance of posterior
         weight_u =  weight_o + weight_p
         var_u = 1./weight_u
-        mean_u = (mean_p*weight_p + o*weight_o)*var_u
+
+        # Shift the mean from mean_p to mean_u
+        mean_u = (mean_p*weight_p + o*weight_o)*var_u 
         
-        # Compute analysis increments in observation space
+        # Compute ensemble of analysis increments in observation space.
+        # There are three terms in delta_y; the first two are the posterior
+        # ensemble state (shifted mean + adjusted perturbations; e.g. Eq 19 in Anderson 2003);
+        # the last one (y_p) is the prior state.
+        # Their difference is the increment in observation space.
         alpha = (var_u*weight_p)**0.5
-        delta_y = mean_u*np.ones(nens) - alpha*mean_p*np.ones(nens) + (alpha-1)*y_p
+        #delta_y = mean_u*np.ones(nens) + (alpha-1)*y_p - alpha*mean_p*np.ones(nens) 
+        delta_y = mean_u*np.ones(nens) + alpha*(y_p - mean_p*np.ones(nens)) - y_p
         
         return delta_y,weight_p
     
@@ -130,6 +137,18 @@ def eakf(xb,o,var_o,xcoord,ocoord,dist,\
 
     # Preallocate array for model equivalents
     yp = np.zeros((nobs,nens))+np.nan
+    increments = np.zeros((nvar,nobs))+np.nan
+
+    # Preallocate array to save variances and covariances
+    # (they are adjusted after ingesting every single observation).
+    # This is not needed by the algorithm, but may be useful
+    # for analysis purposes.
+    if return_covariances_increments_and_innovations:
+        cov_xx = np.zeros((nobs,nvar))+np.nan
+        cov_xy = np.zeros((nobs,nvar))+np.nan
+        cov_yy = np.zeros((nobs))+np.nan
+        innovations_mean = np.zeros((nobs))+np.nan
+        increments_mean = np.zeros((nobs,nvar))+np.nan
     
     # Process observations sequentially
     for i in range(nobs):
@@ -139,16 +158,35 @@ def eakf(xb,o,var_o,xcoord,ocoord,dist,\
         for j in range(nens):
             yp[i,j] = observation_operator(xu[:,j],xcoord,ocoord[i])
 
+        # Store innovation wrt to ensemble mean
+        if return_covariances_increments_and_innovations:
+            innovations_mean[i] = o[i]-yp[i].mean()
+
         # Determine analysis increments in observation space
+        # weight_p is a scalar
+        # delta_y has dimension nens
         delta_y,weight_p = increments_in_obs_space(yp[i,:],o[i],var_o[i])
 
-        # Update state variables in each ensemble member using covariances
+        # Update state variables in each ensemble member using ensemble covariances
+        # The default in np.cov is N-1 degrees of freedom.
         for k in range(nvar):
-            xu[k,:] = xu[k,:] + np.cov(xu[k,:],yp[i,:])[0,1]*weight_p*delta_y*\
-                np.interp(dist[k,i],ds,loc) # covariance localization
+            cov = np.cov(xu[k,:],yp[i,:]) # scalar
+            # Compute analysis increments (=adjustments of ensemble perturbations)
+            increments = cov[0,1]*weight_p*delta_y*\
+                np.interp(dist[k,i],ds,loc) # nens values; second row is covariance localization
+            xu[k,:] = xu[k,:] + increments
+            if return_covariances_increments_and_innovations:
+                cov_xx[i,k] = cov[0,0]
+                cov_xy[i,k] = cov[0,1]
+                cov_yy[i] = cov[1,1]
+                increments_mean[i,k] = increments.mean()
+
 
     # Return
-    return xu,yp
+    if return_covariances_increments_and_innovations:
+        return xu,yp,cov_xx,cov_xy,cov_yy,innovations_mean,increments_mean
+    else:
+        return xu,yp
 
 def letkf(xb,o_global,var_o,xcoord,ocoord,dist,\
                      localization_cutoff=np.infty,inflation_rho=1.00):
@@ -254,6 +292,7 @@ class cycle:
         initial_state_vector = nr.x0
         xcoord = nr.state_coordinates
         nobs = observations.shape[1]
+        npar = cbl_settings['parameter_number']
 
         # Special case: if you re-use a nature run in a experiment without
         # parameter estimation, the extra elements in the state vector must
@@ -269,7 +308,13 @@ class cycle:
         initial_state     = np.zeros((1,state_vector_length,nens)) + np.nan
         backgrounds       = np.zeros((ncycles,state_vector_length,nens)) + np.nan
         analyses          = np.zeros((ncycles,state_vector_length,nens)) + np.nan
-        model_equivalents = np.zeros((ncycles,nobs,nens)) + np.nan        
+        model_equivalents = np.zeros((ncycles,nobs,nens)) + np.nan
+        if cbl_settings['do_parameter_estimation'] and cbl_settings['return_covariances_increments_and_innovations']:
+            increments = np.zeros((ncycles,nobs,npar)) + np.nan
+            cov_py = np.zeros((ncycles,nobs,npar)) + np.nan
+            cov_pp = np.zeros((ncycles,nobs,npar)) + np.nan
+            cov_yy = np.zeros((ncycles,nobs)) + np.nan
+            innov = np.zeros((ncycles,nobs)) + np.nan
 
         # Turn inflation coefficients into an array
         inflation_coefficients_rtps = np.ones(state_vector_length)*inflation_rtps_alpha
@@ -278,6 +323,11 @@ class cycle:
         dist = compute_distances(xcoord,ocoord)
 
         # Replace any NaN distances with 0
+        # NaN distances only occur when the state is augmented with global
+        # parameters, whose location is undefined. Setting the distance to
+        # global parameters equal to zero ensures that the covariance
+        # matrix for global parameters is identity, as required by
+        # global updating.
         np.nan_to_num(dist,copy=False)
 
         # Initialize ensemble
@@ -325,12 +375,25 @@ class cycle:
 
             # Assimilate
             if FILTER == 'EAKF':
-                updates,model_equivalents[k,:,:] = eakf(backgrounds[k,:,:],observations[k,:],
-                            obs_error_sdev_assimilate,
-                            xcoord,
-                            ocoord,
-                            dist,
-                            localization_cutoff=localization_cutoff)
+                if cbl_settings['do_parameter_estimation'] and cbl_settings['return_covariances_increments_and_innovations']:
+                    updates,model_equivalents[k,:,:],cov_xx,cov_xy,cov_yy_dum,innov_dum,increm_dum = eakf(backgrounds[k,:,:],observations[k,:],
+                                obs_error_sdev_assimilate,
+                                xcoord,
+                                ocoord,
+                                dist,
+                                localization_cutoff=localization_cutoff,return_covariances_increments_and_innovations=True)
+                    increments[k,:,-npar] = increm_dum[:,-npar]
+                    cov_pp[k,:,-npar] = cov_xx[:,-npar]
+                    cov_py[k,:,-npar] = cov_xy[:,-npar]
+                    cov_yy[k,:] = cov_yy_dum
+                    innov[k,:] = innov_dum
+                else:
+                    updates,model_equivalents[k,:,:] = eakf(backgrounds[k,:,:],observations[k,:],
+                                obs_error_sdev_assimilate,
+                                xcoord,
+                                ocoord,
+                                dist,
+                                localization_cutoff=localization_cutoff)
             elif FILTER == 'LETKF':
                 updates,model_equivalents[k,:,:] = letkf(backgrounds[k,:,:],observations[k,:],
                             obs_error_sdev_assimilate,
@@ -376,6 +439,12 @@ class cycle:
         self.zt = nr.zt
         self.nens = nens
         self.initial_perturbed_parameters = da.initial_perturbed_parameters
+        if cbl_settings['do_parameter_estimation'] and cbl_settings['return_covariances_increments_and_innovations']:
+            self.cov_py = cov_py
+            self.cov_pp = cov_pp
+            self.cov_yy = cov_yy
+            self.innov = innov
+            self.increments = increments
 
 class experiment:
     def __init__(self,settings):
-- 
GitLab