From 4722007500e5265ac3558e6b36b8cf318de359b1 Mon Sep 17 00:00:00 2001
From: Philipp Griewank <philipp.griewank@uni-koeln.de>
Date: Thu, 28 Oct 2021 14:04:34 +0200
Subject: [PATCH] Messy update

The functions now included are the backbone of version used for the initial paper draft. This was unceremoniasly dumped here before I begin trying to localize the implicit sensitivity approach.
---
 da_functions.py    | 439 ++++++++++++++++++++++++++++++++++++++++++++-
 model_functions.py |   3 +
 plot_functions.py  |  10 +-
 3 files changed, 445 insertions(+), 7 deletions(-)

diff --git a/da_functions.py b/da_functions.py
index d6926c8..9194800 100644
--- a/da_functions.py
+++ b/da_functions.py
@@ -251,6 +251,21 @@ def get_analysis(bg,obs,K,H,da_const):
     an =  bg + np.dot(K,obs_pert-bg_obs)    
     return an
 
+def get_analysis_v2(bg, obs, K, H, obs_error_vec):
+    """
+    Computes analysis: an = bg + K(H*bg-obs_pert), where obs_pert are perturbed observations. v2 comes with a perturb_obs vector to enable differing observation erros. 
+    Mostly made to deal with the shallow water model, but I guess it could be used to make some fun tests. 
+    """
+    #print(obs)
+    #print(obs_error_vec)
+    obs_pert = np.dot(H,obs+np.random.normal(0,obs_error_vec,len(obs)))
+    bg_obs = np.dot(H,bg)
+    #obs_orig =np.dot(H,obs) 
+    #print('obs_orig-bg_obs',np.sum(obs_orig-bg_obs),'obs_pert-bg_obs',np.sum(obs_pert-bg_obs),'update/np.sum(x)',np.sum(np.dot(K,obs_pert-bg_obs) )/np.sum(bg),'max(K)',np.max(K))
+    an =  bg + np.dot(K,obs_pert-bg_obs)    
+    return an
+
+
 def run_linear_advection_KF(m_const,da_const):
     """
     The heart and soul of the whole linear advection EnKF filter. 
@@ -961,7 +976,7 @@ def LETKF_analysis(bg,obs,m_const,da_const):
 
 
 
-def L2_regularized_inversion(A, b, alpha_init=0.1,alpha=None,mismatch_threshold=0.05):
+def L2_regularized_inversion(A, b, alpha_init=1,alpha=None,mismatch_threshold=0.05):
     """Instead of solving for Ax=b, which isn't possible if A is not invertible, the regularization instead minimizes ||Ax-b||^2 + || alpha x ||^2.
     The solution is unique and well defined: x = (AA.T + alpha*alpha I )^-1 A.T b.
     
@@ -991,7 +1006,7 @@ def L2_regularized_inversion(A, b, alpha_init=0.1,alpha=None,mismatch_threshold=
         #x = np.linalg.inv(A.T.dot(A) + alpha**2 *I).dot(A.T).dot(b)
         x = np.linalg.solve((A.T.dot(A) + alpha**2 *I),(A.T).dot(b))#,rcond=-1)
         while  np.sum(np.abs(A.dot(x)-b))/np.sum(np.abs(b))> mismatch_threshold:
-            alpha = alpha/2
+            alpha = alpha/2.
             x = np.linalg.solve((A.T.dot(A) + alpha**2 *I),(A.T).dot(b))#,rcond=-1)
             #x = np.linalg.inv(A.T.dot(A) + alpha**2 *I).dot(A.T).dot(b)
             #print('reducing regularization:',alpha,np.sum(np.abs(A.dot(x)-b))/np.sum(np.abs(b)))
@@ -1004,5 +1019,425 @@ def L2_regularized_inversion(A, b, alpha_init=0.1,alpha=None,mismatch_threshold=
 
 
 
+def single_step_analysis_forecast_v2(background,truth,da_const,m_const,model_seed=0,obs_seed=0):
+    """
+    The idea is that this should be merged into single_step_analysis_forecast, for a uniform processing chain for the paper. So all SWM changes will be included in flags. 
+    
+    What is a bit annoying is that the original was built around the states dictionary, I am not really a fan of the states setup, will switch to background and truth matrixes.
+
+    for now the perturbed observations in the EnKF for the SWM are the same as the actual errors. 
+    """
+    
+    """
+    constant matrices that follow from previously defined constants
+    """
+    if m_const['model'] == 'LA':
+        H = np.identity(m_const["nx"])[da_const["obs_loc"],:]                              # Observation operator    
+        R = da_const["used_std_obs"]**2*np.identity(len(da_const["obs_loc"]))           # Observation error corvariance matrix
+    if m_const['model'] == 'SWM':
+        H = np.identity(m_const["nx"]*3)[da_const["obs_loc"],:]                              # Observation operator    
+        obs_error_vec = np.ones(len(da_const["obs_loc"]))*da_const['used_h_std_obs']
+        obs_error_vec[obs_error_vec<m_const['nx']] = da_const['used_u_std_obs']
+        obs_error_vec[obs_error_vec>2*m_const['nx']] = da_const['used_r_std_obs']
+        R = np.diag(obs_error_vec**2)
+    
+
+    #Construct localization matrix C if loc!=None
+    if da_const["loc"]:
+        C = loc_matrix(da_const,m_const)
+    else:
+        C=np.ones([m_const["nx"],m_const["nx"]])
+    if m_const['model'] == 'SWM':
+        C = np.hstack([C,C,C])
+        C = np.vstack([C,C,C])
+
+    """
+    Generate obs
+    """
+    #make obs by adding some noise, a fixed seed should change over time so that the differences are not always the same for each measurement location
+    np.random.seed(obs_seed)
+    
+    if m_const['model']=='LA': obs = truth + np.random.normal(0,da_const["True_std_obs"],m_const["nx"])
+    if m_const['model']=='SWM': 
+        u_obs_noise =np.random.normal(0,da_const["True_u_std_obs"],m_const["nx"]) 
+        h_obs_noise =np.random.normal(0,da_const["True_h_std_obs"],m_const["nx"]) 
+        r_obs_noise =np.random.normal(0,da_const["True_r_std_obs"],m_const["nx"]) 
+        obs = truth + np.hstack([u_obs_noise,h_obs_noise,r_obs_noise])
+    
+    #Generate new truth constants and integrate in time
+    if m_const['model']=='LA':
+        if da_const["fixed_seed"]==True: np.random.seed(model_seed)
+        u_truth = np.random.normal(m_const["u_ref"],m_const["u_std_truth"])
+        dhdt_truth = np.random.normal(m_const["dhdt_ref"],m_const["dhdt_std_truth"])
+        truth_forecast = linear_advection_model(truth,u_truth,dhdt_truth,m_const["dx"],da_const["dt"],da_const["nt"])
+    
+    if m_const['model']=='SWM':
+        #I assume we are going to have to add an empty second dimension here
+        truth_double  = np.vstack([truth,truth]).T
+        truth_double_forecast  = shallow_water(truth_double,m_const)
+        truth_forecast = truth_double_forecast[:,0]
+    
+
+    """
+    EnKF
+    """
+    if da_const['method'] == 'EnKF':
+        # Compute the background error covariance matrix
+        P = np.cov(background)*C
+    
+        # define relative weights of observations and background in terms of the Kalman Gain matrix of size
+        K = KalmanGain(P, R, H)
+    
+        # Compute the analysis for each ensemble members
+        an = np.zeros_like(background)
+        if m_const['model'] == 'LA': obs_pert_vec = np.ones(len(obs))*da_const['pert_std_obs']
+        if m_const['model'] == 'SWM': 
+            obs_pert_vec = np.hstack([np.ones(m_const['nx'])*da_const['pert_u_std_obs'],
+                                      np.ones(m_const['nx'])*da_const['pert_h_std_obs'],
+                                      np.ones(m_const['nx'])*da_const['pert_r_std_obs']]) 
+        for i in range(da_const["nens"]):
+            an[:,i] = get_analysis_v2(background[:,i],obs,K,H,obs_pert_vec)
+    """
+    LETKF
+    """
+    if da_const['method'] == 'LETKF':
+        an,bla = LETKF_analysis(background,obs,m_const,da_const)   
+    
+
+    """
+    Predict blind forecast and forecast
+    """
+    bf = np.zeros_like(background)
+    fc = np.zeros_like(background)
+    
+    if m_const['model'] == 'LA':
+        for i in range(da_const["nens"]):
+            if da_const["fixed_seed"]==True: np.random.seed(i+model_seed*da_const["nens"])
+            u_ens      = np.random.normal(m_const["u_ref"],da_const["u_std_ens"])
+            dhdt_ens   = np.random.normal(m_const["dhdt_ref"],da_const["dhdt_std_ens"])
+            bf[:,i]    = linear_advection_model(background[:,i],u_ens,dhdt_ens,m_const["dx"],da_const["dt"],da_const["nt"])
+            fc[:,i]    = linear_advection_model(an[:,i],u_ens,dhdt_ens,m_const["dx"],da_const["dt"],da_const["nt"])
+    if m_const['model'] == 'SWM':
+        bf  = shallow_water(background,m_const)
+        fc  = shallow_water(an,m_const)
+
+    
+    """
+    create dictionary to store this single step 
+    """
+    quad_state = {}
+    quad_state['bg'] = background
+    quad_state['an'] = an
+    quad_state['bf'] = bf
+    quad_state['fc'] = fc
+    quad_state['tr_fc'] = truth_forecast
+    quad_state['tr_bg'] = truth
+    quad_state['obs'] = obs
+    return quad_state
+
+
+
+def vr_reloaded(background,truth,m_const,da_const,func_J=sum_mid_tri,
+                reduc = 1,reg_flag=1,
+                quad_state = None,dJdx_inv=None,alpha=None,mismatch_threshold=0.1,
+                iterative_flag=1,explicit_sens_flag = 1,exp_num=0):
+
+    """
+    New version takes only the background and truth, then caculates the quad, finaly calculates the response function and the variance reduction. 
+    
+    The quad can be supplied to save time, which makes sense when you are chaning the response function. 
+    Also is the way to go if you don't need the real reduction value, only the estimate. 
+    
+    Should also return the dJ values of all 4 ensembles (analysis, background, forecast, blind forecast)
+    
+    exp_num x(experiment number j) needs to passed on to the forecast of LA somehow, because it determines the random model error of the ensembles 
+    This will become a major part of the paper, and will see lots of refining. Important options are: 
+    
+    not implemented:
+    - state model reduction
+    
+    response functions planned:
+    - mid tri, sum over middle of domain.
+    - sum over first and last third
+    - Triangle sum
+    - right/left rain amount. 
+    
+    Speed up options: 
+    - If the same forecast can be recycled, the quad can be calculated once and than passed on. 
+    - If the sensitivity can be recycled, that can be calculated once and then reused. 
+    
+    implemented:
+    -iteratively and all at once
+    -explicit vs implicit 
+    -reduced model spacing by subsampling the grid. Reduc is the spacing, eg reduc 3 means every third grid point is used. 
+     should find the nearest reduced model grid point for each observation.
+    -reg_flag added to use L2 regularization
+    """
+    if iterative_flag ==0: from scipy.linalg import sqrtm
+    ###########################################################################
+    #First, need we calculate the quad (analysis, forecast, blind forecast)
+    #A precomputed quad_state can be supplied to quickly use different response functions
+    ###########################################################################
+    if type(quad_state)== type(None):
+        quad_state = single_step_analysis_forecast_v2(background,truth,da_const,m_const)
+    
+    ###########################################################################
+    # Next we need the response functions. 
+    # We only really need the forecast and blind forecast, but for fun I'll calculate all for now
+    ###########################################################################
+    
+    #For now I am not worried about being efficient
+    nobs = len(da_const["obs_loc"])
+    obs = np.arange(nobs)
+    nens = da_const["nens"]
+    nstate = len(truth)
+
+    bf_response = np.zeros(nens)
+    fc_response = np.zeros(nens)
+    an_response = np.zeros(nens)
+    bg_response = np.zeros(nens)
+    for n in range(da_const["nens"]):
+        bf_response[n] = func_J(quad_state["bf"][:,n])
+        fc_response[n] = func_J(quad_state["fc"][:,n])
+        an_response[n] = func_J(quad_state["an"][:,n])
+        bg_response[n] = func_J(quad_state["bg"][:,n])
+        
+    J_dict = {}
+    J_dict['bf']  = bf_response
+    J_dict['bg']  = bg_response
+    J_dict['an']  = an_response
+    J_dict['fc']  = fc_response
+    J_dict['tr_bg']  = func_J(quad_state['tr_bg'])
+    J_dict['tr_fc']  = func_J(quad_state['tr_fc'])
+    
+    ###########################################################################
+    # Creating the R,H, and C matrices we will need for the VR estimate 
+    ###########################################################################
+    
+    if m_const['model'] == 'LA':
+        H = np.identity(m_const["nx"])[da_const["obs_loc"],:]                              # Observation operator    
+        R = da_const["used_std_obs"]**2*np.identity(len(da_const["obs_loc"]))           # Observation error corvariance matrix
+    if m_const['model'] == 'SWM':
+        H = np.identity(m_const["nx"]*3)[da_const["obs_loc"],:]                              # Observation operator    
+        obs_error_vec = np.ones(len(da_const["obs_loc"]))*da_const['used_h_std_obs']
+        obs_error_vec[obs_error_vec<m_const['nx']] = da_const['used_u_std_obs']
+        obs_error_vec[obs_error_vec>2*m_const['nx']] = da_const['used_r_std_obs']
+        R = np.diag(obs_error_vec**2)
+    
+    if da_const["loc"]:
+        C = loc_matrix(da_const,m_const)
+    else:
+        C=np.ones([m_const["nx"],m_const["nx"]])
+    if m_const['model'] == 'SWM':
+        C = np.hstack([C,C,C])
+        C = np.vstack([C,C,C])
+    
+    
+    
+    
+
+
+    if reduc>1: #Does currently not work for 
+        #Defining reduced model domain/state, and then defining obs location on new reduced grid
+        #For the SWM model this only really makes sense if nx is a multiple of reduc 
+        reduced_model_domain = np.arange(0,nstate,reduc)
+        reduced_model_size = len(reduced_model_domain)
+        reduced_obs_loc = np.zeros(nobs).astype(int)
+        for o in range(nobs):
+            reduced_obs_loc[o] = (np.abs(reduced_model_domain-da_const['obs_loc'][o])).argmin()
+        #print('reduced_model_domain:',reduced_model_domain)
+        #print('reduced_model_size  :',reduced_model_size  )
+        #print('reduced_obs_loc     :',reduced_obs_loc     )
+            
+        #Getting the localization matrix in reduced model space
+        H = np.identity(m_const["nx"])[reduced_model_domain,:]   
+        C = np.dot(H,np.dot(C,H.T))
+        x = quad_state['bg'][reduced_model_domain,:]
+        
+    else:
+        x = quad_state['bg'][:,:]
+        
+        #x = states[j]['bg'][t][:,:]
+    dx = x.T-np.mean(x,axis=1)
+    dx = dx.T
+    dx_orig = dx+0
+    
+    A = np.dot(dx,dx.T)/(dx.shape[1]-1)
+
+    J  = bf_response
+    dJ      = J-np.mean(J)
+    dJ_orig = J-np.mean(J)
+
+
+    ###############################################################################################
+    # Sensitivity
+    ###############################################################################################
+    #Covarianz between reponse function dJ and state ensemble dx
+    cov_dJdx_vec = np.dot(dJ,dx.T)/(dx.shape[1]-1)
+
+    if explicit_sens_flag==1:
+        #If a sensitivity is provided it is used instead of calculating it 
+        if type(dJdx_inv) == np.ndarray:
+            #testing supplied sensitivity
+            rel_error_sens = np.sum(np.abs(A.dot(dJdx_inv)-cov_dJdx_vec))/np.sum(np.abs(cov_dJdx_vec))
+            if rel_error_sens>0.05:
+                print('using supplied sensitivity has a relative error of:',rel_error_sens)
+        #Computing the sensitivity, highly recommend using the regularized version before doing so. 
+        else:
+            if reg_flag == 1:
+                dJdx_inv = L2_regularized_inversion(A,cov_dJdx_vec,alpha=alpha,mismatch_threshold=mismatch_threshold)
+            else:
+                A_inv = np.linalg.pinv(A)
+                dJdx_inv = np.dot(A_inv,cov_dJdx_vec)
+
+    estimated_J = bf_response + 0.
+    
+    
+    if iterative_flag ==0:
+        
+        vr_individual = 0.
+        
+        H = np.identity(m_const["nx"])[da_const["obs_loc"],:]                              # Observation operator    
+        R = da_const["used_std_obs"]**2*np.identity(len(da_const["obs_loc"]))           # Observation error corvariance matrix
+            
+        
+        R_obs = R # Observation error corvariance matrix
+        H_obs = H # Not the cleanest coding I know
+        
+        if explicit_sens_flag ==1:
+            #Tanjas approach of calculating the square root K following Kalman gain, formula (10) Whitaker and Hamil 2002
+            E = np.matmul(C*A,H_obs.T)
+            E = np.matmul(H_obs,E)
+            E = E + R_obs
+            Esqrt = sqrtm(E)
+            #alpha = 1./(1.+np.sqrt(R_obs/E))
+            # Kalman gain, formula (10) Whitaker and Hamil 2002
+            K1 = np.matmul(C*A,H_obs.T)
+            K2 = np.linalg.inv(Esqrt)
+            K2 = K2.T
+            K3 = np.linalg.inv(Esqrt + sqrtm(R_obs))
+            K  = np.matmul(K1,K2)
+            K  = np.matmul(K,K3)
+
+            Hdx = np.matmul(H_obs,dx)
+            dx_prime = dx - np.dot(K,Hdx)
+            
+            #estimated J calculated by updating the original dJ
+            estimated_J = estimated_J -np.dot(dJdx_inv,np.dot(K,Hdx))
+            
+            #Using the cheaper variance calculation instead of going to A-B
+            new_J = np.dot(dJdx_inv.T,dx_prime)
+            vr_individual = np.var(new_J,ddof=1)-np.var(np.dot(dJdx_inv.T,dx),ddof=1)
+            vr_total = vr_individual
+            #print('all at once:',vr_individual)
+            dx = dx_prime
+        
+        if explicit_sens_flag ==0: 
+            HAHt = np.dot(H_obs,np.dot(C*A,H_obs.T))
+            HAHtRinv= np.linalg.inv(HAHt+R_obs)
+
+            dJHdxt           =  np.dot(dJ,np.dot(H_obs,dx).T)/(nens-1)
+            vr_individual = -np.dot(dJHdxt,np.dot(HAHtRinv,dJHdxt))
+            vr_total = vr_individual
+    
+    
+    if iterative_flag:
+        vr_individual = np.zeros(nobs)
+        for o in range(nobs):  #loop over each observation individually
+
+            #New A after dx was updated 
+            A = np.cov(dx,ddof=1)
+
+            #Selecting the single R value for this observation
+            R_obs = R[o,o]*np.identity(1)           # Observation error corvariance matrix
+            H_obs = H[o,:]   #Not sure about this :(
+            #if 
+            ##H_rms = np.identity(nobs)[o,:] 
+            #H_rms = np.identity(reduced_model_size)[reduced_obs_loc[o],:] 
+            ##H = np.identity(m_const["nx"])[da_const["obs_loc"][o],:]
+#
+            ##HAHt = np.dot(H,np.dot(A,H.T))
+            ##HAHtRinv= np.linalg.inv(HAHt+R)
+            ##dJHdxt = np.dot(dJ,np.dot(H,dx).T)/(nens-1)
+            
+            #Now we get the change in dx, which uses the localized A matrix and a square root 
+            E = np.matmul(C*A,H_obs.T)
+            E = np.matmul(H_obs,E)
+            E = E + R_obs
+            alpha = 1./(1.+np.sqrt(R_obs/E))
+
+
+            HAHt = np.dot(H_obs,np.dot(C*A,H_obs.T))
+            HAHtRinv= np.linalg.inv(HAHt+R_obs)
+            #print((C*A).shape,H_obs.T.shape,np.dot(C*A,H_obs.T).shape,HAHtRinv.shape)
+    ##
+
+            if explicit_sens_flag ==1: 
+            
+                K = np.dot(C*A,H_obs.T)*HAHtRinv
+                # Update state vector
+                Hdx = np.matmul(H_obs,dx)
+                dx_prime = dx - alpha*np.outer(K,Hdx)
+            
+                #Update expected J for each ensemble member
+                estimated_J = estimated_J -np.dot(dJdx_inv,alpha*np.outer(K,Hdx))
+
+                #A_new = np.dot(dx,dx.T)/(dx.shape[1]-1)
+                #Now we get the variance estimate using the difference between new and old B
+                #dSigma = np.matmul(A_new-A,dJdx_inv.T)
+                #print('wtf: np.sum(np.abs(A_new-A_old))  ',o,np.sum(np.abs(A_new-A)))
+                #dSigma = np.matmul(dJdx_inv,dSigma)
+                #vr_individual[o] = dSigma#np.matmul(np.atleast_2d(np.diag(A_new-A)),(dJdx_inv.T**2))
+                
+                vr_individual[o] = np.var(np.dot(dJdx_inv.T,dx_prime),ddof=1)-np.var(np.dot(dJdx_inv.T,dx),ddof=1)
+                dx = dx_prime
+            if explicit_sens_flag==0:
+                dJHdxt           = np.dot(dJ,np.dot(H_obs,dx).T)/(nens-1)
+                vr_individual[o] = -np.dot(dJHdxt,np.dot(HAHtRinv,dJHdxt))
+                
+                #Still needs localization! 
+                
+                #Now we include dJ into dx to include it in the update
+                dxJ = np.vstack([dx,dJ])
+                AxJ = np.dot(dxJ,dxJ.T)/(dxJ.shape[1]-1)
+                HxJ = np.hstack([H_obs,np.array(0)])
+                CxJ = np.ones([nstate+1,nstate+1])
+                CxJ[:nstate,:nstate] = C
+                HAHt = np.dot(HxJ,np.dot(CxJ*AxJ,HxJ.T))
+                HAHtRinv= np.linalg.inv(HAHt+R_obs)
+
+                #print(AxJ.shape,CxJ.shape)
+                #print(AxJ.shape,HxJ.T.shape,np.dot(AxJ,HxJ.T).shape,HAHtRinv.shape)
+                K = np.dot(CxJ*AxJ,HxJ.T)*HAHtRinv
+
+
+                # Update state vector
+                HdxJ = np.matmul(HxJ,dxJ)
+                dxJ = dxJ - alpha*np.outer(K,HdxJ)
+
+                dx = dxJ[:-1,:]
+                #old_var_dJ = np.var(dJ)
+                dJ = dxJ[-1,:]
+                estimated_J = dJ+np.mean(bf_response)
+                #new_var_dJ = np.var(dJ)
+            
+        #Final var reduciton estimate
+        if explicit_sens_flag==0: vr_total=np.sum(vr_individual)
+        
+        if explicit_sens_flag==1: 
+            vr_total=np.var(np.dot(dJdx_inv.T,dx),ddof=1)-np.var(np.dot(dJdx_inv.T,dx_orig),ddof=1)
+            
+            #Checking different formulation
+
+
+    #
+    
+    J_dict['es'] =  estimated_J
+    
+    J_fc= fc_response
+    dJ_fc = J_fc-np.mean(J_fc)
+    real_reduction=np.var(dJ_fc) - np.var(dJ_orig)
+        
+    return vr_total,vr_individual,real_reduction,J_dict,dJdx_inv,quad_state,dx
 
 
diff --git a/model_functions.py b/model_functions.py
index ad05856..26505af 100644
--- a/model_functions.py
+++ b/model_functions.py
@@ -26,6 +26,9 @@ def set_model_constants(nx=101,dx=100,u=2.33333333,u_std=0,dhdt=0,dhdt_std=0,h_i
     const["dhdt_ref"]        = dhdt        # reference dhdt around which ensemble is generated
     const["dhdt_std_truth"]  = dhdt_std   # standard deviation of dhdt for truth
     
+    #model
+    const["model"] = 'LA'
+
     return const
 
 def gaussian_initial_condition(x,sig):
diff --git a/plot_functions.py b/plot_functions.py
index 72bdc57..ae06f44 100644
--- a/plot_functions.py
+++ b/plot_functions.py
@@ -87,15 +87,15 @@ def quad_plotter(quad_state,m_const,da_const):
     
     for i in range(da_const["nens"]):
         ax[0,0].plot(m_const['x_grid'],quad_state['bg'][:,i],'r',alpha =alpha,zorder=1)
-        ax[0,1].plot(m_const['x_grid'],quad_state['bf'][:,i],'magenta',alpha =alpha,zorder=1)
-        ax[1,0].plot(m_const['x_grid'],quad_state['an'][:,i],'b',alpha =alpha,zorder=1)
+        ax[0,1].plot(m_const['x_grid'],quad_state['bf'][:,i],'b',alpha =alpha,zorder=1)
+        ax[1,0].plot(m_const['x_grid'],quad_state['an'][:,i],'magenta',alpha =alpha,zorder=1)
         ax[1,1].plot(m_const['x_grid'],quad_state['fc'][:,i],'c',alpha =alpha,zorder=1)
     
 
     ax[0,0].plot(m_const["x_grid"],quad_state['tr_bg'],'k',zorder=10,label='truth')
     ax[1,0].plot(m_const["x_grid"],quad_state['tr_bg'],'k',zorder=10,label='truth')
-    ax[0,1].plot(m_const["x_grid"],quad_state['tr_fc'],'k',zorder=10,label='truth')
-    ax[1,1].plot(m_const["x_grid"],quad_state['tr_fc'],'k',zorder=10,label='truth')
+    #ax[0,1].plot(m_const["x_grid"],quad_state['tr_fc'],'k',zorder=10,label='truth')
+    #ax[1,1].plot(m_const["x_grid"],quad_state['tr_fc'],'k',zorder=10,label='truth')
     
     ax[0,0].plot(m_const['x_grid'],np.mean(quad_state['bg'][:,:],axis=1),'k--',alpha =1,zorder=2,label='ens mean')
     ax[1,0].plot(m_const['x_grid'],np.mean(quad_state['an'][:,:],axis=1),'k--',alpha =1,zorder=2,label='ens mean')
@@ -116,7 +116,7 @@ def quad_plotter(quad_state,m_const,da_const):
     ax[1,0].set_xlabel('x [m]')
     ax[0,0].set_ylabel('h [m]')
     ax[1,0].set_ylabel('h [m]')
-    plt.legend()
+    ax[1,0].legend()
     return fig,ax
 
 
-- 
GitLab