nolearn for multi-label classification

帅比萌擦擦* 提交于 2019-12-01 11:46:28

As mentioned by Francisco Vargas, nolearn.dbn is deprecated and you should use nolearn.lasagne instead (if you can).

If you want to do multi-label classification in lasagne, then you should set your regression parameter to True, define a validation score and a custom loss.

Here's an example:

import numpy as np
import theano.tensor as T
from lasagne import layers
from lasagne.updates import nesterov_momentum
from nolearn.lasagne import NeuralNet
from nolearn.lasagne import BatchIterator
from lasagne import nonlinearities

# custom loss: multi label cross entropy
def multilabel_objective(predictions, targets):
    epsilon = np.float32(1.0e-6)
    one = np.float32(1.0)
    pred = T.clip(predictions, epsilon, one - epsilon)
    return -T.sum(targets * T.log(pred) + (one - targets) * T.log(one - pred), axis=1)


net = NeuralNet(
    # customize "layers" to represent the architecture you want
    # here I took a dummy architecture
    layers=[(layers.InputLayer, {"name": 'input', 'shape': (None, 1, 229, 1)}),

            (layers.DenseLayer, {"name": 'hidden1', 'num_units': 20}),
            (layers.DenseLayer, {"name": 'output', 'nonlinearity': nonlinearities.sigmoid, 'num_units': 13})], #because you have 13 outputs

    # optimization method:
    update=nesterov_momentum,
    update_learning_rate=5*10**(-3),
    update_momentum=0.9,

    max_epochs=500,  # we want to train this many epochs
    verbose=1,

    #Here are the important parameters for multi labels
    regression=True,  

    objective_loss_function=multilabel_objective,
    custom_score=("validation score", lambda x, y: np.mean(np.abs(x - y)))

    )

net.fit(X_train, labels_train)

Fit calls BuildDBN which can be found here here an important thing to note is that dbn has been deprecated and you can only find it old_commits. Anyways if you are looking for extra info its probably good to check those two from what I can see in your snippet is that the first parameter of DBN namely [data, 300, 10] should be [data.shape[1], 300, 10] based on the documentation and the source code. Hope this helps.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!