diff --git a/da_functions.py b/da_functions.py index 003a70e283f3810f815e93f4fa1013b1b4ca0a01..d6926c88fdfd9bcd9f16537a04ec0af9e200730d 100644 --- a/da_functions.py +++ b/da_functions.py @@ -957,3 +957,52 @@ def LETKF_analysis(bg,obs,m_const,da_const): return x_a,x_ol_a + + + + +def L2_regularized_inversion(A, b, alpha_init=0.1,alpha=None,mismatch_threshold=0.05): + """Instead of solving for Ax=b, which isn't possible if A is not invertible, the regularization instead minimizes ||Ax-b||^2 + || alpha x ||^2. + The solution is unique and well defined: x = (AA.T + alpha*alpha I )^-1 A.T b. + + While this works fine with np.linalg.inv, I use .solve instead because it is roughly a factor 3 quicker. + + I tried to find a easy way to link the starting alpha to the model space or ensemble size, but after not finding anything easy I just start with a prescribed value + which is 0.1 by default. It is checked that the mismatch between sum(Ax-b)/sum(b) does not exceed the mismatch_threshold, and if it does alpha is reduced by a factor of 2 until it does. + The mismatch_threshold is given in percent, and is used to check if ||(Ax-b)||/||b|| falls below the theshold. + + alternative would be to try to use the things presented by Shu-Chih Yang's talk at the ISDA-online + kappa_req = 10000. + n = numbers of non-diagonal componentts of correlation matrix C + r = average of non-diagonal componentts of correlation matrix C + lamda_max = 1+(n_cols-1)*r + alpha = lamda_max/kappa_req + print(kappa_req,r,lamda_max,alpha) + + + """ + n_cols = A.shape[1] + I = np.identity(n_cols) + #n_ens = b.size + #this is just a guess, but alpha should decrease with size, is not applied if a value other than 1 is applied + if alpha == None: + #alpha= 1./np.sqrt(n_cols) + alpha= alpha_init #1#n_cols/n_ens + #x = np.linalg.inv(A.T.dot(A) + alpha**2 *I).dot(A.T).dot(b) + x = np.linalg.solve((A.T.dot(A) + alpha**2 *I),(A.T).dot(b))#,rcond=-1) + while np.sum(np.abs(A.dot(x)-b))/np.sum(np.abs(b))> mismatch_threshold: + alpha = alpha/2 + x = np.linalg.solve((A.T.dot(A) + alpha**2 *I),(A.T).dot(b))#,rcond=-1) + #x = np.linalg.inv(A.T.dot(A) + alpha**2 *I).dot(A.T).dot(b) + #print('reducing regularization:',alpha,np.sum(np.abs(A.dot(x)-b))/np.sum(np.abs(b))) + else: + x = np.linalg.solve((A.T.dot(A) + alpha**2 *I),(A.T).dot(b))#,rcond=-1) + #x = np.linalg.inv(A.T.dot(A) + alpha**2 *I).dot(A.T).dot(b) + + + return x + + + + +