from dataclasses import dataclass from transformers.models.t5.modeling_t5 import ( T5Stack, T5Block, T5LayerNorm, T5LayerSelfAttention, T5LayerFF, T5LayerCrossAttention, T5PreTrainedModel, T5ForConditionalGeneration, ) import torch import torch.nn as nn from torch.nn import CrossEntropyLoss from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple import copy from transformers.modeling_outputs import ( ModelOutput, BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, ) from transformers.modeling_utils import ( PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer, ) from transformers.utils import logging from transformers import BeamScorer, BeamSearchScorer, T5Tokenizer, T5Config import argparse from data import load_train_dataloaders logger = logging.get_logger(__name__) class JointEncoder(T5Stack): def __init__(self, config, embed_tokens=None): super(T5Stack, self).__init__(config) self.config = config self.d_model = self.config.d_model self.embed_tokens = embed_tokens self.is_decoder = self.config.is_decoder assert self.config.is_decoder is False self.block = nn.ModuleList( [ T5Block(config, has_relative_attention_bias=(i == 0)) for i in range(config.num_layers) ] ) self.final_layer_norm = T5LayerNorm( config.d_model, eps=config.layer_norm_epsilon ) self.dropout = nn.Dropout(config.dropout_rate) ## Set maximum 512 whole words in a source text self.whole_word_embeddings = nn.Embedding( 512, config.d_model ## config.d_model is 768 for base ) self.position_embeddings = nn.Embedding( 120, config.d_model ## config.d_model is 768 for base ) self.init_weights() self.position_embeddings.weight.data[0] = torch.zeros_like( self.position_embeddings.weight.data[0] ) self.model_parallel = False self.device_map = None def set_input_embeddings(self, new_embeddings): self.embed_tokens = new_embeddings def forward( self, input_ids=None, whole_word_ids=None, attention_mask=None, inputs_embeds=None, head_mask=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, whole_word_embedding_type=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache if inputs_embeds is None: assert ( self.embed_tokens is not None ), "You have to initialize the model with valid token embeddings" inputs_embeds = self.embed_tokens( input_ids ) ### embedding step - add HERE ### if whole_word_ids is not None: if whole_word_embedding_type == "shijie": whole_word_embeds = self.whole_word_embeddings(whole_word_ids) else: whole_word_embeds = self.position_embeddings(whole_word_ids) assert whole_word_embeds.shape[-1] == inputs_embeds.shape[-1] inputs_embeds = inputs_embeds + whole_word_embeds B, L = inputs_embeds.size()[:-1] if attention_mask is None: attention_mask = input_ids.ne(self.config.pad_token_id).to( dtype=inputs_embeds.dtype, device=inputs_embeds.device ) # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask = self.get_extended_attention_mask( attention_mask, (B, L), inputs_embeds.device ) # initialize past_key_values with `None` if past does not exist if past_key_values is None: past_key_values = [None] * len(self.block) # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None hidden_states = self.dropout(inputs_embeds) if self.config.num_layers > 0: assert self.block[0].layer[0].SelfAttention.has_relative_attention_bias seq_length = L q_len = seq_length k_len = seq_length # [1, n_heads, Q_len, K_len] text_position_bias = self.block[0].layer[0].SelfAttention.compute_bias(L, L) num_heads = text_position_bias.size(1) position_bias = text_position_bias.new_zeros( 1, num_heads, seq_length, seq_length ) position_bias[:, :, :L, :L] = text_position_bias position_bias = position_bias + extended_attention_mask for i, (layer_module, past_key_value) in enumerate( zip(self.block, past_key_values) ): layer_outputs = layer_module( hidden_states, attention_mask=extended_attention_mask, position_bias=position_bias, encoder_hidden_states=None, encoder_attention_mask=None, encoder_decoder_position_bias=None, layer_head_mask=head_mask[i], past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, ) if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) hidden_states, present_key_value_state = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention weights), # (self-attention position bias), (cross-attention weights), (cross-attention position bias) position_bias = layer_outputs[2] # append next layer key value states if use_cache: present_key_value_states = present_key_value_states + ( present_key_value_state, ) hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) # Add last layer if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [ hidden_states, present_key_value_states, all_hidden_states, all_attentions, all_cross_attentions, ] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=present_key_value_states, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) class P5(T5ForConditionalGeneration): _keys_to_ignore_on_load_missing = [ r"encoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight", r"lm_head\.weight", ] _keys_to_ignore_on_load_unexpected = [ r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", ] def __init__(self, config): super(T5ForConditionalGeneration, self).__init__(config) self.config = config self.model_dim = config.d_model self.shared = nn.Embedding(config.vocab_size, config.d_model) encoder_config = copy.deepcopy(config) encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.is_encoder_decoder = False self.encoder = JointEncoder(encoder_config, self.shared) decoder_config = copy.deepcopy(config) decoder_config.is_decoder = True decoder_config.is_encoder_decoder = False self.decoder = T5Stack(decoder_config, self.shared) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.init_weights() self.model_parallel = False self.device_map = None self.sigmoid = nn.Sigmoid() def set_input_embeddings(self, new_embeddings): self.shared = new_embeddings self.encoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings) def extend_vocab(self, vocab_size): new_shared = nn.Embedding(vocab_size, self.config.d_model) old_weight = self.shared.weight.data.detach().clone() old_vocab_size = old_weight.size(0) new_shared.weight.data[:old_vocab_size, :] = old_weight self.shared = new_shared new_lm_head = nn.Linear(self.config.d_model, vocab_size, bias=False) old_weight = self.lm_head.weight.data.detach().clone() old_vocab_size = old_weight.size(0) new_lm_head.weight.data[:old_vocab_size, :] = old_weight self.lm_head = new_lm_head self.encoder.embed_tokens = self.shared self.decoder.embed_tokens = self.shared self.lm_head.weight = self.shared.weight self.config.vocab_size = vocab_size self.encoder.config.vocab_size = vocab_size self.decoder.config.vocab_size = vocab_size def forward( self, input_ids=None, whole_word_ids=None, attention_mask=None, encoder_outputs=None, decoder_input_ids=None, decoder_attention_mask=None, past_key_values=None, use_cache=None, labels=None, inputs_embeds=None, decoder_inputs_embeds=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, reduce_loss=False, return_hidden_state=False, alpha=2, **kwargs, ): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, whole_word_ids=whole_word_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs, ) elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) hidden_states = encoder_outputs[0] if ( labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None ): # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) if attention_mask is None: attention_mask = input_ids.ne(self.config.pad_token_id).to( dtype=hidden_states.dtype, device=hidden_states.device ) encoder_attention_mask = attention_mask # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, past_key_values=past_key_values, encoder_hidden_states=hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = decoder_outputs[0] assert self.config.tie_word_embeddings is True if self.config.tie_word_embeddings: sequence_output = sequence_output * (self.model_dim ** -0.5) # if return_hidden_state: lm_logits = self.lm_head(sequence_output) loss = None if labels is not None: if reduce_loss: loss_fct = CrossEntropyLoss(ignore_index=-100) else: loss_fct = CrossEntropyLoss(ignore_index=-100, reduction="none") loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) if return_hidden_state: return P5Seq2SeqLMOutput( loss=loss, logits=lm_logits, encoder_last_hidden_state=hidden_states, past_key_values=decoder_outputs.past_key_values, decoder_last_hidden_state=decoder_outputs.last_hidden_state, decoder_hidden_states=decoder_outputs.hidden_states, ) return P5Seq2SeqLMOutput( loss=loss, logits=lm_logits, past_key_values=decoder_outputs.past_key_values, decoder_last_hidden_state=decoder_outputs.last_hidden_state, decoder_hidden_states=decoder_outputs.hidden_states, ) def predict( self, input_ids=None, whole_word_ids=None, attention_mask=None, encoder_outputs=None, decoder_input_ids=None, decoder_attention_mask=None, past_key_values=None, use_cache=None, labels=None, inputs_embeds=None, decoder_inputs_embeds=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, reduce_loss=False, return_hidden_state=False, **kwargs, ): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, whole_word_ids=whole_word_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) hidden_states = encoder_outputs[0] if ( labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None ): # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) # If decoding with past key value states, only the last tokens # should be given as an input if past_key_values is not None: assert ( labels is None ), "Decoder should not use cached key value states when training." if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] if attention_mask is None: attention_mask = input_ids.ne(self.config.pad_token_id).to( dtype=hidden_states.dtype, device=hidden_states.device ) encoder_attention_mask = attention_mask # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, past_key_values=past_key_values, encoder_hidden_states=hidden_states, # hidden states encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = decoder_outputs[0] assert self.config.tie_word_embeddings is True if self.config.tie_word_embeddings: sequence_output = sequence_output * (self.model_dim ** -0.5) # if return_hidden_state: lm_logits = self.lm_head(sequence_output) loss = [] if labels is not None: if reduce_loss: loss_fct = CrossEntropyLoss(ignore_index=-100) else: loss_fct = CrossEntropyLoss(ignore_index=-100, reduction="none") for b in range(lm_logits.size(0)): loss_value = loss_fct(lm_logits[b], labels[b]) loss.append(loss_value) if return_hidden_state: return ( loss, P5Seq2SeqLMOutput( loss=None, logits=lm_logits, encoder_last_hidden_state=hidden_states, past_key_values=decoder_outputs.past_key_values, decoder_last_hidden_state=decoder_outputs.last_hidden_state, decoder_hidden_states=decoder_outputs.hidden_states, ), ) return ( loss, P5Seq2SeqLMOutput( loss=None, logits=lm_logits, past_key_values=decoder_outputs.past_key_values, decoder_last_hidden_state=decoder_outputs.last_hidden_state, decoder_hidden_states=decoder_outputs.hidden_states, ), ) def prepare_inputs_for_generation( self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs, ): if past is not None: input_ids = input_ids[:, -1:] output = { "decoder_input_ids": input_ids, "past_key_values": past, "encoder_outputs": encoder_outputs, "attention_mask": attention_mask, "use_cache": use_cache, } return output @staticmethod def _expand_inputs_for_generation( input_ids: torch.LongTensor, expand_size: int = 1, is_encoder_decoder: bool = False, attention_mask: torch.LongTensor = None, encoder_outputs: ModelOutput = None, **model_kwargs, ) -> Tuple[torch.LongTensor, Dict[str, Any]]: expanded_return_idx = ( torch.arange(input_ids.shape[0]) .view(-1, 1) .repeat(1, expand_size) .view(-1) .to(input_ids.device) ) input_ids = input_ids.index_select(0, expanded_return_idx) if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] model_kwargs["token_type_ids"] = token_type_ids.index_select( 0, expanded_return_idx ) if attention_mask is not None: model_kwargs["attention_mask"] = attention_mask.index_select( 0, expanded_return_idx ) if is_encoder_decoder: assert encoder_outputs is not None encoder_outputs[ "last_hidden_state" ] = encoder_outputs.last_hidden_state.index_select(0, expanded_return_idx) model_kwargs["encoder_outputs"] = encoder_outputs return input_ids, model_kwargs def adversarial( self, discriminator=None, discriminator_label=None, input_ids=None, feature_boundary_ids=None, discriminator_weight=None, whole_word_ids=None, attention_mask=None, encoder_outputs=None, decoder_input_ids=None, decoder_attention_mask=None, past_key_values=None, use_cache=None, labels=None, labels_attention=None, inputs_embeds=None, decoder_inputs_embeds=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, reduce_loss=False, return_hidden_state=False, train_discriminator=False, **kwargs, ): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, whole_word_ids=whole_word_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) hidden_states = encoder_outputs[0] B = hidden_states.size(0) user_embeddings = None for b in range(B): user_embedding = ( hidden_states[b][ feature_boundary_ids[b][0] : feature_boundary_ids[b][1] ] .mean(dim=0) .unsqueeze(0) ) if b == 0: user_embeddings = user_embedding else: user_embeddings = torch.cat([user_embeddings, user_embedding], dim=0) # B * embedding_dim assert user_embeddings is not None discriminator_loss = discriminator(user_embeddings, discriminator_label) if ( labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None ): # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) # If decoding with past key value states, only the last tokens # should be given as an input if past_key_values is not None: assert ( labels is None ), "Decoder should not use cached key value states when training." if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] if attention_mask is None: attention_mask = input_ids.ne(self.config.pad_token_id).to( dtype=hidden_states.dtype, device=hidden_states.device ) encoder_attention_mask = attention_mask # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, past_key_values=past_key_values, encoder_hidden_states=hidden_states, encoder_attention_mask=encoder_attention_mask, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = decoder_outputs[0] assert self.config.tie_word_embeddings is True if self.config.tie_word_embeddings: sequence_output = sequence_output * (self.model_dim ** -0.5) lm_logits = self.lm_head(sequence_output) rec_loss = None if labels is not None: if reduce_loss: loss_fct = CrossEntropyLoss(ignore_index=-100) else: loss_fct = CrossEntropyLoss(ignore_index=-100, reduction="none") rec_loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) lm_mask = labels_attention != 0 lm_mask = lm_mask.float() B, L = labels.size() rec_loss = rec_loss.view(B, L) * lm_mask rec_loss = (rec_loss.sum(dim=1) / lm_mask.sum(dim=1).clamp(min=1)).mean() if not train_discriminator: loss = rec_loss - discriminator_weight * discriminator_loss else: loss = rec_loss + discriminator_weight * discriminator_loss return P5AdversarialSeq2SeqLMOutput( loss=loss, rec_loss=rec_loss, discriminator_loss=discriminator_loss, feature_embeddings=user_embeddings, ) @dataclass class P5Seq2SeqLMOutput(ModelOutput): """ Base class for sequence-to-sequence language models outputs. Args: loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): Languaged modeling loss. logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be used (see ``past_key_values`` input) to speed up sequential decoding. decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. decoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder of the model. encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the self-attention heads. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None past_key_values: Optional[List[torch.FloatTensor]] = None decoder_last_hidden_state: Optional[Tuple[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None @dataclass class P5AdversarialSeq2SeqLMOutput(ModelOutput): """ Base class for sequence-to-sequence language models outputs. Args: loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): Languaged modeling loss. logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be used (see ``past_key_values`` input) to speed up sequential decoding. decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the decoder at the output of each layer plus the initial embedding outputs. decoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the self-attention heads. encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Sequence of hidden-states at the output of the last layer of the encoder of the model. encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the encoder at the output of each layer plus the initial embedding outputs. encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the self-attention heads. """ loss: Optional[torch.FloatTensor] = None rec_loss: Optional[torch.FloatTensor] = None discriminator_loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None past_key_values: Optional[List[torch.FloatTensor]] = None decoder_last_hidden_state: Optional[Tuple[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None feature_embeddings: Optional[torch.FloatTensor] = None if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--data_dir", type=str, default="data/") parser.add_argument( "--task", type=str, default="movie", help="movie, insurance, AliEC" ) parser.add_argument("--toy", action="store_true") parser.add_argument("--batch_size", type=int, default=4) args = parser.parse_args() tokenizer = T5Tokenizer.from_pretrained("t5-base") train_loader, val_loader, test_loader = load_dataloaders(args, tokenizer) print("finished loading data") config = T5Config.from_pretrained("t5-base") model = P5(config).cuda() print("finished building model") for batch in train_loader: input_ids = batch[0].cuda() attn = batch[1].cuda() whole_input_ids = batch[2].cuda() output_ids = batch[3].cuda() output_attention = batch[4].cuda() output = model( input_ids=input_ids, whole_word_ids=whole_input_ids, attention_mask=attn, labels=output_ids, return_dict=True, ) loss = output["loss"] lm_mask = output_attention != 0 lm_mask = lm_mask.float() B, L = output_ids.size() loss = loss.view(B, L) * lm_mask loss = loss.sum(dim=1) / lm_mask.sum(dim=1).clamp(min=1) print(loss) break