Skip to content

Commit

Permalink
utility.split_point: refactored for mixed source type setups
Browse files Browse the repository at this point in the history
n_sources: refactored from int to list of ints
  • Loading branch information
hvasbath committed Jun 29, 2023
1 parent e88f862 commit d0e5698
Show file tree
Hide file tree
Showing 19 changed files with 170 additions and 92 deletions.
16 changes: 11 additions & 5 deletions beat/apps/beat.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def setup(parser):
action="callback",
callback=list_callback,
default=["RectangularSource"],
help="Source types to solve for. Can be any combination of the "
help="List of source types to solve for. Can be any combination of the "
"following for mode: geometry - %s; bem - %s; "
"Default: 'RectangularSource'"
% (
Expand All @@ -276,7 +276,7 @@ def setup(parser):
default=1,
action="callback",
callback=list_callback_int,
help="List integer Number of sources per source type to invert for. Default: [1]",
help="List of integer numbers of sources per source type to invert for. Default: [1]",
)

parser.add_option(
Expand Down Expand Up @@ -664,6 +664,11 @@ def setup(parser):

if options.mode == ffi_mode_str:
n_sources = problem.config.problem_config.n_sources
if len(n_sources) != 1:
raise TypeError(
"FFI with more than one source type is not implemented!"
)

if options.import_from_mode == geometry_mode_str:
logger.info("Importing non-linear source geometry results!")

Expand All @@ -672,12 +677,12 @@ def setup(parser):
point.pop(param)

point = utility.adjust_point_units(point)
source_points = utility.split_point(point)
source_points = utility.split_point(point, n_sources_total=n_sources[0])

reference_sources = bconfig.init_reference_sources(
source_points,
n_sources,
c.problem_config.source_type,
c.problem_config.source_types[0],
c.problem_config.stf_type,
event=c.event,
)
Expand Down Expand Up @@ -718,6 +723,7 @@ def setup(parser):

elif options.mode == geometry_mode_str:
if options.import_from_mode == geometry_mode_str:
# TODO update for n_sources refactoring
n_sources = problem.config.problem_config.n_sources
logger.info("Importing non-linear source geometry results!")

Expand Down Expand Up @@ -1542,7 +1548,7 @@ def setup(parser):
logger.info("Fault discretization done! Updating problem_config...")
logger.info("%s" % fault.__str__())

c.problem_config.n_sources = fault.nsubfaults
c.problem_config.n_sources = [fault.nsubfaults]
mode_c.npatches = fault.npatches
mode_c.subfault_npatches = fault.subfault_npatches

Expand Down
69 changes: 44 additions & 25 deletions beat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
dump,
load,
)
from typing import Dict as TDict
from typing import List as TList

from theano import config as tconfig

from beat import utility, bem
Expand Down Expand Up @@ -1207,29 +1210,41 @@ class DatatypeParameterMapping(Object):

sources_variables = List.T(Dict.T(String.T(), Int.T()))

def point_to_source_mapping(self):
total_variables = {}
start_idx = 0
for source_variables in self.sources_variables:
for variable, size in source_variables.items():
end_idx = size + start_idx
if variable in total_variables:
total_variables[variable].append(end_idx)
else:
total_variables[variable] = [start_idx, end_idx]
def __init__(self, **kwargs):

Object.__init__(self, **kwargs)

self._mapping = None

def point_to_sources_mapping(self) -> TDict[str, TList[int]]:

start_idx += size
if self._mapping is None:
start_idx = 0
total_variables = {}
for source_variables in self.sources_variables:
for variable, size in source_variables.items():
end_idx = size + start_idx
source_idxs = list(range(start_idx, end_idx))
if variable in total_variables:
total_variables[variable].extend(source_idxs)
else:
total_variables[variable] = source_idxs

start_idx += size

self._mapping = total_variables

return self._mapping

return total_variables
def point_variable_names(self) -> TList[int]:
return self.point_to_sources_mapping().keys()

def get_all_point_variable_names(self):
return self.get_total_variables_sizes().keys()
def total_variables_sizes(self) -> TDict[str, int]:

def get_total_variables_sizes(self):
mapping = self.point_to_source_mapping()
mapping = self.point_to_sources_mapping()
variables_sizes = {}
for variable, idxs in mapping.items():
variables_sizes[variable] = sum(idxs)
variables_sizes[variable] = len(idxs)

return variables_sizes

Expand All @@ -1249,7 +1264,7 @@ def __init__(self, **kwargs):
for datatype in self.datatypes:
self.mappings[datatype] = None

def add(self, sources_variables: Dict = {}, datatype: str = "geodetic"):
def add(self, sources_variables: TDict = {}, datatype: str = "geodetic"):
if datatype in self.mappings:
self.mappings[datatype] = DatatypeParameterMapping(
sources_variables=sources_variables
Expand All @@ -1259,7 +1274,13 @@ def add(self, sources_variables: Dict = {}, datatype: str = "geodetic"):
"Datatype for the source mapping has not been initialized!"
)

def get_unique_variables_sizes(self) -> Dict:
def __getitem__(self, k):
if k not in self.mappings.keys():
raise KeyError(k)

return self.mappings[k]

def unique_variables_sizes(self) -> TDict[str, int]:
"""
Combine source specific variable dicts into a common setup dict
Expand All @@ -1277,9 +1298,7 @@ def get_unique_variables_sizes(self) -> Dict:
)
unique_variables = {}
for datatype_parameter_mapping in self.mappings.values():
unique_variables.update(
datatype_parameter_mapping.get_total_variables_sizes()
)
unique_variables.update(datatype_parameter_mapping.total_variables_sizes())

return unique_variables

Expand Down Expand Up @@ -1583,12 +1602,12 @@ def get_test_point(self):

def get_parameter_shape(self, param):
if self.mode == ffi_mode_str and param.name in hypo_vars:
shape = self.n_sources
shape = self.n_sources[0]
elif self.mode == ffi_mode_str and self.mode_config.npatches:
shape = self.mode_config.subfault_npatches
if len(shape) == 0:
shape = self.mode_config.npatches
elif self.mode in [ffi_mode_str, geometry_mode_str]:
elif self.mode in [ffi_mode_str, geometry_mode_str, bem_mode_str]:
shape = param.dimension

else:
Expand Down Expand Up @@ -2155,7 +2174,7 @@ def init_dataset_config(config, datatype, mode):
n_sources = gmc.problem_config.n_sources[0]
point = {k: v.testvalue for k, v in gmc.problem_config.priors.items()}
point = utility.adjust_point_units(point)
source_points = utility.split_point(point)
source_points = utility.split_point(point, n_sources_total=n_sources)

reference_sources = init_reference_sources(
source_points,
Expand Down
5 changes: 3 additions & 2 deletions beat/ffi/fault.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,7 @@ def point2sources(self, point, events=[]):
sf_patches = self.get_subfault_patches(
index, datatype=datatype, component=component
)
n_sf_patches = len(sf_patches)

ucomps = {}
for comp in slip_directions.keys():
Expand Down Expand Up @@ -707,8 +708,8 @@ def point2sources(self, point, events=[]):
except KeyError:
pass

patch_points = split_point(sf_point)
assert len(patch_points) == len(sf_patches)
patch_points = split_point(sf_point, n_sources_total=n_sf_patches)
assert len(patch_points) == n_sf_patches

for patch, patch_point in zip(sf_patches, patch_points):
update_source(patch, **patch_point)
Expand Down
13 changes: 10 additions & 3 deletions beat/models/geodetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,10 @@ def __init__(self, gc, project_dir, sources, mapping, events, hypers=False):
self.sources = sources
self.mapping = mapping

@property
def n_sources_total(self):
return len(self.sources)

def point2sources(self, point):
"""
Updates the composite source(s) (in place) with the point values.
Expand All @@ -610,13 +614,16 @@ def point2sources(self, point):
if hyper in tpoint:
tpoint.pop(hyper)

source_parameter_names = self.mapping.get_all_point_variable_names()
source_parameter_names = self.mapping.point_variable_names()
for param in list(tpoint.keys()):
if param not in source_parameter_names:
tpoint.pop(param)

# TODO source to point mapping
source_points = utility.split_point(tpoint)
source_points = utility.split_point(
tpoint,
point_to_sources=self.mapping.point_to_sources_mapping(),
n_sources_total=self.n_sources_total,
)
for i, source in enumerate(self.sources):
utility.update_source(source, **source_points[i])
# reset source time may result in store error otherwise
Expand Down
15 changes: 11 additions & 4 deletions beat/models/polarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def __init__(self, polc, project_dir, sources, mapping, events, hypers=False):
)
self.wavemaps.append(pmap)

@property
def n_sources_total(self):
return len(self.sources)

@property
def is_location_fixed(self):
"""
Expand Down Expand Up @@ -162,10 +166,9 @@ def point2sources(self, point):
if hyper in tpoint:
tpoint.pop(hyper)

source_params = list(self.sources[0].keys())

source_parameter_names = self.mapping.point_variable_names()
for param in list(tpoint.keys()):
if param not in source_params:
if param not in source_parameter_names:
tpoint.pop(param)

if "time" in tpoint:
Expand All @@ -175,7 +178,11 @@ def point2sources(self, point):
for i, event in enumerate(self.events): # multi event
tpoint["time"][i] += event.time

source_points = split_point(tpoint)
source_points = split_point(
tpoint,
point_to_sources=self.mapping.point_to_sources_mapping(),
n_sources_total=self.n_sources_total,
)

for i, source in enumerate(self.sources):
update_source(source, **source_points[i])
Expand Down
8 changes: 5 additions & 3 deletions beat/models/problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,12 +629,13 @@ def __init__(self, config, hypers=False):
super(SourceOptimizer, self).__init__(config, hypers)

pc = config.problem_config
n_sources_total = sum(pc.n_sources)

if self.nevents != num.sum(pc.n_sources) and self.nevents != 1:
if self.nevents != n_sources_total and self.nevents != 1:
raise ValueError(
"Number of events and sources have to be equal or only one "
"event has to be used! Number if events %i and number of "
"sources: %i!" % (self.nevents, pc.n_sources)
"sources: %i!" % (self.nevents, n_sources_total)
)

# Init sources
Expand Down Expand Up @@ -675,12 +676,13 @@ def __init__(self, config, hypers=False):
pc = config.problem_config

dsources = transform_sources(self.sources, pc.datatypes, pc.decimation_factors)

mappings = pc.get_variables_mapping()
for datatype in pc.datatypes:
self.composites[datatype] = geometry_composite_catalog[datatype](
config[datatype + "_config"],
config.project_dir,
dsources[datatype],
mappings[datatype],
self.events,
hypers,
)
Expand Down
16 changes: 11 additions & 5 deletions beat/models/seismic.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,10 @@ def __init__(self, sc, project_dir, sources, mapping, events, hypers=False):

self.config = sc

@property
def n_sources_total(self):
return len(self.sources)

def point2sources(self, point):
"""
Updates the composite source(s) (in place) with the point values.
Expand All @@ -697,11 +701,9 @@ def point2sources(self, point):
if hyper in tpoint:
tpoint.pop(hyper)

source = self.sources[0]
source_params = list(source.keys()) + list(source.stf.keys())

source_parameter_names = self.mapping.point_variable_names()
for param in list(tpoint.keys()):
if param not in source_params:
if param not in source_parameter_names:
tpoint.pop(param)

# update source times
Expand All @@ -712,7 +714,11 @@ def point2sources(self, point):
for i, event in enumerate(self.events): # multi event
tpoint["time"][i] += event.time

source_points = utility.split_point(tpoint)
source_points = utility.split_point(
tpoint,
point_to_sources=self.mapping.point_to_sources_mapping(),
n_sources_total=self.n_sources_total,
)

for i, source in enumerate(self.sources):
utility.update_source(source, **source_points[i])
Expand Down
Loading

0 comments on commit d0e5698

Please sign in to comment.