Skip to content

Commit

Permalink
Create drop attributes func (#720)
Browse files Browse the repository at this point in the history
* fix forcing_feedback settings formatting

* add check for user_pp_scripts attribute before looping through list to multifilepreprocessor add_user_pp_scripts method

* add snakeviz to env_dev.yml

* move drop_atts loop to a separate function that is called by crop_date_range and before merging xradate_range and before merging datasets in query_catalog in the preprocessor
  • Loading branch information
wrongkindofdoctor authored Dec 12, 2024
1 parent 5b0d16c commit 48a016a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
1 change: 1 addition & 0 deletions src/conda/env_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ dependencies:
- intake-esm=2024.2.6
- cf_xarray=0.8.4
- cloud_sptheme
- snakeviz=2.2.0
30 changes: 18 additions & 12 deletions src/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,21 @@ def normalize_group_time_vals(self, time_vals: np.ndarray) -> np.ndarray:
time_vals[i] = '0' + time_vals[i]
return time_vals

def drop_attributes(self, xr_ds: xr.Dataset) -> xr.Dataset:
""" Drop attributes that cause conflicts with xarray dataset merge"""
drop_atts = ['average_T2',
'time_bnds',
'lat_bnds',
'lon_bnds',
'average_DT',
'average_T1',
'height',
'date']
for att in drop_atts:
if xr_ds.get(att, None) is not None:
xr_ds = xr_ds.drop_vars(att)
return xr_ds

def check_multichunk(self, group_df: pd.DataFrame, case_dr, log) -> pd.DataFrame:
"""Sort the files found by date, grabs the files whose 'chunk_freq' is the
largest number where endyr-startyr modulo 'chunk_freq' is zero and throws out
Expand All @@ -834,6 +849,7 @@ def check_multichunk(self, group_df: pd.DataFrame, case_dr, log) -> pd.DataFrame
return pd.DataFrame.from_dict(group_df).reset_index()

def crop_date_range(self, case_date_range: util.DateRange, xr_ds, time_coord) -> xr.Dataset:
xr_ds = self.drop_attributes(xr_ds)
xr_ds = xr.decode_cf(xr_ds,
decode_coords=True, # parse coords attr
decode_times=True,
Expand Down Expand Up @@ -965,6 +981,7 @@ def check_group_daterange(self, df: pd.DataFrame, date_range: util.DateRange,
# hit an exception; return empty DataFrame to signify failure
return pd.DataFrame(columns=group_df.columns)


def query_catalog(self,
case_dict: dict,
data_catalog: str,
Expand All @@ -990,15 +1007,6 @@ def query_catalog(self,
if 'date_range' not in [c.lower() for c in cols]:
cols.append('date_range')

drop_atts = ['average_T2',
'time_bnds',
'lat_bnds',
'lon_bnds',
'average_DT',
'average_T1',
'height',
'date']

for case_name, case_d in case_dict.items():
# path_regex = re.compile(r'(?i)(?<!\\S){}(?!\\S+)'.format(case_name))
path_regex = [re.compile(r'({})'.format(case_name))]
Expand Down Expand Up @@ -1134,9 +1142,7 @@ def query_catalog(self,
var_xr = xr.concat([var_xr, cat_subset_dict[cat_index]], var.X.name)
else:
var_xr = xr.concat([var_xr, cat_subset_dict.values[cat_index]], var.N.name)
for att in drop_atts:
if var_xr.get(att, None) is not None:
var_xr = var_xr.drop_vars(att)
var_xr = self.drop_attributes(var_xr)
# add standard_name to the variable xarray dataset if it is not defined
for vname in var_xr.variables:
if (not isinstance(var_xr.variables[vname], xr.IndexVariable)
Expand Down

0 comments on commit 48a016a

Please sign in to comment.