graykode

(add) add patch_ids for model encoder inputs

......@@ -115,8 +115,8 @@ class SummarizationModule(BaseTransformer):
for d in [self.model.encoder, self.model.decoder]:
freeze_params(d.embed_tokens)
def forward(self, input_ids, **kwargs):
return self.model(input_ids, **kwargs)
def forward(self, input_ids, patch_ids, **kwargs):
return self.model(input_ids, patch_ids, **kwargs)
def ids_to_clean_text(self, generated_ids: List[int]):
gen_text = self.tokenizer.batch_decode(
......@@ -133,7 +133,7 @@ class SummarizationModule(BaseTransformer):
else:
decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
outputs = self(src_ids, src_patch, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
lm_logits = outputs[0]
if self.hparams.label_smoothing == 0:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
......
......@@ -114,6 +114,7 @@ class GenerationMixin:
def generate(
self,
input_ids: Optional[torch.LongTensor] = None,
patch_ids: Optional[torch.LongTensor] = None,
max_length: Optional[int] = None,
min_length: Optional[int] = None,
do_sample: Optional[bool] = None,
......@@ -396,12 +397,13 @@ class GenerationMixin:
# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True)
encoder_outputs: ModelOutput = encoder(input_ids, patch_ids, attention_mask=attention_mask, return_dict=True)
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1:
input_ids_len = input_ids.shape[-1]
input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
patch_ids = patch_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
attention_mask = attention_mask.unsqueeze(1).expand(
batch_size, effective_batch_mult * num_beams, input_ids_len
)
......@@ -409,6 +411,9 @@ class GenerationMixin:
input_ids = input_ids.contiguous().view(
effective_batch_size * num_beams, input_ids_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
patch_ids = patch_ids.contiguous().view(
effective_batch_size * num_beams, input_ids_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
attention_mask = attention_mask.contiguous().view(
effective_batch_size * num_beams, input_ids_len
) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
......
......@@ -307,7 +307,7 @@ class BartEncoder(nn.Module):
self.padding_idx,
config.extra_pos_embeddings,
)
self.embed_patches = nn.Embedding(3, config.d_model)
self.embed_patches = nn.Embedding(3, config.d_model, padding_idx=0)
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
# mbart has one extra layer_norm
......@@ -1113,6 +1113,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
):
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"patch_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
......