Estimator是Tensorflow的高阶API。除了Tensorflow官方定义的内置Estimator之外,用户也可以实现自定义的Estimator。
Estimator定义
Estimator的构造如下:
def __init__(self, model_fn, # 定义模型,根据不同的模式分别定义训练、评估和预测的图。 model_dir=None, # 模型导出目录 config=None, # 配置参数 params=None, # 自定义Estimator的额外参数 warm_start_from=None): # 模型热启动
其中最核心的参数为model_fn
,其接口如下
def _model_fn(features, # 特征,可以是Tensor或dict of Tensor labels, # 标签 mode, # 模式 params, # 自定义参数,即上面Estimator构造函数中的params config): # 配置参数
model_fn
会被Estimator多次调用,通过调用Tensorflow的layer来实现模型。通过模式字段(ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT)来判断是训练、评估还是预测阶段,分别构造不同的图。model_fn
的返回结构为EstimatorSpec
,使用其中的训练、loss和预测的OP,Estimator就可以驱动完成训练、评估和预测。
EstimatorSpec的定义如下
def __new__(cls, mode, # 模式 predictions=None, # 预测的Tensor或dict,mode为PREDICT时必填。 loss=None, # loss Tensor,mode为TRAIN或EVAL时必填。 train_op=None, # 训练OP,mode为TRAIN时必填。 eval_metric_ops=None, # 评估OP的dict export_outputs=None, training_chief_hooks=None, training_hooks=None, scaffold=None, evaluation_hooks=None, prediction_hooks=None):
训练
Estimator的训练接口如下
def train(self, input_fn, # 返回训练特征和标签的tuple hooks=None, # 通过hook指定训练过程中的自定义行为 steps=None, # 训练步数 max_steps=None, ## 训练总步数 saving_listeners=None): with context.graph_mode(): hooks.extend(self._convert_train_steps_to_hooks(steps, max_steps)) loss = self._train_model(input_fn, hooks, saving_listeners) logging.info('Loss for final step: %s.', loss)
_train_model
根据不同的配置,分别走到分布式训练和本地训练的函数。
def _train_model(self, input_fn, hooks, saving_listeners): if self._train_distribution: return self._train_model_distributed(input_fn, hooks, saving_listeners) else: return self._train_model_default(input_fn, hooks, saving_listeners)
我们先看本地训练的实现。
def _train_model_default(self, input_fn, hooks, saving_listeners): with ops.Graph().as_default() as g, g.device(self._device_fn): random_seed.set_random_seed(self._config.tf_random_seed) global_step_tensor = self._create_and_assert_global_step(g) features, labels, input_hooks = ( self._get_features_and_labels_from_input_fn( input_fn, ModeKeys.TRAIN)) worker_hooks.extend(input_hooks) estimator_spec = self._call_model_fn( features, labels, ModeKeys.TRAIN, self.config) global_step_tensor = training_util.get_global_step(g) return self._train_with_estimator_spec(estimator_spec, worker_hooks, hooks, global_step_tensor, saving_listeners)
其流程为先创建global_step,然后调用input_fn
来得到训练特征和标签,调用model_fn
来得到训练图,最后进入training loop。
_get_features_and_labels_from_input_fn
最终会调用input_fn
,得到训练特征和标签。
with ops.device('/cpu:0'): return input_fn(**kwargs)
_call_model_fn
会调用model_fn
,注意传递的参数为ModeKeys.TRAIN
,用于表征训练阶段。
def _call_model_fn(self, features, labels, mode, config): model_fn_results = self._model_fn(features=features, **kwargs)
下面看_train_with_estimator_spec
的实现。
def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks, global_step_tensor, saving_listeners): # 满足条件则热启动 if self._warm_start_settings: warm_starting_util.warm_start(*self._warm_start_settings) # 创建Hook worker_hooks.extend(hooks) worker_hooks.append(training.NanTensorHook(estimator_spec.loss) worker_hooks.append(training.LoggingTensorHook(...)) saver_hooks = [ h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)] worker_hooks.extend(estimator_spec.training_hooks) worker_hooks.append(training.SummarySaverHook(...)) worker_hooks.append(training.StepCounterHook(...)) with training.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, checkpoint_dir=self._model_dir, scaffold=estimator_spec.scaffold, hooks=worker_hooks, chief_only_hooks=( tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)), save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=save_summary_steps, config=self._session_config, log_step_count_steps=log_step_count_steps) as mon_sess: loss = None any_step_done = False while not mon_sess.should_stop(): _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss]) any_step_done = True if not any_step_done: logging.warning('Training with estimator made no steps. ' 'Perhaps input is empty or misspecified.') return loss
前面主要在创建Hook,后面使用MonitoredTrainingSession进行Training loop。
评估
评估的接口为
def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, name=None):
其中input_fn
接口与训练函数中的input_fn
有相同的接口,调用后返回评估用的特征和标签。评估最终会调用到下面的函数
def _actual_eval(self, input_fn, strategy=None, steps=None, hooks=None, checkpoint_path=None, name=None): ... def _evaluate(): (scaffold, update_op, eval_dict, all_hooks) = ( self._evaluate_build_graph(input_fn, hooks, checkpoint_path)) return self._evaluate_run( checkpoint_path=checkpoint_path, scaffold=scaffold, update_op=update_op, eval_dict=eval_dict, all_hooks=all_hooks, output_dir=self.eval_dir(name)) return _evaluate()
_evaluate_build_graph
的实现如下:
def _evaluate_build_graph(self, input_fn, hooks=None, checkpoint_path=None): """Builds the graph and related hooks to run evaluation.""" (scaffold, evaluation_hooks, input_hooks, update_op, eval_dict) = ( self._call_model_fn_eval(input_fn, self.config)) all_hooks = list(input_hooks) all_hooks.extend(hooks) all_hooks.extend(list(evaluation_hooks or [])) if scaffold and scaffold.local_init_op: # 创建评估step evaluation._get_or_create_eval_step() # pylint: disable=protected-access scaffold = monitored_session.Scaffold( local_init_op=control_flow_ops.group( scaffold.local_init_op, monitored_session.Scaffold.default_local_init_op()), copy_from_scaffold=scaffold ) return scaffold, update_op, eval_dict, all_hooks
_evaluate_build_graph
会调用_call_model_fn_eval
,进行评估构图,然后返回scaffold。
def _call_model_fn_eval(self, input_fn, config): """Call model_fn for evaluation and handle return values.""" features, labels, input_hooks = self._get_features_and_labels_from_input_fn( input_fn, ModeKeys.EVAL) estimator_spec = self._call_model_fn( features, labels, ModeKeys.EVAL, config) eval_metric_ops = _verify_and_create_loss_metric( estimator_spec.eval_metric_ops, estimator_spec.loss) update_op, eval_dict = _extract_metric_update_ops(eval_metric_ops) return (estimator_spec.scaffold, estimator_spec.evaluation_hooks, input_hooks, update_op, eval_dict)
_call_model_fn_eval
流程为从input_fn
获取评估用的特征和标签,然后调用model_fn
进行评估构图,_actual_eval
调用完_evaluate_build_graph
之后,接着调用_evaluate_run
。
def _evaluate_run(self, checkpoint_path, scaffold, update_op, eval_dict, all_hooks, output_dir): """Run evaluation.""" eval_results = evaluation._evaluate_once( # pylint: disable=protected-access checkpoint_path=checkpoint_path, master=self._config.evaluation_master, scaffold=scaffold, eval_ops=update_op, final_ops=eval_dict, hooks=all_hooks, config=self._session_config) ...
def _evaluate_once(checkpoint_path, master='', scaffold=None, eval_ops=None, feed_dict=None, final_ops=None, final_ops_feed_dict=None, hooks=None, config=None): # 准备eval_ops if isinstance(eval_ops, dict): eval_ops['update_eval_step'] = update_eval_step elif isinstance(eval_ops, (tuple, list)): eval_ops = list(eval_ops) + [update_eval_step] else: eval_ops = [eval_ops, update_eval_step] eval_step_value = _get_latest_eval_step_value(eval_ops) # Prepare the session creator. session_creator = monitored_session.ChiefSessionCreator( scaffold=scaffold, checkpoint_filename_with_path=checkpoint_path, master=master, config=config) with monitored_session.MonitoredSession( session_creator=session_creator, hooks=hooks) as session: if eval_ops is not None: while not session.should_stop(): session.run(eval_ops, feed_dict)
_evaluate_once
执行最终的评估逻辑,先准备好评估用的ops,然后通过MonitoredSession执行评估的loop。
预测
预测的接口和实现如下,相对最为简单。
def predict(self, input_fn, predict_keys=None, hooks=None, checkpoint_path=None, yield_single_examples=True): with ops.Graph().as_default() as g: # 从`input_fn`获取预测用的特征。 features, input_hooks = self._get_features_from_input_fn( input_fn, ModeKeys.PREDICT) estimator_spec = self._call_model_fn( features, None, ModeKeys.PREDICT, self.config) predictions = self._extract_keys( estimator_spec.predictions, predict_keys) with training.MonitoredSession( session_creator=training.ChiefSessionCreator( checkpoint_filename_with_path=checkpoint_path, master=self._config.master, scaffold=estimator_spec.scaffold, config=self._session_config), hooks=all_hooks) as mon_sess: while not mon_sess.should_stop(): preds_evaluated = mon_sess.run(predictions)
导出模型
Estimator最后一个重要接口为导出模型接口,
def export_saved_model( self, export_dir_base, serving_input_receiver_fn, assets_extra=None, as_text=False, checkpoint_path=None, experimental_mode=ModeKeys.PREDICT): input_receiver_fn_map = {experimental_mode: serving_input_receiver_fn} return self._export_all_saved_models( export_dir_base, input_receiver_fn_map, assets_extra=assets_extra, as_text=as_text, checkpoint_path=checkpoint_path, strip_default_attrs=True)
def _export_all_saved_models( self, export_dir_base, input_receiver_fn_map, assets_extra=None, as_text=False, checkpoint_path=None, strip_default_attrs=True): with context.graph_mode(): builder = saved_model_builder.SavedModelBuilder(temp_export_dir) if input_receiver_fn_map.get(ModeKeys.PREDICT): self._add_meta_graph_for_mode( builder, input_receiver_fn_map, checkpoint_path, save_variables, mode=ModeKeys.PREDICT, strip_default_attrs=strip_default_attrs) builder.save(as_text)
内置Estimator
我们看一下LinearClassifierV2的实现
class LinearClassifierV2(estimator.EstimatorV2): def __init__(self, feature_columns, model_dir=None, n_classes=2, weight_column=None, label_vocabulary=None, optimizer='Ftrl', config=None, warm_start_from=None, loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, sparse_combiner='sum'): head = head_utils.binary_or_multi_class_head( n_classes, weight_column=weight_column, label_vocabulary=label_vocabulary, loss_reduction=loss_reduction) def _model_fn(features, labels, mode, config): """Call the defined shared _linear_model_fn.""" return _linear_model_fn_v2( features=features, labels=labels, mode=mode, head=head, feature_columns=tuple(feature_columns or []), optimizer=optimizer, config=config, sparse_combiner=sparse_combiner) super(LinearClassifierV2, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config, warm_start_from=warm_start_from)
可以看到内置Estimator的实现和自定义Estimator的实现没什么区别,也是通过实现model_fn并创建Estimator实例得到的。