CRF/Seq2Seq/CTC的目标函数对比(CRF Loss解析)
这里基于TensorFlow的实现,对三种序列化的任务的目标函数做一个总结。
1. 序列化任务的定义和训练
输入输出都是序列。
先明确下三个任务的不同:
CRF:通常用于序列标注任务,比如:BiLSTM+CRF、IDCNN+CRF,场景的特点是,输入与输出是一一对应的。
其中语义模型先根据输入生成每个字的“打分”(后验概率的-log),作为解码时的反向观测概率。
Seq2Seq:通常用于生成式问答、摘要生成、机器翻译等等,一般是一种编码器和解码器的结构,特点是:输入与输出长度不一定相同。
CTC解码:一种语音识别的方法,输入为语音,输出为文字,特点是:一种输出可能对应着多个正确的路径。
CTC可以参考:https://distill.pub/2017/ctc/
三个问题都是解码问题,因为特点的不同,目标函数也不一样:
对于CRF来说目标函数包含两个部分:loss = unary potential + pairwise potential = ClassifyLoss + TranstionLoss
(名字是我自己编的,感觉好理解一些),然后用句子的真实长度做mask。
- ClassifyLoss(unary potential):语义模型会生成每个字的得分,其实就是预测的tag在tags词典中的概率的-log,这里用交叉熵计算出来的。
- TranstionLoss(pairwise potential):利用到转移矩阵,为当前的输出tags对应解码空间中的前向概率(实际是概率的-log)
对应的CRF的解码方法可以参考:Tensorflow 中 crf_decode 和 viterbi_decode 的使用
Seq2Seq Loss:这个比较简单,解码时每次生成一个字,这里的loss就是每个字对应的交叉熵,在整个句子上的平均,一般训练时使用one-best解码,配合teach force(就是使用标注数据的上一个字生成下一个字)。
CTC Loss:这里一个label序列不是会对应多条正确的解码路径么,这里的loss就是这些解码路径的前向概率-log求和。
2. CRF Loss代码层面分析
CRF一个比较实用的教程(基于pytorch的):
- https://towardsdatascience.com/implementing-a-linear-chain-conditional-random-field-crf-in-pytorch-16b0b9c4b4ea
记录一下CRF LOSS的计算思路
def crf_log_likelihood(inputs,
tag_indices,
sequence_lengths,
transition_params=None):
"""Computes the log-likelihood of tag sequences in a CRF.
Args:
inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
to use as input to the CRF layer.
tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we
compute the log-likelihood.
sequence_lengths: A [batch_size] vector of true sequence lengths.
transition_params: A [num_tags, num_tags] transition matrix, if available.
Returns:
log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of
each example, given the sequence of tag indices.
transition_params: A [num_tags, num_tags] transition matrix. This is either
provided by the caller or created in this function.
"""
# Get shape information.
num_tags = inputs.get_shape()[2].value
# Get the transition matrix if not provided.
if transition_params is None:
transition_params = vs.get_variable("transitions", [num_tags, num_tags])
sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths,
transition_params)
log_norm = crf_log_norm(inputs, sequence_lengths, transition_params)
# Normalize the scores to get the log-likelihood per example.
log_likelihood = sequence_scores - log_norm
return log_likelihood, transition_params
loss的计算分为两个部分:
- crf_sequence_score:计算当前预测结果与标注结果的loss
- crf_log_norm:对loss的batch维度做归一化实用,因为是log值,所以用减法来实现;
crf_sequence_score 的计算如下:
def crf_sequence_score(inputs, tag_indices, sequence_lengths,
transition_params):
"""Computes the unnormalized score for a tag sequence.
Args:
inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
to use as input to the CRF layer.
tag_indices: A [batch_size, max_seq_len] matrix of tag indices for which we
compute the unnormalized score.
sequence_lengths: A [batch_size] vector of true sequence lengths.
transition_params: A [num_tags, num_tags] transition matrix.
Returns:
sequence_scores: A [batch_size] vector of unnormalized sequence scores.
"""
# If max_seq_len is 1, we skip the score calculation and simply gather the
# unary potentials of the single tag.
def _single_seq_fn():
batch_size = array_ops.shape(inputs, out_type=tag_indices.dtype)[0]
example_inds = array_ops.reshape(
math_ops.range(batch_size, dtype=tag_indices.dtype), [-1, 1])
sequence_scores = array_ops.gather_nd(
array_ops.squeeze(inputs, [1]),
array_ops.concat([example_inds, tag_indices], axis=1))
sequence_scores = array_ops.where(math_ops.less_equal(sequence_lengths, 0),
array_ops.zeros_like(sequence_scores),
sequence_scores)
return sequence_scores
def _multi_seq_fn():
# Compute the scores of the given tag sequence.
unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
binary_scores = crf_binary_score(tag_indices, sequence_lengths,
transition_params)
sequence_scores = unary_scores + binary_scores
return sequence_scores
return utils.smart_cond(
pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1],
1),
true_fn=_single_seq_fn,
false_fn=_multi_seq_fn)
分为两种情况:
- 序列长度为1:这里就没转移概率的事儿了
- 序列长度不为1:也分成两个部分计算:crf_unary_score + crf_binary_score
- crf_unary_score:计算有效长度内(sequence_lengths),使用
tag_indices
作为索引去inputs
中检索,将检索的分数作为这一部分的loss,其实就是交叉熵了; - crf_binary_score:是将当前的标签序列作为path,使用当前的转译矩阵计算这个path的得分,并将这个得分最小化;
- crf_unary_score:计算有效长度内(sequence_lengths),使用
unary potentials
指的是神经网络模型的预测结果。
来源:CSDN
作者:泰迪宝宝
链接:https://blog.csdn.net/baobao3456810/article/details/103478739