Estimator是Tensorflow的高阶API。除了Tensorflow官方定义的内置Estimator之外,用户也可以实现自定义的Estimator。

Estimator定义

Estimator的构造函数如下:

def __init__(self,
    model_fn,  # 定义模型,根据不同的模式分别定义训练、评估和预测的图。
    model_dir=None,  # 模型导出目录
    config=None,     # 配置参数
    params=None,     # 自定义Estimator的额外参数
    warm_start_from=None):  # 模型热启动

其中最核心的参数为model_fn,其接口如下

def _model_fn(features,  # 特征,可以是Tensor或dict of Tensor
                labels,    # 标签
                mode,      # 模式
                params,    # 自定义参数,即上面Estimator构造函数中的params
                config):   # 配置参数

model_fn会被Estimator多次调用,通过调用Tensorflow的layer来实现模型。通过模式字段(ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT)来判断是训练、评估还是预测阶段,分别构造不同的图。model_fn的返回结构为EstimatorSpec,使用其中的训练、loss和预测的OP,Estimator就可以驱动完成训练、评估和预测。

EstimatorSpec的定义如下

def __new__(cls,
              mode,  # 模式
              predictions=None,  # 预测的Tensor或dict,mode为PREDICT时必填。
              loss=None,  # loss Tensor,mode为TRAIN或EVAL时必填。
              train_op=None,  # 训练OP,mode为TRAIN时必填。
              eval_metric_ops=None,  # 评估OP的dict
              export_outputs=None,
              training_chief_hooks=None,
              training_hooks=None,
              scaffold=None,
              evaluation_hooks=None,
              prediction_hooks=None):

训练

Estimator的训练接口如下

def train(self,
            input_fn,    # 返回训练特征和标签的tuple
            hooks=None,  # 通过hook指定训练过程中的自定义行为
            steps=None,  # 训练步数
            max_steps=None,  ## 训练总步数
            saving_listeners=None):
    with context.graph_mode():
      hooks.extend(self._convert_train_steps_to_hooks(steps, max_steps))
      loss = self._train_model(input_fn, hooks, saving_listeners)
      logging.info('Loss for final step: %s.', loss)

_train_model根据不同的配置,分别走到分布式训练和本地训练的函数。

def _train_model(self, input_fn, hooks, saving_listeners):
    if self._train_distribution:
      return self._train_model_distributed(input_fn, hooks, saving_listeners)
    else:
      return self._train_model_default(input_fn, hooks, saving_listeners)

我们先看本地训练的实现。

def _train_model_default(self, input_fn, hooks, saving_listeners):
    with ops.Graph().as_default() as g, g.device(self._device_fn):
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step_tensor = self._create_and_assert_global_step(g)
      features, labels, input_hooks = (
          self._get_features_and_labels_from_input_fn(
              input_fn, ModeKeys.TRAIN))
      worker_hooks.extend(input_hooks)
      estimator_spec = self._call_model_fn(
          features, labels, ModeKeys.TRAIN, self.config)
      global_step_tensor = training_util.get_global_step(g)
      return self._train_with_estimator_spec(estimator_spec, worker_hooks,
                                             hooks, global_step_tensor,
                                             saving_listeners)

其流程为先创建global_step,然后调用input_fn来得到训练特征和标签,调用model_fn来得到训练图,最后进入training loop。

_get_features_and_labels_from_input_fn最终会调用input_fn,得到训练特征和标签。

with ops.device('/cpu:0'):
      return input_fn(**kwargs)

_call_model_fn会调用model_fn,注意传递的参数为ModeKeys.TRAIN,用于表征训练阶段。

def _call_model_fn(self, features, labels, mode, config):
    model_fn_results = self._model_fn(features=features, **kwargs)

下面看_train_with_estimator_spec的实现。

def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,
                                 global_step_tensor, saving_listeners):
    # 满足条件则热启动
    if self._warm_start_settings:
      warm_starting_util.warm_start(*self._warm_start_settings)
    # 创建Hook
    worker_hooks.extend(hooks)
    worker_hooks.append(training.NanTensorHook(estimator_spec.loss)
    worker_hooks.append(training.LoggingTensorHook(...))
    saver_hooks = [
        h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
    worker_hooks.extend(estimator_spec.training_hooks)
    worker_hooks.append(training.SummarySaverHook(...))
    worker_hooks.append(training.StepCounterHook(...))

    with training.MonitoredTrainingSession(
        master=self._config.master,
        is_chief=self._config.is_chief,
        checkpoint_dir=self._model_dir,
        scaffold=estimator_spec.scaffold,
        hooks=worker_hooks,
        chief_only_hooks=(
            tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
        save_checkpoint_secs=0,  # Saving is handled by a hook.
        save_summaries_steps=save_summary_steps,
        config=self._session_config,
        log_step_count_steps=log_step_count_steps) as mon_sess:
      loss = None
      any_step_done = False
      while not mon_sess.should_stop():
        _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
        any_step_done = True
    if not any_step_done:
      logging.warning('Training with estimator made no steps. '
                      'Perhaps input is empty or misspecified.')
    return loss

前面主要在创建Hook,后面使用MonitoredTrainingSession进行Training loop。

评估

评估的接口为

def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None,
               name=None):

其中input_fn接口与训练函数中的input_fn有相同的接口,调用后返回评估用的特征和标签。评估最终会调用到下面的函数

def _actual_eval(self,
                   input_fn,
                   strategy=None,
                   steps=None,
                   hooks=None,
                   checkpoint_path=None,
                   name=None):
      ...
      def _evaluate():
        (scaffold, update_op, eval_dict, all_hooks) = (
            self._evaluate_build_graph(input_fn, hooks, checkpoint_path))
        return self._evaluate_run(
            checkpoint_path=checkpoint_path,
            scaffold=scaffold,
            update_op=update_op,
            eval_dict=eval_dict,
            all_hooks=all_hooks,
            output_dir=self.eval_dir(name))
      return _evaluate()

_evaluate_build_graph的实现如下:

def _evaluate_build_graph(self, input_fn, hooks=None, checkpoint_path=None):
    """Builds the graph and related hooks to run evaluation."""
    (scaffold, evaluation_hooks, input_hooks, update_op, eval_dict) = (
          self._call_model_fn_eval(input_fn, self.config))

    all_hooks = list(input_hooks)
    all_hooks.extend(hooks)
    all_hooks.extend(list(evaluation_hooks or []))
    if scaffold and scaffold.local_init_op:
      # 创建评估step
      evaluation._get_or_create_eval_step()  # pylint: disable=protected-access

      scaffold = monitored_session.Scaffold(
          local_init_op=control_flow_ops.group(
              scaffold.local_init_op,
              monitored_session.Scaffold.default_local_init_op()),
          copy_from_scaffold=scaffold
      )
    return scaffold, update_op, eval_dict, all_hooks

_evaluate_build_graph会调用_call_model_fn_eval,进行评估构图,然后返回scaffold。

def _call_model_fn_eval(self, input_fn, config):
    """Call model_fn for evaluation and handle return values."""
    features, labels, input_hooks = self._get_features_and_labels_from_input_fn(
        input_fn, ModeKeys.EVAL)

    estimator_spec = self._call_model_fn(
        features, labels, ModeKeys.EVAL, config)
    eval_metric_ops = _verify_and_create_loss_metric(
        estimator_spec.eval_metric_ops, estimator_spec.loss)
    update_op, eval_dict = _extract_metric_update_ops(eval_metric_ops)
    return (estimator_spec.scaffold, estimator_spec.evaluation_hooks,
            input_hooks, update_op, eval_dict)

_call_model_fn_eval流程为从input_fn获取评估用的特征和标签,然后调用model_fn进行评估构图。
_actual_eval调用完_evaluate_build_graph之后,接着调用_evaluate_run

def _evaluate_run(self, checkpoint_path, scaffold, update_op, eval_dict,
                    all_hooks, output_dir):
    """Run evaluation."""
    eval_results = evaluation._evaluate_once(  # pylint: disable=protected-access
        checkpoint_path=checkpoint_path,
        master=self._config.evaluation_master,
        scaffold=scaffold,
        eval_ops=update_op,
        final_ops=eval_dict,
        hooks=all_hooks,
        config=self._session_config)
    ...
def _evaluate_once(checkpoint_path,
                   master='',
                   scaffold=None,
                   eval_ops=None,
                   feed_dict=None,
                   final_ops=None,
                   final_ops_feed_dict=None,
                   hooks=None,
                   config=None):
    # 准备eval_ops
    if isinstance(eval_ops, dict):
      eval_ops['update_eval_step'] = update_eval_step
    elif isinstance(eval_ops, (tuple, list)):
      eval_ops = list(eval_ops) + [update_eval_step]
    else:
      eval_ops = [eval_ops, update_eval_step]

    eval_step_value = _get_latest_eval_step_value(eval_ops)

  # Prepare the session creator.
  session_creator = monitored_session.ChiefSessionCreator(
      scaffold=scaffold,
      checkpoint_filename_with_path=checkpoint_path,
      master=master,
      config=config)

  with monitored_session.MonitoredSession(
      session_creator=session_creator, hooks=hooks) as session:
    if eval_ops is not None:
      while not session.should_stop():
        session.run(eval_ops, feed_dict)

_evaluate_once执行最终的评估逻辑,先准备好评估用的ops,然后通过MonitoredSession执行评估的loop。

预测

预测的接口和实现如下,相对最为简单。

def predict(self,
              input_fn,
              predict_keys=None,
              hooks=None,
              checkpoint_path=None,
              yield_single_examples=True):
      with ops.Graph().as_default() as g:
        # 从`input_fn`获取预测用的特征。
        features, input_hooks = self._get_features_from_input_fn(
            input_fn, ModeKeys.PREDICT)
        estimator_spec = self._call_model_fn(
            features, None, ModeKeys.PREDICT, self.config)

        predictions = self._extract_keys(
            estimator_spec.predictions, predict_keys)
        with training.MonitoredSession(
            session_creator=training.ChiefSessionCreator(
                checkpoint_filename_with_path=checkpoint_path,
                master=self._config.master,
                scaffold=estimator_spec.scaffold,
                config=self._session_config),
            hooks=all_hooks) as mon_sess:
          while not mon_sess.should_stop():
            preds_evaluated = mon_sess.run(predictions)

导出模型

Estimator最后一个重要接口为导出模型接口,

def export_saved_model(
      self, export_dir_base, serving_input_receiver_fn,
      assets_extra=None,
      as_text=False,
      checkpoint_path=None,
      experimental_mode=ModeKeys.PREDICT):
    input_receiver_fn_map = {experimental_mode: serving_input_receiver_fn}
    return self._export_all_saved_models(
        export_dir_base,
        input_receiver_fn_map,
        assets_extra=assets_extra,
        as_text=as_text,
        checkpoint_path=checkpoint_path,
        strip_default_attrs=True)
def _export_all_saved_models(
      self, export_dir_base, input_receiver_fn_map,
      assets_extra=None, as_text=False, checkpoint_path=None,
      strip_default_attrs=True):
    with context.graph_mode():
      builder = saved_model_builder.SavedModelBuilder(temp_export_dir)
      if input_receiver_fn_map.get(ModeKeys.PREDICT):
        self._add_meta_graph_for_mode(
            builder, input_receiver_fn_map, checkpoint_path,
            save_variables, mode=ModeKeys.PREDICT,
            strip_default_attrs=strip_default_attrs)
      builder.save(as_text)

内置Estimator

我们看一下LinearClassifierV2的实现

class LinearClassifierV2(estimator.EstimatorV2):
  def __init__(self,
               feature_columns,
               model_dir=None,
               n_classes=2,
               weight_column=None,
               label_vocabulary=None,
               optimizer='Ftrl',
               config=None,
               warm_start_from=None,
               loss_reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
               sparse_combiner='sum'):
    head = head_utils.binary_or_multi_class_head(
        n_classes, weight_column=weight_column,
        label_vocabulary=label_vocabulary,
        loss_reduction=loss_reduction)

    def _model_fn(features, labels, mode, config):
      """Call the defined shared _linear_model_fn."""
      return _linear_model_fn_v2(
          features=features,
          labels=labels,
          mode=mode,
          head=head,
          feature_columns=tuple(feature_columns or []),
          optimizer=optimizer,
          config=config,
          sparse_combiner=sparse_combiner)

    super(LinearClassifierV2, self).__init__(
        model_fn=_model_fn,
        model_dir=model_dir,
        config=config,
        warm_start_from=warm_start_from)

可以看到内置Estimator的实现和自定义Estimator的实现没什么区别,也是通过实现model_fn并创建Estimator实例得到的。