Skip to content

Commit

Permalink
Mix-up for Roberta
Browse files Browse the repository at this point in the history
  • Loading branch information
uoo723 committed Nov 25, 2020
1 parent d2813ce commit ff0f9c9
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 5 deletions.
9 changes: 7 additions & 2 deletions deepxml/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

# from apex import amp

__all__ = ['Model', 'XMLModel']
__all__ = ['Model', 'XMLModel', 'TransformerXML']


class Model(object):
Expand Down Expand Up @@ -342,7 +342,12 @@ def train_step(self, epoch: int, train_x: torch.Tensor,
self.optimizer.zero_grad()
self.model.train()

logits = self.model(train_x, attention_mask)[0]
if self.mixup_fn is not None and epoch >= self.mixup_warmup:
outputs = self.model(train_x, attention_mask, return_hidden=True)
hidden, train_y = self.mixup_fn(outputs[0], train_y)
logits = self.model(pass_hidden=True, outputs=(hidden, *outputs[1:]))
else:
logits = self.model(train_x, attention_mask)[0]
loss = self.loss_fn(logits, train_y)

loss.backward()
Expand Down
91 changes: 90 additions & 1 deletion deepxml/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
"""
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.modeling_roberta import (RobertaClassificationHead,
RobertaModel,
RobertaPreTrainedModel)

from deepxml.modules import *

__all__ = ['AttentionRNN', 'FastAttentionRNN']
__all__ = ['AttentionRNN', 'FastAttentionRNN', 'RobertaForSeqClassification']


class Network(nn.Module):
Expand Down Expand Up @@ -106,3 +110,88 @@ def forward(self, x, adj):
x = F.dropout(x, self.dropout, training=self.training)
x = self.gc2(x, adj)
return F.log_softmax(x, dim=1)


# https://huggingface.co/transformers/_modules/transformers/modeling_roberta.html#RobertaForSequenceClassification
# Reimplementation for mix-up
class RobertaForSeqClassification(RobertaPreTrainedModel):
authorized_missing_keys = [r"position_ids"]

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels

self.roberta = RobertaModel(config, add_pooling_layer=False)
self.classifier = RobertaClassificationHead(config)

self.init_weights()


def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
return_hidden=False,
pass_hidden=False,
outputs=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if return_hidden and pass_hidden:
raise ValueError("`return_hidden` and `pass_hidden` cannot be both true.")

if not pass_hidden:
outputs = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = outputs[0]
else:
sequence_output = outputs[0]

if return_hidden:
return outputs

logits = self.classifier(sequence_output)

loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = nn.MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
4 changes: 2 additions & 2 deletions deepxml/train/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
from deepxml.data_utils import get_data, get_mlb, output_res
from deepxml.dataset import MultiLabelDataset
from deepxml.models import TransformerXML
from deepxml.networks import RobertaForSeqClassification
from logzero import logger
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from transformers import RobertaForSequenceClassification

from .utils import load_dataset, log_config, log_results

MODEL_TYPE = {"roberta": RobertaForSequenceClassification}
MODEL_TYPE = {"roberta": RobertaForSeqClassification}


def transformer_train(
Expand Down

0 comments on commit ff0f9c9

Please sign in to comment.