Hyunji

logging

1 +""" logging related functionality """
2 +
3 +
4 +import logging
5 +
6 +from torch.utils.tensorboard import SummaryWriter
7 +
8 +logger = logging.getLogger()
9 +
10 +
11 +def print_verbose(string, verbose):
12 + if verbose:
13 + print(string)
14 +
15 +
16 +def loss_logger_helper(
17 + loss, aux_loss, writer: SummaryWriter, step: int, epoch: int, log_every: int,
18 + string: str = "train", force_print: bool = False, new_line: bool = False
19 +):
20 + # write to tensorboard at every step but only print at log step or when force_print is passed
21 + writer.add_scalar(f"{string}/loss", loss, step)
22 + for k, v in aux_loss.items():
23 + writer.add_scalar(f"{string}/" + k, v, step)
24 +
25 + if step % log_every == 0 or force_print:
26 + logger.info(f"{string}/loss: ({step}/{epoch}) {loss}")
27 +
28 +
29 +
30 + if force_print:
31 + if new_line:
32 + for k, v in aux_loss.items():
33 + logger.info(f"{string}/{k}:{v} ")
34 + else:
35 + str_ = ""
36 + for k, v in aux_loss.items():
37 + str_ += f"{string}/{k}:{v} "
38 + logger.info(f"{str_}")