Toggle navigation
Toggle navigation
This project
Loading...
Sign in
graykode
/
commit-autosuggestions
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Snippets
Network
Create a new issue
Builds
Commits
Issue Boards
Authored by
graykode
2020-09-09 13:49:59 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
7055c069a5026ae1f5cd3fb49a23dc5dfcb027dd
7055c069
1 parent
2a254f02
(add) add patch_ids for model encoder inputs
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
5 deletions
finetune.py
generation_utils.py
modeling_bart.py
finetune.py
View file @
7055c06
...
...
@@ -115,8 +115,8 @@ class SummarizationModule(BaseTransformer):
for
d
in
[
self
.
model
.
encoder
,
self
.
model
.
decoder
]:
freeze_params
(
d
.
embed_tokens
)
def
forward
(
self
,
input_ids
,
**
kwargs
):
return
self
.
model
(
input_ids
,
**
kwargs
)
def
forward
(
self
,
input_ids
,
patch_ids
,
**
kwargs
):
return
self
.
model
(
input_ids
,
patch_ids
,
**
kwargs
)
def
ids_to_clean_text
(
self
,
generated_ids
:
List
[
int
]):
gen_text
=
self
.
tokenizer
.
batch_decode
(
...
...
@@ -133,7 +133,7 @@ class SummarizationModule(BaseTransformer):
else
:
decoder_input_ids
=
shift_tokens_right
(
tgt_ids
,
pad_token_id
)
outputs
=
self
(
src_ids
,
attention_mask
=
src_mask
,
decoder_input_ids
=
decoder_input_ids
,
use_cache
=
False
)
outputs
=
self
(
src_ids
,
src_patch
,
attention_mask
=
src_mask
,
decoder_input_ids
=
decoder_input_ids
,
use_cache
=
False
)
lm_logits
=
outputs
[
0
]
if
self
.
hparams
.
label_smoothing
==
0
:
# Same behavior as modeling_bart.py, besides ignoring pad_token_id
...
...
generation_utils.py
View file @
7055c06
...
...
@@ -114,6 +114,7 @@ class GenerationMixin:
def
generate
(
self
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
patch_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
max_length
:
Optional
[
int
]
=
None
,
min_length
:
Optional
[
int
]
=
None
,
do_sample
:
Optional
[
bool
]
=
None
,
...
...
@@ -396,12 +397,13 @@ class GenerationMixin:
# get encoder and store encoder outputs
encoder
=
self
.
get_encoder
()
encoder_outputs
:
ModelOutput
=
encoder
(
input_ids
,
attention_mask
=
attention_mask
,
return_dict
=
True
)
encoder_outputs
:
ModelOutput
=
encoder
(
input_ids
,
patch_ids
,
attention_mask
=
attention_mask
,
return_dict
=
True
)
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if
num_return_sequences
>
1
or
num_beams
>
1
:
input_ids_len
=
input_ids
.
shape
[
-
1
]
input_ids
=
input_ids
.
unsqueeze
(
1
)
.
expand
(
batch_size
,
effective_batch_mult
*
num_beams
,
input_ids_len
)
patch_ids
=
patch_ids
.
unsqueeze
(
1
)
.
expand
(
batch_size
,
effective_batch_mult
*
num_beams
,
input_ids_len
)
attention_mask
=
attention_mask
.
unsqueeze
(
1
)
.
expand
(
batch_size
,
effective_batch_mult
*
num_beams
,
input_ids_len
)
...
...
@@ -409,6 +411,9 @@ class GenerationMixin:
input_ids
=
input_ids
.
contiguous
()
.
view
(
effective_batch_size
*
num_beams
,
input_ids_len
)
# shape: (batch_size * num_return_sequences * num_beams, cur_len)
patch_ids
=
patch_ids
.
contiguous
()
.
view
(
effective_batch_size
*
num_beams
,
input_ids_len
)
# shape: (batch_size * num_return_sequences * num_beams, cur_len)
attention_mask
=
attention_mask
.
contiguous
()
.
view
(
effective_batch_size
*
num_beams
,
input_ids_len
)
# shape: (batch_size * num_return_sequences * num_beams, cur_len)
...
...
modeling_bart.py
View file @
7055c06
...
...
@@ -307,7 +307,7 @@ class BartEncoder(nn.Module):
self
.
padding_idx
,
config
.
extra_pos_embeddings
,
)
self
.
embed_patches
=
nn
.
Embedding
(
3
,
config
.
d_model
)
self
.
embed_patches
=
nn
.
Embedding
(
3
,
config
.
d_model
,
padding_idx
=
0
)
self
.
layers
=
nn
.
ModuleList
([
EncoderLayer
(
config
)
for
_
in
range
(
config
.
encoder_layers
)])
self
.
layernorm_embedding
=
LayerNorm
(
embed_dim
)
if
config
.
normalize_embedding
else
nn
.
Identity
()
# mbart has one extra layer_norm
...
...
@@ -1113,6 +1113,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
):
return
{
"input_ids"
:
None
,
# encoder_outputs is defined. input_ids not needed
"patch_ids"
:
None
,
# encoder_outputs is defined. input_ids not needed
"encoder_outputs"
:
encoder_outputs
,
"past_key_values"
:
past
,
"decoder_input_ids"
:
decoder_input_ids
,
...
...
Please
register
or
login
to post a comment