How to convert tflite_graph.pb to detect.tflite properly

∥☆過路亽.° 提交于 2020-06-01 05:17:07

问题


I am using tensorflow object-detection api for training a custom model using ssdlite_mobilenet_v2_coco_2018_05_09 from tensorflow model zoo.

I successfully trained the model and test it out using a script provided in this tutorial.

Here is the problem, I need a detect.tflite to use it in my target machine (an embedded system). But when I actually make a tflite out of my model, it outputs almost nothing and when it does, its a wrong detection. To make the .tflite file, I first used export_tflite_ssd_graph.py and then toco on the output with this command by following the doc and some google searches:

toco --graph_def_file=$OUTPUT_DIR/tflite_graph.pb --output_file=$OUTPUT_DIR/detect.tflite --input_shapes=1,300,300,3 --input_arrays=normalized_input_image_tensor --output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' --allow_custom_ops

Also, the code I'm using for detection task from .tflite is working properly, as I tested it with ssd_mobilenet_v3_small_coco detect.tflite file.


回答1:


The problem was with the toco command. Some documents that I used were outdated and mislead me. toco is deprecated and I should have used tflite_convert tool instead.

Here is the full command I used (run from your training directory):

tflite_convert --graph_def_file tflite_inference_graph/tflite_graph.pb --output_file=./detect.tflite --output_format=TFLITE --input_shapes=1,300,300,3 --input_arrays=normalized_input_image_tensor --output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' --inference_type=QUANTIZED_UINT8 --mean_values=128 --std_dev_values=127 --change_concat_input_ranges=false --allow_custom_ops

I did the training on ssdlite_mobilenet_v2_coco_2018_05_09 model and added this at the end of my .config file.

 graph_rewriter {
  quantization {
    delay: 400
    weight_bits: 8
    activation_bits: 8
  }
}

Also I used this command to generate tflite_graph.pb in tflite_inference_graph directory:

python export_tflite_ssd_graph.py --pipeline_config_path 2020-05-17_train_ssdlite_v2/ssd_mobilenet_v2_coco.config --trained_checkpoint_prefix 2020-05-17_train_ssdlite_v2/train/model.ckpt-1146 --output_directory 2020-05-17_train_ssdlite_v2/tflite_inference_graph --add_postprocessing_op=true

Note: I wanted to use a quantized model on my embedded system. That is the reason I added graph_rewriter in the config file and --inference_type=QUANTIZED_UINT8 in my tflite_convert command.



来源:https://stackoverflow.com/questions/61749548/how-to-convert-tflite-graph-pb-to-detect-tflite-properly

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!