Toggle navigation
Toggle navigation
This project
Loading...
Sign in
graykode
/
commit-autosuggestions
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Snippets
Network
Create a new issue
Builds
Commits
Issue Boards
Authored by
graykode
2020-09-10 17:18:55 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
1b5806694a22952fae842c627cf72de066824a10
1b580669
1 parent
ce932c68
(refactor) get out train.py from train folder
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
35 additions
and
17 deletions
requirements.txt
train.py
train/finetune.py
train/lightning_base.py
train/modeling_bart.py
train/modeling_utils.py
requirements.txt
View file @
1b58066
...
...
@@ -2,6 +2,7 @@ whatthepatch
gitpython
matorage
transformers
packaging
psutil
sacrebleu
...
...
train.py
0 → 100644
View file @
1b58066
# Copyright 2020-present Tae Hwan Jung
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
argparse
import
pytorch_lightning
as
pl
from
train.finetune
import
main
,
SummarizationModule
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
SummarizationModule
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
=
parser
.
parse_args
()
main
(
args
)
\ No newline at end of file
train/finetune.py
View file @
1b58066
...
...
@@ -12,7 +12,7 @@ import pytorch_lightning as pl
import
torch
from
torch.utils.data
import
DataLoader
from
lightning_base
import
BaseTransformer
,
add_generic_args
,
generic_train
from
train.
lightning_base
import
BaseTransformer
,
add_generic_args
,
generic_train
from
transformers
import
MBartTokenizer
,
T5ForConditionalGeneration
from
transformers.modeling_bart
import
shift_tokens_right
...
...
@@ -260,16 +260,16 @@ class SummarizationModule(BaseTransformer):
def
get_dataset
(
self
,
type_path
)
->
Seq2SeqDataset
:
max_target_length
=
self
.
target_lens
[
type_path
]
data_config
=
DataConfig
(
endpoint
=
arg
s
.
endpoint
,
endpoint
=
self
.
hparam
s
.
endpoint
,
access_key
=
os
.
environ
[
"access_key"
],
secret_key
=
os
.
environ
[
"secret_key"
],
region
=
arg
s
.
region
,
region
=
self
.
hparam
s
.
region
,
dataset_name
=
"commit-autosuggestions"
,
additional
=
{
"mode"
:
(
"training"
if
type_path
==
"train"
else
"evaluation"
),
"max_source_length"
:
self
.
hparams
.
max_source_length
,
"max_target_length"
:
max_target_length
,
"url"
:
arg
s
.
url
,
"url"
:
self
.
hparam
s
.
url
,
},
attributes
=
[
(
"input_ids"
,
"int32"
,
(
self
.
hparams
.
max_source_length
,)),
...
...
@@ -462,13 +462,3 @@ def main(args, model=None) -> SummarizationModule:
# test() without a model tests using the best checkpoint automatically
trainer
.
test
()
return
model
\ No newline at end of file
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
SummarizationModule
.
add_model_specific_args
(
parser
,
os
.
getcwd
())
args
=
parser
.
parse_args
()
main
(
args
)
...
...
train/lightning_base.py
View file @
1b58066
...
...
@@ -21,7 +21,7 @@ from transformers import (
PretrainedConfig
,
PreTrainedTokenizer
,
)
from
modeling_bart
import
BartForConditionalGeneration
from
train.
modeling_bart
import
BartForConditionalGeneration
from
transformers.optimization
import
(
Adafactor
,
...
...
train/modeling_bart.py
View file @
1b58066
...
...
@@ -41,7 +41,7 @@ from transformers.modeling_outputs import (
Seq2SeqQuestionAnsweringModelOutput
,
Seq2SeqSequenceClassifierOutput
,
)
from
modeling_utils
import
PreTrainedModel
from
train.
modeling_utils
import
PreTrainedModel
import
logging
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
...
...
train/modeling_utils.py
View file @
1b58066
...
...
@@ -39,7 +39,7 @@ from transformers.file_utils import (
is_torch_tpu_available
,
replace_return_docstrings
,
)
from
generation_utils
import
GenerationMixin
from
train.
generation_utils
import
GenerationMixin
import
logging
logger
=
logging
.
getLogger
(
__name__
)
# pylint: disable=invalid-name
...
...
Please
register
or
login
to post a comment