os.py 10.5 KB
""" general utility functions"""
import argparse
import importlib
import json
import logging
import os
import random
import re
import shutil
import sys
import typing
from argparse import ArgumentParser
from collections.abc import MutableMapping

import numpy
import torch
from box import Box

logger = logging.getLogger()


def listorstr(inp):
    if len(inp) == 1:
        return try_cast(inp[0])

    for i, val in enumerate(inp):
        inp[i] = try_cast(val)
    return inp


def try_cast(text):
    """ try to cast to int or float if possible, else return the text itself"""
    result = try_int(text, None)
    if result is not None:
        return result

    result = try_float(text, None)
    if result is not None:
        return result

    return text


def try_float(text, default: typing.Optional[int] = 0.0):
    result = default
    try:
        result = float(text)
    except Exception as _:
        pass
    return result


def try_int(text, default: typing.Optional[int] = 0):
    result = default
    try:
        result = int(text)
    except Exception as _:
        pass
    return result


def parse_args(parser: ArgumentParser) -> Box:
    # get defaults
    defaults = {}
    # taken from parser_known_args code
    # add any action defaults that aren't present
    for action in parser._actions:
        if action.dest is not argparse.SUPPRESS:
            if action.default is not argparse.SUPPRESS:
                defaults[action.dest] = action.default

    # add any parser defaults that aren't present
    for dest in parser._defaults:
        defaults[dest] = parser._defaults[dest]

    # check if there is config & read config
    args = parser.parse_args()
    if vars(args).get("config") is not None:
        # load a .py config
        configFile = args.config
        spec = importlib.util.spec_from_file_location("config", configFile)
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)
        config = module.config
        # merge config and override defaults
        defaults.update({k: v for k, v in config.items()})

    # override defaults with command line params
    # this will get rid of defaults and only read command line args
    parser._defaults = {}
    parser._actions = {}
    args = parser.parse_args()
    defaults.update({k: v for k, v in vars(args).items()})

    return boxify_dict(defaults)


def boxify_dict(config):
    """
  this takes a flat dictionary and break it into sub-dictionaries based on "." seperation
    a = {"model.a":  1, "model.b" : 2,  "alpha" : 3} will return Box({"model" : {"a" :1,
    "b" : 2}, alpha:3})
    a = {"model.a":  1, "model.b" : 2,  "model" : 3} will throw error
  """
    new_config = {}
    # iterate over keys and split on "."
    for key in config:
        if "." in key:
            temp_config = new_config
            for k in key.split(".")[:-1]:
                # create non-existent keys as dictionary recursively
                if temp_config.get(k) is None:
                    temp_config[k] = {}
                elif not isinstance(temp_config.get(k), dict):
                    raise TypeError(f"Key '{k}' has values as well as child")
                temp_config = temp_config[k]
            temp_config[key.split(".")[-1]] = config[key]
        else:
            if new_config.get(key) is None:
                new_config[key] = config[key]
            else:
                raise TypeError(f"Key '{key}' has values as well as child")

    return Box(new_config)


# https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys
def flatten(d, parent_key='', sep='.'):
    items = []
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, MutableMapping):
            items.extend(flatten(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return Box(dict(items))


def str2bool(v: typing.Union[bool, str, int]) -> bool:
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1", 1):
        return True
    if v.lower() in ("no", "false", "f", "n", "0", 0):
        return False
    raise TypeError("Boolean value expected.")


def safe_isdir(dir_name):
    return os.path.exists(dir_name) and os.path.isdir(dir_name)


def safe_makedirs(dir_name):
    try:
        os.makedirs(dir_name)
    except OSError as e:
        print(e)


def jsonize(x: object) -> typing.Union[str, dict]:
    try:
        temp = json.dumps(x)
        return temp
    except Exception as e:
        return {}


def copy_code(folder_to_copy, out_folder, replace=False):
    logger.info(f"copying {folder_to_copy} to {out_folder}")

    if os.path.exists(out_folder):
        if not os.path.isdir(out_folder):
            logger.error(f"{out_folder} is not a directory")
            sys.exit()
        else:
            logger.info(f"Not deleting existing result folder: {out_folder}")
    else:
        os.makedirs(out_folder)

    # replace / with _
    folder_name = f'{out_folder}/{re.sub("/", "_", folder_to_copy)}'

    # create a new copy if something already exists
    if not replace:
        i = 1
        temp = folder_name
        while os.path.exists(temp):
            temp = f"{folder_name}_{i}"
            i += 1
        folder_name = temp
    else:
        if os.path.exists(folder_name):
            if os.path.isdir(folder_name):
                shutil.rmtree(folder_name)
            else:
                raise FileExistsError("There is a file with same name as folder")

    logger.info(f"Copying {folder_to_copy} to {folder_name}")
    shutil.copytree(folder_to_copy, folder_name)


def get_state_params(wandb_use, run_id, result_folder, statefile):
    """This searches for model and run id in result folder
        The logic is as follows

    if we are not given run_id there are four cases:
        - we want to restart the wandb run but too lazy to look up run-id or/and statefile
        - we want a new wandb fresh run
        - we are not using wandb at all and need to restart
        - we are not using wandb and need a fresh run

    Case 1/3:
        - If we want to restart the run, we expect the result_folder name to end with
        /run_<numeric>.
        - In this case, if we are using wandb then we need to go inside wandb folder, list all
        directory and pick up run id and (or) statefile
        - If we are not using wandb we just look for model inside the run_<numeric> folder and
        return statefile, run id as none

    case 2/4:
        if not 1/3, it is case 2/4

    This is expected to be a fail safe script. i.e any of run_id or statefile may not be specified
    and relies on whims of the user _-_
    """
    # if not resume get run number and create result_folder/run_{run_num}
    # if someone is resuming we expect them to give the exact folder name upto run num.

    # this part of code searches for run_id i.e will work only if we are using wandb
    if run_id is None:
        # if result folder if of type folder/run_<num>, then search for current checkpoint and
        # run-id else we will just create a new run with run_<num+1>

        regex = r"^.*/?run_[0-9]+/?$"
        if re.match(regex, result_folder):

            # search for checkpoint and run-id if using wandb
            if wandb_use:
                # search in wandb folder if it exists else we want a new run
                if os.path.exists(f"{result_folder}/wandb/"):
                    # case 1
                    for folder in sorted(os.listdir(f"{result_folder}/wandb/"), reverse=True):
                        # assume run_<##> will have only single run,
                        # also no other crap in this folder
                        if os.path.exists(f"{result_folder}/wandb/{folder}/current_model.pt"):
                            run_id = folder.split("-")[-1]
                            logger.info(f"using run id {run_id}")
                        # we are done break out of for loop
                        break
            else:
                # case 3
                # if not using wandb search within run_<num> directory
                logger.info(f"not using wandb")
                if os.path.exists(f"{result_folder}/current_model.pt"):
                    statefile = f"{result_folder}/current_model.pt"
                    logger.info(f"using statefile {statefile}")
                else:
                    # just start a new run
                    pass
        else:
            # trailing is not run_<num>; that means user wants a new fresh run
            # so we give a fresh run and create a new folder
            # case 2/4
            last_run_num = max(
                [0] + [try_int(i[-4:]) for i in os.listdir(result_folder)]) + 1
            result_folder = f"{result_folder}/run_{last_run_num:04d}"
            logger.info(f"Creating new run with {result_folder}")
            safe_makedirs(result_folder)

    # search for last checkpoint in case --statefile is none and we are resuming
    if run_id is not None and statefile is None:
        folders = sorted(os.listdir(f"{result_folder}/wandb"), reverse=True)
        for folder in folders:
            if run_id in folder:
                # check for current_model.pt
                if os.path.exists(f"{result_folder}/wandb/{folder}/current_model.pt"):
                    statefile = f"{result_folder}/wandb/{folder}/current_model.pt"
                    logger.info(f"Using state file {statefile} and run id {run_id}")
                    break
        if statefile is None:
            raise Exception("Did not find statefile, exiting!!")
    return statefile, run_id, result_folder


if __name__ == "__main__":
    # test boxify_dict
    a = {"model.a": 1, "m   odel.b": 2, "alpha": 3}
    print(boxify_dict(a))

    try:
        a = {"model.a": 1, "model.b": 2, "model": 3}
        print(boxify_dict(a))
    except Exception as e:
        print(e)

    try:
        a = {"model": 4, "model.a": 1, "model.b": 2, "model": 3}
        print(boxify_dict(a))
    except Exception as e:
        print(e)

    try:
        a = {"model.a": 1, "model": 4, "model.b": 2, "model": 3}
        print(boxify_dict(a))
    except Exception as e:
        print(e)

    try:
        a = {"model": {"attr1": 1, "attr2": {"attr_attr_3": 3}}, "train": 10}
        print(flatten(a))
    except Exception as e:
        print(e)


def set_seed(seed):
    if isinstance(seed, list):
        torch_seed, numpy_seed, random_seed = seed
    else:
        torch_seed, numpy_seed, random_seed = seed, seed, seed

    torch.manual_seed(torch_seed)
    numpy.random.seed(numpy_seed)
    random.seed(random_seed)