graykode

(add) add patch_ids for model encoder inputs

...@@ -115,8 +115,8 @@ class SummarizationModule(BaseTransformer): ...@@ -115,8 +115,8 @@ class SummarizationModule(BaseTransformer):
115 for d in [self.model.encoder, self.model.decoder]: 115 for d in [self.model.encoder, self.model.decoder]:
116 freeze_params(d.embed_tokens) 116 freeze_params(d.embed_tokens)
117 117
118 - def forward(self, input_ids, **kwargs): 118 + def forward(self, input_ids, patch_ids, **kwargs):
119 - return self.model(input_ids, **kwargs) 119 + return self.model(input_ids, patch_ids, **kwargs)
120 120
121 def ids_to_clean_text(self, generated_ids: List[int]): 121 def ids_to_clean_text(self, generated_ids: List[int]):
122 gen_text = self.tokenizer.batch_decode( 122 gen_text = self.tokenizer.batch_decode(
...@@ -133,7 +133,7 @@ class SummarizationModule(BaseTransformer): ...@@ -133,7 +133,7 @@ class SummarizationModule(BaseTransformer):
133 else: 133 else:
134 decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) 134 decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id)
135 135
136 - outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False) 136 + outputs = self(src_ids, src_patch, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
137 lm_logits = outputs[0] 137 lm_logits = outputs[0]
138 if self.hparams.label_smoothing == 0: 138 if self.hparams.label_smoothing == 0:
139 # Same behavior as modeling_bart.py, besides ignoring pad_token_id 139 # Same behavior as modeling_bart.py, besides ignoring pad_token_id
......
...@@ -114,6 +114,7 @@ class GenerationMixin: ...@@ -114,6 +114,7 @@ class GenerationMixin:
114 def generate( 114 def generate(
115 self, 115 self,
116 input_ids: Optional[torch.LongTensor] = None, 116 input_ids: Optional[torch.LongTensor] = None,
117 + patch_ids: Optional[torch.LongTensor] = None,
117 max_length: Optional[int] = None, 118 max_length: Optional[int] = None,
118 min_length: Optional[int] = None, 119 min_length: Optional[int] = None,
119 do_sample: Optional[bool] = None, 120 do_sample: Optional[bool] = None,
...@@ -396,12 +397,13 @@ class GenerationMixin: ...@@ -396,12 +397,13 @@ class GenerationMixin:
396 397
397 # get encoder and store encoder outputs 398 # get encoder and store encoder outputs
398 encoder = self.get_encoder() 399 encoder = self.get_encoder()
399 - encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True) 400 + encoder_outputs: ModelOutput = encoder(input_ids, patch_ids, attention_mask=attention_mask, return_dict=True)
400 401
401 # Expand input ids if num_beams > 1 or num_return_sequences > 1 402 # Expand input ids if num_beams > 1 or num_return_sequences > 1
402 if num_return_sequences > 1 or num_beams > 1: 403 if num_return_sequences > 1 or num_beams > 1:
403 input_ids_len = input_ids.shape[-1] 404 input_ids_len = input_ids.shape[-1]
404 input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) 405 input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
406 + patch_ids = patch_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len)
405 attention_mask = attention_mask.unsqueeze(1).expand( 407 attention_mask = attention_mask.unsqueeze(1).expand(
406 batch_size, effective_batch_mult * num_beams, input_ids_len 408 batch_size, effective_batch_mult * num_beams, input_ids_len
407 ) 409 )
...@@ -409,6 +411,9 @@ class GenerationMixin: ...@@ -409,6 +411,9 @@ class GenerationMixin:
409 input_ids = input_ids.contiguous().view( 411 input_ids = input_ids.contiguous().view(
410 effective_batch_size * num_beams, input_ids_len 412 effective_batch_size * num_beams, input_ids_len
411 ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) 413 ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
414 + patch_ids = patch_ids.contiguous().view(
415 + effective_batch_size * num_beams, input_ids_len
416 + ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
412 attention_mask = attention_mask.contiguous().view( 417 attention_mask = attention_mask.contiguous().view(
413 effective_batch_size * num_beams, input_ids_len 418 effective_batch_size * num_beams, input_ids_len
414 ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) 419 ) # shape: (batch_size * num_return_sequences * num_beams, cur_len)
......
...@@ -307,7 +307,7 @@ class BartEncoder(nn.Module): ...@@ -307,7 +307,7 @@ class BartEncoder(nn.Module):
307 self.padding_idx, 307 self.padding_idx,
308 config.extra_pos_embeddings, 308 config.extra_pos_embeddings,
309 ) 309 )
310 - self.embed_patches = nn.Embedding(3, config.d_model) 310 + self.embed_patches = nn.Embedding(3, config.d_model, padding_idx=0)
311 self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) 311 self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
312 self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity() 312 self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
313 # mbart has one extra layer_norm 313 # mbart has one extra layer_norm
...@@ -1113,6 +1113,7 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -1113,6 +1113,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
1113 ): 1113 ):
1114 return { 1114 return {
1115 "input_ids": None, # encoder_outputs is defined. input_ids not needed 1115 "input_ids": None, # encoder_outputs is defined. input_ids not needed
1116 + "patch_ids": None, # encoder_outputs is defined. input_ids not needed
1116 "encoder_outputs": encoder_outputs, 1117 "encoder_outputs": encoder_outputs,
1117 "past_key_values": past, 1118 "past_key_values": past,
1118 "decoder_input_ids": decoder_input_ids, 1119 "decoder_input_ids": decoder_input_ids,
......