how to use tf.metrics.__ with estimator model predict output

橙三吉。 提交于 2019-12-22 22:46:51

问题


I try to follow the tensorflow API 1.4 document to achieve what I need in a learning process.

I am now at this stage, can produce a predict object for example:

classifier = tf.estimator.DNNClassifier(feature_columns=feature_cols,hidden_units=[10, 20, 10], n_classes=3, model_dir="/tmp/xlz_model")

predict = classifier.predict(input_fn=input_pd_fn_prt (test_f),predict_keys=["class_ids"])
label =tf.constant(test_l.values, tf.int64)

how can I use predict and label in tf.metrics.auc for example:

out, opt = tf.metrics.auc(label, predict)

I have tried so many different options. there are no clear documentation how these tensorflow APIs can be should be used.


回答1:


The function returns 2 operations:

auc, update_op = tf.metrics.auc(...)

If you run sess.run(auc) you will get back the current auc value. This is the value you want to report on, for example, print sess.run([auc, cost], feed_dict={...}).

The AUC metric may need to be computed over many calls to sess.run. For example, when the dataset you're computing the AUC for doesn't fit in memory. That's where the update_op comes in. You need to call it each time to accumulate the values needed to compute auc.

So during a test set evaluation, you might have this:

for i in range(num_batches):
    sess.run([accuracy, cost, update_op], feed_dict={...})

print("Final (accumulated) AUC value):", sess.run(auc))

When you want to reset the accumulated values (before you re-evaluate your test set, for example) you should re-initialize your local variables. The tf.metrics package wisely adds its accumulator variables to the local variables collection, which don't include trainable variables such as weights by default.

sess.run(tf.local_variables_initializer())  # Resets AUC accumulator variables

https://www.tensorflow.org/api_docs/python/tf/metrics/auc



来源:https://stackoverflow.com/questions/47231777/how-to-use-tf-metrics-with-estimator-model-predict-output

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!