基于OpenVINO的手写数字识别
点击上方“小白学视觉”,选择加"星标"或“置顶”
重磅干货,第一时间送达
转载自微信公众号:OpenCV学堂
关注获取更多计算机视觉与深度学习知识
模型介绍
之前没有注意到,最近在OpenVINO2020R04版本的模型库中发现了它有个手写数字识别的模型,支持
handwritten-score-recognition-0003
该模型是基于LSTM双向神经网络训练,基于CTC损失,
输入格式为:[NCHW]= [1x1x32x64]
输出格式为:[WxBxL]=[16x1x13]
其中13表示"0123456789._#",#表示空白、_表示非数字的字符
对输出格式的解码方式支持CTC贪心与Beam搜索,演示程序使用CTC贪心解码,这种方式最简单,我喜欢!
代码演示
代码基于OPenVINO-Python SDK实现,首先需要说明一下,OpenVINO python SDK中主要的类是IECore,首先创建IECore实例对象,然后完成下面的流程操作:
创建实例,加载模型
1log.info("Creating Inference Engine")
2ie = IECore()
3net = ie.read_network(model=model_xml, weights=model_bin)
获取输入与输出层名称
1log.info("Preparing input blobs")
2input_it = iter(net.input_info)
3input_blob = next(input_it)
4print(input_blob)
5output_it = iter(net.outputs)
6out_blob = next(output_it)
7
8# Read and pre-process input images
9print(net.input_info[input_blob].input_data.shape)
10n, c, h, w = net.input_info[input_blob].input_data.shape
加载网络为可执行网络,
1# Loading model to the plugin
2exec_net = ie.load_network(network=net, device_name="CPU")
读取输入图像,并处理为
1ocr = cv.imread("D:/images/zsxq/ocr1.png")
2cv.imshow("input", ocr)
3gray = cv.cvtColor(ocr, cv.COLOR_BGR2GRAY)
4binary = cv.adaptiveThreshold(gray, 255, cv.ADAPTIVE_THRESH_GAUSSIAN_C, cv.THRESH_BINARY_INV, 25, 10)
5cv.imshow("binary", binary)
6contours, hireachy = cv.findContours(binary, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
7for cnt in range(len(contours)):
8 area = cv.contourArea(contours[cnt])
9 if area < 10:
10 cv.drawContours(binary, contours, cnt, (0), -1, 8)
11cv.imshow("remove noise", binary)
12
13# 获取每个分数
14temp = np.copy(binary)
15se = cv.getStructuringElement(cv.MORPH_RECT, (45, 5))
16temp = cv.dilate(temp, se)
17contours, hireachy = cv.findContours(temp, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
18for cnt in range(len(contours)):
19 x, y, iw, ih = cv.boundingRect(contours[cnt])
20 roi = gray[y:y + ih, x:x + iw]
21 image = cv.resize(roi, (w, h))
输入原图:
二值化以后:
去掉干扰之后:
推理与解析
1# Start sync inference
2log.info("Starting inference in synchronous mode")
3inf_start1 = time.time()
4res = exec_net.infer(inputs={input_blob: [img_blob]})
5inf_end1 = time.time() - inf_start1
6print("inference time(ms) : %.3f" % (inf_end1 * 1000))
7res = res[out_blob]
8
9# CTC greedy decode from here
10print(res.shape)
11# 解析输出text
12ocrstr = ""
13prev_pad = False;
14for i in range(res.shape[0]):
15 ctc = res[i] # 1x13
16 ctc = np.squeeze(ctc, 0)
17 index, prob = ctc_soft_max(ctc)
18 if digit_nums[index] == '#':
19 prev_pad = True
20 else:
21 if len(ocrstr) == 0 or prev_pad or (len(ocrstr) > 0 and digit_nums[index] != ocrstr[-1]):
22 prev_pad = False
23 ocrstr += digit_nums[index]
24cv.putText(ocr, ocrstr, (x, y-5), cv.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255), 2, 8)
25cv.rectangle(ocr, (x, y), (x+iw, y+ih), (0, 255, 0), 2, 8, 0)
CTC贪心解析
这个上次有个哥们问我,原因居然是我很久以前写的代码,没有交代CTC贪心解析,OpenVINO的文本与数字识别均支持CTC贪心解析,这个实现非常简单,首先来看一下输出的格式[16x1x13],可以简化为[16x13],取得每个一行13列的softmax之后的最大值,或许还可以阈值一下,得到的结果就是输出,这个就是CTC贪心解析最直接的解释。不用看公式,看完你会晕倒而且写不出代码!这个函数为:
def ctc_soft_max(data):
sum = 0;
max_val = max(data)
index = np.argmax(data)
for i in range(len(data)):
sum += np.exp(data[i]- max_val)
prob = 1.0 / sum
return index, prob
最终的测试结果如下:
评论