How can I view weights in a .tflite file?

↘锁芯ラ 提交于 2020-05-13 06:33:10

问题


I get the pre-trained .pb file of MobileNet and find it's not quantized while the fully quantized model should be converted into .tflite format. Since I'm not familiar with tools for mobile app developing, how can I get the fully quantized weights of MobileNet from .tflite file. More precisely, how can I extract quantized parameters and view its numerical values ?


回答1:


The Netron model viewer has nice view and export of data, as well as a nice network diagram view. https://github.com/lutzroeder/netron




回答2:


I'm also in the process of studying how TFLite works. What I found may not be the best approach and I would appreciate any expert opinions. Here's what I found so far using flatbuffer python API.

First you'll need to compile the schema with flatbuffer. The output will be a folder called tflite.

flatc --python tensorflow/contrib/lite/schema/schema.fbs

Then you can load the model and get the tensor you want. Tensor has a method called Buffer() which is, according to the schema,

An index that refers to the buffers table at the root of the model.

So it points you to the location of the data.

from tflite import Model
buf = open('/path/to/mode.tflite', 'rb').read()
model = Model.Model.GetRootAsModel(buf, 0)
subgraph = model.Subgraphs(0)
# Check tensor.Name() to find the tensor_idx you want
tensor = subgraph.Tensors(tensor_idx) 
buffer_idx = tensor.Buffer()
buffer = model.Buffers(buffer_idx)

After that you'll be able to read the data by calling buffer.Data()

Reference: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/schema/schema.fbs https://github.com/google/flatbuffers/tree/master/samples




回答3:


Using TensorFlow 2.0, you can extract the weights and some information regarding the tensor (shape, dtype, name, quantization) with the following script - inspired from TensorFlow documentation

import tensorflow as tf
import h5py


# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="v3-large_224_1.0_uint8.tflite")
interpreter.allocate_tensors()

# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()


# get details for each layer
all_layers_details = interpreter.get_tensor_details() 


f = h5py.File("mobilenet_v3_weights_infos.hdf5", "w")   

for layer in all_layers_details:
     # to create a group in an hdf5 file
     grp = f.create_group(str(layer['index']))

     # to store layer's metadata in group's metadata
     grp.attrs["name"] = layer['name']
     grp.attrs["shape"] = layer['shape']
     # grp.attrs["dtype"] = all_layers_details[i]['dtype']
     grp.attrs["quantization"] = layer['quantization']

     # to store the weights in a dataset
     grp.create_dataset("weights", data=interpreter.get_tensor(layer['index']))


 f.close()


来源:https://stackoverflow.com/questions/52111699/how-can-i-view-weights-in-a-tflite-file

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!