Hyunji

os

1 +""" general utility functions"""
2 +import argparse
3 +import importlib
4 +import json
5 +import logging
6 +import os
7 +import random
8 +import re
9 +import shutil
10 +import sys
11 +import typing
12 +from argparse import ArgumentParser
13 +from collections.abc import MutableMapping
14 +
15 +import numpy
16 +import torch
17 +from box import Box
18 +
19 +logger = logging.getLogger()
20 +
21 +
22 +def listorstr(inp):
23 + if len(inp) == 1:
24 + return try_cast(inp[0])
25 +
26 + for i, val in enumerate(inp):
27 + inp[i] = try_cast(val)
28 + return inp
29 +
30 +
31 +def try_cast(text):
32 + """ try to cast to int or float if possible, else return the text itself"""
33 + result = try_int(text, None)
34 + if result is not None:
35 + return result
36 +
37 + result = try_float(text, None)
38 + if result is not None:
39 + return result
40 +
41 + return text
42 +
43 +
44 +def try_float(text, default: typing.Optional[int] = 0.0):
45 + result = default
46 + try:
47 + result = float(text)
48 + except Exception as _:
49 + pass
50 + return result
51 +
52 +
53 +def try_int(text, default: typing.Optional[int] = 0):
54 + result = default
55 + try:
56 + result = int(text)
57 + except Exception as _:
58 + pass
59 + return result
60 +
61 +
62 +def parse_args(parser: ArgumentParser) -> Box:
63 + # get defaults
64 + defaults = {}
65 + # taken from parser_known_args code
66 + # add any action defaults that aren't present
67 + for action in parser._actions:
68 + if action.dest is not argparse.SUPPRESS:
69 + if action.default is not argparse.SUPPRESS:
70 + defaults[action.dest] = action.default
71 +
72 + # add any parser defaults that aren't present
73 + for dest in parser._defaults:
74 + defaults[dest] = parser._defaults[dest]
75 +
76 + # check if there is config & read config
77 + args = parser.parse_args()
78 + if vars(args).get("config") is not None:
79 + # load a .py config
80 + configFile = args.config
81 + spec = importlib.util.spec_from_file_location("config", configFile)
82 + module = importlib.util.module_from_spec(spec)
83 + spec.loader.exec_module(module)
84 + config = module.config
85 + # merge config and override defaults
86 + defaults.update({k: v for k, v in config.items()})
87 +
88 + # override defaults with command line params
89 + # this will get rid of defaults and only read command line args
90 + parser._defaults = {}
91 + parser._actions = {}
92 + args = parser.parse_args()
93 + defaults.update({k: v for k, v in vars(args).items()})
94 +
95 + return boxify_dict(defaults)
96 +
97 +
98 +def boxify_dict(config):
99 + """
100 + this takes a flat dictionary and break it into sub-dictionaries based on "." seperation
101 + a = {"model.a": 1, "model.b" : 2, "alpha" : 3} will return Box({"model" : {"a" :1,
102 + "b" : 2}, alpha:3})
103 + a = {"model.a": 1, "model.b" : 2, "model" : 3} will throw error
104 + """
105 + new_config = {}
106 + # iterate over keys and split on "."
107 + for key in config:
108 + if "." in key:
109 + temp_config = new_config
110 + for k in key.split(".")[:-1]:
111 + # create non-existent keys as dictionary recursively
112 + if temp_config.get(k) is None:
113 + temp_config[k] = {}
114 + elif not isinstance(temp_config.get(k), dict):
115 + raise TypeError(f"Key '{k}' has values as well as child")
116 + temp_config = temp_config[k]
117 + temp_config[key.split(".")[-1]] = config[key]
118 + else:
119 + if new_config.get(key) is None:
120 + new_config[key] = config[key]
121 + else:
122 + raise TypeError(f"Key '{key}' has values as well as child")
123 +
124 + return Box(new_config)
125 +
126 +
127 +# https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys
128 +def flatten(d, parent_key='', sep='.'):
129 + items = []
130 + for k, v in d.items():
131 + new_key = parent_key + sep + k if parent_key else k
132 + if isinstance(v, MutableMapping):
133 + items.extend(flatten(v, new_key, sep=sep).items())
134 + else:
135 + items.append((new_key, v))
136 + return Box(dict(items))
137 +
138 +
139 +def str2bool(v: typing.Union[bool, str, int]) -> bool:
140 + if isinstance(v, bool):
141 + return v
142 + if v.lower() in ("yes", "true", "t", "y", "1", 1):
143 + return True
144 + if v.lower() in ("no", "false", "f", "n", "0", 0):
145 + return False
146 + raise TypeError("Boolean value expected.")
147 +
148 +
149 +def safe_isdir(dir_name):
150 + return os.path.exists(dir_name) and os.path.isdir(dir_name)
151 +
152 +
153 +def safe_makedirs(dir_name):
154 + try:
155 + os.makedirs(dir_name)
156 + except OSError as e:
157 + print(e)
158 +
159 +
160 +def jsonize(x: object) -> typing.Union[str, dict]:
161 + try:
162 + temp = json.dumps(x)
163 + return temp
164 + except Exception as e:
165 + return {}
166 +
167 +
168 +def copy_code(folder_to_copy, out_folder, replace=False):
169 + logger.info(f"copying {folder_to_copy} to {out_folder}")
170 +
171 + if os.path.exists(out_folder):
172 + if not os.path.isdir(out_folder):
173 + logger.error(f"{out_folder} is not a directory")
174 + sys.exit()
175 + else:
176 + logger.info(f"Not deleting existing result folder: {out_folder}")
177 + else:
178 + os.makedirs(out_folder)
179 +
180 + # replace / with _
181 + folder_name = f'{out_folder}/{re.sub("/", "_", folder_to_copy)}'
182 +
183 + # create a new copy if something already exists
184 + if not replace:
185 + i = 1
186 + temp = folder_name
187 + while os.path.exists(temp):
188 + temp = f"{folder_name}_{i}"
189 + i += 1
190 + folder_name = temp
191 + else:
192 + if os.path.exists(folder_name):
193 + if os.path.isdir(folder_name):
194 + shutil.rmtree(folder_name)
195 + else:
196 + raise FileExistsError("There is a file with same name as folder")
197 +
198 + logger.info(f"Copying {folder_to_copy} to {folder_name}")
199 + shutil.copytree(folder_to_copy, folder_name)
200 +
201 +
202 +def get_state_params(wandb_use, run_id, result_folder, statefile):
203 + """This searches for model and run id in result folder
204 + The logic is as follows
205 +
206 + if we are not given run_id there are four cases:
207 + - we want to restart the wandb run but too lazy to look up run-id or/and statefile
208 + - we want a new wandb fresh run
209 + - we are not using wandb at all and need to restart
210 + - we are not using wandb and need a fresh run
211 +
212 + Case 1/3:
213 + - If we want to restart the run, we expect the result_folder name to end with
214 + /run_<numeric>.
215 + - In this case, if we are using wandb then we need to go inside wandb folder, list all
216 + directory and pick up run id and (or) statefile
217 + - If we are not using wandb we just look for model inside the run_<numeric> folder and
218 + return statefile, run id as none
219 +
220 + case 2/4:
221 + if not 1/3, it is case 2/4
222 +
223 + This is expected to be a fail safe script. i.e any of run_id or statefile may not be specified
224 + and relies on whims of the user _-_
225 + """
226 + # if not resume get run number and create result_folder/run_{run_num}
227 + # if someone is resuming we expect them to give the exact folder name upto run num.
228 +
229 + # this part of code searches for run_id i.e will work only if we are using wandb
230 + if run_id is None:
231 + # if result folder if of type folder/run_<num>, then search for current checkpoint and
232 + # run-id else we will just create a new run with run_<num+1>
233 +
234 + regex = r"^.*/?run_[0-9]+/?$"
235 + if re.match(regex, result_folder):
236 +
237 + # search for checkpoint and run-id if using wandb
238 + if wandb_use:
239 + # search in wandb folder if it exists else we want a new run
240 + if os.path.exists(f"{result_folder}/wandb/"):
241 + # case 1
242 + for folder in sorted(os.listdir(f"{result_folder}/wandb/"), reverse=True):
243 + # assume run_<##> will have only single run,
244 + # also no other crap in this folder
245 + if os.path.exists(f"{result_folder}/wandb/{folder}/current_model.pt"):
246 + run_id = folder.split("-")[-1]
247 + logger.info(f"using run id {run_id}")
248 + # we are done break out of for loop
249 + break
250 + else:
251 + # case 3
252 + # if not using wandb search within run_<num> directory
253 + logger.info(f"not using wandb")
254 + if os.path.exists(f"{result_folder}/current_model.pt"):
255 + statefile = f"{result_folder}/current_model.pt"
256 + logger.info(f"using statefile {statefile}")
257 + else:
258 + # just start a new run
259 + pass
260 + else:
261 + # trailing is not run_<num>; that means user wants a new fresh run
262 + # so we give a fresh run and create a new folder
263 + # case 2/4
264 + last_run_num = max(
265 + [0] + [try_int(i[-4:]) for i in os.listdir(result_folder)]) + 1
266 + result_folder = f"{result_folder}/run_{last_run_num:04d}"
267 + logger.info(f"Creating new run with {result_folder}")
268 + safe_makedirs(result_folder)
269 +
270 + # search for last checkpoint in case --statefile is none and we are resuming
271 + if run_id is not None and statefile is None:
272 + folders = sorted(os.listdir(f"{result_folder}/wandb"), reverse=True)
273 + for folder in folders:
274 + if run_id in folder:
275 + # check for current_model.pt
276 + if os.path.exists(f"{result_folder}/wandb/{folder}/current_model.pt"):
277 + statefile = f"{result_folder}/wandb/{folder}/current_model.pt"
278 + logger.info(f"Using state file {statefile} and run id {run_id}")
279 + break
280 + if statefile is None:
281 + raise Exception("Did not find statefile, exiting!!")
282 + return statefile, run_id, result_folder
283 +
284 +
285 +if __name__ == "__main__":
286 + # test boxify_dict
287 + a = {"model.a": 1, "m odel.b": 2, "alpha": 3}
288 + print(boxify_dict(a))
289 +
290 + try:
291 + a = {"model.a": 1, "model.b": 2, "model": 3}
292 + print(boxify_dict(a))
293 + except Exception as e:
294 + print(e)
295 +
296 + try:
297 + a = {"model": 4, "model.a": 1, "model.b": 2, "model": 3}
298 + print(boxify_dict(a))
299 + except Exception as e:
300 + print(e)
301 +
302 + try:
303 + a = {"model.a": 1, "model": 4, "model.b": 2, "model": 3}
304 + print(boxify_dict(a))
305 + except Exception as e:
306 + print(e)
307 +
308 + try:
309 + a = {"model": {"attr1": 1, "attr2": {"attr_attr_3": 3}}, "train": 10}
310 + print(flatten(a))
311 + except Exception as e:
312 + print(e)
313 +
314 +
315 +def set_seed(seed):
316 + if isinstance(seed, list):
317 + torch_seed, numpy_seed, random_seed = seed
318 + else:
319 + torch_seed, numpy_seed, random_seed = seed, seed, seed
320 +
321 + torch.manual_seed(torch_seed)
322 + numpy.random.seed(numpy_seed)
323 + random.seed(random_seed)