基于OpenVINO的手写数字识别

小白学视觉

共 2243字,需浏览 5分钟

 ·

2020-10-26 09:55

点击上方小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

转载自微信公众号OpenCV学堂

关注获取更多计算机视觉与深度学习知识

模型介绍

之前没有注意到,最近在OpenVINO2020R04版本的模型库中发现了它有个手写数字识别的模型,支持 or . 格式的数字识别与小数点识别。相关的模型为:

handwritten-score-recognition-0003

该模型是基于LSTM双向神经网络训练,基于CTC损失,

输入格式为:[NCHW]= [1x1x32x64]输出格式为:[WxBxL]=[16x1x13]

其中13表示"0123456789._#",#表示空白、_表示非数字的字符

对输出格式的解码方式支持CTC贪心与Beam搜索,演示程序使用CTC贪心解码,这种方式最简单,我喜欢!

代码演示

代码基于OPenVINO-Python SDK实现,首先需要说明一下,OpenVINO python SDK中主要的类是IECore,首先创建IECore实例对象,然后完成下面的流程操作:

创建实例,加载模型

1log.info("Creating Inference Engine")
2ie = IECore()
3net = ie.read_network(model=model_xml, weights=model_bin)


获取输入与输出层名称

 1log.info("Preparing input blobs")
2input_it = iter(net.input_info)
3input_blob = next(input_it)
4print(input_blob)
5output_it = iter(net.outputs)
6out_blob = next(output_it)
7
8# Read and pre-process input images
9print(net.input_info[input_blob].input_data.shape)
10n, c, h, w = net.input_info[input_blob].input_data.shape


加载网络为可执行网络,

1# Loading model to the plugin
2exec_net = ie.load_network(network=net, device_name="CPU")


读取输入图像,并处理为 or ., 格式,代码实现如下:

 1ocr = cv.imread("D:/images/zsxq/ocr1.png")
2cv.imshow("input", ocr)
3gray = cv.cvtColor(ocr, cv.COLOR_BGR2GRAY)
4binary = cv.adaptiveThreshold(gray, 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C, cv.THRESH_BINARY_INV, 2510)
5cv.imshow("binary", binary)
6contours, hireachy = cv.findContours(binary, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
7for cnt in range(len(contours)):
8    area = cv.contourArea(contours[cnt])
9    if area < 10:
10        cv.drawContours(binary, contours, cnt, (0), -18)
11cv.imshow("remove noise", binary)
12
13# 获取每个分数
14temp = np.copy(binary)
15se = cv.getStructuringElement(cv.MORPH_RECT, (455))
16temp = cv.dilate(temp, se)
17contours, hireachy = cv.findContours(temp, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
18for cnt in range(len(contours)):
19    x, y, iw, ih = cv.boundingRect(contours[cnt])
20    roi = gray[y:y + ih, x:x + iw]
21    image = cv.resize(roi, (w, h))

输入原图:

二值化以后:

去掉干扰之后:


推理与解析

 1# Start sync inference
2log.info("Starting inference in synchronous mode")
3inf_start1 = time.time()
4res = exec_net.infer(inputs={input_blob: [img_blob]})
5inf_end1 = time.time() - inf_start1
6print("inference time(ms) : %.3f" % (inf_end1 * 1000))
7res = res[out_blob]
8
9# CTC greedy decode from here
10print(res.shape)
11# 解析输出text
12ocrstr = ""
13prev_pad = False;
14for i in range(res.shape[0]):
15    ctc = res[i] # 1x13
16    ctc = np.squeeze(ctc, 0)
17    index, prob = ctc_soft_max(ctc)
18    if digit_nums[index] == '#':
19        prev_pad = True
20    else:
21        if len(ocrstr) == 0 or prev_pad or (len(ocrstr) > 0 and digit_nums[index] != ocrstr[-1]):
22            prev_pad = False
23            ocrstr += digit_nums[index]
24cv.putText(ocr, ocrstr, (x, y-5), cv.FONT_HERSHEY_SIMPLEX, 0.75, (00255), 28)
25cv.rectangle(ocr, (x, y), (x+iw, y+ih), (02550), 280)


CTC贪心解析

这个上次有个哥们问我,原因居然是我很久以前写的代码,没有交代CTC贪心解析,OpenVINO的文本与数字识别均支持CTC贪心解析,这个实现非常简单,首先来看一下输出的格式[16x1x13],可以简化为[16x13],取得每个一行13列的softmax之后的最大值,或许还可以阈值一下,得到的结果就是输出,这个就是CTC贪心解析最直接的解释。不用看公式,看完你会晕倒而且写不出代码!这个函数为:

def ctc_soft_max(data):    sum = 0;    max_val = max(data)    index = np.argmax(data)    for i in range(len(data)):        sum += np.exp(data[i]- max_val)    prob = 1.0 / sum    return index, prob

最终的测试结果如下:


浏览 35
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报