文字识别中CTC损失的直觉解释

小白学视觉

共 3638字,需浏览 8分钟

 ·

2022-04-18 23:37

点击上方小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

导读

非常直观的解释了文字识别中非常常用的CTC损失和解码的操作。


如果你想让计算机识别文本,神经网络(NN)是一个不错的选择,因为它们目前的表现优于所有其他方法。这里的神经网络通常由卷积层(CNN)递归层(RNN)组成,前者用于提取特征序列,后者用于通过该序列传递信息。它输出每个序列元素的字符得分,它只需要用一个矩阵来表示。现在,我们想对这个矩阵做两件事:

  1. 训练:计算神经网络的损失来训练网络

  2. 推断:解码矩阵,得到包含在输入图像中的文本

这两项任务都由CTC操作完成。手写识别系统概述如图1所示。

让我们更仔细地看看CTC操作,并讨论它是如何工作的,剖析一下它所基于的复杂公式背后的巧妙思想。最后,如果你感兴趣,我会给你一些参考资料,你可以在其中找到Python代码和(不太复杂的)公式。

图1:手写字识别的概要图


我们为什么要用CTC


当然,我们可以用文本行图像创建一个数据集,然后为图像的每个水平位置指定相应的字符,如图2所示。然后,我们可以训练一个神经网络输出每个水平位置的字符得分。然而,这种解决方案有两个问题:

  1. 在字符级标注数据集非常耗时(而且枯燥)。

  2. 我们只得到字符分数,因此需要进一步处理才能得到最终文本。一个字符可以跨越多个水平位置,例如,我们可以得到“ttooo”,因为“o”是一个宽字符,如图2所示。我们必须删除所有重复的“t”和“o”。但如果被认可的文本是“too”呢?删除所有重复的“o”会得到错误的结果。如何处理这个问题?


图2:图像中每个水平位置的标注


  1. 我们只需要告诉CTC loss函数图像中出现的文本。因此,我们忽略了图像中字符的位置和宽度。

  2. 不需要对已识别的文本进行进一步处理。

CTC如何工作

正如已经讨论过的,我们不希望在每个水平位置上对图像进行标注(从现在开始,我们将其称为time-step)。神经网络的训练将以CTC损失函数为指导。我们只将神经网络的输出矩阵和相应的ground-truth(GT)文本输入到CTC损失函数中。但是它如何知道每个字符出现在哪里呢?它不知道。相反,它尝试图像中GT文本的所有可能的对齐,并对所有的得分求和。这样,如果对齐分数的求和值很高,则GT文本的分数就很高。

文本编码
 

还有一个问题是如何编码重复的字符(还记得我们说过的“too”吗?),它通过引入一个伪字符(称为blank,但不要将它与“真正的”blank混淆,即空白字符)来解决。在下面的文本中,这个特殊字符将被表示为“-”。我们使用一个聪明的编码模式来解决重复字符的问题:在编码文本时,我们可以在任何位置插入任意多的空格,而在解码时这些空格将被删除。但是,我们必须在重复的字符之间插入空格,比如“hello”。此外,我们可以重复每个字符,只要我们喜欢。

我们来看一些例子:

  • “to”→“ttttttooo”,或“-t-o”,或“to”

  • “too”- >“ttttto - o”或“- t - o - o”,或“- o”,但不能是“too”

正如你所看到的,这个模式还允许我们轻松地创建相同文本的不同对齐,例如“t-o”和“too”以及“-to”都表示相同的文本(“to”),但是对图像的对齐不同。该神经网络被训练成输出编码文本(编码在神经网络输出矩阵中)。

损失的计算
 

我们需要计算给定对图像和GT文本的损失值来训练神经网络。你已经知道,NN在每个时间步长输出一个矩阵,其中包含每个字符的得分。一个极小矩阵如图3所示:有两个时间步长(t0, t1)和三个字符(“A”,“b”和空白“-”)。每个时间步的字符得分总和为1。

图3:神经网络的输出矩阵。字符概率用颜色编码,数值在每个矩阵条目旁边打印。细线是表示文本“a”的路径,而粗虚线是表示文本“”的唯一路径。


此外,你已经知道损失的计算方法是将GT文本的所有可能对齐的所有得分相加,这样,文本在图像中出现的位置就不重要了。

一个对齐(或path,通常在文献中称为path)的得分是通过将相应的字符得分相乘来计算的。在上面的例子中,路径“aa”的得分为0.4·0.4=0.16,路径“a-”的得分为0.4·0.6=0.24,路径“a”的得分为0.6·0.4=0.24。为了得到给定GT文本的分数,我们对与此文本对应的所有路径的分数求和。假设示例中的GT文本是“a”:我们必须计算长度为2的所有可能路径(因为矩阵有2个时间步长),即:“aa”、“a-”和“a-”。我们已经计算了这些路径的得分,所以我们只需要对它们求和,得到0.4·0.4+0.4·0.6+0.6·0.4=0.64。假设GT文本为“”,我们看到只有一条对应的路径,即“--”,得到总分0.6·0.6=0.36。

如果你仔细看,你会发现我们计算的是GT文本的概率,而不是损失。然而,损失只是概率的负对数。损失值通过神经网络反向传播,神经网络的参数根据所使用的优化器进行更新,在此不再赘述。 

解码
 

当我们有一个训练过的神经网络时,我们通常想要用它来识别以前没有见过的图像中的文本。或者用更专业的术语来说>我们希望计算给定神经网络输出矩阵的最有可能的文本。你已经知道了计算给定文本得分的方法。但这一次,我们没有得到任何文本,事实上,这正是我们要找的文本。如果只有几个时间步骤和字符,那么尝试所有可能的文本都是可行的,但是对于实际的情况下,这是不可行的。

一个简单而快速的算法是最佳路径解码,它包括两个步骤:

  1. 它通过在每个时间步中选择最可能的字符来计算最佳路径。

  2. 它首先删除重复的字符,然后从路径中删除所有空格,从而撤消编码。剩下的表示已识别的文本。

如图4所示。字符是“a”、“b”和“-”(空白)。有5个时间步骤。让我们对这个矩阵应用我们最好的路径解码器:t0最可能的字符是“a”,同样适用于t1和t2。空白字符在t3处得分最高。最后,“b”最有可能出现在t4。这给出了路径aaa-b。我们删除重复的字符,这将生成“a-b”,然后从剩余的路径中删除任何空白,这将生成文本“ab”,我们将其输出为可识别的文本。

图4:神经网络的输出矩阵,粗虚线表示最佳路径。


当然,最佳路径解码只是一种近似。对于它给出错误结果的例子很容易构造:如果你从图3中解码矩阵,你将得到“”作为可识别的文本。但是我们已经知道""的概率只有0.36而"a"的概率是0.64。然而,这种近似算法在实际应用中往往能得到较好的结果。还有一些更高级的译码器,如beam-search解码器、prefix-search解码器和令牌传递译码器,它们也利用语言结构的信息来改进解码结果。

总结

首先,我们研究了神经网络解决方案所产生的问题。然后,我们看到CTC是如何解决这些问题的。然后,我们通过研究CTC如何编码文本、如何进行损失计算以及如何解码CTC训练的神经网络的输出来研究CTC的工作原理。

这将使你可以很好地理解在TensorFlow中调用ctc_loss或ctc_greedy_decoder等函数时,幕后发生了什么。然而,当你想要自己实现CTC时,你需要了解更多的细节,尤其是想使它运行得更快。Graves等人[1]介绍了CTC运算,并给出了所有相关的数学计算。如果你对如何改进解码感兴趣,请参阅有关beam-search解码的文章。我用Python和c++实现了一些解码器和损失函数,你可以在github上找到。最后,如果你想了解如何识别(手写)文本的全局,请参阅我关于如何构建手写文本识别系统的文章。

小白团队出品:零基础精通语义分割↓

下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲
小白学视觉公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲
小白学视觉公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群


欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~


浏览 30
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报