CRF/Seq2Seq/CTC的Loss实现对比

本秂侑毒 提交于 2019-12-16 08:05:34

CRF/Seq2Seq/CTC的目标函数对比(CRF Loss解析)

这里基于TensorFlow的实现,对三种序列化的任务的目标函数做一个总结。

1. 序列化任务的定义和训练

输入输出都是序列。

先明确下三个任务的不同:

CRF:通常用于序列标注任务,比如:BiLSTM+CRF、IDCNN+CRF,场景的特点是,输入与输出是一一对应的
其中语义模型先根据输入生成每个字的“打分”(后验概率的-log),作为解码时的反向观测概率。
CRF结构
Seq2Seq:通常用于生成式问答、摘要生成、机器翻译等等,一般是一种编码器和解码器的结构,特点是:输入与输出长度不一定相同。
Seq2Seq

CTC解码:一种语音识别的方法,输入为语音,输出为文字,特点是:一种输出可能对应着多个正确的路径。
CTC可以参考:https://distill.pub/2017/ctc/
在这里插入图片描述
三个问题都是解码问题,因为特点的不同,目标函数也不一样:

对于CRF来说目标函数包含两个部分:loss = unary potential + pairwise potential = ClassifyLoss + TranstionLoss(名字是我自己编的,感觉好理解一些),然后用句子的真实长度做mask。

  • ClassifyLoss(unary potential):语义模型会生成每个字的得分,其实就是预测的tag在tags词典中的概率的-log,这里用交叉熵计算出来的。
  • TranstionLoss(pairwise potential):利用到转移矩阵,为当前的输出tags对应解码空间中的前向概率(实际是概率的-log)

在这里插入图片描述
对应的CRF的解码方法可以参考:Tensorflow 中 crf_decode 和 viterbi_decode 的使用

Seq2Seq Loss:这个比较简单,解码时每次生成一个字,这里的loss就是每个字对应的交叉熵,在整个句子上的平均,一般训练时使用one-best解码,配合teach force(就是使用标注数据的上一个字生成下一个字)。

CTC Loss:这里一个label序列不是会对应多条正确的解码路径么,这里的loss就是这些解码路径的前向概率-log求和。

2. CRF Loss代码层面分析

CRF一个比较实用的教程(基于pytorch的):

  • https://towardsdatascience.com/implementing-a-linear-chain-conditional-random-field-crf-in-pytorch-16b0b9c4b4ea

记录一下CRF LOSS的计算思路

def crf_log_likelihood(inputs,
                       tag_indices,
                       sequence_lengths,
                       transition_params=None):
  """Computes the log-likelihood of tag sequences in a CRF.

  Args:
    inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
        to use as input to the CRF layer.
    tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we
        compute the log-likelihood.
    sequence_lengths: A [batch_size] vector of true sequence lengths.
    transition_params: A [num_tags, num_tags] transition matrix, if available.
  Returns:
    log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of
      each example, given the sequence of tag indices.
    transition_params: A [num_tags, num_tags] transition matrix. This is either
        provided by the caller or created in this function.
  """
  # Get shape information.
  num_tags = inputs.get_shape()[2].value

  # Get the transition matrix if not provided.
  if transition_params is None:
    transition_params = vs.get_variable("transitions", [num_tags, num_tags])

  sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths,
                                       transition_params)
  log_norm = crf_log_norm(inputs, sequence_lengths, transition_params)

  # Normalize the scores to get the log-likelihood per example.
  log_likelihood = sequence_scores - log_norm
  return log_likelihood, transition_params

loss的计算分为两个部分:

  • crf_sequence_score:计算当前预测结果与标注结果的loss
  • crf_log_norm:对loss的batch维度做归一化实用,因为是log值,所以用减法来实现;

crf_sequence_score 的计算如下:


def crf_sequence_score(inputs, tag_indices, sequence_lengths,
                       transition_params):
  """Computes the unnormalized score for a tag sequence.

  Args:
    inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
        to use as input to the CRF layer.
    tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we
        compute the unnormalized score.
    sequence_lengths: A [batch_size] vector of true sequence lengths.
    transition_params: A [num_tags, num_tags] transition matrix.
  Returns:
    sequence_scores: A [batch_size] vector of unnormalized sequence scores.
  """
  # If max_seq_len is 1, we skip the score calculation and simply gather the
  # unary potentials of the single tag.
  def _single_seq_fn():
    batch_size = array_ops.shape(inputs, out_type=tag_indices.dtype)[0]
    example_inds = array_ops.reshape(
        math_ops.range(batch_size, dtype=tag_indices.dtype), [-1, 1])
    sequence_scores = array_ops.gather_nd(
        array_ops.squeeze(inputs, [1]),
        array_ops.concat([example_inds, tag_indices], axis=1))
    sequence_scores = array_ops.where(math_ops.less_equal(sequence_lengths, 0),
                                      array_ops.zeros_like(sequence_scores),
                                      sequence_scores)
    return sequence_scores

  def _multi_seq_fn():
    # Compute the scores of the given tag sequence.
    unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
    binary_scores = crf_binary_score(tag_indices, sequence_lengths,
                                     transition_params)
    sequence_scores = unary_scores + binary_scores
    return sequence_scores

  return utils.smart_cond(
      pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1],
                          1),
      true_fn=_single_seq_fn,
      false_fn=_multi_seq_fn)

分为两种情况:

  • 序列长度为1:这里就没转移概率的事儿了
  • 序列长度不为1:也分成两个部分计算:crf_unary_score + crf_binary_score
    • crf_unary_score:计算有效长度内(sequence_lengths),使用tag_indices作为索引去inputs中检索,将检索的分数作为这一部分的loss,其实就是交叉熵了;
    • crf_binary_score:是将当前的标签序列作为path,使用当前的转译矩阵计算这个path的得分,并将这个得分最小化;

unary potentials指的是神经网络模型的预测结果。

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!