os.py
10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
""" general utility functions"""
import argparse
import importlib
import json
import logging
import os
import random
import re
import shutil
import sys
import typing
from argparse import ArgumentParser
from collections.abc import MutableMapping
import numpy
import torch
from box import Box
logger = logging.getLogger()
def listorstr(inp):
if len(inp) == 1:
return try_cast(inp[0])
for i, val in enumerate(inp):
inp[i] = try_cast(val)
return inp
def try_cast(text):
""" try to cast to int or float if possible, else return the text itself"""
result = try_int(text, None)
if result is not None:
return result
result = try_float(text, None)
if result is not None:
return result
return text
def try_float(text, default: typing.Optional[int] = 0.0):
result = default
try:
result = float(text)
except Exception as _:
pass
return result
def try_int(text, default: typing.Optional[int] = 0):
result = default
try:
result = int(text)
except Exception as _:
pass
return result
def parse_args(parser: ArgumentParser) -> Box:
# get defaults
defaults = {}
# taken from parser_known_args code
# add any action defaults that aren't present
for action in parser._actions:
if action.dest is not argparse.SUPPRESS:
if action.default is not argparse.SUPPRESS:
defaults[action.dest] = action.default
# add any parser defaults that aren't present
for dest in parser._defaults:
defaults[dest] = parser._defaults[dest]
# check if there is config & read config
args = parser.parse_args()
if vars(args).get("config") is not None:
# load a .py config
configFile = args.config
spec = importlib.util.spec_from_file_location("config", configFile)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
config = module.config
# merge config and override defaults
defaults.update({k: v for k, v in config.items()})
# override defaults with command line params
# this will get rid of defaults and only read command line args
parser._defaults = {}
parser._actions = {}
args = parser.parse_args()
defaults.update({k: v for k, v in vars(args).items()})
return boxify_dict(defaults)
def boxify_dict(config):
"""
this takes a flat dictionary and break it into sub-dictionaries based on "." seperation
a = {"model.a": 1, "model.b" : 2, "alpha" : 3} will return Box({"model" : {"a" :1,
"b" : 2}, alpha:3})
a = {"model.a": 1, "model.b" : 2, "model" : 3} will throw error
"""
new_config = {}
# iterate over keys and split on "."
for key in config:
if "." in key:
temp_config = new_config
for k in key.split(".")[:-1]:
# create non-existent keys as dictionary recursively
if temp_config.get(k) is None:
temp_config[k] = {}
elif not isinstance(temp_config.get(k), dict):
raise TypeError(f"Key '{k}' has values as well as child")
temp_config = temp_config[k]
temp_config[key.split(".")[-1]] = config[key]
else:
if new_config.get(key) is None:
new_config[key] = config[key]
else:
raise TypeError(f"Key '{key}' has values as well as child")
return Box(new_config)
# https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys
def flatten(d, parent_key='', sep='.'):
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, MutableMapping):
items.extend(flatten(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return Box(dict(items))
def str2bool(v: typing.Union[bool, str, int]) -> bool:
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1", 1):
return True
if v.lower() in ("no", "false", "f", "n", "0", 0):
return False
raise TypeError("Boolean value expected.")
def safe_isdir(dir_name):
return os.path.exists(dir_name) and os.path.isdir(dir_name)
def safe_makedirs(dir_name):
try:
os.makedirs(dir_name)
except OSError as e:
print(e)
def jsonize(x: object) -> typing.Union[str, dict]:
try:
temp = json.dumps(x)
return temp
except Exception as e:
return {}
def copy_code(folder_to_copy, out_folder, replace=False):
logger.info(f"copying {folder_to_copy} to {out_folder}")
if os.path.exists(out_folder):
if not os.path.isdir(out_folder):
logger.error(f"{out_folder} is not a directory")
sys.exit()
else:
logger.info(f"Not deleting existing result folder: {out_folder}")
else:
os.makedirs(out_folder)
# replace / with _
folder_name = f'{out_folder}/{re.sub("/", "_", folder_to_copy)}'
# create a new copy if something already exists
if not replace:
i = 1
temp = folder_name
while os.path.exists(temp):
temp = f"{folder_name}_{i}"
i += 1
folder_name = temp
else:
if os.path.exists(folder_name):
if os.path.isdir(folder_name):
shutil.rmtree(folder_name)
else:
raise FileExistsError("There is a file with same name as folder")
logger.info(f"Copying {folder_to_copy} to {folder_name}")
shutil.copytree(folder_to_copy, folder_name)
def get_state_params(wandb_use, run_id, result_folder, statefile):
"""This searches for model and run id in result folder
The logic is as follows
if we are not given run_id there are four cases:
- we want to restart the wandb run but too lazy to look up run-id or/and statefile
- we want a new wandb fresh run
- we are not using wandb at all and need to restart
- we are not using wandb and need a fresh run
Case 1/3:
- If we want to restart the run, we expect the result_folder name to end with
/run_<numeric>.
- In this case, if we are using wandb then we need to go inside wandb folder, list all
directory and pick up run id and (or) statefile
- If we are not using wandb we just look for model inside the run_<numeric> folder and
return statefile, run id as none
case 2/4:
if not 1/3, it is case 2/4
This is expected to be a fail safe script. i.e any of run_id or statefile may not be specified
and relies on whims of the user _-_
"""
# if not resume get run number and create result_folder/run_{run_num}
# if someone is resuming we expect them to give the exact folder name upto run num.
# this part of code searches for run_id i.e will work only if we are using wandb
if run_id is None:
# if result folder if of type folder/run_<num>, then search for current checkpoint and
# run-id else we will just create a new run with run_<num+1>
regex = r"^.*/?run_[0-9]+/?$"
if re.match(regex, result_folder):
# search for checkpoint and run-id if using wandb
if wandb_use:
# search in wandb folder if it exists else we want a new run
if os.path.exists(f"{result_folder}/wandb/"):
# case 1
for folder in sorted(os.listdir(f"{result_folder}/wandb/"), reverse=True):
# assume run_<##> will have only single run,
# also no other crap in this folder
if os.path.exists(f"{result_folder}/wandb/{folder}/current_model.pt"):
run_id = folder.split("-")[-1]
logger.info(f"using run id {run_id}")
# we are done break out of for loop
break
else:
# case 3
# if not using wandb search within run_<num> directory
logger.info(f"not using wandb")
if os.path.exists(f"{result_folder}/current_model.pt"):
statefile = f"{result_folder}/current_model.pt"
logger.info(f"using statefile {statefile}")
else:
# just start a new run
pass
else:
# trailing is not run_<num>; that means user wants a new fresh run
# so we give a fresh run and create a new folder
# case 2/4
last_run_num = max(
[0] + [try_int(i[-4:]) for i in os.listdir(result_folder)]) + 1
result_folder = f"{result_folder}/run_{last_run_num:04d}"
logger.info(f"Creating new run with {result_folder}")
safe_makedirs(result_folder)
# search for last checkpoint in case --statefile is none and we are resuming
if run_id is not None and statefile is None:
folders = sorted(os.listdir(f"{result_folder}/wandb"), reverse=True)
for folder in folders:
if run_id in folder:
# check for current_model.pt
if os.path.exists(f"{result_folder}/wandb/{folder}/current_model.pt"):
statefile = f"{result_folder}/wandb/{folder}/current_model.pt"
logger.info(f"Using state file {statefile} and run id {run_id}")
break
if statefile is None:
raise Exception("Did not find statefile, exiting!!")
return statefile, run_id, result_folder
if __name__ == "__main__":
# test boxify_dict
a = {"model.a": 1, "m odel.b": 2, "alpha": 3}
print(boxify_dict(a))
try:
a = {"model.a": 1, "model.b": 2, "model": 3}
print(boxify_dict(a))
except Exception as e:
print(e)
try:
a = {"model": 4, "model.a": 1, "model.b": 2, "model": 3}
print(boxify_dict(a))
except Exception as e:
print(e)
try:
a = {"model.a": 1, "model": 4, "model.b": 2, "model": 3}
print(boxify_dict(a))
except Exception as e:
print(e)
try:
a = {"model": {"attr1": 1, "attr2": {"attr_attr_3": 3}}, "train": 10}
print(flatten(a))
except Exception as e:
print(e)
def set_seed(seed):
if isinstance(seed, list):
torch_seed, numpy_seed, random_seed = seed
else:
torch_seed, numpy_seed, random_seed = seed, seed, seed
torch.manual_seed(torch_seed)
numpy.random.seed(numpy_seed)
random.seed(random_seed)