Showing
1 changed file
with
323 additions
and
0 deletions
lib/utils/os.py
0 → 100644
1 | +""" general utility functions""" | ||
2 | +import argparse | ||
3 | +import importlib | ||
4 | +import json | ||
5 | +import logging | ||
6 | +import os | ||
7 | +import random | ||
8 | +import re | ||
9 | +import shutil | ||
10 | +import sys | ||
11 | +import typing | ||
12 | +from argparse import ArgumentParser | ||
13 | +from collections.abc import MutableMapping | ||
14 | + | ||
15 | +import numpy | ||
16 | +import torch | ||
17 | +from box import Box | ||
18 | + | ||
19 | +logger = logging.getLogger() | ||
20 | + | ||
21 | + | ||
22 | +def listorstr(inp): | ||
23 | + if len(inp) == 1: | ||
24 | + return try_cast(inp[0]) | ||
25 | + | ||
26 | + for i, val in enumerate(inp): | ||
27 | + inp[i] = try_cast(val) | ||
28 | + return inp | ||
29 | + | ||
30 | + | ||
31 | +def try_cast(text): | ||
32 | + """ try to cast to int or float if possible, else return the text itself""" | ||
33 | + result = try_int(text, None) | ||
34 | + if result is not None: | ||
35 | + return result | ||
36 | + | ||
37 | + result = try_float(text, None) | ||
38 | + if result is not None: | ||
39 | + return result | ||
40 | + | ||
41 | + return text | ||
42 | + | ||
43 | + | ||
44 | +def try_float(text, default: typing.Optional[int] = 0.0): | ||
45 | + result = default | ||
46 | + try: | ||
47 | + result = float(text) | ||
48 | + except Exception as _: | ||
49 | + pass | ||
50 | + return result | ||
51 | + | ||
52 | + | ||
53 | +def try_int(text, default: typing.Optional[int] = 0): | ||
54 | + result = default | ||
55 | + try: | ||
56 | + result = int(text) | ||
57 | + except Exception as _: | ||
58 | + pass | ||
59 | + return result | ||
60 | + | ||
61 | + | ||
62 | +def parse_args(parser: ArgumentParser) -> Box: | ||
63 | + # get defaults | ||
64 | + defaults = {} | ||
65 | + # taken from parser_known_args code | ||
66 | + # add any action defaults that aren't present | ||
67 | + for action in parser._actions: | ||
68 | + if action.dest is not argparse.SUPPRESS: | ||
69 | + if action.default is not argparse.SUPPRESS: | ||
70 | + defaults[action.dest] = action.default | ||
71 | + | ||
72 | + # add any parser defaults that aren't present | ||
73 | + for dest in parser._defaults: | ||
74 | + defaults[dest] = parser._defaults[dest] | ||
75 | + | ||
76 | + # check if there is config & read config | ||
77 | + args = parser.parse_args() | ||
78 | + if vars(args).get("config") is not None: | ||
79 | + # load a .py config | ||
80 | + configFile = args.config | ||
81 | + spec = importlib.util.spec_from_file_location("config", configFile) | ||
82 | + module = importlib.util.module_from_spec(spec) | ||
83 | + spec.loader.exec_module(module) | ||
84 | + config = module.config | ||
85 | + # merge config and override defaults | ||
86 | + defaults.update({k: v for k, v in config.items()}) | ||
87 | + | ||
88 | + # override defaults with command line params | ||
89 | + # this will get rid of defaults and only read command line args | ||
90 | + parser._defaults = {} | ||
91 | + parser._actions = {} | ||
92 | + args = parser.parse_args() | ||
93 | + defaults.update({k: v for k, v in vars(args).items()}) | ||
94 | + | ||
95 | + return boxify_dict(defaults) | ||
96 | + | ||
97 | + | ||
98 | +def boxify_dict(config): | ||
99 | + """ | ||
100 | + this takes a flat dictionary and break it into sub-dictionaries based on "." seperation | ||
101 | + a = {"model.a": 1, "model.b" : 2, "alpha" : 3} will return Box({"model" : {"a" :1, | ||
102 | + "b" : 2}, alpha:3}) | ||
103 | + a = {"model.a": 1, "model.b" : 2, "model" : 3} will throw error | ||
104 | + """ | ||
105 | + new_config = {} | ||
106 | + # iterate over keys and split on "." | ||
107 | + for key in config: | ||
108 | + if "." in key: | ||
109 | + temp_config = new_config | ||
110 | + for k in key.split(".")[:-1]: | ||
111 | + # create non-existent keys as dictionary recursively | ||
112 | + if temp_config.get(k) is None: | ||
113 | + temp_config[k] = {} | ||
114 | + elif not isinstance(temp_config.get(k), dict): | ||
115 | + raise TypeError(f"Key '{k}' has values as well as child") | ||
116 | + temp_config = temp_config[k] | ||
117 | + temp_config[key.split(".")[-1]] = config[key] | ||
118 | + else: | ||
119 | + if new_config.get(key) is None: | ||
120 | + new_config[key] = config[key] | ||
121 | + else: | ||
122 | + raise TypeError(f"Key '{key}' has values as well as child") | ||
123 | + | ||
124 | + return Box(new_config) | ||
125 | + | ||
126 | + | ||
127 | +# https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys | ||
128 | +def flatten(d, parent_key='', sep='.'): | ||
129 | + items = [] | ||
130 | + for k, v in d.items(): | ||
131 | + new_key = parent_key + sep + k if parent_key else k | ||
132 | + if isinstance(v, MutableMapping): | ||
133 | + items.extend(flatten(v, new_key, sep=sep).items()) | ||
134 | + else: | ||
135 | + items.append((new_key, v)) | ||
136 | + return Box(dict(items)) | ||
137 | + | ||
138 | + | ||
139 | +def str2bool(v: typing.Union[bool, str, int]) -> bool: | ||
140 | + if isinstance(v, bool): | ||
141 | + return v | ||
142 | + if v.lower() in ("yes", "true", "t", "y", "1", 1): | ||
143 | + return True | ||
144 | + if v.lower() in ("no", "false", "f", "n", "0", 0): | ||
145 | + return False | ||
146 | + raise TypeError("Boolean value expected.") | ||
147 | + | ||
148 | + | ||
149 | +def safe_isdir(dir_name): | ||
150 | + return os.path.exists(dir_name) and os.path.isdir(dir_name) | ||
151 | + | ||
152 | + | ||
153 | +def safe_makedirs(dir_name): | ||
154 | + try: | ||
155 | + os.makedirs(dir_name) | ||
156 | + except OSError as e: | ||
157 | + print(e) | ||
158 | + | ||
159 | + | ||
160 | +def jsonize(x: object) -> typing.Union[str, dict]: | ||
161 | + try: | ||
162 | + temp = json.dumps(x) | ||
163 | + return temp | ||
164 | + except Exception as e: | ||
165 | + return {} | ||
166 | + | ||
167 | + | ||
168 | +def copy_code(folder_to_copy, out_folder, replace=False): | ||
169 | + logger.info(f"copying {folder_to_copy} to {out_folder}") | ||
170 | + | ||
171 | + if os.path.exists(out_folder): | ||
172 | + if not os.path.isdir(out_folder): | ||
173 | + logger.error(f"{out_folder} is not a directory") | ||
174 | + sys.exit() | ||
175 | + else: | ||
176 | + logger.info(f"Not deleting existing result folder: {out_folder}") | ||
177 | + else: | ||
178 | + os.makedirs(out_folder) | ||
179 | + | ||
180 | + # replace / with _ | ||
181 | + folder_name = f'{out_folder}/{re.sub("/", "_", folder_to_copy)}' | ||
182 | + | ||
183 | + # create a new copy if something already exists | ||
184 | + if not replace: | ||
185 | + i = 1 | ||
186 | + temp = folder_name | ||
187 | + while os.path.exists(temp): | ||
188 | + temp = f"{folder_name}_{i}" | ||
189 | + i += 1 | ||
190 | + folder_name = temp | ||
191 | + else: | ||
192 | + if os.path.exists(folder_name): | ||
193 | + if os.path.isdir(folder_name): | ||
194 | + shutil.rmtree(folder_name) | ||
195 | + else: | ||
196 | + raise FileExistsError("There is a file with same name as folder") | ||
197 | + | ||
198 | + logger.info(f"Copying {folder_to_copy} to {folder_name}") | ||
199 | + shutil.copytree(folder_to_copy, folder_name) | ||
200 | + | ||
201 | + | ||
202 | +def get_state_params(wandb_use, run_id, result_folder, statefile): | ||
203 | + """This searches for model and run id in result folder | ||
204 | + The logic is as follows | ||
205 | + | ||
206 | + if we are not given run_id there are four cases: | ||
207 | + - we want to restart the wandb run but too lazy to look up run-id or/and statefile | ||
208 | + - we want a new wandb fresh run | ||
209 | + - we are not using wandb at all and need to restart | ||
210 | + - we are not using wandb and need a fresh run | ||
211 | + | ||
212 | + Case 1/3: | ||
213 | + - If we want to restart the run, we expect the result_folder name to end with | ||
214 | + /run_<numeric>. | ||
215 | + - In this case, if we are using wandb then we need to go inside wandb folder, list all | ||
216 | + directory and pick up run id and (or) statefile | ||
217 | + - If we are not using wandb we just look for model inside the run_<numeric> folder and | ||
218 | + return statefile, run id as none | ||
219 | + | ||
220 | + case 2/4: | ||
221 | + if not 1/3, it is case 2/4 | ||
222 | + | ||
223 | + This is expected to be a fail safe script. i.e any of run_id or statefile may not be specified | ||
224 | + and relies on whims of the user _-_ | ||
225 | + """ | ||
226 | + # if not resume get run number and create result_folder/run_{run_num} | ||
227 | + # if someone is resuming we expect them to give the exact folder name upto run num. | ||
228 | + | ||
229 | + # this part of code searches for run_id i.e will work only if we are using wandb | ||
230 | + if run_id is None: | ||
231 | + # if result folder if of type folder/run_<num>, then search for current checkpoint and | ||
232 | + # run-id else we will just create a new run with run_<num+1> | ||
233 | + | ||
234 | + regex = r"^.*/?run_[0-9]+/?$" | ||
235 | + if re.match(regex, result_folder): | ||
236 | + | ||
237 | + # search for checkpoint and run-id if using wandb | ||
238 | + if wandb_use: | ||
239 | + # search in wandb folder if it exists else we want a new run | ||
240 | + if os.path.exists(f"{result_folder}/wandb/"): | ||
241 | + # case 1 | ||
242 | + for folder in sorted(os.listdir(f"{result_folder}/wandb/"), reverse=True): | ||
243 | + # assume run_<##> will have only single run, | ||
244 | + # also no other crap in this folder | ||
245 | + if os.path.exists(f"{result_folder}/wandb/{folder}/current_model.pt"): | ||
246 | + run_id = folder.split("-")[-1] | ||
247 | + logger.info(f"using run id {run_id}") | ||
248 | + # we are done break out of for loop | ||
249 | + break | ||
250 | + else: | ||
251 | + # case 3 | ||
252 | + # if not using wandb search within run_<num> directory | ||
253 | + logger.info(f"not using wandb") | ||
254 | + if os.path.exists(f"{result_folder}/current_model.pt"): | ||
255 | + statefile = f"{result_folder}/current_model.pt" | ||
256 | + logger.info(f"using statefile {statefile}") | ||
257 | + else: | ||
258 | + # just start a new run | ||
259 | + pass | ||
260 | + else: | ||
261 | + # trailing is not run_<num>; that means user wants a new fresh run | ||
262 | + # so we give a fresh run and create a new folder | ||
263 | + # case 2/4 | ||
264 | + last_run_num = max( | ||
265 | + [0] + [try_int(i[-4:]) for i in os.listdir(result_folder)]) + 1 | ||
266 | + result_folder = f"{result_folder}/run_{last_run_num:04d}" | ||
267 | + logger.info(f"Creating new run with {result_folder}") | ||
268 | + safe_makedirs(result_folder) | ||
269 | + | ||
270 | + # search for last checkpoint in case --statefile is none and we are resuming | ||
271 | + if run_id is not None and statefile is None: | ||
272 | + folders = sorted(os.listdir(f"{result_folder}/wandb"), reverse=True) | ||
273 | + for folder in folders: | ||
274 | + if run_id in folder: | ||
275 | + # check for current_model.pt | ||
276 | + if os.path.exists(f"{result_folder}/wandb/{folder}/current_model.pt"): | ||
277 | + statefile = f"{result_folder}/wandb/{folder}/current_model.pt" | ||
278 | + logger.info(f"Using state file {statefile} and run id {run_id}") | ||
279 | + break | ||
280 | + if statefile is None: | ||
281 | + raise Exception("Did not find statefile, exiting!!") | ||
282 | + return statefile, run_id, result_folder | ||
283 | + | ||
284 | + | ||
285 | +if __name__ == "__main__": | ||
286 | + # test boxify_dict | ||
287 | + a = {"model.a": 1, "m odel.b": 2, "alpha": 3} | ||
288 | + print(boxify_dict(a)) | ||
289 | + | ||
290 | + try: | ||
291 | + a = {"model.a": 1, "model.b": 2, "model": 3} | ||
292 | + print(boxify_dict(a)) | ||
293 | + except Exception as e: | ||
294 | + print(e) | ||
295 | + | ||
296 | + try: | ||
297 | + a = {"model": 4, "model.a": 1, "model.b": 2, "model": 3} | ||
298 | + print(boxify_dict(a)) | ||
299 | + except Exception as e: | ||
300 | + print(e) | ||
301 | + | ||
302 | + try: | ||
303 | + a = {"model.a": 1, "model": 4, "model.b": 2, "model": 3} | ||
304 | + print(boxify_dict(a)) | ||
305 | + except Exception as e: | ||
306 | + print(e) | ||
307 | + | ||
308 | + try: | ||
309 | + a = {"model": {"attr1": 1, "attr2": {"attr_attr_3": 3}}, "train": 10} | ||
310 | + print(flatten(a)) | ||
311 | + except Exception as e: | ||
312 | + print(e) | ||
313 | + | ||
314 | + | ||
315 | +def set_seed(seed): | ||
316 | + if isinstance(seed, list): | ||
317 | + torch_seed, numpy_seed, random_seed = seed | ||
318 | + else: | ||
319 | + torch_seed, numpy_seed, random_seed = seed, seed, seed | ||
320 | + | ||
321 | + torch.manual_seed(torch_seed) | ||
322 | + numpy.random.seed(numpy_seed) | ||
323 | + random.seed(random_seed) |
-
Please register or login to post a comment