MXNet Python Model API

The model API provides a simplified way to train neural networks using common best practices. It’s a thin wrapper built on top of the ndarray and symbolic modules that makes neural network training easy.


Train the Model

To train a model, perform two steps: configure the model using the symbol parameter, then call model.Feedforward.create to create the model. The following example creates a two-layer neural network.

    # configure a two layer neuralnetwork
    data = mx.symbol.Variable('data')
    fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
    act1 = mx.symbol.Activation(fc1, name='relu1', act_type='relu')
    fc2 = mx.symbol.FullyConnected(act1, name='fc2', num_hidden=64)
    softmax = mx.symbol.SoftmaxOutput(fc2, name='sm')
    # create a model
    model = mx.model.FeedForward.create(

You can also use the scikit-learn-style construct and fit function to create a model.

    # create a model using sklearn-style two step way
    model = mx.model.FeedForward(

For more information, see Model API Reference.

Save the Model

After the job is done, save your work. To save the model, you can directly pickle it with Python. We also provide save and load functions.

    # save a model to mymodel-symbol.json and mymodel-0100.params
    prefix = 'mymodel'
    iteration = 100, iteration)

    # load model back
    model_loaded = mx.model.FeedForward.load(prefix, iteration)

The advantage of these save and load functions is that they are language agnostic. You should be able to save and load directly into cloud storage, such as Amazon S3 and HDFS.

Periodic Checkpointing

We recommend checkpointing your model after each iteration. To do this, add a checkpoint callback do_checkpoint(path) to the function. The training process automatically checkpoints the specified location after each iteration.

    model = mx.model.FeedForward.create(

You can load the model checkpoint later using Feedforward.load.

Use Multiple Devices

Set ctx to the list of devices that you want to train on.

    devices = [mx.gpu(i) for i in range(num_device)]
    model = mx.model.FeedForward.create(

Training occurs in parallel on the GPUs that you specify.

Initializer API Reference

Evaluation Metric API Reference

Optimizer API Reference

Model API Reference

Next Steps