#!/usr/bin/env python

#Contains all functions related to setting up and running the ensemble filter. There is quite a bit of overlap with the model functions, but I felt it was a bit cleaner to try to separate them a bit. 

import numpy as np
from model_functions import *

def set_da_constants(ncyc=10,nt=1,dt=500,u_std=0.5,dhdt_std=1e-4,True_std_obs=0.1,used_std_obs=0.1,pert_std_obs=0.00,obs_loc_h=np.arange(5,101,15),nens=100,nexp=1,init_noise=0.,fixed_seed=True,
                     loc=None,init_spread = False, init_spread_h=0.5,init_spread_x = 500.,
                     loc_length = 1000,loc_type='gaussian'):
    """
    Sets constants needed for data assimilation and stores them in a dictionary. 
    There is some confusting misnaming going on, e.g. "u_std_ens" = u_std, but nothing that needs to be changed immediately

    nt is not really relevant as long as we use the linear advection model, but I left it in from Yvonne in case we start applying it to different models. 
    
    TODO:
    - Figure out what size std_obs should be.
    - Add localization
    - Add different assimilation options (e.g. LETKF)
    
    """
    DA_const = {}
    
     
    DA_const["ncyc"] = ncyc                   # Number of assimilation cycles
    DA_const["nens"] = nens                   # number of ensemble members
    DA_const["nexp"] = nexp                   # number of parallel experiments to run
    
    
    DA_const["nt"] = nt                       # Number of model timesteps between observations
    DA_const["dt"] = dt                       # time duration of model timesteps 
    
    
    #Ensemble Errors and ensemble  
    DA_const["u_std_ens"]    = u_std          # Standard deviation of model u 
    DA_const["dhdt_std_ens"] = dhdt_std       # Standard deviation of model dydt 
    DA_const["fixed_seed"]   = fixed_seed     # If True, then the ensemble number is used as a seed, so that the u and dhdt values are always randomized the same way
    
    #Ensemble initialization
    DA_const["init_noise"]    = init_noise    # spread of noise added to initial conditions to avoid singular matrices
    DA_const["init_spread"]   = init_spread   # If True, adds an x and h displacement to initial ensemble spread 
    DA_const["init_spread_h"] = init_spread_h # initial spread of ensemble in h 
    DA_const["init_spread_x"] = init_spread_x # initial spread of ensemble in x
    
    #Observations
    DA_const["True_std_obs"] = True_std_obs   # Standard deviation of true observation error of y 
    DA_const["used_std_obs"] = used_std_obs   # Standard deviation of assumed observation error used to calculate R
    DA_const["pert_std_obs"] = pert_std_obs   # Standard deviation of pertubations added to the true observation for each ensemble member individually when updating
    DA_const["obs_loc"] = obs_loc_h           # index array of which state elements are observed
    
    
    #Localization 
    DA_const["loc"] = loc                     # localization yes or no, no for now
    DA_const["loc_length"] = loc_length                     # localization yes or no, no for now
    DA_const["loc_type"] = loc_type                    # localization yes or no, no for now
        
    return DA_const
    
def create_states_dict(j,states,m_const,da_const):
    """
    Generates the initial analysis and truth.
    Also creates the "states" dictionary where the analysis ensemble, background ensemble, truth and observations are all going to be stored. 
    Very memory hungry, as everything from all assimilation time steps and all experiments is stored. 
    Works find for simple model though. 
    
    A fixed seed is used by default so that the model errors of all ensemble members is constant. But this can also be randomized.
    Alternative version would be to generate the model errors and store them, but this has not happened yet. Would be necessary to test some parameter estimation tests. 

    Modified version of the Yvonne setup to work with the linear advection model

    Todo:
    - Describe states dictionary here. 
    - make model errors stored variables so enable parameter estimation
    - Maybe make a more sensible name, such as init_da_ensembles
    - I am not convinced this is the best way to generate the initial analysis.
    - Might be best to change to dictionary time axis. Forecast and anaylis at the same time have difference of 1 in the time integer. 
    """
    nx = m_const["nx"]

    #initial conditions
    h_initial = gaussian_initial_condition(m_const["x_grid"],m_const["h_init_std"])
    

    #Generate truth
    if da_const["fixed_seed"]==True: np.random.seed(j)
    
    u_truth    = np.random.normal(m_const["u_ref"],m_const["u_std_truth"])
    dhdt_truth = np.random.normal(m_const["dhdt_ref"],m_const["dhdt_std_truth"])
    truth      = linear_advection_model(h_initial,u_truth,dhdt_truth,m_const["dx"],da_const["dt"],da_const["nt"])
    
    an = np.zeros((nx,da_const["nens"]))
    
    #First rows full of nans :)
    bg = np.zeros((nx,da_const["nens"]))
    bg[:] = np.nan
    obs = np.zeros(nx)
    obs[:] = np.nan
    
    for i in range(da_const["nens"]):
        if da_const["fixed_seed"]==True: np.random.seed(i+j*da_const["nens"])
        if da_const["init_noise"]>0:
            h_ens  = np.random.normal(h_initial,da_const["init_noise"])
        else: 
            h_ens  = h_initial
        if da_const["init_spread"]>0:
            #initial spread generated by moving waves forward and backward and up and down using 
            #da_const["init_spread_h"] and da_const["init_spread_x"]
            x_displace = np.random.normal(0.,da_const["init_spread_x"])
            h_displace = np.random.normal(0.,da_const["init_spread_h"])
            u_tmp = x_displace/da_const["dt"]
            h_ens    = semi_lagrangian_advection(h_ens,m_const["dx"],u_tmp,da_const["dt"])
            h_ens    = h_ens+h_displace
            
        u_ens      = np.random.normal(m_const["u_ref"],da_const["u_std_ens"])
        dhdt_ens   = np.random.normal(m_const["dhdt_ref"],da_const["dhdt_std_ens"])
        an[:,i]    = linear_advection_model(h_ens,u_ens,dhdt_ens,m_const["dx"],da_const["dt"],da_const["nt"])

    states[j]={}
    states[j]['bg']=[bg]
    states[j]['an']=[an]
    states[j]['truth']=[truth]
    states[j]['obs']=[obs]
    return an, truth, states
    
    
    return DA_const

def generate_obs(truth,states,m_const,da_const,j,t):
    """
    Generates the truth and observations for the next assimilation step.
    If a fixed_seed is used the model "errors" of the truth remain the same, but noise of the obs is still reseeded differently each time.
    
    To avoid diffusion over time in the truth I take advantage of the linearity of the model to compute it directly from the initial conditions. 
    This only works if the truth is not perturbed each timestep though.  
    """
    #Generate new truth constants and integrate in time
    if da_const["fixed_seed"]==True: np.random.seed(j)
    u_truth = np.random.normal(m_const["u_ref"],m_const["u_std_truth"])
    dhdt_truth = np.random.normal(m_const["dhdt_ref"],m_const["dhdt_std_truth"])
    if da_const["fixed_seed"]==True:
        truth = linear_advection_model(states['truth'][0],u_truth,dhdt_truth,m_const["dx"],da_const["dt"]*t,da_const["nt"])
    else:
        truth = linear_advection_model(truth,u_truth,dhdt_truth,m_const["dx"],da_const["dt"],da_const["nt"])
    
    #make obs by adding some noise, a fixed seed should change over time so that the differences are not always the same for each measurement location
    if da_const["fixed_seed"]==True: np.random.seed((j+1)*t)
    obs = truth + np.random.normal(0,da_const["True_std_obs"],m_const["nx"])
    states["truth"] = states["truth"]+[truth]
    states["obs"] = states["obs"]+[obs]
    return truth, obs, states





def predict(analysis,states,da_const,m_const,j):
    """
    Runs the analysis ensemble forward in time using the model to generate the next forecast/background ensemble for the next assimilation step.

    Outlook:
    - Add different models than the linear_advection_model? 
    """
    bg = np.zeros((m_const["nx"],da_const["nens"]))
    for i in range(da_const["nens"]):
        if da_const["fixed_seed"]==True: np.random.seed(i+j*da_const["nens"])
        u_ens      = np.random.normal(m_const["u_ref"],da_const["u_std_ens"])
        dhdt_ens   = np.random.normal(m_const["dhdt_ref"],da_const["dhdt_std_ens"])
        bg[:,i]    = linear_advection_model(analysis[:,i],u_ens,dhdt_ens,m_const["dx"],da_const["dt"],da_const["nt"])
        
    states["bg"] = states["bg"]+[bg]
    return bg, states




def update(background,obs,R,H,C,states,da_const,m_const):
    """
    Computes the analysis by individually changing each ensemble member of the forecast/background through the shared ensemble Kalman Gain and the observations.
    Now also used the localization matrix,C
    """
    # Compute the background error covariance matrix
    P = np.cov(background)*C

    # define relative weights of observations and background in terms of the Kalman Gain matrix of size
    K = KalmanGain(P, R, H)

    # Compute the analysis for each ensemble members
    an = np.zeros((m_const["nx"],da_const["nens"]))
    for i in range(da_const["nens"]):
        an[:,i] = get_analysis(background[:,i],obs,K,H,da_const)
    states["an"] = states["an"]+[an]
    return an, states

#def update_noloc(background,obs,R,H,states,da_const,m_const):
#    """
#    Computes the analysis by individually changing each ensemble member of the forecast/background through the shared ensemble Kalman Gain and the observations.
#
#    Todo:
#    -Figure out exactly which version of the EnKF is used here, so that I can refer to it properly
#    """
#    # Compute the background error covariance matrix
#    P = np.cov(background)
#
#    # define relative weights of observations and background in terms of the Kalman Gain matrix of size
#    K = KalmanGain(P, R, H)
#
#    # Compute the analysis for each ensemble members
#    an = np.zeros((m_const["nx"],da_const["nens"]))
#    for i in range(da_const["nens"]):
#        an[:,i] = get_analysis(background[:,i],obs,K,H,da_const)
#    states["an"] = states["an"]+[an]
#    return an, states



def KalmanGain(P,R,H):
    """
    Computes the Kalman gain matrix: K = PH^T(HPH^T+R)^(-1)
    """
    P_H = np.dot(P,H.transpose())
    S = np.dot(H,P_H)+R   
    K = np.dot(P_H,np.linalg.inv(S))
    
    return K
    
def get_analysis(bg,obs,K,H,da_const):    
    """
    Computes analysis: an = bg + K(H*bg-obs_pert), where obs_pert are perturbed observations
    """
    obs_pert = np.dot(H,obs+np.random.normal(0,da_const["pert_std_obs"],len(obs)))
    bg_obs = np.dot(H,bg)
    #obs_orig =np.dot(H,obs) 
    #print('obs_orig-bg_obs',np.sum(obs_orig-bg_obs),'obs_pert-bg_obs',np.sum(obs_pert-bg_obs),'update/np.sum(x)',np.sum(np.dot(K,obs_pert-bg_obs) )/np.sum(bg),'max(K)',np.max(K))
    an =  bg + np.dot(K,obs_pert-bg_obs)    
    return an

def run_linear_advection_EnKF(m_const,da_const):
    """
    The heart and soul of the whole linear advection EnKF filter. 
    Steps: 
    - Computes constant Matrices H and R
    - Initializes states dictionary with first truth and analysis ensemble
    - And so on

    Todo: Improve the documentation
    """

    """
    constant matrices that follow from previously defined constants
    """
    H = np.identity(m_const["nx"])[da_const["obs_loc"],:]                              # Observation operator    
    R = da_const["used_std_obs"]**2*np.identity(len(da_const["obs_loc"]))           # Observation error corvariance matrix




    """
    create dictionary to store background, analysis, truth and observations for verification 
    """
    states = {}

    #Construct localization matrix C if loc!=None
    if da_const["loc"]:
        C = loc_matrix(da_const,m_const)
    else:
        C=np.ones([m_const["nx"],m_const["nx"]])

    """
    Loop over experiments
    """
    for j in range(da_const["nexp"]):

        """
        Initialize a random truth and a random analysis ensemble stored in "states"
        """
        analysis, truth, states = create_states_dict(j,states,m_const,da_const)
        

        """
        START DATA ASSIMILATION CYCLE
        """
        for t in range(0,da_const["ncyc"]):
            """
            Run the truth forward in time until the next assimilation step and create observations. 
            Note that this step is usually provided by nature and measurements obtained from weather stations, 
            wind profilers, radiosondes, aircraft reports, radars, satellites, etc.
            """
            truth, obs, states[j] = generate_obs(truth,states[j],m_const,da_const,j,t)

            """
            Predict
            """
            # Predict the state at the next assimilation step by running the analysis forward in time
            background, states[j] = predict(analysis,states[j],da_const,m_const,j)

            """
            Update
            """
            ## Combine background and observations to get the analysis
            #analysis, states[j] = update_noloc(background,obs,R,H,states[j],da_const,m_const)
            analysis, states[j] = update(background,obs,R,H,C,states[j],da_const,m_const)
        """
        END DATA ASSIMILATION CYCLE
        """
    """
    end loop over experiments
    """
    return states


def get_spread_and_rmse(states,da_const,m_const):
    """
    computes RMSE over space and time respectively, averaged over all experiments for the background ensmeble mean and analysis ensemble mean.
    Structure of the output is as follows: rmse['dim']['quan'] where dim = space or time, quan = bg or an, bf will also be calculated when available
    
    Returns spread and rmse over time and space averaged over all experiments. So increasing the number of experiments acts smooths out the time errors. 
    """
    
    
    n    = m_const["nx"]
    ncyc = da_const["ncyc"]+1
    nexp = da_const["nexp"]
    nens = da_const["nens"]
    time = np.arange(0,ncyc)
    
    quantities = ['an','bg']
    if 'bf' in states[0].keys(): quantities.append('bf')
    tr_matrix = np.zeros([nexp,ncyc,n])
    for j in range(nexp):
        tr_matrix[j,:,:] = np.array(states[j]["truth"][:])
    rmse = {}
    spread = {}
    

    for dim in ['time','space']: # Choose dim to average results
        rmse[dim]={}
        spread[dim]={}
        for quan in quantities: # Loop over fields
            
            
            rmse[dim][quan]={}
            spread[dim][quan] = {}
            rmse[dim][quan]['mean']=0.
            spread[dim][quan]['mean']=0.

            for j in range(nexp): # Loop over experiments
                ens_matrix = np.array(states[j][quan][:]) 
                if dim == "space": # Average over time to get one value per gridpoint
                    rmse[dim][quan][j] = np.zeros(n)
                    spread[dim][quan][j] = np.zeros(n)
                    for x in range(n): # Loop DA cycles
                        rmse[dim][quan][j][x]   = np.nanmean(L2norm(ens_matrix[:,x,:].T-tr_matrix[j,:,x]))
                        spread[dim][quan][j][x] = np.nanmean(np.std(ens_matrix[:,x,:], axis=1, ddof=1))
                        
                if dim == "time": # Average over space so that one value remains per timestep
                    rmse[dim][quan][j]   = np.zeros(ncyc)
                    spread[dim][quan][j] = np.zeros(ncyc)
                    for t in time: # Loop DA cycles
                        rmse[dim][quan][j][t]   = np.mean(L2norm(ens_matrix[t,:,:].T-tr_matrix[j,t,:]))
                        spread[dim][quan][j][t] = np.mean(np.std(ens_matrix[t,:,:], axis=1, ddof=1))
                rmse[dim][quan]['mean']   = rmse[dim][quan]['mean']   + rmse[dim][quan][j]/float(nexp)
                spread[dim][quan]['mean'] = spread[dim][quan]['mean'] + spread[dim][quan][j] / float(nexp)

            rmse[dim][quan] = rmse[dim][quan]['mean']
            spread[dim][quan] = spread[dim][quan]['mean']
    return rmse, spread


def L2norm(error):
    """
    Computes the l2 norm over the first dimension of the input array
    """
    return np.sqrt(np.average(error**2,axis=0))



def predict_blind(background,states,da_const,m_const,j):
    """
    Runs the background ensemble forward in time using the model to predict the truth at the next assimilation step.
    This is than saved as 'bf', for blind forecast.
    """
    bf = np.zeros((m_const["nx"],da_const["nens"]))
    for i in range(da_const["nens"]):
        if da_const["fixed_seed"]==True: np.random.seed(i+j*da_const["nens"])
        u_ens      = np.random.normal(m_const["u_ref"],da_const["u_std_ens"])
        dhdt_ens   = np.random.normal(m_const["dhdt_ref"],da_const["dhdt_std_ens"])
        bf[:,i]    = linear_advection_model(background[:,i],u_ens,dhdt_ens,m_const["dx"],da_const["dt"],da_const["nt"])
        
        #bg[:,i] = linear_advection_model(analysis[:,i],da_const["dt"],da_const["std_model"],m_const["C"],m_const["a"])
    states["bf"] = states["bf"]+[bf]
    return bf, states



def add_blind_forecast(states,m_const,da_const):
    """
    Takes a states dictionary created by run_linear_advection_EnKF and adds a 'bf' field which stands for blind forecast, also known as free forecast. 
    This blind forecast is generated directly from the background ensemble, without assimilating first.
    Only makes sense if the randomized errors are carefully generated from the same seed, otherwise the model errors used to make the background ensemble will change when creating the blind forecast. 
    
    This feels dangerously close to philipp making a custom solution for everything himself again. 
    """
    """
    Loop over experiments
    """
    for j in range(da_const["nexp"]):
        
        #initializing new dictionary field, this is where things can easily go wrong, so pay attention
        states[j]['bf'] = []
        #add two first initial nan states, which we take from the initial forecast field
        states[j]['bf'].append(states[j]['bg'][0])
        states[j]['bf'].append(states[j]['bg'][0])

        """
        START blind forecasting
        """
        for t in range(2,da_const["ncyc"]+1):
            """
            Predict
            """
            #print('blind forecast time',t)
            # Predict the state at the next assimilation step by running the analysis forward in time
            blind_forecast, states[j] = predict_blind(states[j]['bg'][t-1],states[j],da_const,m_const,j)

        """
        END blind forecasting
        """
    """
    end loop over experiments
    """
    return states

def sum_mid_tri(x):
    """
    Default response function, is a simple sum over the middle of the state vector 
    """
    nx = len(x)
    idx_str = int(nx/3.)
    idx_end = int(2*nx/3.)
    return np.sum(x[idx_str:idx_end])


def add_response(states,func_J=sum_mid_tri):
    """
    Goes through the analysis, background, truth, and blind forecast if it exists and applies the given response function to the ensemble members. 
    Isn't pretty, could be made more elegantly to loop over the given variables instead of hard coding them.
    Also poorly handles that there are differing number of bg, an, and bf fields. 
    Also so far doesn't allow computing the truth as a response function, as only the individual ensemble members are passed to func_J with no other information
    
    ToDo: 
    - Allow using more diverse func_Js
    - tidy up, is a mess
    """
    nexp = len(states)
    ncyc = len(states[0]["bg"])
    nx   = states[0]["bg"][0][:,0].shape[0]
    nens = states[0]["bg"][0][0,:].shape[0]

    #Poor way of checking if a blind forecast also exists
    bf_flag = 0
    if 'bf' in states[0]:  bf_flag = 1

    for e in range(nexp):
        states[e]["response"]={}

        #Initialize dictionary fields, not the cleanest solution, but fuck it
        c = 0

        an_response = np.zeros(nens)
        bg_response = np.zeros(nens)
        if bf_flag:        bf_response = np.zeros(nens)
        for n in range(nens):
            an_response[n] = func_J(states[e]["an"][0][:,n])
            bg_response[n] = func_J(states[e]["bg"][0][:,n])
            if bf_flag:  bf_response[n] = func_J(states[e]["bf"][0][:,n])


        states[e]["response"]["an"]   =[an_response]
        states[e]["response"]["bg"]   =[bg_response]
        states[e]["response"]["truth"]=[func_J(states[e]["truth"][0])]

        if bf_flag : states[e]["response"]["bf"]   =[bf_response]


        for c in range(1,ncyc):
            an_response = np.zeros(nens)
            bg_response = np.zeros(nens)
            bf_response = np.zeros(nens)
            for n in range(nens):
                an_response[n] = func_J(states[e]["an"][c][:,n])
                bg_response[n] = func_J(states[e]["bg"][c][:,n])
                if bf_flag:  bf_response[n] = func_J(states[e]["bf"][c][:,n])

            states[e]["response"]["bg"]   =states[e]["response"]["bg"]+[bg_response]
            if bf_flag: states[e]["response"]["bf"]   =states[e]["response"]["bf"]+[bf_response]

            states[e]["response"]["an"]   =states[e]["response"]["an"]+[an_response]
            states[e]["response"]["truth"]=states[e]["response"]["truth"]+[func_J(states[e]["truth"][c])]



    return states

def var_reduction_estimate(states,m_const,da_const,j=0,obs = [],ncyc=0,n_start = 10):
    """
    just loops over a full experiment
    Calculates an estimated var reduction of the response fucntion for all observations, as well as for each observation individually
    if a sepcific obs list is applied, that will be used instead of all observations
    
    following naming of Hakim 2020
    
    """
    
    #For now I am not worried about being efficient
    #j=0
    #t=3
    nobs = len(da_const["obs_loc"])
    if obs == []:
        obs = np.arange(nobs)
    

    nens = da_const["nens"]
    if ncyc ==0: ncyc = len(states[j]['response']['bf'])
    

    var_reduction_total      = []
    var_reduction_individual = []
    real_reduction           = []
    
    for t in range(n_start,ncyc-1):

        R = da_const["used_std_obs"]**2*np.identity(len(obs))           # Observation error corvariance matrix
        H = np.identity(m_const["nx"])[da_const["obs_loc"][obs],:]

        x = states[j]['bg'][t][:,:]
        dx = x.T-np.mean(x,axis=1)
        dx = dx.T
        A = np.dot(dx,dx.T)/(dx.shape[1]-1)

        J= states[j]["response"]['bf'][t+1][:]
        dJ = J-np.mean(J)

        HAHt = np.dot(H,np.dot(A,H.T))
        HAHtRinv= np.linalg.inv(HAHt+R)
        dJHdxt = np.dot(dJ,np.dot(H,dx).T)/(nens-1)

        vr_total = -np.dot(dJHdxt,np.dot(HAHtRinv,dJHdxt))
        var_reduction_total.append(vr_total)

        #And now the loop for the indivudual reductions
        vr_individual = np.zeros(nobs)
        for o in range(nobs):
            R = da_const["used_std_obs"]**2*np.identity(1)           # Observation error corvariance matrix
            H = np.identity(m_const["nx"])[da_const["obs_loc"][o],:]
            HAHt = np.dot(H,np.dot(A,H.T))
            HAHtRinv= np.linalg.inv(HAHt+R)
            dJHdxt = np.dot(dJ,np.dot(H,dx).T)/(nens-1)
            vr_individual[o] = - np.dot(dJHdxt,np.dot(HAHtRinv,dJHdxt))        
        var_reduction_individual.append(vr_individual)

        J_f=states[j]["response"]['bg'][t+1][:]
        dJ_f = J_f-np.mean(J_f)

        real_reduction.append(np.var(dJ_f) - np.var(dJ))
    return var_reduction_total,var_reduction_individual,real_reduction

def var_reduction_estimate_iterative(states,m_const,da_const,j=0,dJ_update_flag=0):
    #from scipy.linalg import sqrtm

    """
    Iteratively goes through all observations, estimates the impact on the response function, updates dx and dJ, and continues on.
    Currently starts with the location with the largest estimated impact, and then goes to the next highest, and so on. But the ordering should be unimportant.
    It return the total variance reduction estimate for all the individual observations.

    Following naming of Hakim 2020.

    It was a bit unclear in the paper, but we decided that we would update dx and dJ simultaneously by bringing dJ into the state vector. This was also confirmed by Hakim via email.
    This results in reduction estimates which are pretty much identical down to machine precision with those reached by including all observations at once.

    Is coded very inefficiently, does not make benefit of only looking at the problem in observation space "reduced model space".
    Currently is not tested for inputing multiple observations at once.

    Output:
    Individual variance reduction per each observation

    Options, dJ_update_flag
    0: Hakim version, updates dJ using "appended state" approach
    1: Modified Hakim. Updates dJ, but then scales dJ to perfectly match the estimated variance reduction perfectly. Should be a very small tweak
    2: Scales dJ to match the reduced variance estimate. Should be extremely cheap and robust, but totally neglects covariance information.
    2: Doesn't update dJ at all, only used to check with what Tanya did 


    """

    #For now I am not worried about being efficient
    nobs = len(da_const["obs_loc"])
    #if obs == []:
    obs = np.arange(nobs)

    j=0

    nens = da_const["nens"]
    ncyc = len(states[j]['response']['bf'])


    var_reduction_total      = []
    var_reduction_individual = []
    observation_ranking      = []
    real_reduction           = []

    #R is always the same
    R = da_const["used_std_obs"]**2*np.identity(1)           # Observation error corvariance matrix

    for t in range(1,ncyc-1):
    #for t in [10]:


        x = states[j]['bg'][t][:,:]
        dx = x.T-np.mean(x,axis=1)
        dx = dx.T
        A = np.dot(dx,dx.T)/(dx.shape[1]-1)

        J= states[j]["response"]['bf'][t+1][:]
        dJ = J-np.mean(J)


        obs_ordered = []
        obs_remain  = list(obs)
        vr_individual = np.zeros(nobs)
        for o in range(nobs):  #loop over the number of observations
        #for o in range(1):  #loop over the number of observations

            vr_ind_tmp = np.zeros(nobs)
            for oo in obs_remain: #loop over all observation locastions which have not been used yet used

                H = np.identity(m_const["nx"])[da_const["obs_loc"][oo],:]
                HAHt = np.dot(H,np.dot(A,H.T))
                HAHtRinv= np.linalg.inv(HAHt+R)
                dJHdxt = np.dot(dJ,np.dot(H,dx).T)/(nens-1)
                vr_ind_tmp[oo] = - np.dot(dJHdxt,np.dot(HAHtRinv,dJHdxt))
            #print(vr_ind_tmp)

            ind_min = np.where(vr_ind_tmp==np.min(vr_ind_tmp))[0][0]
            #print(ind_min)
            vr_individual[ind_min] = vr_ind_tmp[ind_min]
            obs_remain.remove(ind_min)

            H = np.identity(m_const["nx"])[da_const["obs_loc"][ind_min],:]

            #New we add in Tanyas reduction
            E = np.matmul(A,H.T)
            E = np.matmul(H,E)
            E = E + R
            alpha = 1./(1.+np.sqrt(R/E))



            if dJ_update_flag==0 or dJ_update_flag==1:

                #Moded Tanyas code because I had en error.
                #Turned out it was a matrix multiplication issue I solved by switching the final matmul with an np.outer. It has the right dimension, but I fear my version will only work if its single point measurements.
                #Now we include dJ into dx to include it in the update
                dxJ = np.vstack([dx,dJ])
                AxJ = np.dot(dxJ,dxJ.T)/(dxJ.shape[1]-1)
                HxJ = np.hstack([H,np.array(0)])
                HAHt = np.dot(HxJ,np.dot(AxJ,HxJ.T))
                HAHtRinv= np.linalg.inv(HAHt+R)


                K = np.dot(AxJ,HxJ.T)*HAHtRinv


                # Update state vector
                HdxJ = np.matmul(HxJ,dxJ)
                dxJ = dxJ - alpha*np.outer(K,HdxJ)

                dx = dxJ[:-1,:]
                old_var_dJ = np.var(dJ)
                dJ = dxJ[-1,:]
                new_var_dJ = np.var(dJ)

                if dJ_update_flag==1:
                    var_scaling = (old_var_dJ+ vr_ind_tmp[ind_min])/new_var_dJ
                    dJ=np.sqrt(var_scaling)*dJ
                    #print(old_var_dJ+vr_ind_tmp[ind_min],new_var_dJ,np.var(dJ))

            if dJ_update_flag==2 or dJ_update_flag==3:
                
                if dJ_update_flag==2:
                    var_scaling=(np.var(dJ)+vr_ind_tmp[ind_min])/np.var(dJ)
                    #New dJ
                    dJ=np.sqrt(var_scaling)*dJ

                #Update dx
                #Moded Tanyas code because I had en error.
                #Turned out it was a matrix multiplication issue I solved by switching the final matmul with an np.outer. It has the right dimension, but I fear my version will only work if its single point measurements.
                HAHt = np.dot(H,np.dot(A,H.T))
                HAHtRinv= np.linalg.inv(HAHt+R)
                K = np.dot(A,H.T)*HAHtRinv

                # Update state vector
                Hdx = np.matmul(H,dx)
                dx = dx - alpha*np.outer(K,Hdx)

            #Recalculating A makes things worse if you don't rescale dJ as well
            A = np.dot(dx,dx.T)/(dx.shape[1]-1)

        var_reduction_individual.append(vr_individual)
    return var_reduction_individual



def loc_matrix(da_const,m_const):
    """
    Creates the localization matrix, using either a gaussian or gaspari cohn function 
    """
    C = np.zeros([m_const['nx'],m_const['nx']])
    # I cheat a bit by mirroring the functions to accomodate for the repeating boundary conditions, but this should only lead to a maximum dx/2 error. 
    if da_const['loc_type']=='gaussian': 
        C[:,0] = gaussian_initial_condition(m_const["x_grid"],da_const["loc_length"])
    if da_const['loc_type']=='gaspari_cohn': 
        C[:,0] = gaspari_cohn(m_const["x_grid"],da_const["loc_length"])

    for i in range(1,m_const['nx']):
        C[:,i] = np.hstack([C[-1,i-1],C[:-1,i-1]])
    return C

def gaspari_cohn(x,loc_length):
    """Gaspari-Cohn function."""
    
    ra = x/loc_length
    gp = np.zeros_like(ra)
    i=np.where(ra<=1.)[0]
    gp[i]=-0.25*ra[i]**5+0.5*ra[i]**4+0.625*ra[i]**3-5./3.*ra[i]**2+1.
    i=np.where((ra>1.)*(ra<=2.))[0]
    gp[i]=1./12.*ra[i]**5-0.5*ra[i]**4+0.625*ra[i]**3+5./3.*ra[i]**2-5.*ra[i]+4.-2./3./ra[i]
        
    #Now we mirror things to zero for the periodic boundary domain
    half_idx = int(ra.shape[0]/2)
    gp[-half_idx:] = gp[half_idx-1::-1]
    return gp