Showing
1 changed file
with
33 additions
and
0 deletions
src/models/regression.py
0 → 100644
1 | +import torch | ||
2 | + | ||
3 | +from lib.base_model import Base as BaseModel | ||
4 | + | ||
5 | + | ||
6 | +class Regression(BaseModel): | ||
7 | + | ||
8 | + def __init__(self, net): | ||
9 | + super().__init__() | ||
10 | + self.net = net | ||
11 | + | ||
12 | + def forward(self, batch): | ||
13 | + | ||
14 | + return self.net(batch[0].to(self.device)) | ||
15 | + | ||
16 | + def loss(self, pred, batch, reduce=True): | ||
17 | + ret_obj = {} | ||
18 | + y = batch[1].to(self.device).float() | ||
19 | + N = y.shape[0] | ||
20 | + y = y.reshape(N, -1) | ||
21 | + y_pred = pred.y_pred.reshape(N, -1) | ||
22 | + loss = torch.nn.functional.mse_loss(y_pred, y, reduction="none").sum(dim=1) | ||
23 | + | ||
24 | + mae = torch.abs(y_pred - y).sum(dim=1) | ||
25 | + | ||
26 | + if reduce: | ||
27 | + #print(sum(y[0])/len(y[[0]])) | ||
28 | + #print(sum(y_pred[0])/len(y_pred[0])) | ||
29 | + #print(sum((y_pred/y)[0])/len((y_pred/y)[0])) | ||
30 | + mae = mae.mean() | ||
31 | + loss = loss.mean() | ||
32 | + | ||
33 | + return loss, {"mse": loss, "mae": mae} |
-
Please register or login to post a comment