Skip to content

Commit

Permalink
Several bugfixes and test improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
jreniel committed Jan 16, 2020
1 parent b7dacbb commit e9093b1
Show file tree
Hide file tree
Showing 13 changed files with 124 additions and 63 deletions.
17 changes: 10 additions & 7 deletions pyschism/cmd/argument_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import argparse
import logging
import sys
import re


def add_mesh_options(parser):
Expand All @@ -17,9 +15,16 @@ def add_mesh_options(parser):
parser.add_argument('--crs', default=4326)


def add_general_options(parser):
def add_general_options(parser, runtype=None):
add_mesh_options(parser)

if runtype is not None:
if runtype == "tidal":
add_tidal_run_options(parser)

elif runtype == 'best_track':
add_best_track_options(parser)

# output directory
msg = "Directory to which SCHISM input files will be written to. "
parser.add_argument('--output-directory', '--outdir', "-o", help=msg)
Expand All @@ -39,7 +44,7 @@ def add_general_options(parser):
add_server_options(parser)

# add tidal constituents
add_tidal_constituents(parser)
add_tidal_constituents_options(parser)

# add surface output requests
add_surface_output_request('elevation', parser)
Expand Down Expand Up @@ -132,7 +137,7 @@ def add_server_options(parser):
)


def add_tidal_constituents(parser):
def add_tidal_constituents_options(parser):
# tidal constituents
msg = "Tidal constituent to be forced in the model. Pass "
msg += "--use-constituent='all' to use all available constituents "
Expand Down Expand Up @@ -201,7 +206,6 @@ def add_stations_output_request(physical_var, parser):


def add_best_track_options(parser):
add_mesh_options(parser)
# storm_id
msg = "National Hurricane Center (NHC) storm id. "
msg += " Examples: AL132012 for Sandy2012 or AL152017 for Maria2017."
Expand All @@ -211,7 +215,6 @@ def add_best_track_options(parser):


def add_tidal_run_options(parser):
add_mesh_options(parser)
# start_date
msg = "Start date is relative to hotstart, that is, this is the "
msg += "true start date of the model (in UTC time). Use format "
Expand Down
7 changes: 4 additions & 3 deletions pyschism/cmd/plot_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def mesh(self):
except AttributeError:
self.__mesh = Mesh.open(
self.args.hgrid,
# vgrid=self.args.vgrid,
# crs=self.args.crs,
vgrid=self.args.vgrid,
crs=self.args.crs,
)
return self.__mesh

Expand Down Expand Up @@ -88,7 +88,8 @@ def parse_args():
parser = argparse.ArgumentParser(
description="Program to see a quick plot of an SCHISM mesh.")
parser.add_argument('hgrid')
# parser.add_argument('--vgrid')
parser.add_argument('--vgrid')
parser.add_argument('--crs')
parser.add_argument("--vmin", type=float)
parser.add_argument("--vmax", type=float)
parser.add_argument("--no-topobathy", action="store_true",)
Expand Down
3 changes: 1 addition & 2 deletions pyschism/cmd/tidal_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ class TidalRunCommand(SchismBaseCommand):

def parse_args():
parser = argparse.ArgumentParser()
argument_parser.add_general_options(parser)
argument_parser.add_tidal_run_options(parser)
argument_parser.add_general_options(parser, 'tidal')
return parser.parse_args()


Expand Down
14 changes: 7 additions & 7 deletions pyschism/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def run(

return self._output_collection

def dump(
def write(
self,
output_directory,
overwrite=False,
Expand All @@ -87,15 +87,15 @@ def dump(
if not outdir.exists():
outdir.mkdir(parents=True)
if hgrid:
self.mesh.hgrid.dump(outdir / hgrid, overwrite)
self.mesh.hgrid.write(outdir / hgrid, overwrite)
if vgrid:
self.mesh.vgrid.dump(outdir / vgrid, overwrite)
self.mesh.vgrid.write(outdir / vgrid, overwrite)
if fgrid:
self.mesh.fgrid.dump(outdir / fgrid, overwrite)
self.mesh.fgrid.write(outdir / fgrid, overwrite)
if param:
self.param.dump(outdir / param, overwrite)
self.param.write(outdir / param, overwrite)
if bctides:
self.bctides.dump(outdir / bctides, overwrite)
self.bctides.write(outdir / bctides, overwrite)

@property
def mesh(self):
Expand All @@ -118,7 +118,7 @@ def use_transport(self):
return self._use_transport

def _run_local(self, nproc, outdir, overwrite):
self.dump(outdir, overwrite)
self.write(outdir, overwrite)

def _run_coldstart(self, nproc, wdir):
self._stage_files('coldstart', nproc, wdir)
Expand Down
14 changes: 14 additions & 0 deletions pyschism/driver/param.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@


class Param:

def write(self, path, overwrite=False):
raise NotImplementedError(self.nml)

@property
def nml(self):
return {
"CORE": self.core,
"OPT": self.opt,
"SCHOUT": self.schout
}
17 changes: 14 additions & 3 deletions pyschism/mesh/gmesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,14 @@ def transform_to(self, dst_crs):
dst_crs = CRS.from_user_input(dst_crs)
if self.srs != dst_crs.srs:
transformer = Transformer.from_crs(
self.crs, dst_crs, always_xy=True)
self._vertices = list(zip(*transformer.transform(self.x, self.y)))
self.crs, dst_crs,
always_xy=True
)
xy = list(zip(*transformer.transform(self.x, self.y)))
ids = list(self._coords.keys())
self._coords = {
ids[i]: coord for i, coord in enumerate(xy)
}
self._crs = dst_crs

def get_index(self, id):
Expand All @@ -45,6 +51,12 @@ def tricontourf(self, axes=None, show=True, figsize=None, **kwargs):
axes.tricontourf(self.triangulation, self.values, **kwargs)
return axes

@_figure
def tripcolor(self, axes=None, show=True, figsize=None, **kwargs):
if len(self.triangles) > 0:
axes.tripcolor(self.triangulation, self.values, **kwargs)
return axes

@_figure
def triplot(
self,
Expand Down Expand Up @@ -107,7 +119,6 @@ def plot_wireframe(self, axes=None, show=False, **kwargs):
return axes

@property
@lru_cache
def coords(self):
return np.array(
[coord for coord in self._coords.values()]
Expand Down
35 changes: 35 additions & 0 deletions pyschism/mesh/gr3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pathlib
import numpy as np
from pyschism.mesh.gmesh import Gmesh
from functools import lru_cache
from collections import defaultdict


Expand Down Expand Up @@ -160,3 +161,37 @@ def open(cls, gr3, crs=None):
kwargs = reader(gr3)
kwargs.update({"crs": crs})
return cls(**kwargs)

def write(self, path, overwrite=False):
writer(self.grd, path, overwrite)

@property
@lru_cache
def grd(self):
description = self.description
if self.crs is not None:
description += f" CRS: {self.crs.srs}"
return {
"nodes": self.nodes,
"elements": self.elements,
"description": description,
}

@property
@lru_cache
def nodes(self):
return {id: ((x, y), -self.values[i]) for i, (id, (x, y))
in enumerate(self._coords.items())}

@property
@lru_cache
def elements(self):
keys = [id for id in self._triangles]
keys.extend([id for id in self._quads])
keys.sort(key=int)
geom = dict(self._triangles.items())
geom.update(dict(self._quads.items()))
elements = dict()
for i, id in enumerate(keys):
elements[id] = geom[id]
return elements
42 changes: 15 additions & 27 deletions pyschism/mesh/hgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,6 @@ def set_boundary_data(self, ibtype, id, indexes, **properties):
**properties
}

def write(self, path, overwrite=False):
grd = {
'description': self.description,
'nodes': self.nodes,
'elements': self.elements,
'boundaries': self.boundaries,
}
gr3.writer(grd, path, overwrite)

@fig._figure
def make_plot(
self,
Expand All @@ -105,8 +96,17 @@ def make_plot(
vmax = np.max(self.values)
kwargs.update(**fig.get_topobathy_kwargs(self.values, vmin, vmax))
kwargs.pop('col_val')
self.tricontourf(axes=axes, vmin=vmin, vmax=vmax, **kwargs)
kwargs.pop('levels')
levels = kwargs.pop('levels')
if vmin != vmax:
self.tricontourf(
axes=axes,
levels=levels,
vmin=vmin,
vmax=vmax,
**kwargs
)
else:
self.tripcolor(axes=axes, **kwargs)
self.quadface(axes=axes, **kwargs)
axes.axis('scaled')
if extent is not None:
Expand Down Expand Up @@ -160,22 +160,10 @@ def plot_boundaries(

@property
@lru_cache
def nodes(self):
return {id: ((x, y), -self.values[i]) for i, (id, (x, y))
in enumerate(self._coords.items())}

@property
@lru_cache
def elements(self):
keys = [id for id in self._triangles]
keys.extend([id for id in self._quads])
keys.sort(key=int)
geom = dict(self._triangles.items())
geom.update(dict(self._quads.items()))
elements = dict()
for i, id in enumerate(keys):
elements[id] = geom[id]
return elements
def grd(self):
grd = super().grd
grd.update({"boundaries": self.boundaries})
return grd

@property
def boundaries(self):
Expand Down
8 changes: 5 additions & 3 deletions pyschism/mesh/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ def open(cls, hgrid, vgrid=None, fgrid=None, crs=None):
return m

def make_plot(self, **kwargs):
if self.vgrid is None:
self.hgrid.make_plot(**kwargs)
else:
if self.vgrid.is3D():
msg = "Plotting not yet supported for 3D meshes."
raise NotImplementedError(msg)
elif self.vgrid.is2D():
self.hgrid.make_plot(**kwargs)

@property
def hgrid(self):
Expand Down Expand Up @@ -70,4 +70,6 @@ def _hgrid(self, hgrid):
def _vgrid(self, vgrid):
if vgrid is not None:
assert isinstance(vgrid, Vgrid)
else:
vgrid = Vgrid()
self.__vgrid = vgrid
6 changes: 6 additions & 0 deletions pyschism/mesh/vgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ def write(self, path, overwrite=False):
with open(path, 'w') as f:
f.write(self.boilerplate_2D)

def is2D(self):
return True

def is3D(self):
return False

@property
@lru_cache
def boilerplate_2D(self): # TODO: *QUICK HACK*, please fix ASAP.
Expand Down
12 changes: 8 additions & 4 deletions tests/cmd/test_argument_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@ def test_add_tidal_run_options(self):
parser = argparse.ArgumentParser()
cmd = ["", "NULL", "start_date", "end_date", "--spinup-days=0"]
with patch.object(sys, 'argv', cmd):
self.assertIsNone(argument_parser.add_tidal_run_options(parser))
self.assertIsNone(
argument_parser.add_general_options(parser, 'tidal')
)

def test_add_brest_track_run_options(self):
def test_add_best_track_run_options(self):
parser = argparse.ArgumentParser()
cmd = ["", "NULL", "storm_id", "-start-date", "--end-date"]
cmd = ["", "NULL", "storm_id", "--start-date", "--end-date"]
with patch.object(sys, 'argv', cmd):
self.assertIsNone(argument_parser.add_best_track_options(parser))
self.assertIsNone(
argument_parser.add_general_options(parser, 'best_track')
)


if __name__ == '__main__':
Expand Down
5 changes: 5 additions & 0 deletions tests/mesh/test_hgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,11 @@ def test_triplot(self):
h = Hgrid(self.nodes, self.elements, self.boundaries)
h.triplot()

def test_make_plot_flat_domain(self):
nodes = {id: (coord, 0.) for id, (coord, _) in self.nodes.items()}
h = Hgrid(nodes, self.elements, self.boundaries)
h.make_plot()

def test__fgrid_getter(self):
h = Hgrid(self.nodes, self.elements)
h.fgrid
Expand Down
7 changes: 0 additions & 7 deletions tests/mesh/test_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,6 @@ def test_make_plot(self):
m = Mesh.open(self.hgrid)
m.make_plot()

def test_make_plot_3D_raise(self):
m = Mesh.open(self.hgrid, self.vgrid)
self.assertRaises(
NotImplementedError,
m.make_plot
)

def test_default_fgrid(self):
m = Mesh.open(self.hgrid)
assert isinstance(m.fgrid, Fgrid)

0 comments on commit e9093b1

Please sign in to comment.