From 68e599a420b26ea132d21ca1660f2d0122e086fc Mon Sep 17 00:00:00 2001
From: Philipp Griewank <philipp.griewank@univie.ac.at>
Date: Thu, 7 Dec 2023 16:41:17 +0100
Subject: [PATCH] Fixed a bug for LETKF weights with no obs

The weight matrix where no obs are present was set to zero instead of
to a diagonal.
---
 da_functions.py | 55 ++++++++++++++++++++++++++++++++++---------------
 1 file changed, 38 insertions(+), 17 deletions(-)

diff --git a/da_functions.py b/da_functions.py
index 8b9ac93..b062a10 100644
--- a/da_functions.py
+++ b/da_functions.py
@@ -558,7 +558,7 @@ def run_linear_advection_KF_22(m_const,da_const,sat_operator):
             
             if da_const['method'] == 'LETKF':
                 # analysis,bla = LETKF_analysis_22(background,obs,obs_sat,m_const,da_const,sat_operator)
-                analysis,bla = LETKF_analysis_23(background,obs,obs_sat,m_const,da_const,sat_operator)
+                analysis,bla,W_a = LETKF_analysis_23(background,obs,obs_sat,m_const,da_const,sat_operator)
                 states[j]["an"] = states[j]["an"]+[analysis]
             
             if da_const['method'] == 'sqEnKF':
@@ -2139,15 +2139,21 @@ def LETKF_analysis_23(bg,obs,obs_sat,m_const,da_const,sat_operator):
         x_a = x_a.T
 
     else:
-        x_a, x_ol_a=LETKF_numba_loop(m_const['x_grid'],da_const['loc_length'],obs_loc,x_ol_b,x_b,R,Y_b,L,delta_y,X_b,m_const['nx'],m_const['dx'])
-        
-    return x_a, x_ol_a
-
-@jit
+        # print('making sure that the fix is being used')
+        # print('using the old version')
+        # print('trying the third version')
+        # print('sigh, what is not working now?')
+        # bla, blub,bleurgh=LETKF_numba_loop(m_const['x_grid'],da_const['loc_length'],obs_loc,x_ol_b,x_b,R,Y_b,L,delta_y,X_b,m_const['nx'],m_const['dx'])
+        # print('this is annoying, I hate numba')
+        x_a, x_ol_a,W_a=LETKF_numba_loop(m_const['x_grid'],da_const['loc_length'],obs_loc,x_ol_b,x_b,R,Y_b,L,delta_y,X_b,m_const['nx'],m_const['dx'])
+    return x_a, x_ol_a,W_a
+
+# @jit
 def LETKF_numba_loop(x_grid,loc_length,obs_loc,x_ol_b,x_b,R,Y_b,L,delta_y,X_b,N,dx):
     x_grid_neg = -x_grid[::-1]-dx
     x_grid_ext = x_grid+dx*N
     x_a_loc = np.zeros((N,L))
+    W_a_loc = np.zeros((N,L,L))
     x_ol_a_loc = np.zeros((N))
     
     for g in range(N):
@@ -2173,6 +2179,7 @@ def LETKF_numba_loop(x_grid,loc_length,obs_loc,x_ol_b,x_b,R,Y_b,L,delta_y,X_b,N,
             #If no observations are within the double gc radius no need to bother computing things
             x_ol_a_loc[g] = x_ol_b[g]
             x_a_loc[g,:]  = x_b[g,:]
+            W_a_loc[g,:,:] = np.identity(L)
         else:
             R_inv =  np.linalg.inv(R) * np.diag(gc_loc)
 
@@ -2186,12 +2193,21 @@ def LETKF_numba_loop(x_grid,loc_length,obs_loc,x_ol_b,x_b,R,Y_b,L,delta_y,X_b,N,
             # Trying to replace sqrtm with numpy eigenvectors so jit can work.
             evalues, evectors = np.linalg.eigh((L-1.)*P_tilde_a)
             W_a = evectors * np.sqrt(evalues) @ np.linalg.inv(evectors)
-            w_a = W_a+w_ol_a
-            x_a =  np.dot(X_b[g,:],w_a).T+ x_ol_a
-            x_ol_a_loc[g] = x_ol_a
+            # w_a = W_a+w_ol_a
+            # x_a =  np.dot(X_b[g,:],w_a).T+ x_ol_a #this seems to be wrong
+            # x_a =  np.dot(X_b[g,:],w_a).T+ x_ol_b[g] # this should be right but isn't returning the right stuff
+            # This fix somehow makes things worse. I'll try something slightly different
+            # Well try Xa = Xb Wa, and then add x_ol_a to that
+            X_a = np.dot(X_b[g,:],W_a)
+            x_a = X_a+x_ol_a
+            x_ol_a_loc[g]   = x_ol_a
             x_a_loc[g,:]    = x_a
+            W_a_loc[g,:,:]  = W_a
+            # print(W_a_loc.shape)
+            # print(x_a_loc.shape)
+            # print(x_ol_a_loc.shape)
 
-    return x_a_loc,x_ol_a_loc
+    return x_a_loc,x_ol_a_loc, W_a_loc
 
 
 def state_to_observation_space(X,m_const,da_const,sat_operator):
@@ -2237,7 +2253,9 @@ def state_to_observation_space(X,m_const,da_const,sat_operator):
 def Kalman_gain_observation_deviations(bg,m_const,da_const,sat_operator):
     """
     Calculates Kalman gain, but using YYT instead of HBH, and XYT intead of BHT,
-    This is important for satellite data because we don't need a linearized version of the satellite operator 
+    This is important for satellite data because we don't need a linearized version of the satellite operator
+
+    Important fix! Localization was applied incorrectly. In stead of localizing BH the whole Kalman gain was localized :( 
     """
     n_obs_h  =len(da_const["obs_loc"])
     n_obs_sat =len(da_const["obs_loc_sat"])
@@ -2255,7 +2273,8 @@ def Kalman_gain_observation_deviations(bg,m_const,da_const,sat_operator):
     if da_const['loc']:
         L_obs, L_obs_state = localization_matrices_observation_space(m_const,da_const)
         YYlocR_inv = np.linalg.inv(L_obs*np.dot(dY_b,dY_b.T)/(L-1)+R)
-        K = L_obs_state*np.dot(X_b,np.dot(dY_b.T,YYlocR_inv))/(L-1)
+        # K = L_obs_state*np.dot(X_b,np.dot(dY_b.T,YYlocR_inv))/(L-1)
+        K = np.dot(L_obs_state*np.dot(X_b,dY_b.T),YYlocR_inv)/(L-1)
     else:
         YYR_inv = np.linalg.inv(np.dot(dY_b,dY_b.T)/(L-1)+R)
         K = np.dot(X_b,np.dot(dY_b.T,YYR_inv))/(L-1)
@@ -2358,6 +2377,11 @@ def single_step_analysis_forecast_22(background,truth,da_const,m_const,sat_opera
     """
     obs, obs_sat = generate_obs_22_single(truth,m_const,da_const,sat_operator,obs_seed) 
     
+    """
+    create dictionary to store this single step 
+    """
+    quad_state = {}
+    
     #Getting the analysis
     if da_const['method'] == 'EnKF':
         np.random.seed(obs_seed+100000)
@@ -2365,7 +2389,8 @@ def single_step_analysis_forecast_22(background,truth,da_const,m_const,sat_opera
     
     if da_const['method'] == 'LETKF':
         # an,bla = LETKF_analysis_22(background,obs,obs_sat,m_const,da_const,sat_operator)
-        an,bla = LETKF_analysis_23(background,obs,obs_sat,m_const,da_const,sat_operator)
+        an,bla,W_a = LETKF_analysis_23(background,obs,obs_sat,m_const,da_const,sat_operator)
+        quad_state['W_a'] = W_a
     
     if da_const['method'] == 'sqEnKF':
         an = sqEnKF_analysis_22(background,obs,obs_sat,da_const,m_const,sat_operator)
@@ -2388,10 +2413,6 @@ def single_step_analysis_forecast_22(background,truth,da_const,m_const,sat_opera
         fc[:,i]    = linear_advection_model(an[:,i],u_ens,dhdt_ens,m_const["dx"],da_const["dt"],da_const["nt"])
 
     
-    """
-    create dictionary to store this single step 
-    """
-    quad_state = {}
     quad_state['bg'] = background
     quad_state['an'] = an
     quad_state['bf'] = bf
-- 
GitLab