make_script.py
599 Bytes
from model import mobilenetv3
import torch
import torch.nn as nn
model = mobilenetv3(n_class=8, blocknum=6)
model = torch.nn.DataParallel(model)
device = torch.device('cpu')
checkpoint = torch.load('output/All/48860_model=MobilenetV3-ep=3000-block=6-class=8/model_best.pth.tar', map_location = device)
model.load_state_dict(checkpoint['state_dict'])
model.to(device)
model.eval()
x = torch.randn(256,1,224,224)
print(x.shape)
jit_model = torch.jit.trace(model.module,x)
jit_model.save("mobilenetv3.pt")
#check jitModel is working
#output = jit_model(torch.ones(3,1,224,224))
#print(output)