From 3f5f57676715157adc4aa8108d7e4f6e625f8ebc Mon Sep 17 00:00:00 2001
From: lkugler <lukas.kugler@gmail.com>
Date: Thu, 8 Jun 2023 11:42:58 +0200
Subject: [PATCH] method to get model grid indices

---
 dartwrf/obs/obsseq.py | 84 ++++++++++++++++++++++++++++++++++++++-----
 1 file changed, 76 insertions(+), 8 deletions(-)

diff --git a/dartwrf/obs/obsseq.py b/dartwrf/obs/obsseq.py
index 49f3bd8..2446f00 100755
--- a/dartwrf/obs/obsseq.py
+++ b/dartwrf/obs/obsseq.py
@@ -141,6 +141,61 @@ class ObsRecord(pd.DataFrame):
         assert np.allclose(Hx.mean(axis=1), self[what+' ensemble mean'])
         return Hx.values
 
+
+
+    def get_model_grid_indices(self, wrf_file_with_grid):
+        """Retrieve the grid indices closest to the observations
+
+        Note:
+            Only the horizontal grid is considered
+
+        Args:
+            wrf_file_with_grid (str):   path to wrf file with grid information
+
+        Returns:
+            pd.DataFrame (n_obs, 2)     columns: i, j
+        """
+        from scipy import spatial
+        import xarray as xr
+
+        def find_index_from_coords_tree(tree, len_latitudes, lat=45., lon=0.):
+            """Find Lat & Lon indices in array
+            to find the state space values nearest to the observation
+
+            Args:
+                len_latitudes (int) : usually xlat.shape[0]
+                    actually this could also have to be len of longitudes (i dont know!)
+                    but it works if len(xlon)==len(xlat)
+
+            Returns:
+                ilat, ilon (int)
+            """
+            dd, ii = tree.query([[lat, lon],])
+            ilat = int(ii/len_latitudes)
+            ilon = int(ii % len_latitudes)
+            return ilat, ilon
+
+        # load coordinates of wrf grid
+        grid = xr.open_dataset(wrf_file_with_grid)
+        xlat = grid.XLAT_M.values.squeeze()
+        xlon = grid.XLONG_M.values.squeeze()
+
+        # build search tree
+        tree = spatial.KDTree(np.c_[xlat.ravel(), xlon.ravel()])
+
+        # get lat lon of observations
+        lon_lat = self.get_lon_lat()
+
+        ilat_ilon = np.empty((len(lon_lat), 2), np.int32)
+
+        # find indices of observations in wrf grid
+        for i, row in lon_lat.iterrows():
+            ilat_ilon[i,:] = find_index_from_coords_tree(tree, xlat.shape[0], row.lat, row.lon)
+            
+        return pd.DataFrame(index=self.index, 
+            data=dict(wrf_i=ilat_ilon[:,0], wrf_j=ilat_ilon[:,1]))
+
+
     def get_lon_lat(self):
         """Retrieve longitude and latitude of observations
 
@@ -190,21 +245,26 @@ class ObsRecord(pd.DataFrame):
         return nlayers
 
     def superob(self, window_km):
-        """Select subset, average, overwrite existing obs with average
-
-        TODO: allow different obs types (KIND)
-        TODO: loc3d overwrite with mean
-        Metadata is copied from the first obs in a superob-box
+        """Create super-observations (averaged observations)
 
         Note:
             This routine discards observations (round off)
-            e.g. 31 obs with 5 obs-window => obs #31 is not processed
+            e.g. 31 obs with 5 obs-window => obs #31 is not processed.
+
+            Metadata is copied from the first observation in a superob-box
+
+            The location (loc3d) of new observation is taken from the center observation
+
+        TODO: allow different obs types (KIND)
 
         Args:
             window_km (numeric):        horizontal window edge length
                                         includes obs on edge
                                         25x25 km with 5 km obs density
                                         = average 5 x 5 observations
+
+        Returns:
+            ObsRecord
         """
         def calc_deg_from_km(distance_km, center_lat):
             """Approximately calculate distance in degrees from meters
@@ -264,9 +324,17 @@ class ObsRecord(pd.DataFrame):
         out = self.drop(self.index)  # this df will be filled
         boxes = []
 
-        for i in range(0, nx+1 - win_obs, win_obs):
+        for i in range(0, nx+1 - win_obs, win_obs):  
+            # i is the number of observations in x direction
+            # but in steps of "number of observations in superob window"
+            # i.e. i = 0, win_obs, 2*win_obs, 3*win_obs, ...
+
             for j in range(0, nx+1 - win_obs, win_obs):
+                # same as i but in y direction
+
                 for k in range(0, nlayers):
+                    # k is the index of the vertical layer
+
                     if debug: print(i,j,k)
 
                     # find indices of observations within superob window
@@ -301,7 +369,7 @@ class ObsRecord(pd.DataFrame):
                     # average spread and other values
                     for key in obs_box:
                         if key in ['loc3d', 'kind', 'metadata', 'time']:
-                            pass
+                            pass  # these parameters are not averaged
                         elif 'spread' in key:
                             # stdev of mean of values = sqrt(mean of variances)
                             obs_mean.at[key] = np.sqrt((obs_box[key]**2).mean())
-- 
GitLab