Using a Keras model inside a TF estimator

后端 未结 2 570
醉梦人生
醉梦人生 2020-12-30 16:41

I want to use one of the pre-built keras\' models (vgg, inception, resnet, etc) included in tf.keras.application for feature extraction to save me some time tra

2条回答
  •  心在旅途
    2020-12-30 16:47

    You can have only tensors in model_fn. Maybe you can try something like this. This can be considered as a hack. The better part is that this code apart from just providing model_fn, it also stores weights of the loaded model as a checkpoint in . This helps you to get the weights when you call estimator.train(...) or estimator.evaluate(...) from the checkpoint.

    def model_fn(features, labels, mode):  
    
        # Import the pretrained model
        base_model = tf.keras.applications.InceptionV3(
            weights='imagenet', 
            include_top=False,
            input_shape=(200,200,3)
        )    
    
        # some check
        if not hasattr(m, 'optimizer'):
            raise ValueError(
                'Given keras model has not been compiled yet. '
                'Please compile first '
                'before creating the estimator.')
    
        # get estimator object from model
        keras_estimator_obj = tf.keras.estimator.model_to_estimator(
            keras_model=base_model,
            model_dir=,
            config=,
        ) 
    
        # pull model_fn that we need (hack)
        return keras_estimator_obj._model_fn
    

提交回复
热议问题