diff --git a/dartwrf/obsseq.py b/dartwrf/obsseq.py index 5a06a7ceea4505d5bbc1ee670e787a0251535e40..adb1ca69567ba96089b8bf0d69197818c905f434 100755 --- a/dartwrf/obsseq.py +++ b/dartwrf/obsseq.py @@ -226,20 +226,29 @@ class ObsSeq(object): list_of_obsdict = obs_list_to_dict(obs_list) return list_of_obsdict - def get_prior_Hx_matrix(self): - """Return prior Hx array (n_obs, n_ens)""" + def _get_model_Hx(self, what): + if what not in ['prior', 'posterior']: + raise ValueError # which columns do we need? keys = self.df.columns - keys_bool = np.array(['prior ensemble member' in a for a in keys]) + keys_bool = np.array([what+' ensemble member' in a for a in keys]) # select columns in DataFrame - prior_Hx = self.df.iloc[:, keys_bool] + Hx = self.df.iloc[:, keys_bool] # consistency check: compute mean over ens - compare with value from file - assert np.allclose(prior_Hx.mean(axis=1), self.df['prior ensemble mean']) + assert np.allclose(Hx.mean(axis=1), self.df[what+' ensemble mean']) + + return Hx.values - return prior_Hx.values + def get_prior_Hx(self): + """Return prior Hx array (n_obs, n_ens)""" + return self._get_model_Hx('prior') + + def get_posterior_Hx(self): + """Return posterior Hx array (n_obs, n_ens)""" + return self._get_model_Hx('posterior') def get_truth_Hx(self): return self.df['truth'].values @@ -416,6 +425,7 @@ class ObsSeq(object): def write_obs(i, obs, next_i_obs=None, prev_i_obs=None): """Write the observation section of a obs_seq.out file + Args: i (int): index of observation obs (dict): observation data