Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wpreimes committed Jan 17, 2025
1 parent 311d28a commit 7b7a184
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/ecmwf_models/era5/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def _download(curr_start, curr_end):
fname = "{start}_{end}.{ext}".format(
start=curr_start.strftime("%Y%m%d"),
end=curr_end.strftime("%Y%m%d"),
ext="zip")
ext="zip" if grb is False else "grb")

dl_file = os.path.join(downloaded_data_path, fname)

Expand Down
7 changes: 5 additions & 2 deletions src/ecmwf_models/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def unzip_nc(
with tempfile.TemporaryDirectory() as tmpdir:
with zipfile.ZipFile(input_zip, "r") as zip_ref:
zip_ref.extractall(tmpdir)
ncfiles = [f for f in os.listdir(tmpdir) if f.endswith(".nc")]
ncfiles = [os.path.join(tmpdir, f) for f in os.listdir(tmpdir)
if f.endswith(".nc")]
if len(ncfiles) == 1:
shutil.move(ncfiles[0], output_nc)
else:
Expand Down Expand Up @@ -98,6 +99,8 @@ def save_ncs_from_nc(
ext='nc')

nc_in = xr.open_dataset(input_nc, mask_and_scale=True)
if 'valid_time' in nc_in.dims:
nc_in = nc_in.rename_dims({"valid_time": 'time'})
if 'valid_time' in nc_in.variables:
nc_in = nc_in.rename_vars({"valid_time": 'time'})

Expand All @@ -118,7 +121,7 @@ def save_ncs_from_nc(

# Expver identifies preliminary data
if 'expver' in subset:
expver = str(subset['expver'].values[i])
expver = str(np.atleast_1d(subset['expver'].values)[i])
subset = subset.drop_vars('expver')
try:
ext = EXPVER[expver]
Expand Down
28 changes: 22 additions & 6 deletions tests/tests_era5/test_era5_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import xarray as xr
import pytest
import tempfile
import zipfile

from c3s_sm.misc import read_summary_yml

Expand Down Expand Up @@ -60,7 +61,9 @@ def test_download_with_cdo_not_installed():
"ecmwf_models-test-data", "download",
"era5_example_downloaded_raw.nc")
save_ncs_from_nc(
infile, out_path, 'ERA5', grid=grid, keep_original=True)
infile, out_path, 'ERA5', grid=grid,
keep_original=True)


def test_dry_download_nc_era5():
with tempfile.TemporaryDirectory() as dl_path:
Expand All @@ -72,8 +75,14 @@ def test_dry_download_nc_era5():
os.path.dirname(os.path.abspath(__file__)), '..',
"ecmwf_models-test-data", "download",
"era5_example_downloaded_raw.nc")
trgt = os.path.join(dl_path, 'temp_downloaded', '20100101_20100101.nc')
shutil.copyfile(thefile, trgt)

assert os.path.exists(thefile)

trgt = os.path.join(dl_path, "temp_downloaded",
"20100101_20100101.zip")
with zipfile.ZipFile(trgt, 'w') as zip:
# Add the file to the ZIP archive
zip.write(thefile, arcname="20100101_20100101.nc")

assert os.path.isfile(trgt)

Expand Down Expand Up @@ -176,9 +185,16 @@ def test_download_nc_era5_regridding():
os.path.dirname(os.path.abspath(__file__)), '..',
"ecmwf_models-test-data", "download",
"era5_example_downloaded_raw.nc")
shutil.copyfile(
thefile,
os.path.join(dl_path, 'temp_downloaded', '20100101_20100101.nc'))

assert os.path.exists(thefile)

trgt = os.path.join(dl_path, "temp_downloaded",
"20100101_20100101.zip")
with zipfile.ZipFile(trgt, 'w') as zip:
# Add the file to the ZIP archive
zip.write(thefile, arcname="20100101_20100101.nc")

assert os.path.isfile(trgt)

startdate = enddate = datetime(2010, 1, 1)

Expand Down

0 comments on commit 7b7a184

Please sign in to comment.