I want to use one of the pre-built keras\' models (vgg, inception, resnet, etc) included in tf.keras.application for feature extraction to save me some time tra
You can have only tensors in model_fn. Maybe you can try something like this. This can be considered as a hack. The better part is that this code apart from just providing model_fn, it also stores weights of the loaded model as a checkpoint in . This helps you to get the weights when you call estimator.train(...) or estimator.evaluate(...) from the checkpoint.
def model_fn(features, labels, mode):
# Import the pretrained model
base_model = tf.keras.applications.InceptionV3(
weights='imagenet',
include_top=False,
input_shape=(200,200,3)
)
# some check
if not hasattr(m, 'optimizer'):
raise ValueError(
'Given keras model has not been compiled yet. '
'Please compile first '
'before creating the estimator.')
# get estimator object from model
keras_estimator_obj = tf.keras.estimator.model_to_estimator(
keras_model=base_model,
model_dir=,
config=,
)
# pull model_fn that we need (hack)
return keras_estimator_obj._model_fn