Skip to content

Commit

Permalink
add from_config to learner
Browse files Browse the repository at this point in the history
  • Loading branch information
king-menin committed Jan 19, 2019
1 parent 510ba7f commit 85b02b0
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 87 deletions.
32 changes: 17 additions & 15 deletions modules/data/bert_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,23 @@ def get_bert_data_loader_for_predict(path, learner):

class BertNerData(object):

@property
def config(self):
config = {
"train_path": self.train_path,
"valid_path": self.valid_path,
"vocab_file": self.vocab_file,
"data_type": self.data_type,
"max_seq_len": self.max_seq_len,
"batch_size": self.batch_size,
"is_cls": self.is_cls,
"cuda": self.cuda,
"is_meta": self.is_meta,
"label2idx": self.label2idx,
"cls2idx": self.cls2idx
}
return config

def __init__(self, train_path, valid_path, vocab_file, data_type,
train_dl=None, valid_dl=None, tokenizer=None,
label2idx=None, max_seq_len=424,
Expand Down Expand Up @@ -343,21 +360,6 @@ def from_config(cls, config):
config["train_path"], config["valid_path"], config["vocab_file"], config["data_type"],
*fn_res, batch_size=config["batch_size"], cuda=config["cuda"], is_meta=config["is_meta"])

def get_config(self):
config = {
"train_path": self.train_path,
"valid_path": self.valid_path,
"vocab_file": self.vocab_file,
"data_type": self.data_type,
"max_seq_len": self.max_seq_len,
"batch_size": self.batch_size,
"is_cls": self.is_cls,
"cuda": self.cuda,
"is_meta": self.is_meta,
"label2idx": self.label2idx,
"cls2idx": self.cls2idx
}
return config
# with open(config_path, "w") as f:
# json.dump(config, f)

Expand Down
31 changes: 17 additions & 14 deletions modules/layers/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,8 @@ def score(self, encoder_outputs, input_mask, labels_ids, cls_ids):
scores, intent_score = self.forward_model(encoder_outputs, input_mask)
batch_size = encoder_outputs.shape[0]
len_ = encoder_outputs.shape[1]
return self.loss(scores.view(batch_size * len_, -1), labels_ids.view(-1)) + self.intent_loss(intent_score, cls_ids)
return self.loss(scores.view(batch_size * len_, -1), labels_ids.view(-1)) + self.intent_loss(
intent_score, cls_ids)

@classmethod
def create(cls, label_size, intent_size,
Expand Down Expand Up @@ -659,6 +660,21 @@ def create(cls, label_size, input_dim, input_dropout=0.5, key_dim=64,


class NCRFDecoder(nn.Module):

@property
def config(self):
config = {
"name": "NCRFDecoder",
"params": {
"label_size": self.label_size,
"input_dim": self.input_dim,
"input_dropout": self.dropout.p,
"use_cuda": self.use_cuda,
"nbest": self.nbest
}
}
return config

def __init__(self,
crf, label_size, input_dim, input_dropout=0.5, nbest=8):
super(NCRFDecoder, self).__init__()
Expand Down Expand Up @@ -695,19 +711,6 @@ def score(self, inputs, labels_mask, labels):
crf_score = self.crf.neg_log_likelihood_loss(logits, labels_mask, labels) / logits.size(0)
return crf_score

def get_config(self):
config = {
"name": "NCRFDecoder",
"params": {
"label_size": self.label_size,
"input_dim": self.input_dim,
"input_dropout": self.dropout.p,
"use_cuda": self.use_cuda,
"nbest": self.nbest
}
}
return config

@classmethod
def from_config(cls, config):
return cls.create(**config)
Expand Down
25 changes: 13 additions & 12 deletions modules/layers/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@

class BertBiLSTMEncoder(nn.Module):

@property
def config(self):
config = {
"name": "BertBiLSTMEncoder",
"params": {
"hidden_dim": self.hidden_dim,
"rnn_layers": self.rnn_layers,
"use_cuda": self.use_cuda,
"embeddings": self.embeddings.config
}
}
return config

def __init__(self, embeddings,
hidden_dim=128, rnn_layers=1, use_cuda=True):
super(BertBiLSTMEncoder, self).__init__()
Expand All @@ -29,18 +42,6 @@ def from_config(cls, config):
raise NotImplemented("form_config is implemented only for BertEmbedder now :(")
return cls.create(embeddings, config["hidden_dim"], config["rnn_layers"], config["use_cuda"])

def get_config(self):
config = {
"name": "BertBiLSTMEncoder",
"params": {
"hidden_dim": self.hidden_dim,
"rnn_layers": self.rnn_layers,
"use_cuda": self.use_cuda,
"embeddings": self.embeddings.get_config()
}
}
return config

def init_weights(self):
# for p in self.lstm.parameters():
# nn.init.xavier_normal(p)
Expand Down
43 changes: 20 additions & 23 deletions modules/models/bert_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,26 @@


class NerModel(nn.Module, metaclass=abc.ABCMeta):

@property
def config(self):
try:
config = {
"name": self.__class__.__name__,
"params": {
"encoder": self.encoder.config,
"decoder": self.decoder.config,
"use_cuda": self.use_cuda
}
}
except AttributeError:
config = {}
print("config is empty :(. Maybe for this model from_config has not implemented yet.", file=sys.stderr)
except NotImplemented:
config = {}
print("config is empty :(. Maybe for this model from_config has not implemented yet.", file=sys.stderr)
return config

"""Base class for all Models"""
def __init__(self, encoder, decoder, use_cuda=True):
super(NerModel, self).__init__()
Expand Down Expand Up @@ -39,31 +59,8 @@ def get_n_trainable_params(self):
pp += num
return pp

def get_config(self):
try:
config = {
"name": self.__class__.__name__,
"params": {
"encoder": self.encoder.get_config(),
"decoder": self.decoder.get_config(),
"use_cuda": self.use_cuda
}
}
except AttributeError:
config = {}
print("config is empty :(. Maybe for this model from_config has not implemented yet.", file=sys.stderr)
except NotImplemented:
config = {}
print("config is empty :(. Maybe for this model from_config has not implemented yet.", file=sys.stderr)
return config

@classmethod
def from_config(cls, config):
name = config["name"]
config = config["params"]
# TODO: release all models (now only for BertBiLSTMNCRF)
if name not in released_models:
raise NotImplemented("from_config is implemented only for {} model :(".format(config["name"]))
encoder = released_models[name]["encoder"].from_config(**config["encoder"]["params"])
decoder = released_models[name]["decoder"].from_config(**config["decoder"]["params"])
return cls(encoder, decoder, config["use_cuda"])
Expand Down
10 changes: 7 additions & 3 deletions modules/train/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,29 @@
from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_


def warmup_cosine(x, warmup=0.002):
if x < warmup:
return x/warmup
return 0.5 * (1.0 + torch.cos(math.pi * x))


def warmup_constant(x, warmup=0.002):
if x < warmup:
return x/warmup
return 1.0


def warmup_linear(x, warmup=0.002):
if x < warmup:
return x/warmup
return 1.0 - x


SCHEDULES = {
'warmup_cosine':warmup_cosine,
'warmup_constant':warmup_constant,
'warmup_linear':warmup_linear,
'warmup_cosine': warmup_cosine,
'warmup_constant': warmup_constant,
'warmup_linear': warmup_linear,
}


Expand Down
87 changes: 67 additions & 20 deletions modules/train/train.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from tqdm._tqdm_notebook import tqdm_notebook
from tqdm import tqdm
from modules.utils.utils import ipython_info
from sklearn_crfsuite.metrics import flat_classification_report
import logging
import torch
from modules.utils.plot_metrics import get_mean_max_metric
from torch.optim import Adam
from .optimization import BertAdam
import json
from modules.data.bert_data import BertNerData
from modules.models.released_models import released_models


logging.basicConfig(level=logging.INFO)
Expand All @@ -16,7 +16,7 @@ def train_step(dl, model, optimizer, lr_scheduler=None, clip=None, num_epoch=1):
model.train()
epoch_loss = 0
idx = 0
pr = tqdm_notebook(dl, total=len(dl), leave=False)
pr = tqdm(dl, total=len(dl), leave=False)
for batch in pr:
idx += 1
model.zero_grad()
Expand Down Expand Up @@ -87,7 +87,7 @@ def validate_step(dl, model, id2label, sup_labels, id2cls=None):
idx = 0
preds_cpu, targets_cpu = [], []
preds_cpu_cls, targets_cpu_cls = [], []
for batch in tqdm_notebook(dl, total=len(dl), leave=False):
for batch in tqdm(dl, total=len(dl), leave=False):
idx += 1
labels_mask, labels_ids = batch[-2:]
preds = model.forward(batch)
Expand All @@ -111,7 +111,7 @@ def predict(dl, model, id2label, id2cls=None):
idx = 0
preds_cpu = []
preds_cpu_cls = []
for batch, sorted_idx in tqdm_notebook(dl, total=len(dl), leave=False):
for batch, sorted_idx in tqdm(dl, total=len(dl), leave=False):
idx += 1
labels_mask, labels_ids = batch[-2:]
preds = model.forward(batch)
Expand All @@ -134,17 +134,50 @@ def predict(dl, model, id2label, id2cls=None):


class NerLearner(object):
def __init__(self, model, data, best_model_path, lr=0.001, betas=list([0.8, 0.9]), clip=5,
verbose=True, sup_labels=None, t_total=-1, warmup=0.1, weight_decay=0.01):
if ipython_info() or True:
global tqdm_notebook
tqdm_notebook = tqdm

@property
def config(self):
config = {
"data": self.data.config,
"model": self.model.config,
"learner": {
"best_model_path": self.best_model_path,
"lr": self.lr,
"betas": self.betas,
"clip": self.clip,
"verbose": self.verbose,
"sup_labels": self.sup_labels,
"t_total": self.t_total,
"warmup": self.warmup,
"weight_decay": self.weight_decay,
"validate_every": self.validate_every,
"schedule": self.schedule,
"e": self.e
}
}
return config

def __init__(self, model, data, best_model_path, lr=0.001, betas=[0.8, 0.9], clip=5,
verbose=True, sup_labels=None, t_total=-1, warmup=0.1, weight_decay=0.01,
validate_every=1, schedule="warmup_linear", e=1e-6):
self.model = model
self.optimizer = BertAdam(model, lr, t_total=t_total, b1=betas[0], b2=betas[1], max_grad_norm=clip)
self.optimizer_defaults = dict(model=model, lr=lr, warmup=warmup, t_total=t_total, schedule='warmup_linear',
b1=betas[0], b2=betas[1], e=1e-6, weight_decay=0.01,
max_grad_norm=clip)
self.optimizer_defaults = dict(
model=model, lr=lr, warmup=warmup, t_total=t_total, schedule="warmup_linear",
b1=betas[0], b2=betas[1], e=1e-6, weight_decay=weight_decay,
max_grad_norm=clip)

self.lr = lr
self.betas = betas
self.clip = clip
self.sup_labels = sup_labels
self.t_total = t_total
self.warmup = warmup
self.weight_decay = weight_decay
self.validate_every = validate_every
self.schedule = schedule
self.data = data
self.e = e
if sup_labels is None:
sup_labels = data.id2label[1:]
self.sup_labels = sup_labels
Expand All @@ -157,6 +190,18 @@ def __init__(self, model, data, best_model_path, lr=0.001, betas=list([0.8, 0.9]
self.best_target_metric = 0.
self.lr_scheduler = None

@classmethod
def from_config(cls, path):
with open(path, "r") as file:
config = json.load(file)
data = BertNerData.from_config(config["data"])
name = config["model"]["name"]
# TODO: release all models (now only for BertBiLSTMNCRF)
if name not in released_models:
raise NotImplemented("from_config is implemented only for {} model :(".format(config["name"]))
model = released_models[name].from_config(**config["model"]["params"])
return cls(data, model, **config["learner"])

def fit(self, epochs=100, resume_history=True, target_metric="f1"):
if not resume_history:
self.optimizer_defaults["t_total"] = epochs * len(self.data.train_dl)
Expand All @@ -176,12 +221,14 @@ def fit(self, epochs=100, resume_history=True, target_metric="f1"):

def fit_one_cycle(self, epoch, target_metric="f1"):
train_step(self.data.train_dl, self.model, self.optimizer, self.lr_scheduler, self.clip, epoch)
if self.data.is_cls:
rep, rep_cls = validate_step(self.data.valid_dl, self.model, self.data.id2label, self.sup_labels, self.data.id2cls)
self.cls_history.append(rep_cls)
else:
rep = validate_step(self.data.valid_dl, self.model, self.data.id2label, self.sup_labels)
self.history.append(rep)
if epoch % self.validate_every == 0:
if self.data.is_cls:
rep, rep_cls = validate_step(self.data.valid_dl, self.model, self.data.id2label, self.sup_labels,
self.data.id2cls)
self.cls_history.append(rep_cls)
else:
rep = validate_step(self.data.valid_dl, self.model, self.data.id2label, self.sup_labels)
self.history.append(rep)
idx, metric = get_mean_max_metric(self.history, target_metric, True)
if self.verbose:
logging.info("on epoch {} by max_{}: {}".format(idx, target_metric, metric))
Expand Down

0 comments on commit 85b02b0

Please sign in to comment.