How to examine the feature weights of a Tensorflow LinearClassifier?

旧巷老猫 提交于 2019-12-22 13:52:59

问题


I am trying to understand the Large-scale Linear Models with TensorFlow documentation. The docs motivate these models as follows:

Linear model can be interpreted and debugged more easily than neural nets. You can examine the weights assigned to each feature to figure out what's having the biggest impact on a prediction.

So I ran the extended code example from the accompanying TensorFlow Linear Model Tutorial. In particular, I ran the example code from GitHub with the model-type flag set to wide. This correctly ran and produced accuracy: 0.833733, similar to the accuracy: 0.83557522 on the Tensorflow web page.

The example uses a tf.estimator.LinearClassifier to train the weights. However, in contrast to the quoted motivation of being able to examine the weights, I can't find any function to actually extract the trained weights in the LinearClassifier documentation.

Question: how do I access the trained weights for the various feature columns in a tf.estimator.LinearClassifier? I'd prefer to be able to extract all the weights in a NumPy array.

Note: I am coming from an R environment where linear regression / classification models have a coefs method to extract learned weights. I want to be able to compare linear models in both R and TensorFlow on the same datasets.


回答1:


After training the model with Estimator, you could use the tf.train.load_variable to retrieve the weights from checkpoint. You can use tf.train.list_variables to find the names for model weights.

There are plans to add this support in Estimator directly also.



来源:https://stackoverflow.com/questions/46131410/how-to-examine-the-feature-weights-of-a-tensorflow-linearclassifier

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!