graykode

(refactor) argument renaming, add region, do_train and do_predict arguments

...@@ -139,9 +139,10 @@ def start(chunked_sha_msgs, train=True): ...@@ -139,9 +139,10 @@ def start(chunked_sha_msgs, train=True):
139 max_target_length = args.max_target_length if train else args.val_max_target_length 139 max_target_length = args.max_target_length if train else args.val_max_target_length
140 140
141 data_config = DataConfig( 141 data_config = DataConfig(
142 - endpoint=args.matorage_dir, 142 + endpoint=args.endpoint,
143 access_key=os.environ['access_key'], 143 access_key=os.environ['access_key'],
144 secret_key=os.environ['secret_key'], 144 secret_key=os.environ['secret_key'],
145 + region=args.region,
145 dataset_name='commit-autosuggestions', 146 dataset_name='commit-autosuggestions',
146 additional={ 147 additional={
147 "mode" : ("training" if train else "evaluation"), 148 "mode" : ("training" if train else "evaluation"),
...@@ -175,8 +176,10 @@ def main(args): ...@@ -175,8 +176,10 @@ def main(args):
175 ] 176 ]
176 177
177 barrier = int(len(chunked_sha_msgs) * (1 - args.p_val)) 178 barrier = int(len(chunked_sha_msgs) * (1 - args.p_val))
178 - start(chunked_sha_msgs[:barrier], train=True) 179 + if args.do_train:
179 - start(chunked_sha_msgs[barrier:], train=False) 180 + start(chunked_sha_msgs[:barrier], train=True)
181 + if args.do_predict:
182 + start(chunked_sha_msgs[barrier:], train=False)
180 183
181 if __name__ == "__main__": 184 if __name__ == "__main__":
182 parser = argparse.ArgumentParser(description="Code to collect commits on github") 185 parser = argparse.ArgumentParser(description="Code to collect commits on github")
...@@ -187,10 +190,16 @@ if __name__ == "__main__": ...@@ -187,10 +190,16 @@ if __name__ == "__main__":
187 help="github url" 190 help="github url"
188 ) 191 )
189 parser.add_argument( 192 parser.add_argument(
190 - "--matorage_dir", 193 + "--endpoint",
191 type=str, 194 type=str,
192 required=True, 195 required=True,
193 - help='matorage saved directory.' 196 + help='matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html'
197 + )
198 + parser.add_argument(
199 + "--region",
200 + type=str,
201 + default=None,
202 + help='matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html'
194 ) 203 )
195 parser.add_argument( 204 parser.add_argument(
196 "--matorage_batch", 205 "--matorage_batch",
...@@ -226,6 +235,8 @@ if __name__ == "__main__": ...@@ -226,6 +235,8 @@ if __name__ == "__main__":
226 "than this will be truncated, sequences shorter will be padded.", 235 "than this will be truncated, sequences shorter will be padded.",
227 ) 236 )
228 parser.add_argument("--p_val", type=float, default=0.25, help="percent of validation dataset") 237 parser.add_argument("--p_val", type=float, default=0.25, help="percent of validation dataset")
238 + parser.add_argument("--do_train", action="store_true", default=False)
239 + parser.add_argument("--do_predict", action="store_true", default=False)
229 args = parser.parse_args() 240 args = parser.parse_args()
230 241
231 args.local_path = args.url.split('/')[-1] 242 args.local_path = args.url.split('/')[-1]
......