Skip to content
Snippets Groups Projects
Commit 61606465 authored by Philipp Griewank's avatar Philipp Griewank
Browse files

Added regularized inversion.

This dramatically helps when localization is applied, but comes with a not well constrained tuning parameter. Will definitely be used as the default method moving forward.
parent 0fbead2e
No related branches found
No related tags found
No related merge requests found
...@@ -957,3 +957,52 @@ def LETKF_analysis(bg,obs,m_const,da_const): ...@@ -957,3 +957,52 @@ def LETKF_analysis(bg,obs,m_const,da_const):
return x_a,x_ol_a return x_a,x_ol_a
def L2_regularized_inversion(A, b, alpha_init=0.1,alpha=None,mismatch_threshold=0.05):
"""Instead of solving for Ax=b, which isn't possible if A is not invertible, the regularization instead minimizes ||Ax-b||^2 + || alpha x ||^2.
The solution is unique and well defined: x = (AA.T + alpha*alpha I )^-1 A.T b.
While this works fine with np.linalg.inv, I use .solve instead because it is roughly a factor 3 quicker.
I tried to find a easy way to link the starting alpha to the model space or ensemble size, but after not finding anything easy I just start with a prescribed value
which is 0.1 by default. It is checked that the mismatch between sum(Ax-b)/sum(b) does not exceed the mismatch_threshold, and if it does alpha is reduced by a factor of 2 until it does.
The mismatch_threshold is given in percent, and is used to check if ||(Ax-b)||/||b|| falls below the theshold.
alternative would be to try to use the things presented by Shu-Chih Yang's talk at the ISDA-online
kappa_req = 10000.
n = numbers of non-diagonal componentts of correlation matrix C
r = average of non-diagonal componentts of correlation matrix C
lamda_max = 1+(n_cols-1)*r
alpha = lamda_max/kappa_req
print(kappa_req,r,lamda_max,alpha)
"""
n_cols = A.shape[1]
I = np.identity(n_cols)
#n_ens = b.size
#this is just a guess, but alpha should decrease with size, is not applied if a value other than 1 is applied
if alpha == None:
#alpha= 1./np.sqrt(n_cols)
alpha= alpha_init #1#n_cols/n_ens
#x = np.linalg.inv(A.T.dot(A) + alpha**2 *I).dot(A.T).dot(b)
x = np.linalg.solve((A.T.dot(A) + alpha**2 *I),(A.T).dot(b))#,rcond=-1)
while np.sum(np.abs(A.dot(x)-b))/np.sum(np.abs(b))> mismatch_threshold:
alpha = alpha/2
x = np.linalg.solve((A.T.dot(A) + alpha**2 *I),(A.T).dot(b))#,rcond=-1)
#x = np.linalg.inv(A.T.dot(A) + alpha**2 *I).dot(A.T).dot(b)
#print('reducing regularization:',alpha,np.sum(np.abs(A.dot(x)-b))/np.sum(np.abs(b)))
else:
x = np.linalg.solve((A.T.dot(A) + alpha**2 *I),(A.T).dot(b))#,rcond=-1)
#x = np.linalg.inv(A.T.dot(A) + alpha**2 *I).dot(A.T).dot(b)
return x
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment