Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve preprocessing #11

Merged
merged 2 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Improve preprocessing utils.
  • Loading branch information
hqucms committed Jan 27, 2024
commit 80c1fd02219e50f8dd50226260cb3706fe43494d
80 changes: 49 additions & 31 deletions weaver/utils/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ def __init__(self, print_info=True, **kwargs):
if print_info:
_logger.debug(opts)

self.train_load_branches = set()
self.train_aux_branches = set()
self.test_load_branches = set()
self.test_aux_branches = set()

self.selection = opts['selection']
self.test_time_selection = opts['test_time_selection'] if opts['test_time_selection'] else self.selection
self.var_funcs = copy.deepcopy(opts['new_variables'])
Expand Down Expand Up @@ -101,26 +106,27 @@ def _get(idx, default):
assert (isinstance(self.label_value, list))
self.label_names = ('_label_',)
label_exprs = ['ak.to_numpy(%s)' % k for k in self.label_value]
self.var_funcs['_label_'] = 'np.argmax(np.stack([%s], axis=1), axis=1)' % (','.join(label_exprs))
self.var_funcs['_labelcheck_'] = 'np.sum(np.stack([%s], axis=1), axis=1)' % (','.join(label_exprs))
self.register('_label_', 'np.argmax(np.stack([%s], axis=1), axis=1)' % (','.join(label_exprs)))
self.register('_labelcheck_', 'np.sum(np.stack([%s], axis=1), axis=1)' % (','.join(label_exprs)), 'train')
else:
self.label_names = tuple(self.label_value.keys())
self.var_funcs.update(self.label_value)
self.register(self.label_value)
self.basewgt_name = '_basewgt_'
self.weight_name = None
if opts['weights'] is not None:
self.weight_name = 'weight_'
self.weight_name = '_weight_'
self.use_precomputed_weights = opts['weights']['use_precomputed_weights']
if self.use_precomputed_weights:
self.var_funcs[self.weight_name] = '*'.join(opts['weights']['weight_branches'])
self.register(self.weight_name, '*'.join(opts['weights']['weight_branches']), 'train')
else:
self.reweight_method = opts['weights']['reweight_method']
self.reweight_basewgt = opts['weights'].get('reweight_basewgt', None)
if self.reweight_basewgt:
self.var_funcs[self.basewgt_name] = self.reweight_basewgt
self.register(self.basewgt_name, self.reweight_basewgt, 'train')
self.reweight_branches = tuple(opts['weights']['reweight_vars'].keys())
self.reweight_bins = tuple(opts['weights']['reweight_vars'].values())
self.reweight_classes = tuple(opts['weights']['reweight_classes'])
self.register(self.reweight_branches + self.reweight_classes, to='train')
self.class_weights = opts['weights'].get('class_weights', None)
if self.class_weights is None:
self.class_weights = np.ones(len(self.reweight_classes))
Expand Down Expand Up @@ -167,44 +173,56 @@ def _log(msg, *args, **kwargs):
'reweight_discard_under_overflow']:
_log('%s: %s' % (k, getattr(self, k)))

# parse config
self.keep_branches = set()
aux_branches = set()
# selection
if self.selection:
aux_branches.update(_get_variable_names(self.selection))
self.register(_get_variable_names(self.selection), to='train')
# test time selection
if self.test_time_selection:
aux_branches.update(_get_variable_names(self.test_time_selection))
# var_funcs
self.keep_branches.update(self.var_funcs.keys())
for expr in self.var_funcs.values():
aux_branches.update(_get_variable_names(expr))
self.register(_get_variable_names(self.test_time_selection), to='test')
# inputs
for names in self.input_dicts.values():
self.keep_branches.update(names)
# labels
self.keep_branches.update(self.label_names)
# weight
if self.weight_name:
self.keep_branches.add(self.weight_name)
if not self.use_precomputed_weights:
aux_branches.update(self.reweight_branches)
aux_branches.update(self.reweight_classes)
self.register(names)
# observers
self.keep_branches.update(self.observer_names)
self.register(self.observer_names, to='test')
# monitor variables
self.keep_branches.update(self.monitor_variables)
# keep and drop
self.drop_branches = (aux_branches - self.keep_branches)
self.load_branches = (aux_branches | self.keep_branches) - set(self.var_funcs.keys()) - {self.weight_name, }
self.register(self.monitor_variables)
# resolve dependencies
func_vars = set(self.var_funcs.keys())
for (load_branches, aux_branches) in (self.train_load_branches, self.train_aux_branches), (self.test_load_branches, self.test_aux_branches):
while (load_branches & func_vars):
for k in (load_branches & func_vars):
aux_branches.add(k)
load_branches.remove(k)
load_branches.update(_get_variable_names(self.var_funcs[k]))
if print_info:
_logger.debug('drop_branches:\n %s', ','.join(self.drop_branches))
_logger.debug('load_branches:\n %s', ','.join(self.load_branches))
_logger.debug('train_load_branches:\n %s', ', '.join(sorted(self.train_load_branches)))
_logger.debug('train_aux_branches:\n %s', ', '.join(sorted(self.train_aux_branches)))
_logger.debug('test_load_branches:\n %s', ', '.join(sorted(self.test_load_branches)))
_logger.debug('test_aux_branches:\n %s', ', '.join(sorted(self.test_aux_branches)))

def __getattr__(self, name):
return self.options[name]

def register(self, name, expr=None, to='both'):
assert to in ('train', 'test', 'both')
if isinstance(name, dict):
for k, v in name.items():
self.register(k, v, to)
elif isinstance(name, (list, tuple)):
for k in name:
self.register(k, None, to)
else:
if to in ('train', 'both'):
self.train_load_branches.add(name)
if to in ('test', 'both'):
self.test_load_branches.add(name)
if expr:
self.var_funcs[name] = expr
if to in ('train', 'both'):
self.train_aux_branches.add(name)
if to in ('test', 'both'):
self.test_aux_branches.add(name)

def dump(self, fp):
with open(fp, 'w') as f:
yaml.safe_dump(self.options, f, sort_keys=False)
Expand Down
89 changes: 48 additions & 41 deletions weaver/utils/data/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from .fileio import _read_files


def _apply_selection(table, selection, funcs={}):
def _apply_selection(table, selection, funcs=None):
if selection is None:
return table
new_vars = {k: funcs[k] for k in _get_variable_names(selection) if k not in table.fields and k in funcs}
_build_new_variables(table, new_vars)
if funcs:
new_vars = {k: funcs[k] for k in _get_variable_names(selection) if k not in table.fields and k in funcs}
_build_new_variables(table, new_vars)
selected = ak.values_astype(_eval_expr(selection, table), 'bool')
return table[selected]

Expand All @@ -28,11 +29,6 @@ def _build_new_variables(table, funcs):
return table


def _clean_up(table, drop_branches):
columns = [k for k in table.fields if k not in drop_branches]
return table[columns]


def _build_weights(table, data_config, reweight_hists=None):
if data_config.weight_name is None:
raise RuntimeError('Error when building weights: `weight_name` is None!')
Expand Down Expand Up @@ -92,27 +88,33 @@ def __init__(self, filelist, data_config):
self.load_range = (0, data_config.preprocess.get('data_fraction', 0.1))

def read_file(self, filelist):
self.keep_branches = set()
self.load_branches = set()
keep_branches = set()
aux_branches = set()
load_branches = set()
for k, params in self._data_config.preprocess_params.items():
if params['center'] == 'auto':
self.keep_branches.add(k)
if k in self._data_config.var_funcs:
expr = self._data_config.var_funcs[k]
self.load_branches.update(_get_variable_names(expr))
else:
self.load_branches.add(k)
keep_branches.add(k)
load_branches.add(k)
if self._data_config.selection:
self.load_branches.update(_get_variable_names(self._data_config.selection))
_logger.debug('[AutoStandardizer] keep_branches:\n %s', ','.join(self.keep_branches))
_logger.debug('[AutoStandardizer] load_branches:\n %s', ','.join(self.load_branches))
table = _read_files(filelist, self.load_branches, self.load_range, show_progressbar=True,
load_branches.update(_get_variable_names(self._data_config.selection))

func_vars = set(self._data_config.var_funcs.keys())
while (load_branches & func_vars):
for k in (load_branches & func_vars):
aux_branches.add(k)
load_branches.remove(k)
load_branches.update(_get_variable_names(self._data_config.var_funcs[k]))

_logger.debug('[AutoStandardizer] keep_branches:\n %s', ','.join(keep_branches))
_logger.debug('[AutoStandardizer] aux_branches:\n %s', ','.join(aux_branches))
_logger.debug('[AutoStandardizer] load_branches:\n %s', ','.join(load_branches))

table = _read_files(filelist, load_branches, self.load_range, show_progressbar=True,
treename=self._data_config.treename,
branch_magic=self._data_config.branch_magic, file_magic=self._data_config.file_magic)
table = _apply_selection(table, self._data_config.selection, funcs=self._data_config.var_funcs)
table = _build_new_variables(
table, {k: v for k, v in self._data_config.var_funcs.items() if k in self.keep_branches})
table = _clean_up(table, self.load_branches - self.keep_branches)
table = _build_new_variables(table, {k: v for k, v in self._data_config.var_funcs.items() if k in aux_branches})
table = table[keep_branches]
return table

def make_preprocess_params(self, table):
Expand Down Expand Up @@ -142,7 +144,7 @@ def produce(self, output=None):
table = self.read_file(self._filelist)
preprocess_params = self.make_preprocess_params(table)
self._data_config.preprocess_params = preprocess_params
# must also propogate the changes to `data_config.options` so it can be persisted
# must also propagate the changes to `data_config.options` so it can be persisted
self._data_config.options['preprocess']['params'] = preprocess_params
if output:
_logger.info(
Expand All @@ -168,26 +170,31 @@ def __init__(self, filelist, data_config):
self._data_config = data_config.copy()

def read_file(self, filelist):
self.keep_branches = set(self._data_config.reweight_branches + self._data_config.reweight_classes +
(self._data_config.basewgt_name,))
self.load_branches = set()
for k in self.keep_branches:
if k in self._data_config.var_funcs:
expr = self._data_config.var_funcs[k]
self.load_branches.update(_get_variable_names(expr))
else:
self.load_branches.add(k)
keep_branches = set(self._data_config.reweight_branches + self._data_config.reweight_classes)
if self._data_config.reweight_basewgt:
keep_branches.add(self._data_config.basewgt_name)
aux_branches = set()
load_branches = keep_branches.copy()
if self._data_config.selection:
self.load_branches.update(_get_variable_names(self._data_config.selection))
_logger.debug('[WeightMaker] keep_branches:\n %s', ','.join(self.keep_branches))
_logger.debug('[WeightMaker] load_branches:\n %s', ','.join(self.load_branches))
table = _read_files(filelist, self.load_branches, show_progressbar=True,
load_branches.update(_get_variable_names(self._data_config.selection))

func_vars = set(self._data_config.var_funcs.keys())
while (load_branches & func_vars):
for k in (load_branches & func_vars):
aux_branches.add(k)
load_branches.remove(k)
load_branches.update(_get_variable_names(self._data_config.var_funcs[k]))

_logger.debug('[WeightMaker] keep_branches:\n %s', ','.join(keep_branches))
_logger.debug('[WeightMaker] aux_branches:\n %s', ','.join(aux_branches))
_logger.debug('[WeightMaker] load_branches:\n %s', ','.join(load_branches))

table = _read_files(filelist, load_branches, show_progressbar=True,
treename=self._data_config.treename,
branch_magic=self._data_config.branch_magic, file_magic=self._data_config.file_magic)
table = _apply_selection(table, self._data_config.selection, funcs=self._data_config.var_funcs)
table = _build_new_variables(
table, {k: v for k, v in self._data_config.var_funcs.items() if k in self.keep_branches})
table = _clean_up(table, self.load_branches - self.keep_branches)
table = _build_new_variables(table, {k: v for k, v in self._data_config.var_funcs.items() if k in aux_branches})
table = table[keep_branches]
return table

def make_weights(self, table):
Expand Down Expand Up @@ -284,7 +291,7 @@ def produce(self, output=None):
table = self.read_file(self._filelist)
wgts = self.make_weights(table)
self._data_config.reweight_hists = wgts
# must also propogate the changes to `data_config.options` so it can be persisted
# must also propagate the changes to `data_config.options` so it can be persisted
self._data_config.options['weights']['reweight_hists'] = {k: v.tolist() for k, v in wgts.items()}
if output:
_logger.info('Writing YAML file w/ reweighting info to %s' % output)
Expand Down
6 changes: 4 additions & 2 deletions weaver/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def _preprocess(table, data_config, options):
if len(table) == 0:
return []
# define new variables
table = _build_new_variables(table, data_config.var_funcs)
aux_branches = data_config.train_aux_branches if options['training'] else data_config.test_aux_branches
table = _build_new_variables(table, {k: v for k, v in data_config.var_funcs.items() if k in aux_branches})
# check labels
if data_config.label_type == 'simple' and options['training']:
_check_labels(table)
Expand All @@ -108,7 +109,8 @@ def _preprocess(table, data_config, options):


def _load_next(data_config, filelist, load_range, options):
table = _read_files(filelist, data_config.load_branches, load_range, treename=data_config.treename,
load_branches = data_config.train_load_branches if options['training'] else data_config.test_load_branches
table = _read_files(filelist, load_branches, load_range, treename=data_config.treename,
branch_magic=data_config.branch_magic, file_magic=data_config.file_magic)
table, indices = _preprocess(table, data_config, options)
return table, indices
Expand Down