make_script.py 599 Bytes
from model import mobilenetv3
import torch
import torch.nn as nn

model = mobilenetv3(n_class=8, blocknum=6)

model = torch.nn.DataParallel(model)
device = torch.device('cpu')
checkpoint = torch.load('output/All/48860_model=MobilenetV3-ep=3000-block=6-class=8/model_best.pth.tar', map_location = device)

model.load_state_dict(checkpoint['state_dict'])

model.to(device)

model.eval()

x = torch.randn(256,1,224,224)

print(x.shape)

jit_model = torch.jit.trace(model.module,x)

jit_model.save("mobilenetv3.pt")

#check jitModel is working
#output = jit_model(torch.ones(3,1,224,224))
#print(output)