From e052621842abf8e1df2178d9b8a9549e3741bfb6 Mon Sep 17 00:00:00 2001
From: Lukas Kugler <lukas.kugler@univie.ac.at>
Date: Tue, 6 May 2025 01:45:28 +0200
Subject: [PATCH] small changes for singlescript mode, errormodel

---
 config/jet.py                |   7 +-
 dartwrf/__init__.py          |   2 +-
 dartwrf/assimilate.py        | 135 +++++++++++++++++++----------------
 dartwrf/obs/error_models.py  |  23 ++++--
 dartwrf/prepare_wrfrundir.py |   6 +-
 dartwrf/print_config.py      |   7 +-
 dartwrf/utils.py             |  37 +++++++++-
 dartwrf/workflows.py         |  53 +++++++-------
 multiple_exps.py             |  78 ++++++++++++--------
 templates/namelist.input     |   2 +-
 test-singlenode.py           | 135 +++++++++++++++++++++++++++++++++++
 timer.py                     |  38 ++++++++++
 12 files changed, 392 insertions(+), 131 deletions(-)
 mode change 100644 => 100755 dartwrf/print_config.py
 create mode 100644 test-singlenode.py
 create mode 100644 timer.py

diff --git a/config/jet.py b/config/jet.py
index 5b30a22..d39f136 100755
--- a/config/jet.py
+++ b/config/jet.py
@@ -33,9 +33,10 @@ cluster_defaults = dict(
     WRF_exe_template = '/jetfs/home/lkugler/DART-WRF/templates/run_WRF.jet.sh',
     WRF_ideal_template = '/jetfs/home/lkugler/DART-WRF/templates/run_WRF_ideal.sh',
 
-    slurm_kwargs = {"account": "lkugler", "partition": "all", "time": "00:30:00",
-                    "ntasks": "1", "ntasks-per-core": "1", "mem": "30G",
-                    "mail-type": "FAIL", "mail-user": "lukas.kugler@univie.ac.at"},
+    slurm_kwargs = {"account": "lkugler", "partition": "devel", "time": "30",
+                    "nodes": "1", "ntasks": "1", "ntasks-per-core": "1", "mem": "25G",
+                    #"exclude": "jet07,jet11,jet16,jet18,jet19",
+                    "mail-type": "FAIL,TIME_LIMIT_80", "mail-user": "lukas.kugler@univie.ac.at"},
 
     # WRF file format, will only change if WRF changes
     wrfout_format = '/wrfout_d01_%Y-%m-%d_%H:%M:%S',
diff --git a/dartwrf/__init__.py b/dartwrf/__init__.py
index b680692..46c1de7 100644
--- a/dartwrf/__init__.py
+++ b/dartwrf/__init__.py
@@ -1 +1 @@
-__all__ = []
\ No newline at end of file
+#__all__ = []
\ No newline at end of file
diff --git a/dartwrf/assimilate.py b/dartwrf/assimilate.py
index bdc25b8..2c8684c 100755
--- a/dartwrf/assimilate.py
+++ b/dartwrf/assimilate.py
@@ -14,7 +14,7 @@ from dartwrf.obs import obsseq
 from dartwrf import dart_nml
 
 
-def prepare_DART_grid_template():
+def prepare_DART_grid_template(cfg):
     """Prepare DART grid template wrfinput_d01 file from a prior file
     
     DART needs a wrfinput file as a template for the grid information.
@@ -30,7 +30,7 @@ def prepare_DART_grid_template():
     else:
         pass # what now?
 
-def prepare_prior_ensemble(assim_time, prior_init_time, prior_valid_time, prior_path_exp):
+def prepare_prior_ensemble(cfg, assim_time, prior_init_time, prior_valid_time, prior_path_exp):
     """Prepares DART files for running filter
     i.e.
     - links first guess state to DART first guess filenames
@@ -70,8 +70,8 @@ def prepare_prior_ensemble(assim_time, prior_init_time, prior_valid_time, prior_
         #if cluster.geo_em_forecast:
         #    wrfout_add_geo.run(cluster.geo_em_forecast, wrfout_dart)
 
-    use_linked_files_as_prior()
-    write_list_of_outputfiles()
+    use_linked_files_as_prior(cfg)
+    write_list_of_outputfiles(cfg)
 
     print("removing preassim and filter_restart")
     os.system("rm -rf " + cfg.dir_dart_run + "/preassim_*")
@@ -82,7 +82,7 @@ def prepare_prior_ensemble(assim_time, prior_init_time, prior_valid_time, prior_
     os.system("rm -rf " + cfg.dir_dart_run + "/obs_seq.fina*")
 
 
-def use_linked_files_as_prior():
+def use_linked_files_as_prior(cfg):
     """Instruct DART to use the prior ensemble as input
     """
     files = []
@@ -91,7 +91,7 @@ def use_linked_files_as_prior():
     write_txt(files, cfg.dir_dart_run+'/input_list.txt')
 
 
-def use_filter_output_as_prior():
+def use_filter_output_as_prior(cfg):
     """Use the last posterior as input for DART, e.g. to evaluate the analysis in observation space
     """
     files = []
@@ -108,14 +108,14 @@ def use_filter_output_as_prior():
     write_txt(files, cfg.dir_dart_run+'/input_list.txt')
 
 
-def write_list_of_outputfiles():
+def write_list_of_outputfiles(cfg):
     files = []
     for iens in range(1, cfg.ensemble_size+1):
         files.append("./filter_restart_d01." + str(iens).zfill(4))
     write_txt(files, cfg.dir_dart_run+'/output_list.txt')
 
 
-def filter():
+def filter(cfg):
     """Calls DART ./filter program
 
     Args:
@@ -146,7 +146,7 @@ def filter():
             "Check log file at " + cfg.dir_dart_run + "/log.filter")
 
 
-def archive_filteroutput(time):
+def archive_filteroutput(cfg, time):
     """Archive filter output files (filter_restart, preassim, postassim, output_mean, output_sd)
     """
     # archive diagnostics
@@ -210,7 +210,7 @@ def get_parametrized_error(obscfg, osf_prior) -> np.ndarray: # type: ignore
                             obscfg.get("sat_channel"))
     
 
-def set_obserr_assimilate_in_obsseqout(oso, outfile="./obs_seq.out"):
+def set_obserr_assimilate_in_obsseqout(cfg, oso, outfile="./obs_seq.out"):
     """"Overwrite existing variance values in obs_seq.out files
 
     Args:
@@ -231,32 +231,43 @@ def set_obserr_assimilate_in_obsseqout(oso, outfile="./obs_seq.out"):
         # modify observation error of each kind sequentially
         where_oso_iskind = oso.df.kind == kind
 
-        if hasattr(obscfg, "error_assimilate"):
-            if obscfg["error_assimilate"] == False:
+        if "error_assimilate" in obscfg:
+            if obscfg["error_assimilate"]  == False:
+                print("error_assimilate is False, will compute dynamic obs-errors")
                 # get a parametrized error for this observation type
-
-                f_osf = cfg.dir_dart_run + "/obs_seq.final"
-                if not os.path.isfile(f_osf):
-                    evaluate(cfg.time, f_out_pattern=f_osf)
-
-                # this file was generated by `evaluate()`
-                osf_prior = obsseq.ObsSeq(f_osf)
-
-                where_osf_iskind = osf_prior.df.kind == kind
-
-                assim_err = get_parametrized_error(
-                    obscfg, osf_prior.df[where_osf_iskind])
-                oso.df.loc[where_oso_iskind, 'variance'] = assim_err**2
-                # assert np.allclose(assim_err, oso.df['variance']**2)  # check
+                
+                # parametrization is state dependent => need states
+                use_external_FO = obscfg.get("external_FO", False)
+                if use_external_FO:
+                    # read prior from obs_seq.out
+                    pass
+                    
+                    # modify OE in obs_seq.out
+                    pass
+                else:
+                    f_osf = cfg.dir_dart_run + "/obs_seq.final"
+                    if not os.path.isfile(f_osf):
+                        print('computing prior as input for dynamic obs errors')
+                        # generates obs_seq.final
+                        evaluate(cfg, cfg.time, f_out_pattern=f_osf)
+
+                    # read prior (obs_seq.final)
+                    osf_prior = obsseq.ObsSeq(f_osf)
+                    where_osf_iskind = osf_prior.df.kind == kind
+
+                    assim_err = get_parametrized_error(
+                        obscfg, osf_prior.df[where_osf_iskind])
+                    oso.df.loc[where_oso_iskind, 'variance'] = assim_err**2
+                    # assert np.allclose(assim_err, oso.df['variance']**2)  # check
             else:
                 # overwrite with user-defined values
                 oso.df.loc[where_oso_iskind,
-                           'variance'] = obscfg["error_assimilate"]**2
+                            'variance'] = obscfg["error_assimilate"]**2
 
     oso.to_dart(outfile)
 
 
-def reject_small_FGD(time, oso):
+def reject_small_FGD(cfg, time, oso):
     """Quality control of observations
     We assume that the prior values have been evaluated and are in `run_DART/obs_seq.final`
 
@@ -330,7 +341,7 @@ def reject_small_FGD(time, oso):
         print('saved', f_out_dart)
 
 
-def evaluate(assim_time,
+def evaluate(cfg, assim_time,
              obs_seq_out=False,
              prior_is_filter_output=False,
              f_out_pattern: str = './obs_seq.final'):
@@ -354,10 +365,10 @@ def evaluate(assim_time,
 
     if prior_is_filter_output:
         print('using filter_restart files in run_DART as prior')
-        use_filter_output_as_prior()
+        use_filter_output_as_prior(cfg)
     else:
         print('using files linked to `run_DART/<exp>/prior_ens*/wrfout_d01` as prior')
-        use_linked_files_as_prior()
+        use_linked_files_as_prior(cfg)
 
     # the observations at which to evaluate the prior at
     if obs_seq_out:
@@ -369,11 +380,11 @@ def evaluate(assim_time,
                                '/obs_seq.out does not exist')
 
     dart_nml.write_namelist(cfg, just_prior_values=True)
-    filter()
-    archive_filter_diagnostics(assim_time, f_out_pattern)
+    filter(cfg)
+    archive_filter_diagnostics(cfg, assim_time, f_out_pattern)
 
 
-def archive_filter_diagnostics(time, f_out_pattern):
+def archive_filter_diagnostics(cfg, time, f_out_pattern):
     """Copy the filter output txt to the archive
     """
     f_archive = time.strftime(f_out_pattern)
@@ -384,7 +395,7 @@ def archive_filter_diagnostics(time, f_out_pattern):
     print(f_archive, "saved.")
 
 
-def txtlink_to_prior(time, prior_init_time, prior_path_exp):
+def txtlink_to_prior(cfg, time, prior_init_time, prior_path_exp):
     """For reproducibility, write the path of the prior to a txt file
     """
     os.makedirs(cfg.dir_archive +
@@ -394,7 +405,7 @@ def txtlink_to_prior(time, prior_init_time, prior_path_exp):
                 + cfg.dir_archive + time.strftime('/%Y-%m-%d_%H:%M/')+'link_to_prior.txt')
 
 
-def prepare_inflation_2(time, prior_init_time):
+def prepare_adapt_inflation(cfg, time, prior_init_time):
     """Prepare inflation files (spatially varying)
 
     Recycles inflation files from previous assimilations
@@ -407,7 +418,6 @@ def prepare_inflation_2(time, prior_init_time):
     dir_priorinf = cfg.dir_archive + \
         prior_init_time.strftime(cfg.pattern_init_time)
 
-    f_default = cfg.dir_archive.replace('<exp>', cfg.name) + "/input_priorinf_mean.nc"
     f_prior = dir_priorinf + \
         time.strftime("/%Y-%m-%d_%H:%M_output_priorinf_mean.nc")
     f_new = cfg.dir_dart_run + '/input_priorinf_mean.nc'
@@ -415,11 +425,13 @@ def prepare_inflation_2(time, prior_init_time):
     if os.path.isfile(f_prior):
         copy(f_prior, f_new)
         print(f_prior, 'copied to', f_new)
-    else:  # no prior inflation file at the first assimilation
+    else:  
+        # no prior inflation file at the first assimilation
         warnings.warn(f_prior + ' does not exist. Using default file instead.')
+        f_default = cfg.dir_archive+"/../input_priorinf_mean.nc"
         copy(f_default, f_new)
 
-    f_default = cfg.archive_base + "/input_priorinf_sd.nc"
+
     f_prior = dir_priorinf + \
         time.strftime("/%Y-%m-%d_%H:%M_output_priorinf_sd.nc")
     f_new = cfg.dir_dart_run + '/input_priorinf_sd.nc'
@@ -428,11 +440,13 @@ def prepare_inflation_2(time, prior_init_time):
         copy(f_prior, f_new)
         print(f_prior, 'copied to', f_new)
     else:
+        # no prior inflation file at the first assimilation
         warnings.warn(f_prior + ' does not exist. Using default file instead.')
+        f_default = cfg.dir_archive + "/../input_priorinf_sd.nc"
         copy(f_default, f_new)
 
 
-def archive_inflation_2(time):
+def archive_adapt_inflation(cfg, time):
     dir_output = cfg.dir_archive + time.strftime(cfg.pattern_init_time)
     os.makedirs(dir_output, exist_ok=True)
 
@@ -490,7 +504,7 @@ def prepare_run_DART_folder(cfg: Config):
     # remove any remains of a previous run
     os.makedirs(cfg.dir_dart_run, exist_ok=True)
     os.chdir(cfg.dir_dart_run)
-    os.system("rm -f input.nml obs_seq.in obs_seq.out obs_seq.out-orig obs_seq.final output_* input_*")
+    os.system("rm -f input.nml obs_seq.in obs_seq.out-orig obs_seq.final")
     
     __link_DART_exe()
     
@@ -500,13 +514,13 @@ def prepare_run_DART_folder(cfg: Config):
             continue  # only need to link RTTOV files once
 
 
-def get_obsseq_out(time, prior_path_exp, prior_init_time, prior_valid_time, lag=None):
+def get_obsseq_out(cfg, time, prior_path_exp, prior_init_time, prior_valid_time):
     """Prepares an obs_seq.out file in the run_DART folder
 
-    3 Options:
+    Options:
     1) Use existing obs_seq.out file
-    2) Use precomputed FO (for cloud fraction)
-    3) Generate new observations with new observation noise
+    2) Use precomputed FO (e.g. cloud fraction)
+    3) Generate new observations from nature run with new noise
 
     Args:
         time (datetime): time of assimilation
@@ -615,38 +629,37 @@ def main(cfg: Config):
     nml = dart_nml.write_namelist(cfg)
 
     print(" get observations with specified obs-error")
-    #lag = dt.timedelta(minutes=15)
-    oso = get_obsseq_out(time, prior_path_exp, prior_init_time, prior_valid_time)
+    oso = get_obsseq_out(cfg, time, prior_path_exp, prior_init_time, prior_valid_time)
     
     # prepare for assimilation
-    prepare_prior_ensemble(time, prior_init_time, prior_valid_time, prior_path_exp)
-    prepare_DART_grid_template()
+    prepare_prior_ensemble(cfg, time, prior_init_time, prior_valid_time, prior_path_exp)
+    prepare_DART_grid_template(cfg)
 
     # additional evaluation of prior (in case some observations are rejected)
     if do_reject_smallFGD:
         print(" evaluate prior for all observations (incl rejected) ")
-        evaluate(time, f_out_pattern=cfg.pattern_obs_seq_final+"-evaluate_prior")
+        evaluate(cfg, time, f_out_pattern=cfg.pattern_obs_seq_final+"-evaluate_prior")
 
     print(" assign observation-errors for assimilation ")
-    set_obserr_assimilate_in_obsseqout(oso, outfile=cfg.dir_dart_run + "/obs_seq.out")
+    set_obserr_assimilate_in_obsseqout(cfg, oso, outfile=cfg.dir_dart_run + "/obs_seq.out")
 
     if do_reject_smallFGD:
         print(" reject observations? ")
-        reject_small_FGD(time, oso)
+        reject_small_FGD(cfg, time, oso)
 
     prior_inflation_type = nml['&filter_nml']['inf_flavor'][0][0]
-    if prior_inflation_type == '2':
-        prepare_inflation_2(time, prior_init_time)
+    if prior_inflation_type != '0':
+        prepare_adapt_inflation(cfg, time, prior_init_time)
 
     print(" run filter ")
     dart_nml.write_namelist(cfg)
-    filter()
-    archive_filteroutput(time)
-    archive_filter_diagnostics(time, cfg.pattern_obs_seq_final)
-    txtlink_to_prior(time, prior_init_time, prior_path_exp)
+    filter(cfg)
+    archive_filteroutput(cfg, time)
+    archive_filter_diagnostics(cfg, time, cfg.pattern_obs_seq_final)
+    txtlink_to_prior(cfg, time, prior_init_time, prior_path_exp)
 
-    if prior_inflation_type == '2':
-        archive_inflation_2(time)
+    if prior_inflation_type != '0':
+        archive_adapt_inflation(cfg, time)
 
     if 'evaluate_posterior_in_obs_space' in cfg:
         if cfg.evaluate_posterior_in_obs_space:
@@ -658,7 +671,7 @@ def main(cfg: Config):
 
             # evaluate() separately after ./filter is crucial when assimilating cloud variables
             # as the obs_seq.final returned by ./filter was produced using un-clamped cloud variables
-            evaluate(time,
+            evaluate(cfg, time,
                     obs_seq_out=f_oso,
                     prior_is_filter_output=True,
                     f_out_pattern=cfg.pattern_obs_seq_final+"-evaluate")
diff --git a/dartwrf/obs/error_models.py b/dartwrf/obs/error_models.py
index 9ec0eae..999662d 100644
--- a/dartwrf/obs/error_models.py
+++ b/dartwrf/obs/error_models.py
@@ -14,7 +14,7 @@ def calc_obserr_WV(channel, Hx_nature, Hx_prior):
     """
     if channel not in ['WV62', 'WV73']:
         raise NotImplementedError("channel not implemented: " + channel)
-    debug = False
+    debug = True
 
     n_obs = len(Hx_nature)
     OEs = np.ones(n_obs)
@@ -29,10 +29,12 @@ def calc_obserr_WV(channel, Hx_nature, Hx_prior):
         if channel == 'WV62':
             oe_model = _OE_model_harnisch_WV62(mean_CI)
         elif channel == 'WV73':
-            oe_model = _OE_model_harnisch_WV73(mean_CI)
+            oe_model = _OE_model_new73(mean_CI)
         
         if debug:
-            print("BT_nature=", bt_y, "=> mean_CI=", mean_CI, "=> OE_assim=", oe_model)
+            print("bt_y=", bt_y, "bt_x_ens=", bt_x_ens)
+            print("CIs=", CIs)
+            print("=> mean_CI=", mean_CI, "=> OE_assim=", oe_model)
         
         OEs[iobs] = oe_model
     return OEs
@@ -44,7 +46,8 @@ def _cloudimpact(channel, bt_mod, bt_obs):
     """
     if channel == 'WV73':
         biascor_obs = 0
-        bt_lim = 255  # Kelvin for 7.3 micron WV channel
+        # bt_lim = 252  # Kelvin for 7.3 micron WV channel
+        bt_lim = 255.0  # new
     elif channel == 'WV62':
         biascor_obs = 0
         bt_lim = 232.5  # Kelvin for 6.2 micron WV channel
@@ -71,10 +74,18 @@ def _OE_model_harnisch_WV73(ci):
         # Kelvin, fit of Fig 7b, Harnisch 2016
         x_ci = [0, 5, 10.5, 13, 16]  # average cloud impact [K]
         y_oe = [1, 4.5, 10, 12, 13]  # adjusted observation error [K]
-        
-        #y_oe = [1.2, 3, 5, 6, 6.5]  # OE for WV62 !!!!
         oe_linear = interp1d(x_ci, y_oe, assume_sorted=True)
         return oe_linear(ci)
     else:  # assign highest observation error
         return 13.0
 
+def _OE_model_new73(ci):
+    # based on exp_nat250_WV73_obs6_loc6_oe2_inf3
+    if ci >= 0 and ci < 25:
+        x_ci = [0, 5, 10, 15, 25]  # average cloud impact [K]
+        y_oe = [1, 2, 9, 10.5, 7]  # adjusted observation error [K]
+        oe_linear = interp1d(x_ci, y_oe, assume_sorted=True)
+        return oe_linear(ci)
+    else:
+        return 7.0
+
diff --git a/dartwrf/prepare_wrfrundir.py b/dartwrf/prepare_wrfrundir.py
index b512cfb..59e5752 100755
--- a/dartwrf/prepare_wrfrundir.py
+++ b/dartwrf/prepare_wrfrundir.py
@@ -10,10 +10,8 @@ Args:
 Returns:
     None
 """
-import os, sys, shutil
-import datetime as dt
-
-from dartwrf.utils import Config, symlink, link_contents, try_remove
+import os, sys
+from dartwrf.utils import Config, symlink, link_contents
 from dartwrf import prepare_namelist
 
 def run(cfg: Config):
diff --git a/dartwrf/print_config.py b/dartwrf/print_config.py
old mode 100644
new mode 100755
index d172532..6cb5d3b
--- a/dartwrf/print_config.py
+++ b/dartwrf/print_config.py
@@ -1,4 +1,9 @@
-"""Script to pretty-pring a config file (pickle format)"""
+#!/usr/bin/python
+"""Script to pretty-pring a config file (pickle format)"
+
+Usage:
+    python print_config.py <config_file>
+"""
 import sys, pickle, pprint
 f = sys.argv[1]
 with open(f, 'rb') as f:
diff --git a/dartwrf/utils.py b/dartwrf/utils.py
index a42bfa5..623cdc6 100755
--- a/dartwrf/utils.py
+++ b/dartwrf/utils.py
@@ -120,6 +120,7 @@ class Config(object):
         self.python = 'python'
         self.pattern_obs_seq_out = pattern_obs_seq_out.replace('<archivedir>', self.dir_archive)
         self.pattern_obs_seq_final = pattern_obs_seq_final.replace('<archivedir>', self.dir_archive)
+        self.obs_kind_nrs = dict()  # will be filled later
         
         # optional
         self.assimilate_existing_obsseq = assimilate_existing_obsseq
@@ -400,4 +401,38 @@ def obskind_read(dart_srcdir: str) -> dict:
             kind_str = data[0].strip()
             kind_nr = int(data[1].strip())
             obskind_nrs[kind_str] = kind_nr
-    return obskind_nrs
\ No newline at end of file
+    return obskind_nrs
+
+def run_bash_command_in_directory(command, directory):
+    """Runs a Bash command in the specified directory.
+
+    Args:
+        command (str): The Bash command to execute.
+        directory (str): The path to the directory where the command should be run.
+
+    Returns:
+        subprocess.CompletedProcess: An object containing information about the executed command.
+    """
+    try:
+        result = subprocess.run(
+            command,
+            shell=True,
+            cwd=directory,
+            check=True,  # Raise an exception for non-zero exit codes
+            stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+            text=True  # Decode stdout and stderr as text
+        )
+        print("Command executed successfully.")
+        print("Stdout:", result.stdout)
+        if result.stderr:
+            print("Stderr:", result.stderr)
+        return result
+    except subprocess.CalledProcessError as e:
+        print(f"Error executing command: {e}")
+        print("Stdout:", e.stdout)
+        print("Stderr:", e.stderr)
+        return e
+    except FileNotFoundError:
+        print(f"Error: Directory not found: {directory}")
+        return None
\ No newline at end of file
diff --git a/dartwrf/workflows.py b/dartwrf/workflows.py
index 63d99ae..b44b35b 100644
--- a/dartwrf/workflows.py
+++ b/dartwrf/workflows.py
@@ -137,7 +137,7 @@ class WorkFlows(object):
         cmd = ' '.join([self.python, path_to_script, cfg.f_cfg_current])
 
         id = self.run_job(cmd, cfg, depends_on=[depends_on],
-                          **{"ntasks": "20", "time": "30", "mem": "200G", "ntasks-per-node": "20"})
+                          **{"ntasks": "20", "mem": "200G", "ntasks-per-node": "20"})
         return id
 
 
@@ -173,7 +173,7 @@ class WorkFlows(object):
                             ).replace('<wrf_rundir_base>', cfg.dir_wrf_run.replace('<ens>', '$IENS')
                             ).replace('<wrf_modules>', cfg.wrf_modules,
                             )
-        id = self.run_job(cmd, cfg, depends_on=[depends_on], time="30", array="1-"+str(cfg.ensemble_size))
+        id = self.run_job(cmd, cfg, depends_on=[depends_on], array="1-"+str(cfg.ensemble_size))
         return id
     
     def run_WRF(self, cfg, depends_on=None):
@@ -184,11 +184,12 @@ class WorkFlows(object):
         end = cfg.WRF_end
         
         # SLURM configuration for WRF
-        slurm_kwargs = {"array": "1-"+str(cfg.ensemble_size),
-                      "nodes": "1", 
-                      "ntasks": str(cfg.max_nproc_for_each_ensemble_member), 
-                      "ntasks-per-core": "1", "mem": "90G", 
-                      "ntasks-per-node": str(cfg.max_nproc_for_each_ensemble_member),}
+        slurm_kwargs = {
+                "partition": "general",
+                "array": "1-"+str(cfg.ensemble_size),
+                "ntasks": str(cfg.max_nproc_for_each_ensemble_member), 
+                "ntasks-per-core": "1", "mem": "100G", 
+                "ntasks-per-node": str(cfg.max_nproc_for_each_ensemble_member),}
 
         # command from template file
         wrf_cmd = script_to_str(cfg.WRF_exe_template
@@ -203,13 +204,11 @@ class WorkFlows(object):
 
         # run WRF ensemble
         time_in_simulation_hours = (end-start).total_seconds()/3600
-        # runtime_wallclock_mins_expected = int(time_in_simulation_hours*15*1.5 + 15)  
-        runtime_wallclock_mins_expected = int(time_in_simulation_hours*30*1.5 + 15)  
-        # usually max 15 min/hour + 50% margin + 15 min buffer
+        runtime_wallclock_mins_expected = int(time_in_simulation_hours*.5)*60 + 20
+        # usually max 15 min/hour + 100% margin + 15 min buffer
         slurm_kwargs.update({"time": str(runtime_wallclock_mins_expected)})
-        if runtime_wallclock_mins_expected > 20:
-            slurm_kwargs.update({"partition": "amd"})
-            #cfg_update.update({"exclude": "jet03"})
+        # if runtime_wallclock_mins_expected > 30:  # this means jobs will mostly take < 15 mins
+        slurm_kwargs.update({"constraint": "zen4"})
 
         id = self.run_job(wrf_cmd, cfg, depends_on=[id], **slurm_kwargs)
         return id
@@ -224,8 +223,9 @@ class WorkFlows(object):
         cmd = ' '.join([self.python, path_to_script, cfg.f_cfg_current])
 
         id = self.run_job(cmd, cfg, depends_on=[depends_on], 
-                          **{"ntasks": "20", "time": "30", "mem": "110G",
-                            "ntasks-per-node": "20", "ntasks-per-core": "1"}, 
+                          **{"ntasks": str(cfg.max_nproc), "time": "30", 
+                             "mem": "110G", "partition": "devel",
+                             "ntasks-per-node": str(cfg.max_nproc), "ntasks-per-core": "1"}, 
                      )
         return id
 
@@ -246,7 +246,7 @@ class WorkFlows(object):
         """
         path_to_script = self.dir_dartwrf_run + '/prep_IC_prior.py'
         cmd = ' '.join([self.python, path_to_script, cfg.f_cfg_current])
-        id = self.run_job(cmd, cfg, depends_on=[depends_on], time="10")
+        id = self.run_job(cmd, cfg, depends_on=[depends_on])
         return id
 
     def update_IC_from_DA(self, cfg, depends_on=None):
@@ -262,7 +262,7 @@ class WorkFlows(object):
         path_to_script = self.dir_dartwrf_run + '/update_IC.py'
         cmd = ' '.join([self.python, path_to_script, cfg.f_cfg_current])
 
-        id = self.run_job(cmd, cfg, depends_on=[depends_on], time="10")
+        id = self.run_job(cmd, cfg, depends_on=[depends_on])
         return id
 
     def run_RTTOV(self, cfg, depends_on=None):
@@ -275,17 +275,22 @@ class WorkFlows(object):
                         '$SLURM_ARRAY_TASK_ID'])
 
         id = self.run_job(cmd, cfg, depends_on=[depends_on],
-                          **{"ntasks": "1", "time": "60", "mem": "10G", 
-                                "array": "1-"+str(cfg.ensemble_size)})
+                          **{"ntasks": "1", "mem": "15G", "array": "1-"+str(cfg.ensemble_size)})
         return id
 
-    def verify(self, cfg: Config, depends_on=None):
+    def verify(self, cfg: Config, init=False,
+               depends_on=None):
         """Not included in DART-WRF"""
-        cmd = ' '.join(['python /jetfs/home/lkugler/osse_analysis/plot_from_raw/analyze_fc.py', 
-                        cfg.name, cfg.verify_against, #cfg.time.strftime('%y%m%d_%H:%M'),
-                        'sat', 'wrf', 'has_node', 'np=10', 'mem=250G'])
+        if not init:
+            cmd = ' '.join(['python /jetfs/home/lkugler/osse_analysis/plot_from_raw/analyze_fc.py', 
+                            cfg.name, cfg.verify_against, 
+                            'sat', 'wrf', 'has_node', 'np=10', 'mem=250G'])
+        else:
+            cmd = ' '.join(['python /jetfs/home/lkugler/osse_analysis/plot_from_raw/analyze_fc.py', 
+                            cfg.name, cfg.verify_against, 'init='+cfg.time.strftime('%Y%m%d_%H:%M'), 'force',
+                            'sat', 'wrf', 'has_node', 'np=10', 'mem=250G'])
         self.run_job(cmd, cfg, depends_on=[depends_on],
-                         **{"time": "03:00:00", "mail-type": "FAIL,END", 
+                         **{"time": "04:00:00", "mail-type": "FAIL", "partition": "general",
                                     "ntasks": "10", "ntasks-per-node": "10", 
                                     "ntasks-per-core": "1", "mem": "250G"})
 
diff --git a/multiple_exps.py b/multiple_exps.py
index 6aa2a94..cbcc38c 100644
--- a/multiple_exps.py
+++ b/multiple_exps.py
@@ -9,22 +9,32 @@ from dartwrf.utils import Config
 from config.jet import cluster_defaults
 from config.defaults import dart_nml, CF_config, vis
 
-# test multiple assimilation windows (11-12, 12-13, 13-14, )
-timedelta_btw_assim = dt.timedelta(minutes=15)
+
 ensemble_size = 40
 
 dart_nml['&filter_nml'].update(num_output_state_members=ensemble_size,
-                                ens_size=ensemble_size)
+                               ens_size=ensemble_size)
 
 # which DART version to use?
-assimilate_cloudfractions = True
-cf1 = dict(kind='CF192km', loc_horiz_km=9999)
-cf2 = dict(kind='CF96km', loc_horiz_km=96)
-cf3 = dict(kind='CF48km', loc_horiz_km=48)
-cf4 = dict(kind='CF24km', loc_horiz_km=24)
-cf5 = dict(kind='CF12km', loc_horiz_km=12)
-
-assimilate_these_observations = [cf5]
+assimilate_cloudfractions = False
+
+# scale_km = 2
+# cf = dict(kind='CF{}km'.format(scale_km), loc_horiz_km=scale_km)
+
+cf1 = dict(kind='CF192km', loc_horiz_km=9999,
+           )
+cf2 = dict(kind='CF96km', loc_horiz_km=96,
+           )
+cf3 = dict(kind='CF48km', loc_horiz_km=48,
+           )
+cf4 = dict(kind='CF24km', loc_horiz_km=24,
+           )
+cf5 = dict(kind='CF12km', loc_horiz_km=12,
+           )
+cf6 = dict(kind='CF6km', loc_horiz_km=6,
+           )
+
+assimilate_these_observations = [vis,] #cf3, cf4, cf5,]
 
 if assimilate_cloudfractions:
     cluster_defaults.update(
@@ -38,11 +48,11 @@ else:
                 )
 
 
-
 time0 = dt.datetime(2008, 7, 30, 11)
+time_end = dt.datetime(2008, 7, 30, 15)
 
 id = None
-cfg = Config(name='exp_nat250_VIS_SO12_loc12_oe2_inf4-0.5', 
+cfg = Config(name='exp_nat250_VIS_obs12_loc12_oe2_inf4-0.5',
     model_dx = 2000,
     ensemble_size = ensemble_size,
     dart_nml = dart_nml,
@@ -54,9 +64,9 @@ cfg = Config(name='exp_nat250_VIS_SO12_loc12_oe2_inf4-0.5',
     CF_config=CF_config,
     
     assimilate_existing_obsseq = False,
-    #nature_wrfout_pattern = '/jetfs/home/lkugler/data/sim_archive/nat_250m_1600x1600x100/*/1/wrfout_d01_%Y-%m-%d_%H_%M_%S',
+    nature_wrfout_pattern = '/jetfs/home/lkugler/data/sim_archive/nat_250m_1600x1600x100/*/1/wrfout_d01_%Y-%m-%d_%H_%M_%S',
     #geo_em_nature = '/jetfs/home/lkugler/data/sim_archive/geo_em.d01.nc.2km_200x200',
-    # geo_em_nature = '/jetfs/home/lkugler/data/sim_archive/geo_em.d01.nc.250m_1600x1600',
+    geo_em_nature = '/jetfs/home/lkugler/data/sim_archive/geo_em.d01.nc.250m_1600x1600',
     
     update_vars = ['U', 'V', 'W', 'THM', 'PH', 'MU', 'QVAPOR', 'QCLOUD', 'QICE', 'QSNOW', 'PSFC'],
     #input_profile = '/jetfs/home/lkugler/data/sim_archive/nat_250m_1600x1600x100/2008-07-30_08:00/1/input_sounding',
@@ -66,31 +76,38 @@ cfg = Config(name='exp_nat250_VIS_SO12_loc12_oe2_inf4-0.5',
 
 w = WorkFlows(cfg)
 w.prepare_WRFrundir(cfg)
-#id = w.run_ideal(cfg, depends_on=id)
+# id = w.run_ideal(cfg, depends_on=id)
 
 # assimilate at these times
-assim_times = pd.date_range(time0, time0 + dt.timedelta(hours=4), freq=timedelta_btw_assim)
+timedelta_btw_assim = dt.timedelta(minutes=15)
+assim_times = pd.date_range(time0, time_end, freq=timedelta_btw_assim)
+#assim_times = [dt.datetime(2008, 7, 30, 12), dt.datetime(2008, 7, 30, 13), dt.datetime(2008, 7, 30, 14),]
 last_assim_time = assim_times[-1]
 
+
 # loop over assimilations
 for i, t in enumerate(assim_times):
+
+    # which scales?
+    # if t.minute == 0:
+    #     CF_config.update(scales_km=(48, 24, 12),)
+    # else:
+    #     CF_config.update(scales_km=(12,))
+        
+    #cfg.update(CF_config=CF_config)
     
-    if i == 0:
-        if t == dt.datetime(2008, 7, 30, 11):
-            prior_init_time = dt.datetime(2008, 7, 30, 8)
-        else:
-            prior_init_time = t - dt.timedelta(minutes=15)
-            
-        cfg.update(time = t,
-            prior_init_time = prior_init_time,
+    if i == 0 and t == dt.datetime(2008, 7, 30, 11):
+        cfg.update(
+            time = t,
+            prior_init_time = dt.datetime(2008, 7, 30, 8),
             prior_valid_time = t,
             prior_path_exp = '/jetfs/home/lkugler/data/sim_archive/exp_nat250m_noDA/',)
     else:
-        cfg.update(time = t,
-            prior_init_time = assim_times[i-1],
+        cfg.update(
+            time = t,
+            prior_init_time = t - dt.timedelta(minutes=15),
             prior_valid_time = t,
             prior_path_exp = cfg.dir_archive,)
-                
 
     id = w.assimilate(cfg, depends_on=id)
 
@@ -119,7 +136,7 @@ for i, t in enumerate(assim_times):
     if t.minute == 0 and i != 0:
         # full hour but not first one
         # make long forecasts without restart files
-        timedelta_integrate = dt.timedelta(hours=4)
+        timedelta_integrate = dt.timedelta(hours=.25)
         restart_interval = 9999
         
         cfg.update( WRF_start=t, 
@@ -133,4 +150,7 @@ for i, t in enumerate(assim_times):
         id = w.run_WRF(cfg, depends_on=id)
         id = w.run_RTTOV(cfg, depends_on=id)
     
+        w.verify(cfg, init=True, depends_on=id)
+        
+# verify the rest
 w.verify(cfg, depends_on=id)
\ No newline at end of file
diff --git a/templates/namelist.input b/templates/namelist.input
index 4ed3019..430be87 100644
--- a/templates/namelist.input
+++ b/templates/namelist.input
@@ -31,7 +31,7 @@
  max_step_increase_pct               = 5, 51, 51,
  starting_time_step                  = 4,
  max_time_step                       = 16,
- min_time_step                       = 4,
+ min_time_step                       = 3,
  time_step                           = 8,
  time_step_fract_num                 = 0,
  time_step_fract_den                 = 1,
diff --git a/test-singlenode.py b/test-singlenode.py
new file mode 100644
index 0000000..5adb862
--- /dev/null
+++ b/test-singlenode.py
@@ -0,0 +1,135 @@
+from timer import Timer
+
+
+with Timer('imports'):
+    import datetime as dt
+    import pandas as pd
+    from dartwrf.workflows import WorkFlows
+    from dartwrf.utils import Config, run_bash_command_in_directory
+
+    # import default config for jet
+    from config.jet_1node import cluster_defaults
+    from config.defaults import dart_nml
+
+
+ensemble_size = 5
+
+dart_nml['&filter_nml'].update(num_output_state_members=ensemble_size,
+                               ens_size=ensemble_size)
+
+
+t = dict(var_name='Temperature', unit='[K]',
+         kind='RADIOSONDE_TEMPERATURE',
+         # n_obs=22500,
+         n_obs=1, obs_locations=[(45., 0.)],
+         error_generate=0.2, error_assimilate=0.2,
+         heights=range(1000, 17001, 2000),
+         loc_horiz_km=1000, loc_vert_km=4)
+
+assimilate_these_observations = [t,] #cf3, cf4, cf5,]
+
+
+time0 = dt.datetime(2008, 7, 30, 11)
+time_end = dt.datetime(2008, 7, 30, 11)
+
+id = None
+with Timer('Config()'):
+    cfg = Config(name='test-1node',
+        model_dx = 2000,
+        ensemble_size = ensemble_size,
+        dart_nml = dart_nml,
+        geo_em_forecast = '/jetfs/home/lkugler/data/sim_archive/geo_em.d01.nc.2km_200x200',
+        time = time0,
+        
+        assimilate_these_observations = assimilate_these_observations,
+        assimilate_existing_obsseq = False,
+        
+        nature_wrfout_pattern = '/jetfs/home/lkugler/data/sim_archive/exp_v1.18_P1_nature+1/*/1/wrfout_d01_%Y-%m-%d_%H:%M:%S',
+        geo_em_nature = '/jetfs/home/lkugler/data/sim_archive/geo_em.d01.nc.2km_200x200',
+        
+        update_vars = ['U', 'V', 'W', 'THM', 'PH', 'MU', 'QVAPOR', 'QCLOUD', 'QICE', 'QSNOW', 'PSFC'],
+        #input_profile = '/jetfs/home/lkugler/data/sim_archive/nat_250m_1600x1600x100/2008-07-30_08:00/1/input_sounding',
+        verify_against = 'nat_250m_blockavg2km',
+        **cluster_defaults)
+
+with Timer('prepare_WRFrundir'):
+    w = WorkFlows(cfg)
+    import dartwrf.prepare_wrfrundir as prepwrf
+    prepwrf.run(cfg)
+    # w.prepare_WRFrundir(cfg)
+
+# id = w.run_ideal(cfg, depends_on=id)
+
+# assimilate at these times
+timedelta_btw_assim = dt.timedelta(minutes=15)
+assim_times = pd.date_range(time0, time_end, freq=timedelta_btw_assim)
+#assim_times = [dt.datetime(2008, 7, 30, 12), dt.datetime(2008, 7, 30, 13), dt.datetime(2008, 7, 30, 14),]
+last_assim_time = assim_times[-1]
+
+
+# loop over assimilations
+for i, t in enumerate(assim_times):
+
+    if i == 0 and t == dt.datetime(2008, 7, 30, 11):
+        cfg.update(
+            time = t,
+            prior_init_time = dt.datetime(2008, 7, 30, 8),
+            prior_valid_time = t,
+            prior_path_exp = '/jetfs/home/lkugler/data/sim_archive/exp_nat250m_noDA/')
+    else:
+        cfg.update(
+            time = t,
+            prior_init_time = t - dt.timedelta(minutes=15),
+            prior_valid_time = t,
+            prior_path_exp = cfg.dir_archive,)
+
+    with Timer('assimilate'):
+        #id = w.assimilate(cfg, depends_on=id)
+        import dartwrf.assimilate as da
+        da.main(cfg)
+
+    with Timer('prepare_IC_from_prior'):
+        # 1) Set posterior = prior
+        #id = w.prepare_IC_from_prior(cfg, depends_on=id)
+        import dartwrf.prep_IC_prior as prep
+        prep.main(cfg)
+
+    with Timer('update_IC_from_DA'):
+        # 2) Update posterior += updates from assimilation
+        #id = w.update_IC_from_DA(cfg, depends_on=id)
+        import dartwrf.update_IC as upd
+        upd.update_initials_in_WRF_rundir(cfg)
+    
+    # integrate until next assimilation
+    timedelta_integrate = dt.timedelta(minutes=15)
+    restart_interval = timedelta_btw_assim.total_seconds()/60  # in minutes
+
+    cfg.update( WRF_start=t, 
+                WRF_end=t+timedelta_integrate, 
+                restart=True, 
+                restart_interval=restart_interval,
+                hist_interval_s=300,
+    )
+
+    import dartwrf.prepare_namelist as prepnl
+    prepnl.run(cfg)
+
+    # 3) Run WRF ensemble
+    #id = w.run_WRF(cfg, depends_on=id)
+
+    with Timer('run_WRF'):
+        # Example usage:
+        cmd = cfg.wrf_modules +'; '
+        
+        for iens in range(1, cfg.ensemble_size+1):
+            dir_wrf_run = cfg.dir_wrf_run.replace('<exp>', cfg.name).replace('<ens>', str(iens))
+
+            cmd += 'cd ' + dir_wrf_run + '; '
+            cmd += 'echo "'+dir_wrf_run+'"; '
+            cmd += 'mpirun -np 4 ./wrf.exe & '
+
+        cmd += 'wait; '
+        cmd += 'echo "WRF run completed."'
+
+        # Run the command in the specified directory
+        run_bash_command_in_directory(cmd, dir_wrf_run)
diff --git a/timer.py b/timer.py
new file mode 100644
index 0000000..9a32db5
--- /dev/null
+++ b/timer.py
@@ -0,0 +1,38 @@
+import time
+
+class Timer:
+    """
+    A context manager for measuring the execution time of a code block.
+    Prints a message before and after the timed block.
+    """
+    def __init__(self, message="Code block"):
+        """
+        Initializes the Timer with an optional message.
+
+        Args:
+            message (str, optional): The message to print before and after the timed block.
+                Defaults to "Code block".
+        """
+        self.message = message
+
+    def __enter__(self):
+        """
+        Starts the timer and prints the initial message.
+        """
+        print(f"{self.message} started.")
+        self.start_time = time.perf_counter()
+        return self  # Returns self, so you can access the timer object if needed
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        """
+        Stops the timer, calculates the elapsed time, and prints the final message
+        along with the execution time.
+
+        Args:
+            exc_type: The type of exception that occurred, if any.
+            exc_val: The exception instance, if any.
+            exc_tb: A traceback object, if an exception occurred.
+        """
+        self.end_time = time.perf_counter()
+        self.elapsed_time = self.end_time - self.start_time
+        print(f"{self.message} finished in {self.elapsed_time:.4f} seconds.")
\ No newline at end of file
-- 
GitLab