mxnet 线性模型

落爺英雄遲暮 提交于 2020-11-06 05:41:10

mxnet 线性模型

import mxnet import mxnet.ndarray as nd from mxnet import gluon from mxnet import autograd # create data def set_data(true_w, true_b, num_examples, *args, **kwargs): num_inputs = len(true_w) X = nd.random_normal(shape=(num_examples, num_inputs)) y = 0 for num in range(num_inputs): # print(num) y += true_w[num] * X[:, num] y += true_b y += 0.1 * nd.random_normal(shape=y.shape) return X, y # create data loader def data_loader(batch_size, X, y, shuffle=False): data_set = gluon.data.ArrayDataset(X, y) data_iter = gluon.data.DataLoader(dataset=data_set, batch_size=batch_size, shuffle=shuffle) return data_iter # create net def set_net(node_num): net = gluon.nn.Sequential() net.add(gluon.nn.Dense(node_num)) net.initialize() return net # create trainer def trainer(net, loss_method, learning_rate): trainer = gluon.Trainer( net.collect_params(), loss_method, {'learning_rate': learning_rate} ) return trainer square_loss = gluon.loss.L2Loss() # start train def start_train(epochs, batch_size, data_iter, net, loss_method, tariner, num_examples): for e in range(epochs): total_loss = 0 for data, label in data_iter: with autograd.record(): output = net(data) loss = loss_method(output, label) loss.backward() trainer.step(batch_size) total_loss += nd.sum(loss).asscalar() print("第 %d次训练, 平均损失: %f" % (e, total_loss / 1000)) dense = net[0] print(dense.weight.data()) print(dense.bias.data()) return dense.weight.data(), dense.bias.data() true_w = [5, 8, 6] true_b = 6 X, y = set_data(true_w=true_w, true_b=true_b, num_examples=1000) data_iter = data_loader(batch_size=10, X=X, y=y, shuffle=True) net = set_net(1) trainer = trainer(net=net, loss_method='sgd', learning_rate=0.1) start_train(epochs=5, batch_size=10, data_iter=data_iter, net=net, loss_method=square_loss, tariner=trainer, num_examples=1000) <wiz_code_mirror>
 
 
 
74
def data_loader(batch_size, X, y, shuffle=False):
 
 
 
 
1
import mxnet
2
import mxnet.ndarray as nd
3
from mxnet import gluon
4
from mxnet import autograd
5
6
7
# create data
8
9
def set_data(true_w, true_b, num_examples, *args, **kwargs):
10
    num_inputs = len(true_w)
11
    X = nd.random_normal(shape=(num_examples, num_inputs))
12
    y = 0
13
    for num in range(num_inputs):
14
        # print(num)
15
        y += true_w[num] * X[:, num]
16
    y += true_b
17
    y += 0.1 * nd.random_normal(shape=y.shape)
18
    return X, y
19
20
21
# create data loader
22
def data_loader(batch_size, X, y, shuffle=False):
23
    data_set = gluon.data.ArrayDataset(X, y)
24
    data_iter = gluon.data.DataLoader(dataset=data_set, batch_size=batch_size, shuffle=shuffle)
25
    return data_iter
26
27
28
# create net
29
def set_net(node_num):
30
    net = gluon.nn.Sequential()
31
    net.add(gluon.nn.Dense(node_num))
32
    net.initialize()
33
    return net
34
35
36
# create trainer
37
def trainer(net, loss_method, learning_rate):
38
    trainer = gluon.Trainer(
39
        net.collect_params(), loss_method, {'learning_rate': learning_rate}
40
    )
41
    return trainer
42
43
44
square_loss = gluon.loss.L2Loss()
45
46
47
# start train
48
def start_train(epochs, batch_size, data_iter, net, loss_method, tariner, num_examples):
49
    for e in range(epochs):
50
        total_loss = 0
51
        for data, label in data_iter:
52
            with autograd.record():
53
                output = net(data)
54
                loss = loss_method(output, label)
55
            loss.backward()
56
            trainer.step(batch_size)
57
            total_loss += nd.sum(loss).asscalar()
58
        print("第 %d次训练, 平均损失: %f" % (e, total_loss / 1000))
59
    dense = net[0]
60
61
    print(dense.weight.data())
62
    print(dense.bias.data())
63
    return dense.weight.data(), dense.bias.data()
64
65
66
true_w = [5, 8, 6]
67
true_b = 6
68
X, y = set_data(true_w=true_w, true_b=true_b, num_examples=1000)
69
data_iter = data_loader(batch_size=10, X=X, y=y, shuffle=True)
70
net = set_net(1)
71
trainer = trainer(net=net, loss_method='sgd', learning_rate=0.1)
72
start_train(epochs=5, batch_size=10, data_iter=data_iter, net=net, loss_method=square_loss, tariner=trainer,
73
            num_examples=1000)
74
 
 
标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!