Problem Using Keras Sequential Model for “reinforcelearn” Package in R

大城市里の小女人 提交于 2020-05-30 12:19:38

问题


I am trying to use a keras(version 2.2.50) neural network / sequential model to create a simple agent in a reinforcement learning setting using the reinforcelearn package (version 0.2.1) according to this vignette: https://cran.r-project.org/web/packages/reinforcelearn/vignettes/agents.html . This is the code I use:

library('reinforcelearn')
library('keras')

model = keras_model_sequential() %>% 
  layer_dense(units = 10, input_shape = 4, activation = "linear") %>%
  compile(optimizer = optimizer_sgd(lr = 0.1), loss = "mae")

agent = makeAgent(policy = "softmax", val.fun = "neural.network", algorithm = "qlearning",
                  val.fun.args = list(model= model))

However, when I try to run the makeAgent function I get the following error message:

Error in .subset2(public_bind_env, "initialize")(...) : 
  Assertion on 'model' failed: Must inherit from class 'keras.models.Sequential', but has classes 'keras.engine.sequential.Sequential','keras.engine.training.Model','keras.engine.network.Network','keras.engine.base_layer.Layer','tensorflow.python.module.module.Module','tensorflow.python.training.tracking.tracking.AutoTrackable','tensorflow.python.training.tracking.base.Trackable','python.builtin.object'.

The problem seems to be the wrong class of the model, but what could I do to solve this problem?


回答1:


I was able to solve the problem by downloading the source code from CRAN (https://cran.r-project.org/src/contrib/reinforcelearn_0.2.1.tar.gz) and commenting out the respective line in the definition of the ValueNetwork R6 class / initialise function:

ValueNetwork = R6::R6Class("ValueNetwork",
  public = list(
    model = NULL,

    # keras model # fixme: add support for mxnet
    initialize = function(model) {
      # checkmate::assertClass(model, "keras.models.Sequential")
      self$model = model
    },
...

Then I just reinstalled the package from source via: install.packages([file path], repos = NULL, type="source")



来源:https://stackoverflow.com/questions/60971763/problem-using-keras-sequential-model-for-reinforcelearn-package-in-r

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!