Toggle navigation
Toggle navigation
This project
Loading...
Sign in
graykode
/
commit-autosuggestions
Go to a project
Toggle navigation
Toggle navigation pinning
Projects
Groups
Snippets
Help
Project
Activity
Repository
Pipelines
Graphs
Issues
0
Merge Requests
0
Wiki
Snippets
Network
Create a new issue
Builds
Commits
Issue Boards
Authored by
graykode
2020-09-07 22:45:39 +0900
Browse Files
Options
Browse Files
Download
Email Patches
Plain Diff
Commit
3ed49fc3f0f1cffd17e04005d9937ed13c74e6dc
3ed49fc3
1 parent
8b156fad
(add) split with train and test
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
16 deletions
gitcommit.py
gitcommit.py
View file @
3ed49fc
...
...
@@ -15,6 +15,7 @@
import
os
import
re
import
enum
import
random
import
logging
import
argparse
import
numpy
as
np
...
...
@@ -87,8 +88,7 @@ def sha_parse(sha, tokenizer, max_length=1024):
return
(
input_ids
,
attention_masks
,
patch_ids
)
def
message_parse
(
msg
,
tokenizer
,
max_length
=
56
):
msg
=
re
.
sub
(
r'#([0-9])+'
,
''
,
msg
)
msg
=
re
.
sub
(
r'(\(|)([A-z])+-([0-9])+(\)|)(:|)'
,
''
,
msg
)
msg
=
re
.
sub
(
r'(\(|)#([0-9])+(\)|)'
,
''
,
msg
)
msg
=
re
.
sub
(
r'[\u0100-\uFFFF\U00010000-\U0010FFFF]+'
,
''
,
msg
)
.
strip
()
msg
=
tokenizer
.
tokenize
(
msg
)
...
...
@@ -97,7 +97,7 @@ def message_parse(msg, tokenizer, max_length=56):
return
msg
def
jobs
(
sha_msgs
,
args
,
data_config
):
def
jobs
(
sha_msgs
,
args
,
data_config
,
train
=
True
):
input_ids
,
attention_masks
,
patch_ids
,
targets
=
[],
[],
[],
[]
data_saver
=
DataSaver
(
config
=
data_config
)
...
...
@@ -105,11 +105,19 @@ def jobs(sha_msgs, args, data_config):
for
sha_msg
in
sha_msgs
:
sha
,
msg
=
sha_msg
source
=
sha_parse
(
sha
,
tokenizer
=
args
.
tokenizer
)
source
=
sha_parse
(
sha
,
tokenizer
=
args
.
tokenizer
,
max_length
=
args
.
max_source_length
)
if
not
source
:
continue
input_id
,
attention_mask
,
patch_id
=
source
target
=
message_parse
(
msg
,
tokenizer
=
args
.
tokenizer
)
target
=
message_parse
(
msg
,
tokenizer
=
args
.
tokenizer
,
max_length
=
(
args
.
max_target_length
if
train
else
args
.
val_max_target_length
),
)
input_ids
.
append
(
input_id
)
attention_masks
.
append
(
attention_mask
)
...
...
@@ -124,9 +132,11 @@ def jobs(sha_msgs, args, data_config):
})
data_saver
.
disconnect
()
def
main
(
args
):
if
'access_key'
not
in
os
.
environ
or
'secret_key'
not
in
os
.
environ
:
raise
OSError
(
"access_key or secret_key are not found."
)
def
start
(
chunked_sha_msgs
,
train
=
True
):
logger
.
info
(
f
"Start
%
s pre-processing"
%
(
"training"
if
train
else
"evaluation"
))
max_target_length
=
args
.
max_target_length
if
train
else
args
.
val_max_target_length
data_config
=
DataConfig
(
endpoint
=
args
.
matorage_dir
,
...
...
@@ -134,27 +144,39 @@ def main(args):
secret_key
=
os
.
environ
[
'secret_key'
],
dataset_name
=
'commit-autosuggestions'
,
additional
=
{
"mode"
:
(
"training"
if
train
else
"evaluation"
),
"max_source_length"
:
args
.
max_source_length
,
"max_target_length"
:
args
.
max_target_length
,
"max_target_length"
:
max_target_length
,
"url"
:
args
.
url
,
},
attributes
=
[
attributes
=
[
(
'input_ids'
,
'int32'
,
(
args
.
max_source_length
,)),
(
'attention_masks'
,
'int32'
,
(
args
.
max_source_length
,)),
(
'patch_ids'
,
'int32'
,
(
args
.
max_source_length
,)),
(
'targets'
,
'int32'
,
(
args
.
max_target_length
,))
(
'targets'
,
'int32'
,
(
max_target_length
,))
]
)
func
=
partial
(
jobs
,
args
=
args
,
data_config
=
data_config
,
train
=
train
)
with
Pool
(
processes
=
args
.
num_workers
)
as
pool
:
with
tqdm
(
total
=
len
(
chunked_sha_msgs
))
as
pbar
:
for
i
,
_
in
tqdm
(
enumerate
(
pool
.
imap_unordered
(
func
,
chunked_sha_msgs
))):
pbar
.
update
()
def
main
(
args
):
if
'access_key'
not
in
os
.
environ
or
'secret_key'
not
in
os
.
environ
:
raise
OSError
(
"access_key or secret_key are not found."
)
sha_msgs
=
[(
c
.
hexsha
,
c
.
summary
)
for
c
in
repo
.
iter_commits
()]
random
.
shuffle
(
sha_msgs
)
chunked_sha_msgs
=
[
sha_msgs
[
x
:
x
+
args
.
matorage_batch
]
for
x
in
range
(
0
,
len
(
sha_msgs
),
args
.
matorage_batch
)
]
func
=
partial
(
jobs
,
args
=
args
,
data_config
=
data_config
)
with
Pool
(
processes
=
args
.
num_workers
)
as
pool
:
with
tqdm
(
total
=
len
(
chunked_sha_msgs
))
as
pbar
:
for
i
,
_
in
tqdm
(
enumerate
(
pool
.
imap_unordered
(
func
,
chunked_sha_msgs
))):
pbar
.
update
()
barrier
=
int
(
len
(
chunked_sha_msgs
)
*
(
1
-
args
.
p_val
))
start
(
chunked_sha_msgs
[:
barrier
],
train
=
True
)
start
(
chunked_sha_msgs
[
barrier
:],
train
=
False
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Code to collect commits on github"
)
...
...
@@ -196,6 +218,14 @@ if __name__ == "__main__":
help
=
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
,
)
parser
.
add_argument
(
"--val_max_target_length"
,
default
=
142
,
# these defaults are optimized for CNNDM. For xsum, see README.md.
type
=
int
,
help
=
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
,
)
parser
.
add_argument
(
"--p_val"
,
type
=
float
,
default
=
0.25
,
help
=
"percent of validation dataset"
)
args
=
parser
.
parse_args
()
args
.
local_path
=
args
.
url
.
split
(
'/'
)[
-
1
]
...
...
Please
register
or
login
to post a comment