How to pass a parameter to Scikit-Learn Keras model function

浪尽此生 提交于 2019-11-28 09:48:59

You can add an input_dim keyarg to the KerasClassifier constructor:

model = KerasClassifier(build_fn=create_model, input_dim=5, nb_epoch=150, batch_size=10, verbose=0)

Last answer does not work anymore.

An alternative is to return a function from create_model, as KerasClassifier build_fn expects a function:

def create_model(input_dim=None):
    def model():
        # create model
        nn = Sequential()
        nn.add(Dense(12, input_dim=input_dim, init='uniform', activation='relu'))
        nn.add(Dense(6, init='uniform', activation='relu'))
        nn.add(Dense(1, init='uniform', activation='sigmoid'))
        # Compile model
        nn.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
        return nn

    return model

Or even better, according to documentation

sk_params takes both model parameters and fitting parameters. Legal model parameters are the arguments of build_fn. Note that like all other estimators in scikit-learn, build_fn should provide default values for its arguments, so that you could create the estimator without passing any values to sk_params

So you can define your function like this:

def create_model(number_of_features=10): # 10 is the *default value*
    # create model
    nn = Sequential()
    nn.add(Dense(12, input_dim=number_of_features, init='uniform', activation='relu'))
    nn.add(Dense(6, init='uniform', activation='relu'))
    nn.add(Dense(1, init='uniform', activation='sigmoid'))
    # Compile model
    nn.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    return nn

And create a wrapper:

KerasClassifier(build_fn=create_model, number_of_features=20, epochs=25, batch_size=1000, ...)
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!