From 0e88b899d97d899e999787c20e23f5b02da841bb Mon Sep 17 00:00:00 2001
From: lkugler <lukas.kugler@gmail.com>
Date: Fri, 19 May 2023 18:07:32 +0200
Subject: [PATCH] evaluate only if necessary

---
 dartwrf/assim_synth_obs.py | 54 +++++++++++++++++++++++---------------
 1 file changed, 33 insertions(+), 21 deletions(-)

diff --git a/dartwrf/assim_synth_obs.py b/dartwrf/assim_synth_obs.py
index e3860f3..b3dc551 100755
--- a/dartwrf/assim_synth_obs.py
+++ b/dartwrf/assim_synth_obs.py
@@ -204,26 +204,32 @@ def get_parametrized_error(obscfg, osf_prior):
         NotImplementedError('sat_channel not implemented', obscfg.get("sat_channel"))
 
 
-def set_obserr_assimilate_in_obsseqout(oso, osf_prior, outfile="./obs_seq.out"):
+def set_obserr_assimilate_in_obsseqout(oso, outfile="./obs_seq.out"):
     """"Overwrite existing variance values in obs_seq.out files
     
     Args:
         oso (ObsSeq): python representation of obs_seq.out file, will be modified and written to file
-        osf_prior (ObsSeq): python representation of obs_seq.final (output of filter in evaluate-mode without posterior)
-                            contains prior values; used for parameterized errors
 
     Returns:
         None    (writes to file)
+
+    Variables:
+        osf_prior (ObsSeq): python representation of obs_seq.final (output of filter in evaluate-mode without posterior)
+                        contains prior values; used for parameterized errors
     """
+
     for obscfg in exp.observations:
-        kind_str = obscfg['kind']
-        kind = osq.obs_kind_nrs[kind_str]
+        kind_str = obscfg['kind']  # e.g. 'RADIOSONDE_TEMPERATURE'
+        kind = osq.obs_kind_nrs[kind_str]  # e.g. 263
 
-        # modify each kind separately, one after each other
+        # modify observation error of each kind sequentially
         where_oso_iskind = oso.df.kind == kind
-        where_osf_iskind = osf_prior.df.kind == kind
-
+        
         if obscfg["error_assimilate"] == False:
+            osf_prior = obsseq.ObsSeq(cluster.dartrundir + "/obs_seq.final")  # this file will be generated by `evaluate()`
+            
+            where_osf_iskind = osf_prior.df.kind == kind
+
             assim_err = get_parametrized_error(obscfg, osf_prior.df[where_osf_iskind])
             oso.df.loc[where_oso_iskind, 'variance'] = assim_err**2
             #assert np.allclose(assim_err, oso.df['variance']**2)  # check
@@ -233,7 +239,9 @@ def set_obserr_assimilate_in_obsseqout(oso, osf_prior, outfile="./obs_seq.out"):
 
     oso.to_dart(outfile)
 
-def qc_obs(time, oso, osf_prior):
+def qc_obs(time, oso):
+    osf_prior = obsseq.ObsSeq(cluster.dartrundir + "/obs_seq.final")
+
     # obs should be superobbed already!
     for i, obscfg in enumerate(exp.observations): 
         if i > 0:
@@ -328,9 +336,6 @@ def evaluate(assim_time,
     copy(cluster.dart_rundir + "/obs_seq.final", fout)
     print(fout, "saved.")
 
-    osf = obsseq.ObsSeq(cluster.dart_rundir + "/obs_seq.final")
-    return osf
-
 
 
 def generate_obsseq_out(time):
@@ -526,6 +531,10 @@ def main(time, prior_init_time, prior_valid_time, prior_path_exp):
         None
     """
     nproc = cluster.max_nproc
+    do_QC = getattr(exp, "reject_smallFGD", False)  # True: triggers additional evaluations of prior & posterior
+
+    # for which observation type do we have a parametrized observation error?
+    error_is_parametrized = [obscfg["error_assimilate"] == False for obscfg in exp.observations]
 
     prepare_run_DART_folder()
     nml = dart_nml.write_namelist()
@@ -540,15 +549,17 @@ def main(time, prior_init_time, prior_valid_time, prior_path_exp):
     print(" 1) get observations with specified obs-error")
     oso = get_obsseq_out(time)
 
-    print(" 2.1) evaluate prior for all observations (incl rejected)")
-    osf_prior = evaluate(time, output_format="%Y-%m-%d_%H:%M_obs_seq.final-eval_prior_allobs")
+    # is any observation error parametrized?
+    if any(error_is_parametrized) or do_QC:
+        print(" (optional) evaluate prior for all observations (incl rejected)")
+        evaluate(time, output_format="%Y-%m-%d_%H:%M_obs_seq.final-eval_prior_allobs")
 
-    print(" 2.2) assign observation-errors for assimilation ")
-    set_obserr_assimilate_in_obsseqout(oso, osf_prior, outfile=cluster.dart_rundir + "/obs_seq.out")
+    print(" assign observation-errors for assimilation ")
+    set_obserr_assimilate_in_obsseqout(oso, outfile=cluster.dartrundir + "/obs_seq.out")
 
-    if getattr(exp, "reject_smallFGD", False):
+    if do_QC:
         print(" 2.3) reject observations? ")
-        qc_obs(time, oso, osf_prior)
+        qc_obs(time, oso)
 
     if prior_inflation_type == '2':
         prepare_inflation_2(time, prior_init_time)
@@ -561,9 +572,10 @@ def main(time, prior_init_time, prior_valid_time, prior_path_exp):
     if prior_inflation_type == '2':
         archive_inflation_2(time)
 
-    print(" 4) evaluate posterior observations for all observations (incl rejected)")
-    write_list_of_inputfiles_posterior(time)
-    if getattr(exp, "reject_smallFGD", False):
+    if do_QC:
+        print(" 4) evaluate posterior observations for all observations (incl rejected)")
+        write_list_of_inputfiles_posterior(time)
+        
         copy(cluster.archivedir+'/obs_seq_out/'+time.strftime('%Y-%m-%d_%H:%M_obs_seq.out-beforeQC'), 
              cluster.dart_rundir+'/obs_seq.out')
     evaluate(time, output_format="%Y-%m-%d_%H:%M_obs_seq.final-eval_posterior_allobs")
-- 
GitLab