Showing
3 changed files
with
11 additions
and
5 deletions
... | @@ -115,8 +115,8 @@ class SummarizationModule(BaseTransformer): | ... | @@ -115,8 +115,8 @@ class SummarizationModule(BaseTransformer): |
115 | for d in [self.model.encoder, self.model.decoder]: | 115 | for d in [self.model.encoder, self.model.decoder]: |
116 | freeze_params(d.embed_tokens) | 116 | freeze_params(d.embed_tokens) |
117 | 117 | ||
118 | - def forward(self, input_ids, **kwargs): | 118 | + def forward(self, input_ids, patch_ids, **kwargs): |
119 | - return self.model(input_ids, **kwargs) | 119 | + return self.model(input_ids, patch_ids, **kwargs) |
120 | 120 | ||
121 | def ids_to_clean_text(self, generated_ids: List[int]): | 121 | def ids_to_clean_text(self, generated_ids: List[int]): |
122 | gen_text = self.tokenizer.batch_decode( | 122 | gen_text = self.tokenizer.batch_decode( |
... | @@ -133,7 +133,7 @@ class SummarizationModule(BaseTransformer): | ... | @@ -133,7 +133,7 @@ class SummarizationModule(BaseTransformer): |
133 | else: | 133 | else: |
134 | decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) | 134 | decoder_input_ids = shift_tokens_right(tgt_ids, pad_token_id) |
135 | 135 | ||
136 | - outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False) | 136 | + outputs = self(src_ids, src_patch, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False) |
137 | lm_logits = outputs[0] | 137 | lm_logits = outputs[0] |
138 | if self.hparams.label_smoothing == 0: | 138 | if self.hparams.label_smoothing == 0: |
139 | # Same behavior as modeling_bart.py, besides ignoring pad_token_id | 139 | # Same behavior as modeling_bart.py, besides ignoring pad_token_id | ... | ... |
... | @@ -114,6 +114,7 @@ class GenerationMixin: | ... | @@ -114,6 +114,7 @@ class GenerationMixin: |
114 | def generate( | 114 | def generate( |
115 | self, | 115 | self, |
116 | input_ids: Optional[torch.LongTensor] = None, | 116 | input_ids: Optional[torch.LongTensor] = None, |
117 | + patch_ids: Optional[torch.LongTensor] = None, | ||
117 | max_length: Optional[int] = None, | 118 | max_length: Optional[int] = None, |
118 | min_length: Optional[int] = None, | 119 | min_length: Optional[int] = None, |
119 | do_sample: Optional[bool] = None, | 120 | do_sample: Optional[bool] = None, |
... | @@ -396,12 +397,13 @@ class GenerationMixin: | ... | @@ -396,12 +397,13 @@ class GenerationMixin: |
396 | 397 | ||
397 | # get encoder and store encoder outputs | 398 | # get encoder and store encoder outputs |
398 | encoder = self.get_encoder() | 399 | encoder = self.get_encoder() |
399 | - encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True) | 400 | + encoder_outputs: ModelOutput = encoder(input_ids, patch_ids, attention_mask=attention_mask, return_dict=True) |
400 | 401 | ||
401 | # Expand input ids if num_beams > 1 or num_return_sequences > 1 | 402 | # Expand input ids if num_beams > 1 or num_return_sequences > 1 |
402 | if num_return_sequences > 1 or num_beams > 1: | 403 | if num_return_sequences > 1 or num_beams > 1: |
403 | input_ids_len = input_ids.shape[-1] | 404 | input_ids_len = input_ids.shape[-1] |
404 | input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) | 405 | input_ids = input_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) |
406 | + patch_ids = patch_ids.unsqueeze(1).expand(batch_size, effective_batch_mult * num_beams, input_ids_len) | ||
405 | attention_mask = attention_mask.unsqueeze(1).expand( | 407 | attention_mask = attention_mask.unsqueeze(1).expand( |
406 | batch_size, effective_batch_mult * num_beams, input_ids_len | 408 | batch_size, effective_batch_mult * num_beams, input_ids_len |
407 | ) | 409 | ) |
... | @@ -409,6 +411,9 @@ class GenerationMixin: | ... | @@ -409,6 +411,9 @@ class GenerationMixin: |
409 | input_ids = input_ids.contiguous().view( | 411 | input_ids = input_ids.contiguous().view( |
410 | effective_batch_size * num_beams, input_ids_len | 412 | effective_batch_size * num_beams, input_ids_len |
411 | ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) | 413 | ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) |
414 | + patch_ids = patch_ids.contiguous().view( | ||
415 | + effective_batch_size * num_beams, input_ids_len | ||
416 | + ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) | ||
412 | attention_mask = attention_mask.contiguous().view( | 417 | attention_mask = attention_mask.contiguous().view( |
413 | effective_batch_size * num_beams, input_ids_len | 418 | effective_batch_size * num_beams, input_ids_len |
414 | ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) | 419 | ) # shape: (batch_size * num_return_sequences * num_beams, cur_len) | ... | ... |
... | @@ -307,7 +307,7 @@ class BartEncoder(nn.Module): | ... | @@ -307,7 +307,7 @@ class BartEncoder(nn.Module): |
307 | self.padding_idx, | 307 | self.padding_idx, |
308 | config.extra_pos_embeddings, | 308 | config.extra_pos_embeddings, |
309 | ) | 309 | ) |
310 | - self.embed_patches = nn.Embedding(3, config.d_model) | 310 | + self.embed_patches = nn.Embedding(3, config.d_model, padding_idx=0) |
311 | self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) | 311 | self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)]) |
312 | self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity() | 312 | self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity() |
313 | # mbart has one extra layer_norm | 313 | # mbart has one extra layer_norm |
... | @@ -1113,6 +1113,7 @@ class BartForConditionalGeneration(PretrainedBartModel): | ... | @@ -1113,6 +1113,7 @@ class BartForConditionalGeneration(PretrainedBartModel): |
1113 | ): | 1113 | ): |
1114 | return { | 1114 | return { |
1115 | "input_ids": None, # encoder_outputs is defined. input_ids not needed | 1115 | "input_ids": None, # encoder_outputs is defined. input_ids not needed |
1116 | + "patch_ids": None, # encoder_outputs is defined. input_ids not needed | ||
1116 | "encoder_outputs": encoder_outputs, | 1117 | "encoder_outputs": encoder_outputs, |
1117 | "past_key_values": past, | 1118 | "past_key_values": past, |
1118 | "decoder_input_ids": decoder_input_ids, | 1119 | "decoder_input_ids": decoder_input_ids, | ... | ... |
-
Please register or login to post a comment