Save/Load MXNet model parameters using NumPy

不问归期 提交于 2021-01-28 09:36:06

问题


How can I save the parameters for an MXNet model into a NumPy file (.npy)? After doing so, how can I load these parameters from the .npy file back into my model?

Here is a minimal example to save the parameters for MXNet model using MXNet API.

import mxnet as mx
from mxnet import gluon
from mxnet.gluon.model_zoo import vision
import numpy as np

num_gpus = 0
ctx = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
resnet = vision.resnet50_v2(pretrained=True, ctx=ctx)

parameters = resnet.collect_params()
resnet.save_parameters('model.params')

Minimal example to load parameters back into the model from a file using MXNet API.

import mxnet as mx
from mxnet import gluon
from mxnet.gluon.model_zoo import vision
import numpy as np

num_gpus = 0
ctx = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
resnet = vision.resnet50_v2(pretrained=True, ctx=ctx)

resnet.load_parameters('model.params', ctx=ctx)

In both the above examples, I'm using MXNet API to save/load model parameters. Instead of this, I want to save/load the model using numpy and then use these parameters into my MXNet model. How can I do that?

来源:https://stackoverflow.com/questions/62942031/save-load-mxnet-model-parameters-using-numpy

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!