From 4ce7f2add89817066651bce46de4aa29594923fe Mon Sep 17 00:00:00 2001 From: Philipp Griewank <philipp.griewank@uni-koeln.de> Date: Fri, 11 Aug 2023 10:46:18 +0200 Subject: [PATCH] Optimized the LETKF function using numba. Still isn't fast, but isn't as slow as it use to be. --- da_functions.py | 119 ++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 116 insertions(+), 3 deletions(-) diff --git a/da_functions.py b/da_functions.py index 8dc15c0..8b9ac93 100644 --- a/da_functions.py +++ b/da_functions.py @@ -5,6 +5,7 @@ import numpy as np from model_functions import * from misc_functions import * +from numba import jit @@ -556,7 +557,8 @@ def run_linear_advection_KF_22(m_const,da_const,sat_operator): states[j]["an"] = states[j]["an"]+[analysis] if da_const['method'] == 'LETKF': - analysis,bla = LETKF_analysis_22(background,obs,obs_sat,m_const,da_const,sat_operator) + # analysis,bla = LETKF_analysis_22(background,obs,obs_sat,m_const,da_const,sat_operator) + analysis,bla = LETKF_analysis_23(background,obs,obs_sat,m_const,da_const,sat_operator) states[j]["an"] = states[j]["an"]+[analysis] if da_const['method'] == 'sqEnKF': @@ -2081,6 +2083,116 @@ def LETKF_analysis_22(bg,obs,obs_sat,m_const,da_const,sat_operator): return x_a, x_ol_a +def LETKF_analysis_23(bg,obs,obs_sat,m_const,da_const,sat_operator): + """ + Same as LETKF_analysis_22 but speed things up by putting the middle loop into a separate numba function. + + It seems roughly 10 times faster than the 22 version. + + Further speed up could probably be achieved by reducing the size of the state arrays to not include areas outside the localized observations influence. + + Follows the recipe and notation of Hunt 2007, e.g. what is normally called B is now P_a, and _ol refers to over line, so the mean. + P_tilde_a refers to B but in ensemble space. + """ + + + + n_obs_h =len(da_const["obs_loc"]) + n_obs_sat =len(da_const["obs_loc_sat"]) + n_obs = n_obs_h + n_obs_sat + obs_loc = np.hstack([da_const["obs_loc"],da_const["obs_loc_sat"]]).astype(int) + + + L = da_const['nens'] + x_b = bg + x_ol_b = np.mean(x_b,axis=1) + X_b = x_b.T-x_ol_b + X_b = X_b.T + + R = da_const['R'] + + # this is the function that ports the background ensemble state to pertubations in observation space + Y_b, y_ol_b = state_to_observation_space(bg,m_const,da_const,sat_operator) + + # ineligant way to merge the obs vector depending on which observations occur + if n_obs_h>0: y_obs = obs[da_const["obs_loc"]] + if n_obs_sat>0: sat_obs = obs_sat[da_const["obs_loc_sat"]] + if n_obs_h>0 and n_obs_sat>0 : y_obs = np.hstack([y_obs,sat_obs]) + if n_obs_sat>0 and n_obs_h==0 : y_obs = sat_obs + + delta_y = y_obs-y_ol_b + + if da_const['loc']==False: + """ Now that all the variables are set, we start by computing the covariance matrix in ensemble state """ + YRY = np.dot(Y_b.T,np.dot(np.linalg.inv(R),Y_b)) + P_tilde_a = np.linalg.inv((L-1)*np.identity(L)+YRY) + + """Next step, computing the enesemble mean analysis via the weighting vector w_ol_a""" + w_ol_a = np.dot(P_tilde_a,np.dot(Y_b.T,np.dot(np.linalg.inv(R),delta_y))) + x_ol_a = x_ol_b+np.dot(X_b,w_ol_a) + + """We now get the ensemble by calculating the weighting matrix through a square root of the error covariance matrix, and adding the mean values to the ensemble deviations""" + W_a = np.real(sqrtm((L-1)*P_tilde_a)) + w_a = W_a+w_ol_a + + x_a = np.dot(X_b,w_a).T+ x_ol_a + x_a = x_a.T + + else: + x_a, x_ol_a=LETKF_numba_loop(m_const['x_grid'],da_const['loc_length'],obs_loc,x_ol_b,x_b,R,Y_b,L,delta_y,X_b,m_const['nx'],m_const['dx']) + + return x_a, x_ol_a + +@jit +def LETKF_numba_loop(x_grid,loc_length,obs_loc,x_ol_b,x_b,R,Y_b,L,delta_y,X_b,N,dx): + x_grid_neg = -x_grid[::-1]-dx + x_grid_ext = x_grid+dx*N + x_a_loc = np.zeros((N,L)) + x_ol_a_loc = np.zeros((N)) + + for g in range(N): + dist_reg = np.abs(x_grid[obs_loc]-x_grid[g]) + dist_neg = x_grid[g]-x_grid_neg[obs_loc] + dist_ext = x_grid_ext[obs_loc]-x_grid[g] + dist = np.minimum(dist_reg,dist_ext)#apparently minimum doesn't like 3 variables ,dist_neg) + dist = np.minimum(dist,dist_neg) + #And now we calculate the gaspari cohn weighting and apply it to inverse R + """Gaspari-Cohn function, with no mirroring.""" + + ra = np.abs(dist)/loc_length + + gp = np.zeros_like(ra) + i=np.where(ra<=1.)[0] + gp[i]=-0.25*ra[i]**5+0.5*ra[i]**4+0.625*ra[i]**3-5./3.*ra[i]**2+1. + i=np.where((ra>1.)*(ra<=2.))[0] + gp[i]=1./12.*ra[i]**5-0.5*ra[i]**4+0.625*ra[i]**3+5./3.*ra[i]**2-5.*ra[i]+4.-2./3./ra[i] + + gc_loc=gp + + if np.max(gc_loc) == 0.: + #If no observations are within the double gc radius no need to bother computing things + x_ol_a_loc[g] = x_ol_b[g] + x_a_loc[g,:] = x_b[g,:] + else: + R_inv = np.linalg.inv(R) * np.diag(gc_loc) + + YRY = np.dot(Y_b.T,np.dot(R_inv,Y_b)) + """ Now that all the variables are set, we start by computing the covariance matrix in ensemble state """ + P_tilde_a = np.linalg.inv((L-1)*np.identity(L)+YRY) + + """Next step, computing the enesemble mean analysis via the weighting vector w_ol_a""" + w_ol_a = np.dot(P_tilde_a,np.dot(Y_b.T,np.dot(R_inv,delta_y))) + x_ol_a = x_ol_b[g]+np.dot(X_b[g],w_ol_a) + # Trying to replace sqrtm with numpy eigenvectors so jit can work. + evalues, evectors = np.linalg.eigh((L-1.)*P_tilde_a) + W_a = evectors * np.sqrt(evalues) @ np.linalg.inv(evectors) + w_a = W_a+w_ol_a + x_a = np.dot(X_b[g,:],w_a).T+ x_ol_a + x_ol_a_loc[g] = x_ol_a + x_a_loc[g,:] = x_a + + return x_a_loc,x_ol_a_loc + def state_to_observation_space(X,m_const,da_const,sat_operator): """ @@ -2176,7 +2288,7 @@ def square_root_Kalman_gain_observation_deviations(bg,m_const,da_const,sat_opera L_obs, L_obs_state = localization_matrices_observation_space(m_const,da_const) else: L_obs = np.ones([n_obs,n_obs]) - L_obs_state = np.ones([n_x,n_obs]) + L_obs_state = np.ones([m_const['nx'],n_obs]) YYR = L_obs*np.dot(dY_b,dY_b.T)/(L-1)+R XY = L_obs_state*np.dot(X_b,dY_b.T)/(L-1) @@ -2252,7 +2364,8 @@ def single_step_analysis_forecast_22(background,truth,da_const,m_const,sat_opera an = ENKF_analysis_22(background,obs,obs_sat,da_const,m_const,sat_operator) if da_const['method'] == 'LETKF': - an,bla = LETKF_analysis_22(background,obs,obs_sat,m_const,da_const,sat_operator) + # an,bla = LETKF_analysis_22(background,obs,obs_sat,m_const,da_const,sat_operator) + an,bla = LETKF_analysis_23(background,obs,obs_sat,m_const,da_const,sat_operator) if da_const['method'] == 'sqEnKF': an = sqEnKF_analysis_22(background,obs,obs_sat,da_const,m_const,sat_operator) -- GitLab