Skip to content

Commit

Permalink
Move passes-related settings outside _create_params so they don't con…
Browse files Browse the repository at this point in the history
…fuse model loading
  • Loading branch information
osma committed Jan 30, 2019
1 parent 94b1ca3 commit b251014
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions annif/backend/vw_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def initialize(self):
backend_id=self.backend_id)
self.debug('loading VW model from {}'.format(path))
params = self._create_params({'i': path, 'quiet': True})
if 'passes' in params:
# don't confuse the model with passes
del params['passes']
self.debug("model parameters: {}".format(params))
self._model = pyvw.vw(**params)
self.debug('loaded model {}'.format(str(self._model)))

Expand Down Expand Up @@ -111,9 +115,6 @@ def _create_params(self, params):
params.update({param: self._convert_param(param, val)
for param, val in self.params.items()
if param in self.VW_PARAMS})
if params.get('passes', 1) > 1:
# need a cache file when there are multiple passes
params.update({'cache': True, 'kill_cache': True})
if self.algorithm == 'oaa':
# only the oaa algorithm supports probabilities output
params.update({'probabilities': True, 'loss_function': 'logistic'})
Expand All @@ -124,6 +125,9 @@ def _create_model(self, project):
trainpath = os.path.join(self._get_datadir(), self.TRAIN_FILE)
params = self._create_params(
{'data': trainpath, self.algorithm: len(project.subjects)})
if params.get('passes', 1) > 1:
# need a cache file when there are multiple passes
params.update({'cache': True, 'kill_cache': True})
self.debug("model parameters: {}".format(params))
self._model = pyvw.vw(**params)
modelpath = os.path.join(self._get_datadir(), self.MODEL_FILE)
Expand Down

0 comments on commit b251014

Please sign in to comment.