(refactor) argument renaming, add region, do_train and do_predict arguments
Showing
1 changed file
with
14 additions
and
3 deletions
... | @@ -139,9 +139,10 @@ def start(chunked_sha_msgs, train=True): | ... | @@ -139,9 +139,10 @@ def start(chunked_sha_msgs, train=True): |
139 | max_target_length = args.max_target_length if train else args.val_max_target_length | 139 | max_target_length = args.max_target_length if train else args.val_max_target_length |
140 | 140 | ||
141 | data_config = DataConfig( | 141 | data_config = DataConfig( |
142 | - endpoint=args.matorage_dir, | 142 | + endpoint=args.endpoint, |
143 | access_key=os.environ['access_key'], | 143 | access_key=os.environ['access_key'], |
144 | secret_key=os.environ['secret_key'], | 144 | secret_key=os.environ['secret_key'], |
145 | + region=args.region, | ||
145 | dataset_name='commit-autosuggestions', | 146 | dataset_name='commit-autosuggestions', |
146 | additional={ | 147 | additional={ |
147 | "mode" : ("training" if train else "evaluation"), | 148 | "mode" : ("training" if train else "evaluation"), |
... | @@ -175,7 +176,9 @@ def main(args): | ... | @@ -175,7 +176,9 @@ def main(args): |
175 | ] | 176 | ] |
176 | 177 | ||
177 | barrier = int(len(chunked_sha_msgs) * (1 - args.p_val)) | 178 | barrier = int(len(chunked_sha_msgs) * (1 - args.p_val)) |
179 | + if args.do_train: | ||
178 | start(chunked_sha_msgs[:barrier], train=True) | 180 | start(chunked_sha_msgs[:barrier], train=True) |
181 | + if args.do_predict: | ||
179 | start(chunked_sha_msgs[barrier:], train=False) | 182 | start(chunked_sha_msgs[barrier:], train=False) |
180 | 183 | ||
181 | if __name__ == "__main__": | 184 | if __name__ == "__main__": |
... | @@ -187,10 +190,16 @@ if __name__ == "__main__": | ... | @@ -187,10 +190,16 @@ if __name__ == "__main__": |
187 | help="github url" | 190 | help="github url" |
188 | ) | 191 | ) |
189 | parser.add_argument( | 192 | parser.add_argument( |
190 | - "--matorage_dir", | 193 | + "--endpoint", |
191 | type=str, | 194 | type=str, |
192 | required=True, | 195 | required=True, |
193 | - help='matorage saved directory.' | 196 | + help='matorage endpoint, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html' |
197 | + ) | ||
198 | + parser.add_argument( | ||
199 | + "--region", | ||
200 | + type=str, | ||
201 | + default=None, | ||
202 | + help='matorage s3 region, check document of matorage: https://matorage.readthedocs.io/en/stable/storage.html' | ||
194 | ) | 203 | ) |
195 | parser.add_argument( | 204 | parser.add_argument( |
196 | "--matorage_batch", | 205 | "--matorage_batch", |
... | @@ -226,6 +235,8 @@ if __name__ == "__main__": | ... | @@ -226,6 +235,8 @@ if __name__ == "__main__": |
226 | "than this will be truncated, sequences shorter will be padded.", | 235 | "than this will be truncated, sequences shorter will be padded.", |
227 | ) | 236 | ) |
228 | parser.add_argument("--p_val", type=float, default=0.25, help="percent of validation dataset") | 237 | parser.add_argument("--p_val", type=float, default=0.25, help="percent of validation dataset") |
238 | + parser.add_argument("--do_train", action="store_true", default=False) | ||
239 | + parser.add_argument("--do_predict", action="store_true", default=False) | ||
229 | args = parser.parse_args() | 240 | args = parser.parse_args() |
230 | 241 | ||
231 | args.local_path = args.url.split('/')[-1] | 242 | args.local_path = args.url.split('/')[-1] | ... | ... |
-
Please register or login to post a comment