-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
import torch
from model import SSD300
from apex.parallel.LARC import LARC
from apex import amp
from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import *
def load_model_weight(model, model_path):
model_checkpoint=torch.load(model_path)["model"]
model.load_state_dict(model_checkpoint.state_dict())
#model = amp.initialize(model, opt_level='O2')
model = model.eval()
#model=model.to("cpu")
return model
def export_onnx_model(model, input_shape, onnx_path, input_names=None, output_names=None, dynamic_axes=None):
inputs_R = torch.ones((1,3,512,640)).to("cuda")
inputs_T = torch.ones((1,1,512,640)).to("cuda")
#import pdb;pdb.set_trace()
#model(inputs)
torch.onnx.export(model, (inputs_R,inputs_T), onnx_path, input_names=input_names, output_names=output_names)
if __name__ == "__main__":
checkpoint ='./checkpoint_ssd300.pth.tar083'
model_path = "model_checkpoint.pth"
model = SSD300(2)
model=model.to("cuda")
model = load_model_weight(model, checkpoint)
input_shape = (1, 1, 512, 640)
onnx_path = "halfway.onnx"
# input_names=['input']
# output_names=['output']
# dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
export_onnx_model(model, input_shape, onnx_path)
it was based this site
Metadata
Metadata
Assignees
Labels
No labels