问题
I have converted the Google Inception trained model .pb
file which reads like bellow:
A
mixed_9/join/concat_dimConst*
dtype0*
value :
A
mixed_8/join/concat_dimConst*
dtype0*
value :
A
mixed_7/join/concat_dimConst*
dtype0*
value :
A
mixed_6/join/concat_dimConst*
using Google Protobuf --decode_raw
which reads from stdin. Now, the output reads as .proto
file including the name of the layers and some encoded numbers. Here is the first 30 lines of .proto
file:
syntax="proto2";
1 {
1: "mixed_10/join/concat_dim"
2: "Const"
5 {
1: "dtype"
2 {
6: 3
}
}
5 {
1: "value"
2 {
8 {
1: 3
2: ""
7: "\003"
}
}
}
1 {
1: "mixed_9/join/concat_dim"
2: "Const"
5 {
1: "dtype"
2 {
6: 3
}
}
Parsing the file, I am looking for the trained weights of inception model, for instance in this case:
1 {
1: "Mul"
2 {
10: 108
12: 0x7265646c6f686563
}
5 {
1: "dtype"
2 {
6: 1
}
}
5 {
1: "shape"
2 {
7: ""
}
}
}
On the other hand, using a small python script I could print out all the tensors in the inception model:
import tensorflow as tf
from tensorflow.python.platform import gfile
INCEPTION_LOG_DIR = '/tmp/inception_v3_log'
if not os.path.exists(INCEPTION_LOG_DIR):
os.makedirs(INCEPTION_LOG_DIR)
with tf.Session() as sess:
model_filename = './model/tensorflow_inception_v3_stripped_optimized_quantized.pb'
with gfile.FastGFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_= tf.import_graph_def(graph_def,name='')
pprint([out for op in tf.get_default_graph().get_operations() if op.type != 'Placeholder' for out in op.values() if out.dtype == tf.float32])
I have generated all the layers of that model. So, that Mul
layer corresponds to the middle line of the output of my Python script:
(<tf.Tensor 'mixed/join/concat_dim:0' shape=() dtype=int32>,)
(<tf.Tensor 'Mul:0' shape=<unknown> dtype=float32>,)
(<tf.Tensor 'conv/conv2d_params_quint8_const:0' shape=(3, 3, 3, 32) dtype=quint8>,)
My issue is that I don't find a way to read these float32 values which I assume are the weights for each layer.
I have tried protoc
v3.3 on my .proto
file but I am receiving an error:
$ protoc inception.proto.utf --print_free_field_numbers
inception.proto.utf:2:1: Expected top-level statement (e.g. "message").
Any help would be appreciated.
P.s: The .pb
file of the inception_model is available here.
回答1:
Unless your model doesn't have any variables (trained model parameters), or they have already been converted to constants before export, you'll also need to load variable values from a separate checkpoint file. They also mayb be difficult to load in because from what I understand .pb files don't save the collections variables were in when saved. MetaGraphDef
s were created for this reason, and there's a good chance you'll be better off looking for a relevant one of these.
If your model truly doesn't have any variables, you should be able to get the values of that layer by running the session after loading the graph def.
session.run('Mul:0')
You may have to use a feed_dict
if the model has placeholders.
Note: these won't be the weights of the layer, but the result of the multiplication.
来源:https://stackoverflow.com/questions/45829802/seeing-the-float32-weight-in-a-proto-file