train.py 648 Bytes
import argparse
from configs import Config
from trainer import Trainer
from unet_trainer import UNetTrainer


def main(args, cfg):
    if args.config == "segm":
        trainer = UNetTrainer(args, cfg)
    else:
        trainer = Trainer(args, cfg)
    trainer.fit()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Training custom model")
    parser.add_argument("--resume", default=None, type=str, help="resume training")
    parser.add_argument("config", default="config", type=str, help="config training")
    args = parser.parse_args()

    config = Config(f"./configs/{args.config}.yaml")
    main(args, config)