完整的文本检测与识别 | 附源码
共 33226字,需浏览 67分钟
·
2024-07-16 10:12
点击上方“小白学视觉”,选择加"星标"或“置顶”
重磅干货,第一时间送达
另外,你还记得每家店铺都有独特的名字书写方式吗?像Gucci、Sears、Pantaloons和Lifestyle这样的知名品牌在其商标中使用了曲线或圆形字体。虽然这一切吸引了顾客,但对于执行文本检测和识别的深度学习(DL)模型来说,它确实提出了挑战。
当你读取横幅上的文字时,你会怎么做?你的眼睛首先会检测到文本的存在,找出每个字符的位置,然后识别这些字符。这正是一个DL模型需要做的!最近,OCR在深度学习中成为热门话题,其中每个新架构都在努力超越其他架构。
流行的基于深度学习的OCR模块Tesseract在结构化文本(如文件)上表现出色,但在花哨字体的曲线、不规则形状的文本方面却表现不佳。幸运的是,我们有Clova AI提供的这些出色的网络,它们在真实世界中出现的各种文本外观方面胜过了Tesseract。在本博客中,我们将简要讨论这些架构并学习如何将它们整合起来。
使用CRAFT进行文本检测
场景文本检测是在复杂背景中检测文本区域并用边界框标记它们的任务。CRAFT是一项2019年提出的主要目标是定位单个字符区域并将检测到的字符链接到文本实例的全称:Character-Region Awareness For Text detection。
CRAFT采用了基于VGG-16的全卷积网络架构。简单来说,VGG16本质上是特征提取架构,用于将网络的输入编码成某种特征表示。CRAFT网络的解码段类似于UNet。它具有聚合低级特征的跳跃连接。CRAFT为每个字符预测两个分数:
区域分数:顾名思义,它给出了字符的区域。它定位字符。
亲和力分数:'亲和力'是指物质倾向于与另一种物质结合的程度。
因此,亲和力分数将字符合并为单个实例(一个词)。CRAFT生成两个地图作为输出:区域级地图和亲和力地图。让我们通过示例来理解它们的含义:
输入图像
存在字符的区域在区域地图中标记出来:
区域地图
亲和力地图以图形方式表示相关字符。红色表示字符具有较高的亲和力,必须合并为一个词:
亲和力地图
最后,将亲和力分数和区域分数组合起来,给出每个单词的边界框。坐标的顺序是:(左上)、(右上)、(右下)、(左下),其中每个坐标都是一个(x,y)对。
为什么不按照四点格式?
看下面的图片:你能在仅有4个值的情况下定位“LOVE”吗?
CRAFT是多语言的,这意味着它可以检测任何脚本中的文本。
文本识别:四阶段场景文本识别框架
2019年,Clova AI发表了一篇关于现有场景文本识别(STR)数据集的不一致性,并提出了一个大多数现有STR模型都适用的统一框架的研究论文。
让我们讨论这四个阶段:
转换:记住我们正在处理的是景观文本,它是任意形状和曲线的。如果我们直接进行特征提取,那么它需要学习输入文本的几何形状,这对于特征提取模块来说是额外的工作。因此,STR网络应用了薄板样条(TPS)变换,并将输入文本规范化为矩形形状。
特征提取:将变换后的图像映射到与字符识别相关的一组特征上。字体、颜色、大小和背景都被丢弃了。作者对不同的骨干网络进行了实验,包括ResNet、VGG和RCNN。
序列建模:如果我写下'ba_',你很可能猜到填在空格处的字母可能是'd'、'g'、't',而不是'u'、'p'。我们如何教网络捕捉上下文信息?使用BiLSTMs!但是,BiLSTMs会占用内存,因此用户可以根据需要选择或取消这个阶段。
预测:这个阶段从图像的已识别特征中估计输出字符序列。
作者进行了几个实验。他们为每个阶段选择了不同的网络。准确性总结在下表中:
代码
CRAFT预测每个单词的边界框。四阶段STR将单个单词(作为图像)作为输入,并预测字母。如果你正在处理单个字的图像(如CUTE80),使用这些DL模块的OCR将会很轻松。
步骤1:安装要求
步骤2:克隆代码库
步骤3:修改以返回检测框分数
CRAFT返回高于一定分数阈值的边界框。如果你想看到每个边界框的分数值,我们需要对原始库进行一些更改。打开克隆的CRAFT Repository中的craft_utils.py文件。你需要将第83行和第239行更改为如下所示。
"""Modify to Return Scores of Detection Boxes""""""Copyright (c) 2019-present NAVER Corp.MIT License"""# -*- coding: utf-8 -*-import numpy as npimport cv2import math""" auxilary functions """# unwarp corodinatesdef warpCoord(Minv, pt):out = np.matmul(Minv, (pt[0], pt[1], 1))return np.array([out[0]/out[2], out[1]/out[2]])""" end of auxilary functions """def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text):# prepare datalinkmap = linkmap.copy()textmap = textmap.copy()img_h, img_w = textmap.shape""" labeling method """ret, text_score = cv2.threshold(textmap, low_text, 1, 0)ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0)text_score_comb = np.clip(text_score + link_score, 0, 1)nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4)det = []det_scores = []mapper = []for k in range(1,nLabels):# size filteringsize = stats[k, cv2.CC_STAT_AREA]if size < 10: continue# thresholdingif np.max(textmap[labels==k]) < text_threshold: continue# make segmentation mapsegmap = np.zeros(textmap.shape, dtype=np.uint8)segmap[labels==k] = 255segmap[np.logical_and(link_score==1, text_score==0)] = 0 # remove link areax, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2)sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1# boundary checkif sx < 0 : sx = 0if sy < 0 : sy = 0if ex >= img_w: ex = img_wif ey >= img_h: ey = img_hkernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter))segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel)# make boxnp_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2)rectangle = cv2.minAreaRect(np_contours)box = cv2.boxPoints(rectangle)# align diamond-shapew, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])box_ratio = max(w, h) / (min(w, h) + 1e-5)if abs(1 - box_ratio) <= 0.1:l, r = min(np_contours[:,0]), max(np_contours[:,0])t, b = min(np_contours[:,1]), max(np_contours[:,1])box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)# make clock-wise orderstartidx = box.sum(axis=1).argmin()box = np.roll(box, 4-startidx, 0)box = np.array(box)det.append(box)mapper.append(k)det_scores.append(np.max(textmap[labels==k]))return det, labels, mapper, det_scoresdef getPoly_core(boxes, labels, mapper, linkmap):# configsnum_cp = 5max_len_ratio = 0.7expand_ratio = 1.45max_r = 2.0step_r = 0.2polys = []for k, box in enumerate(boxes):# size filter for small instancew, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(np.linalg.norm(box[1] - box[2]) + 1)if w < 10 or h < 10:polys.append(None); continue# warp imagetar = np.float32([[0,0],[w,0],[w,h],[0,h]])M = cv2.getPerspectiveTransform(box, tar)word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST)try:Minv = np.linalg.inv(M)except:polys.append(None); continue# binarization for selected labelcur_label = mapper[k]word_label[word_label != cur_label] = 0word_label[word_label > 0] = 1""" Polygon generation """# find top/bottom contourscp = []max_len = -1for i in range(w):region = np.where(word_label[:,i] != 0)[0]if len(region) < 2 : continuecp.append((i, region[0], region[-1]))length = region[-1] - region[0] + 1if length > max_len: max_len = length# pass if max_len is similar to hif h * max_len_ratio < max_len:polys.append(None); continue# get pivot points with fixed lengthtot_seg = num_cp * 2 + 1seg_w = w / tot_seg # segment widthpp = [None] * num_cp # init pivot pointscp_section = [[0, 0]] * tot_segseg_height = [0] * num_cpseg_num = 0num_sec = 0prev_h = -1for i in range(0,len(cp)):(x, sy, ey) = cp[i]if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg:# average previous segmentif num_sec == 0: breakcp_section[seg_num] = [cp_section[seg_num][0] / num_sec, cp_section[seg_num][1] / num_sec]num_sec = 0# reset variablesseg_num += 1prev_h = -1# accumulate center pointscy = (sy + ey) * 0.5cur_h = ey - sy + 1cp_section[seg_num] = [cp_section[seg_num][0] + x, cp_section[seg_num][1] + cy]num_sec += 1if seg_num % 2 == 0: continue # No polygon areaif prev_h < cur_h:pp[int((seg_num - 1)/2)] = (x, cy)seg_height[int((seg_num - 1)/2)] = cur_hprev_h = cur_h# processing last segmentif num_sec != 0:cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec]# pass if num of pivots is not sufficient or segment widh is smaller than character heightif None in pp or seg_w < np.max(seg_height) * 0.25:polys.append(None); continue# calc median maximum of pivot pointshalf_char_h = np.median(seg_height) * expand_ratio / 2# calc gradiant and apply to make horizontal pivotsnew_pp = []for i, (x, cy) in enumerate(pp):dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0]dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1]if dx == 0: # gradient if zeronew_pp.append([x, cy - half_char_h, x, cy + half_char_h])continuerad = - math.atan2(dy, dx)c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad)new_pp.append([x - s, cy - c, x + s, cy + c])# get edge points to cover character heatmapsisSppFound, isEppFound = False, Falsegrad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0])grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0])for r in np.arange(0.5, max_r, step_r):dx = 2 * half_char_h * rif not isSppFound:line_img = np.zeros(word_label.shape, dtype=np.uint8)dy = grad_s * dxp = np.array(new_pp[0]) - np.array([dx, dy, dx, dy])cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:spp = pisSppFound = Trueif not isEppFound:line_img = np.zeros(word_label.shape, dtype=np.uint8)dy = grad_e * dxp = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy])cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1)if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r:epp = pisEppFound = Trueif isSppFound and isEppFound:break# pass if boundary of polygon is not foundif not (isSppFound and isEppFound):polys.append(None); continue# make final polygonpoly = []poly.append(warpCoord(Minv, (spp[0], spp[1])))for p in new_pp:poly.append(warpCoord(Minv, (p[0], p[1])))poly.append(warpCoord(Minv, (epp[0], epp[1])))poly.append(warpCoord(Minv, (epp[2], epp[3])))for p in reversed(new_pp):poly.append(warpCoord(Minv, (p[2], p[3])))poly.append(warpCoord(Minv, (spp[2], spp[3])))# add to final resultpolys.append(np.array(poly))return polysdef getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False):boxes, labels, mapper, det_scores = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text)if poly:polys = getPoly_core(boxes, labels, mapper, linkmap)else:polys = [None] * len(boxes)return boxes, polys, det_scoresdef adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2):if len(polys) > 0:polys = np.array(polys)for k in range(len(polys)):if polys[k] is not None:polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net)return polys
步骤4:从CRAFT中删除参数解析器
打开test.py并修改如下所示。我们删除了参数解析器。
"""Modify to Remove Argument Parser""""""Copyright (c) 2019-present NAVER Corp.MIT License"""# -*- coding: utf-8 -*-import sysimport osimport timeimport argparseimport torchimport torch.nn as nnimport torch.backends.cudnn as cudnnfrom torch.autograd import Variablefrom PIL import Imageimport cv2from skimage import ioimport numpy as npimport craft_utilsimport imgprocimport file_utilsimport jsonimport zipfilefrom craft import CRAFTfrom collections import OrderedDictdef copyStateDict(state_dict):if list(state_dict.keys())[0].startswith("module"):start_idx = 1else:start_idx = 0new_state_dict = OrderedDict()for k, v in state_dict.items():name = ".".join(k.split(".")[start_idx:])new_state_dict[name] = vreturn new_state_dictdef test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, args, refine_net=None):t0 = time.time()# resizeimg_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio)ratio_h = ratio_w = 1 / target_ratio# preprocessingx = imgproc.normalizeMeanVariance(img_resized)x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w]x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w]if cuda:x = x.cuda()# forward passwith torch.no_grad():y, feature = net(x)# make score and link mapscore_text = y[0,:,:,0].cpu().data.numpy()score_link = y[0,:,:,1].cpu().data.numpy()# refine linkif refine_net is not None:with torch.no_grad():y_refiner = refine_net(y, feature)score_link = y_refiner[0,:,:,0].cpu().data.numpy()t0 = time.time() - t0t1 = time.time()# Post-processingboxes, polys, det_scores = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly)# coordinate adjustmentboxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)for k in range(len(polys)):if polys[k] is None: polys[k] = boxes[k]t1 = time.time() - t1# render results (optional)render_img = score_text.copy()render_img = np.hstack((render_img, score_link))ret_score_text = imgproc.cvt2HeatmapImg(render_img)if args.show_time : print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))return boxes, polys, ret_score_text, det_scores
步骤5:编写一个单独的脚本,将图像名称和检测框坐标保存到CSV文件中
这将帮助我们裁剪需要作为四阶段STR输入的单词。它还帮助我们将所有与边界框和文本相关的信息保存在一个地方。创建一个新文件(我将其命名为pipeline.py)并添加以下代码。
import sysimport osimport timeimport argparseimport torchimport torch.nn as nnimport torch.backends.cudnn as cudnnfrom torch.autograd import Variablefrom PIL import Imageimport cv2from skimage import ioimport numpy as npimport craft_utilsimport testimport imgprocimport file_utilsimport jsonimport zipfileimport pandas as pdfrom craft import CRAFTfrom collections import OrderedDictfrom google.colab.patches import cv2_imshowdef str2bool(v):return v.lower() in ("yes", "y", "true", "t", "1")#CRAFTparser = argparse.ArgumentParser(description='CRAFT Text Detection')parser.add_argument('--trained_model', default='weights/craft_mlt_25k.pth', type=str, help='pretrained model')parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold')parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score')parser.add_argument('--link_threshold', default=0.4, type=float, help='link confidence threshold')parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda for inference')parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference')parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio')parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type')parser.add_argument('--show_time', default=False, action='store_true', help='show processing time')parser.add_argument('--test_folder', default='/data/', type=str, help='folder path to input images')parser.add_argument('--refine', default=False, action='store_true', help='enable link refiner')parser.add_argument('--refiner_model', default='weights/craft_refiner_CTW1500.pth', type=str, help='pretrained refiner model')args = parser.parse_args()""" For test images in a folder """image_list, _, _ = file_utils.get_files(args.test_folder)image_names = []image_paths = []#CUSTOMISE STARTstart = args.test_folderfor num in range(len(image_list)):image_names.append(os.path.relpath(image_list[num], start))result_folder = './Results'if not os.path.isdir(result_folder):os.mkdir(result_folder)if __name__ == '__main__':data=pd.DataFrame(columns=['image_name', 'word_bboxes', 'pred_words', 'align_text'])data['image_name'] = image_names# load netnet = CRAFT() # initializeprint('Loading weights from checkpoint (' + args.trained_model + ')')if args.cuda:net.load_state_dict(test.copyStateDict(torch.load(args.trained_model)))else:net.load_state_dict(test.copyStateDict(torch.load(args.trained_model, map_location='cpu')))if args.cuda:net = net.cuda()net = torch.nn.DataParallel(net)cudnn.benchmark = Falsenet.eval()# LinkRefinerrefine_net = Noneif args.refine:from refinenet import RefineNetrefine_net = RefineNet()print('Loading weights of refiner from checkpoint (' + args.refiner_model + ')')if args.cuda:refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model)))refine_net = refine_net.cuda()refine_net = torch.nn.DataParallel(refine_net)else:refine_net.load_state_dict(copyStateDict(torch.load(args.refiner_model, map_location='cpu')))refine_net.eval()args.poly = Truet = time.time()# load datafor k, image_path in enumerate(image_list):print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r')image = imgproc.loadImage(image_path)bboxes, polys, score_text, det_scores = test.test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly, args, refine_net)bbox_score={}for box_num in range(len(bboxes)):key = str (det_scores[box_num])item = bboxes[box_num]bbox_score[key]=itemdata['word_bboxes'][k]=bbox_score# save score textfilename, file_ext = os.path.splitext(os.path.basename(image_path))mask_file = result_folder + "/res_" + filename + '_mask.jpg'cv2.imwrite(mask_file, score_text)file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder)data.to_csv('/content/Pipeline/data.csv', sep = ',', na_rep='Unknown')print("elapsed time : {}s".format(time.time() - t))
pandas DataFrame(变量data)在单独的列中存储图像名称和其中包含的单词的边界框。我们去掉了图像的完整路径,只保留了图像,以避免笨拙。你当然可以根据自己的需要进行定制。现在可以运行脚本了:
在这个阶段,CSV看起来像这样。对于每个检测,我们都存储了一个包含分数:坐标的Python字典。
步骤6:裁剪单词
现在我们有了每个框的坐标和分数。我们可以设置一个阈值,裁剪我们希望识别字符的单词。创建一个新脚本crop_images.py。请记住,在提到的地方添加你的路径。裁剪的单词保存在'dir'文件夹中。我们为每个图像创建一个文件夹,并以以下格式保存从中裁剪的单词:<父图像>_<由下划线分隔的8个坐标> 这样做可以帮助你跟踪每个裁剪单词来自哪个图像。
import osimport numpy as npimport cv2import pandas as pdfrom google.colab.patches import cv2_imshowdef crop(pts, image):"""Takes inputs as 8 pointsand Returns cropped, masked image with a white background"""rect = cv2.boundingRect(pts)x,y,w,h = rectcropped = image[y:y+h, x:x+w].copy()pts = pts - pts.min(axis=0)mask = np.zeros(cropped.shape[:2], np.uint8)cv2.drawContours(mask, [pts], -1, (255, 255, 255), -1, cv2.LINE_AA)dst = cv2.bitwise_and(cropped, cropped, mask=mask)bg = np.ones_like(cropped, np.uint8)*255cv2.bitwise_not(bg,bg, mask=mask)dst2 = bg + dstreturn dst2def generate_words(image_name, score_bbox, image):num_bboxes = len(score_bbox)for num in range(num_bboxes):bbox_coords = score_bbox[num].split(':')[-1].split(',\n')if bbox_coords!=['{}']:l_t = float(bbox_coords[0].strip(' array([').strip(']').split(',')[0])t_l = float(bbox_coords[0].strip(' array([').strip(']').split(',')[1])r_t = float(bbox_coords[1].strip(' [').strip(']').split(',')[0])t_r = float(bbox_coords[1].strip(' [').strip(']').split(',')[1])r_b = float(bbox_coords[2].strip(' [').strip(']').split(',')[0])b_r = float(bbox_coords[2].strip(' [').strip(']').split(',')[1])l_b = float(bbox_coords[3].strip(' [').strip(']').split(',')[0])b_l = float(bbox_coords[3].strip(' [').strip(']').split(',')[1].strip(']'))pts = np.array([[int(l_t), int(t_l)], [int(r_t) ,int(t_r)], [int(r_b) , int(b_r)], [int(l_b), int(b_l)]])if np.all(pts) > 0:word = crop(pts, image)folder = '/'.join( image_name.split('/')[:-1])dir = '/content/Pipeline/Crop Words/'if os.path.isdir(os.path.join(dir + folder)) == False :os.makedirs(os.path.join(dir + folder))try:file_name = os.path.join(dir + image_name)cv2.imwrite(file_name+'_{}_{}_{}_{}_{}_{}_{}_{}.jpg'.format(l_t, t_l, r_t ,t_r, r_b , b_r ,l_b, b_l), word)print('Image saved to '+file_name+'_{}_{}_{}_{}_{}_{}_{}_{}.jpg'.format(l_t, t_l, r_t ,t_r, r_b , b_r ,l_b, b_l))except:continuedata=pd.read_csv('PATH TO CSV')start = PATH TO TEST IMAGESfor image_num in range(data.shape[0]):image = cv2.imread(os.path.join(start, data['image_name'][image_num]))image_name = data['image_name'][image_num].strip('.jpg')score_bbox = data['word_bboxes'][image_num].split('),')generate_words(image_name, score_bbox, image)
运行脚本:
步骤6:识别(最后!)
现在你可以在裁剪的单词上盲目运行识别模块了。但如果你想让事情更有条理,修改如下所示。我们在每个图像文件夹中创建一个.txt文件,并将识别的单词与裁剪图像的名称一起保存。除此之外,预测的单词也保存在我们维护的CSV中。
import stringimport argparseimport torchimport torch.backends.cudnn as cudnnimport torch.utils.dataimport torch.nn.functional as Ffrom utils import CTCLabelConverter, AttnLabelConverterfrom dataset import RawDataset, AlignCollatefrom model import Modeldevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')import pandas as pdimport osdef demo(opt):"""Open csv file wherein you are going to write the Predicted Words"""data = pd.read_csv('/content/Pipeline/data.csv')""" model configuration """if 'CTC' in opt.Prediction:converter = CTCLabelConverter(opt.character)else:converter = AttnLabelConverter(opt.character)opt.num_class = len(converter.character)if opt.rgb:opt.input_channel = 3model = Model(opt)print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,opt.SequenceModeling, opt.Prediction)model = torch.nn.DataParallel(model).to(device)# load modelprint('loading pretrained model from %s' % opt.saved_model)model.load_state_dict(torch.load(opt.saved_model, map_location=device))# prepare data. two demo images from https://github.com/bgshih/crnn#run-demoAlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDatasetdemo_loader = torch.utils.data.DataLoader(demo_data, batch_size=opt.batch_size,shuffle=False,num_workers=int(opt.workers),collate_fn=AlignCollate_demo, pin_memory=True)# predictmodel.eval()with torch.no_grad():for image_tensors, image_path_list in demo_loader:batch_size = image_tensors.size(0)image = image_tensors.to(device)# For max length predictionlength_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)if 'CTC' in opt.Prediction:preds = model(image, text_for_pred)# Select max probabilty (greedy decoding) then decode index to characterpreds_size = torch.IntTensor([preds.size(1)] * batch_size)_, preds_index = preds.max(2)# preds_index = preds_index.view(-1)preds_str = converter.decode(preds_index.data, preds_size.data)else:preds = model(image, text_for_pred, is_train=False)# select max probabilty (greedy decoding) then decode index to character_, preds_index = preds.max(2)preds_str = converter.decode(preds_index, length_for_pred)dashed_line = '-' * 80head = f'{"image_path":25s}\t {"predicted_labels":25s}\t confidence score'print(f'{dashed_line}\n{head}\n{dashed_line}')# log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')preds_prob = F.softmax(preds, dim=2)preds_max_prob, _ = preds_prob.max(dim=2)for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob):start = PATH TO CROPPED WORDSpath = os.path.relpath(img_name, start)folder = os.path.dirname(path)image_name=os.path.basename(path)file_name='_'.join(image_name.split('_')[:-8])txt_file=os.path.join(start, folder, file_name)log = open(f'{txt_file}_log_demo_result_vgg.txt', 'a')if 'Attn' in opt.Prediction:pred_EOS = pred.find('[s]')pred = pred[:pred_EOS] # prune after "end of sentence" token ([s])pred_max_prob = pred_max_prob[:pred_EOS]# calculate confidence score (= multiply of pred_max_prob)confidence_score = pred_max_prob.cumprod(dim=0)[-1]print(f'{image_name:25s}\t {pred:25s}\t {confidence_score:0.4f}')log.write(f'{image_name:25s}\t {pred:25s}\t {confidence_score:0.4f}\n')log.close()if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--image_folder', required=True, help='path to image_folder which contains text images')parser.add_argument('--workers', type=int, help='number of data loading workers', default=4)parser.add_argument('--batch_size', type=int, default=192, help='input batch size')parser.add_argument('--saved_model', required=True, help="path to saved_model to evaluation")""" Data processing """parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length')parser.add_argument('--imgH', type=int, default=32, help='the height of the input image')parser.add_argument('--imgW', type=int, default=100, help='the width of the input image')parser.add_argument('--rgb', action='store_true', help='use rgb input')parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label')parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode')parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize')""" Model Architecture """parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS')parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet')parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM')parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn')parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN')parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor')parser.add_argument('--output_channel', type=int, default=512,help='the number of output channel of Feature extractor')parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state')opt = parser.parse_args()""" vocab / character number configuration """if opt.sensitive:opt.character = string.printable[:-6] # same with ASTER setting (use 94 char).cudnn.benchmark = Truecudnn.deterministic = Trueopt.num_gpu = torch.cuda.device_count()# print (opt.image_folder)# pred_words=demo(opt)demo(opt)
从Clova AI STR Github Repository下载权重后,你可以运行以下命令:
我们选择了这种网络组合,因为它们的准确性很高。现在CSV看起来是这样的。pred_words有检测框坐标和预测的单词,用冒号分隔。
结论
我们已经集成了两个准确的模型,创建了一个单一的检测和识别模块。现在你有了预测的单词和它们的边界框在一个单独的列中,你可以以任何你想要的方式对齐文本!
下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。
下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。
下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。
交流群
欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~
