0%

解析登录网址

首先我们要定位到登录的网址以便driver访问。

将浏览器打开到登录页面,可以看到对应的网址为

1
https://jaccount.sjtu.edu.cn/jaccount/jalogin?sid=jaoauth220160718&client=CAzoSoajfvlBJKYW7q11rIlS5ucaE%2FIWcLi2eNuNEJDa&returl=CPDcg4bZMUVfWbZWFi2BIAB%2BzZprZ2JnKxq%2Bd6MsjeS0nMw73qkvsCxG47FeRZmpnkzSb7Gf%2FCPIbwa%2BNH267zBN%2BexUD2RL%2BhgfQzpYNJ96UJUYoajLq%2Fqgx6g%2BL2Ol6i3p5RWYYIcKoV28CmbnNbPia3K1RfdgUgktPh4yNrgYsIciWlBvtMDF%2FLrJfK4rGV9Z%2BbdiXK6EFaQJId4mZXj45XlS2pmQtR5P6qOeSj3nvLDaxZ6bablX7cNO5IhPAmaj%2BMOH%2Fx5F7YpIbYoZg3R01iV6Rd7oO8pSmHFTFkvhJFWPTWI1C01nBDaKato1IDv%2BP0%2FF8C6e%2BTDf6MOGmge0SAXgweU0o4I2v06Xb5nDWAGGTAtyB2Vm0qxD7RB6C74n77jiGIGju4y13a%2Fi10RO4pmfPcik9st2unhTWfVMqDtYd%2BJ9jCjl8UOoel8G99qS7ET%2FVN%2FbP%2BfUkBJb7I8%3D&se=CBZz3nk3K70uQiJvss5rVxhD4A5GYffm65sDqZ6Zh2PbEFVqISnzjU2JVjCO1VYIYuNZMv%2F8W6s3

很明显这是拼接了随机数的网址而不是真正应该访问的,为了得到真正的登录页面的网址,需要动态解析。

打开一开始的登录选项,要点击校内用户登录才会加载出真正的登录页面,因此可以猜测这是一个动态加载的页面。所以我们打开页面控制台,点击Network选项,一开始是空白的,等待新加载的信息。

image-20220512114437557

然后点击登录按钮,会跳转到登录页面。左边的控制台已经加载了许多信息,要在这些信息中找到加载网址的信息,一般来说我们查看第一条条目即可。

image-20220512114740458

选中第一条,可以看到请求的url,这就是真正的登录页面的网址,这里可以复制这个地址打开看看是不是正确的,如果不是就继续找。

image-20220512114923472

模拟登录

得到网址后,可以进行登录。登录一般来说有两种方式,一种是用requests库的post请求,一种是用selenium的driver直接进行网页模拟登录。这里选择用driver的方式主要基于以下两个方面:

  • 目的上看,我们想要自动登录后还能够操纵我们的网页干其他事情,用driver的话刚刚好
  • 从可行性的角度上看,登录过程我们需要验证码识别,如果用requests库抓取,前面提到抓取下来的验证码是不一样的,每次抓取都会改变内容。因此我能想到的解决方法就是用driver的页面进行静态的截图,实际上这符合我们人眼识别的过程。

因此,我们对于验证码的识别就是根据截图来进行的。同样,这里使用我们之前训练的模型,因此要把模型先放到model.py里面。下面给出处理验证码的代码文件。首先定义验证码信息参数,如长度、大小等。然后我们要把模型初始化,加载参数的过程会比较慢(推理很快),我们不想浪费这段时间,由于用driver打开网页也需要一段时间,因此就可以开一个线程来加载参数,与此同时打开网页,可以节省一些时间。

然后定义了截图操作,先截全屏再截下验证码。这里验证码的位置信息通过网页源代码(html)的信息获取,用xpath、class、id定位都行,在webdriver都定义了相关的操作。需要注意的是scaling_ratio这一变量,要截下正确的验证码位置与我们屏幕的缩放比例有关。可以在桌面上右键点击显示设置,然后查看屏幕的缩放与布局,根据缩放的比例设置变量。

image-20220512121304774

最后还有一个验证码解码的方法,实际上这跟前面predict.py的操作类似。所有代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# captcha.py
from model import CNN_sjtu
import torch.nn
import torch
from PIL import Image
from torchvision.transforms import Compose, ToTensor, Resize,Normalize
import time
from selenium.webdriver.common.by import By
from threading import Thread
import warnings

#pytorch版本关系,会有warning,实际上可以忽略
warnings.filterwarnings("ignore")

#定义验证码信息参数
numchar = 4
alphabet = 'abcdefghijklmnopqrstuvwxyz'
width = 100
height = 40

#模型初始化
model_net = CNN_sjtu(num_class=len(alphabet), num_char=int(numchar), width=width, height=height)
model_net = model_net.cuda()
model_net.eval()

def load_net():#load参数 比较慢,用一个额外的线程先load
global model_net
print("load net......")
model_net.load_state_dict(torch.load(r'C:\Users\14242\PycharmProjects'
r'\DL\Pytorch_project\captcha-CNN-验证码识别\weights/model_sjtu_4.path'))#参数位置

thread_load= Thread(target=load_net)
thread_load.start()

def decode_captcha(img_path):
global model_net
with torch.no_grad():
img=Image.open(img_path)
img = img.convert('RGB')
transforms = Compose([Resize((height, width)), ToTensor(), Normalize(0, 1)])
img = transforms(img)
img = img.view(1, 3, height, width).cuda()
output = model_net(img)
output = output.view(-1, len(alphabet))
output = torch.nn.functional.softmax(output, dim=1)
output = torch.argmax(output, dim=1)
output = output.view(-1, numchar)[0]
return ''.join([alphabet[i] for i in output.cpu().detach().numpy()])

def get_snap(driver): # 对目标网页进行截屏。这里截的是全屏
driver.save_screenshot('full_snap.png')
page_snap_obj=Image.open('full_snap.png')
return page_snap_obj

def get_image(driver): # 对验证码所在位置进行定位,然后截取验证码图片
scaling_ratio=1#系统显示的缩放比例
img = driver.find_element(By.XPATH, '//*[@id="captcha-img"]')#获取验证码在网页中的位置信息
time.sleep(0.10)#等待一会
location = img.location#位置
size = img.size#大小
left = location['x']*scaling_ratio
top = location['y']*scaling_ratio
right = left + size['width']*scaling_ratio
bottom = top + size['height']*scaling_ratio
page_snap_obj = get_snap(driver)#获取整个页面截图
image_obj = page_snap_obj.crop((left, top, right, bottom))#截下验证码
image_obj.save('captcha.png')

下面是我们登录的主程序,注意这种方式要先下载对应于谷歌浏览器版本号的chromedriver.exe放到当前的目录下。

通过send_keys()方法填入用户名、密码、验证码;接着通过click()方法自动点击登录按钮即可,注意操作网页的过程一定要适当地等待加载(sleep)。

由于验证码可能识别错误,因此要循环判断是否成功登录,直到成功再停止。我的方法是判断当前页面的title是否发生了改变,这是一个比较简单方便的方式。

最后,由于一开始我们生成了两个截图,可以用os.remove()将其删除。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# login.py
from captcha import decode_captcha as decode_c
from captcha import get_image
from selenium import webdriver
import time
import os
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.common.by import By

#调用谷歌driver的基本操作
options = Options()
driver = webdriver.Chrome(options=options)#调用当前路径下的chromedriver.exe
driver.maximize_window()#最大化

#截图的保存地址
captcha_path='./captcha.png'
snap_path='./full_snap.png'

#两个地址链接
url = 'https://i.sjtu.edu.cn/jaccountlogin'
url ='https://oc.sjtu.edu.cn/login/openid_connect'

driver.get(url) #打开网页

cur_title=driver.title
while(driver.title==cur_title):#失败则一直试,因为验证码可能错误,两次应该就能成功
time.sleep(0.10) # 加载等待
get_image(driver)
print("推理验证码......")
captcha_res=decode_c(captcha_path)
driver.find_element(By.NAME, 'user').send_keys('username') # 填入用户名
driver.find_element(By.NAME, 'pass').send_keys('password') # 填入密码
driver.find_element(By.NAME,'captcha').send_keys(captcha_res) # 填入验证码
driver.find_element(By.ID,"submit-button").click()

os.remove(captcha_path)
os.remove(snap_path)
print("ending......")

简介

前段时间打算用谷歌driver自动登录学校的网站,方便后面继续开发脚本,主要还是想着以后也许会用到爬虫爬取图像等数据用来训练。

这个项目从最开始使用 pytorch 搭模型、爬验证码数据、以及训练改进,到后来觉得登录学校网站不用手打验证码挺方便的,就不断优化程序最后完成了一个快速登录的可执行文件,下面是一个demo演示。

这里打算开个坑写写训练的过程和用谷歌driver登录网站的设计思路,以及最后的一些优化日志。

验证码数据集获取

解析验证码网址

我们打开网址的登录页面,检查页面源代码,定位到验证码的位置,可以看到一个验证码的网页。

验证码网址解析1

打开这个网页,显示的就是验证码的图片,但这个验证码不是原来的验证码,说明验证码是动态加载的,尽管网址一样但是内容是不相同的。可以使用爬虫工具下载这个图片,为了观察这个网址以便批量下载,我们刷新登录页面,并再次打开一个新的验证码网址。可以看到唯一的变化就是网址后缀上的 uuid 变化了。uuid 是一种标识码,后端算法会根据这个 uuid 生成一个验证码。

所以可以先随机生成一个 uuid 标识,然后拼接成完整的验证码网址,通过爬虫下载图片。

验证码解析2

验证码解析3

批量下载验证码

使用python(环境:python3.7)提供一系列的操作:

  • 生成uuid并拼接网址

    1
    2
    3
    import uuid #该库用于生成uuid,有多种方式
    uuidx=str(uuid.uuid4())
    url = 'https://jaccount.sjtu.edu.cn/jaccount/captcha?uuid='+uuidx
  • 通过爬虫下载验证码图片

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    import requests
    origin_path='captcha-sjtu/origin.jpg'
    # 构造请求头
    headers={'User-Agent':'Mozilla/5.0 (Windows NT 10.0; Win64; x64) '
    'AppleWebKit/537.36 (KHTML, like Gecko) Chrome/89.0.4389.90 Safari/537.36 Edg/89.0.774.57'}
    # 发送请求
    res = requests.get(url=url,headers=headers)
    # 把获取的二进制写成图片
    with open(origin_path, 'wb') as f:
    f.write(res.content)
  • 获得了一张图片后,我们需要将它命名为它的验证码识别结果,作为训练数据和测试数据的标签。但我们不可能人眼去识别和手动修改,这样工作量太大了,因此可以选择外接库来帮我们识别(本意不是识别出验证码,而是学习自己搭网络来训练,因此尽管有外接库,还是希望自己能完成一个网络)。这里选择ddddocr这个库,用法很简单

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    import ddddocr
    import os
    ocr = ddddocr.DdddOcr(use_gpu=True) #实例化一个识别器,使用gpu
    with open(origin_path, 'rb') as f:
    img_bytes = f.read() #读入二进制数据
    try: #可能上面的uuid拼接的url弄的不是一个验证码
    res = ocr.classification(img_bytes) #识别的结果
    except:
    continue #如果url不是验证码就跳过,无所谓。不用try语句的话会中断
    newname='captcha-sjtu/train/'+res+'.jpg'
    try:
    os.rename(origin_path, newname) #可能已经有一个同名的了,为了不让程序中断,还是使用try
    except:
    continue
  • 至此就处理完了一张图片,下面为批量处理的完整代码

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    import requests
    import time
    import uuid
    import os
    import ddddocr
    ocr = ddddocr.DdddOcr(use_gpu=True)
    origin_path='captcha-sjtu/origin.jpg'
    for i in range(10000):
    # 每爬取500个,歇1秒,确保服务器不会受影响
    if i%500 == 0:
    print("-----",i/1000,'组-----------')
    time.sleep(1)
    # 生成随机数
    uuidx=str(uuid.uuid4())
    url = 'https://jaccount.sjtu.edu.cn/jaccount/captcha?uuid='+uuidx
    # 构造请求头
    headers={'User-Agent':'Mozilla/5.0 (Windows NT 10.0; Win64; x64) '
    'AppleWebKit/537.36 (KHTML, like Gecko) Chrome/89.0.4389.90 Safari/537.36 Edg/89.0.774.57'}
    # 发送请求
    res = requests.get(url=url,headers=headers)
    # 把获取的二进制写成图片
    with open(origin_path, 'wb') as f:
    f.write(res.content)
    # 再读取回来
    with open(origin_path, 'rb') as f:
    img_bytes = f.read()
    try:
    res = ocr.classification(img_bytes)#可能上面的uuid拼接的url弄的不是一个验证码
    except:
    continue
    newname='captcha-sjtu/train/'+res+'.jpg'
    try:
    os.rename(origin_path, newname)#可能已经有一个同名的了
    except:
    continue

pytorch 模型搭建

构建 Dataset

将数据集分成9:1,并预留一些数据用于验证。首先定义一些参数设置,写在setting.py内。通过观察验证码,可以得到图片的宽度和高度,以及内容和长度。这里验证码的内容都是小写字母,因此我们只需要小写字母的字母表;而验证码长度是4或者5,对于CNN网络来说,需要接受一个固定大小的输入,然后输出一个固定大小的标签等等。

因此我们没办法同时识别长度为4和长度为5的验证码,但根据观察发现这两种长度的验证码出现的频率几乎是1:1的,那么在登录验证时只需要不断 try 尝试即可,而且并不会造成很多的时间浪费。以下为setting的内容。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
#  setting.py 
width_sjtu=100 #图片尺寸
height_sjtu=40
alphabet_sjtu='abcdefghijklmnopqrstuvwxyz'#全是小写字母
#验证码长度
numchar=4
#train:
# 遍历数据集训练的次数
max_epoch=100
# 批处理数量
batch_size=128
# 学习率
base_lr=0.0003

# 训练数据存放路径
train_data_path_sjtu='./captcha-sjtu/train'
# 测试数据存放路径
test_data_path_sjtu= './captcha-sjtu/test'
# 预测数据
samples_path_sjtu = './captcha-sjtu/predict'
# 是否使用gpu
use_gpu= True
# gpu并行处理进程数
num_workers= 0
# 训练后的模型输出的路径
model_path='./weights'

随后,我们要重写Dataset,大部分工作在重写_getitem_()方法,返回处理后的图像和标签,于是我们需要先考虑这个标签的形式。我们似乎可以使用一个字母表大小(里面也可以包括数字等等)的一维向量,比如全是小写字母那么我的向量长度就为26,然后将验证码图片出现的字母映射在向量对应的索引处:如果字母出现则为1,不出现则为0。但这种方法,一方面没办法表示验证码的顺序,一方面没办法识别有重复字母的验证码。

因此考虑对每个字母都建立一个长度为26的向量进行映射,因此向量的总长度就是验证码长度×字母表长度

对于一个图片,前面使用了它的名称作为验证码结果,因为我们下载的时候并没有区别长度,因此这里长度不一致的数据要剔除,执行一次continue即可。然后,我们把验证码字符串的每一个字符都映射到一个向量上,在python中可以使用str.find(),这里的str即为我们的字母表。然后我们把这些向量都拼接起来,就构成了label。

make_dataset()函数会返回图片路径、图片label,最后我们Dataset中使用它重写_getitem_()方法。对应的dataset.py文件如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# datasets.py
import os
from PIL import Image
import torch
from torch.utils.data import Dataset

def img_loader(img_path):
img = Image.open(img_path)
# 将图像转换为 RGB
return img.convert('RGB')

# 处理数据集所在文件夹下的数据
def make_dataset(data_path, alphabet, num_class, num_char):
# 获取数据集所在文件夹的所有文件名
img_names = os.listdir(data_path)
samples = []
for img_name in img_names:
# 拼接每个图像数据集的路径
img_path = os.path.join(data_path, img_name)
# 找出该图像的label
target_str = img_name.split('.')[0]
# 判断lable和结果的长度是否一致
if len(target_str) != num_char:
continue

target = []
# 创建每个数据的target数组 4 * alphabet,这里使用one hot
for char in target_str:
#------如果只看小写要映射成小写-------------------
# if ord(char)>=65 or ord(char)<=90:
# char=chr(ord(char)+32)
#---------------------------------

vec = [0] * num_class
vec[alphabet.find(char)] = 1
target += vec#要四个数组,如果写在一个数组内,重复的表示不出来

# 加入数据集
samples.append((img_path, target))
# 返回数据集
return samples


class CaptchaData(Dataset):
def __init__(self, data_path, num_class=62, num_char=4, transform=None, target_transform=None,
alphabet="0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"):
super(Dataset, self).__init__();
self.data_path = data_path
self.num_class = num_class
self.num_char = num_char
self.transform = transform
self.target_transform = target_transform
self.alphabet = alphabet
self.samples = make_dataset(self.data_path, self.alphabet,
self.num_class, self.num_char)

def __len__(self):
return len(self.samples)

def __getitem__(self, index):
img_path, target = self.samples[index]
img = img_loader(img_path)
# 如果有传入预处理函数,就预处理数据集
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)

return img, torch.Tensor(target)

构建CNN model

接着,我们搭建一个CNN网络,这个网络不能太大,因为我们的标签向量本身很大,如果只是用自己机器训练的话,gpu内存可能不够用(全连接层的参数尤其多)。要查看gpu的占用,可以在任务管理器–性能–GPU1处实时监控专用GPU内存利用率(GPU0一般是处理器的而非显卡),如下图

GPU内存占用监控示意图

我最后选择了一个可以训练的模型,使用四层卷积和两层全连接层,每次卷积后使用一个2×2的最大池化,接着批归一化,最后使用ReLU激活函数。因为学校网址的验证码本身不是很复杂,训练后可以针对长度4和长度5的验证码都可以达到96%的准确率。

实际上,并不需要多高的准确率,因为try一次的时间甚至不到1秒钟(当然调整网络结构优化识别性能是一件很有趣的事情)。model.py的代码如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#  model.py
import torch.nn as nn
class CNN_sjtu(nn.Module):#4长度和5长度准确率都在96%左右,验证都是正确的
def __init__(self, num_class=36, num_char=4, width=100, height=40):
super(CNN_sjtu, self).__init__()
self.num_class = num_class
self.num_char = num_char
# 卷积层后,全连接层的一维数组输入长度
# 512是卷积处理后图片的通道数,长度和宽度各除以16是因为图像经过了四次2*2池化层(MaxPool2d)
self.line_size = int(512 * (width // 2 // 2 // 2 // 2) * (height // 2 // 2 // 2 // 2))
self.conv1 = nn.Sequential(
# 输入的是RGB图像,所以是3通道。
# 这里设置该层有16个卷积核,所以输出是16通道
# padding(1,1)表示在图像上下左右各加1行、1列,保证在卷积后图像大小不变
nn.Conv2d(3, 16, 3, padding=(1, 1)),
# 池化层,保留图像每2*2片段像素的最大值
nn.MaxPool2d(2, 2),
# 对每个通道的图像都归一化,防止梯度爆炸
nn.BatchNorm2d(16),
# 激活函数
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 64, 3, padding=(1, 1)),
nn.MaxPool2d(2, 2),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.conv3 = nn.Sequential(
nn.Conv2d(64, 512, 3, padding=(1, 1)),
nn.MaxPool2d(2, 2),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.conv4 = nn.Sequential(
nn.Conv2d(512, 512, 3, padding=(1, 1)),
nn.MaxPool2d(2, 2),
nn.BatchNorm2d(512),
nn.ReLU()
)
# 全连接层
self.fc = nn.Sequential(
nn.Linear(self.line_size, self.line_size),
# nn.Identity(),
# 输出应为 验证码长度*字符的分类数
nn.Linear(self.line_size, self.num_char * self.num_class)
)

def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
# resize输入数组的尺寸,相当于flatten
x = x.view(-1, self.line_size)
x = self.fc(x)

return x

训练模型

训练代码如下,每一轮迭代都会保存模型参数到给定文件夹。最终观察表现最好的参数模型,手动删除其他不好的,并可以修改一下命名,以防之后重新训练覆盖了这个参数模型。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#train.py
import torch
from model import *
from datasets import CaptchaData
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Resize,Normalize
import time
import os
from setting import *

#导入设定参数
alphabet=alphabet_sjtu
width=width_sjtu
height=height_sjtu

# 训练数据存放路径
train_data_path=train_data_path_sjtu
# 测试数据存放路径
test_data_path=test_data_path_sjtu
# 预测数据
samples_path=samples_path_sjtu

if not os.path.exists(model_path):
os.makedirs(model_path)

device =torch.device("cuda")#设置gpu

# 计算准确度
def calculat_acc(output, target):
output, target = output.view(-1, len(alphabet)), target.view(-1, len(alphabet)) #字母有26个就是26列
output = nn.functional.softmax(output, dim=1) #缩放到0-1区间,所有元素和为1
output = torch.argmax(output, dim=1) #返回每一列得分最高的索引值,说明预测的是这个位置的字母
target = torch.argmax(target, dim=1) #然后dim=1这个维度会消失
output, target = output.view(-1, int(numchar)), target.view(-1, int(numchar))
correct_list = []
for i, j in zip(target, output):
if torch.equal(i, j): #如果两个列表相等(相同大小和元素)
correct_list.append(1)
else:
correct_list.append(0)
acc = sum(correct_list) / len(correct_list)
return acc


def train():
# 数据shape的预处理,缩放、转tensor,以及图像处理基本都要使用归一化
transforms = Compose([Resize((height, width)), ToTensor(),Normalize(0, 1)])
# 创建训练数据集对象
train_dataset = CaptchaData(train_data_path, num_class=len(alphabet), num_char=int(numchar), transform=transforms, alphabet=alphabet)
# 初始化DataLoader,之后训练的数据由它按照我们的要求如batch_size等提供
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers,
shuffle=True, drop_last=False)
# 创建测试数据集对象
test_dataset = CaptchaData(test_data_path, num_class=len(alphabet), num_char=int(numchar), transform=transforms, alphabet=alphabet)
test_data_loader = DataLoader(test_dataset, batch_size=batch_size,
num_workers=num_workers,drop_last=False)

print("训练数据量:",train_dataset.__len__(),'\t测试数据量:',test_dataset.__len__())

# 初始化模型
cnn = CNN_sjtu(num_class=len(alphabet), num_char=int(numchar), width=width, height=height)
if use_gpu:
cnn=cnn.to(device)

#----------------损失函数及优化方法---------------------
# 使用Adam优化方法
optimizer = torch.optim.Adam(cnn.parameters(), lr=base_lr)
# 使用多标签分类的损失函数
criterion = nn.MultiLabelSoftMarginLoss()

#----------------开始迭代训练-------------------------
# 训练我们指定的epoch次
print("开始训练...")
for epoch in range(max_epoch):
start_ = time.time()
loss_history = []
acc_history = []
# 切换到训练模式
cnn.train()
for img, target in train_data_loader:
if use_gpu:
img = img.to(device)
target = target.to(device)
# 获取神经网络的输出
output = cnn(img)
# 计算损失函数
loss = criterion(output, target)
# 初始化梯度
optimizer.zero_grad()
# 反向传播计算梯度
loss.backward()
# 优化参数
optimizer.step()
# 计算准确度
acc = calculat_acc(output, target)
acc_history.append(float(acc))
loss_history.append(float(loss))
print('epoch:{},train_loss: {:.4}|train_acc: {:.4}'.format(
epoch,
torch.mean(torch.Tensor(loss_history)),
torch.mean(torch.Tensor(acc_history)),
))

with torch.no_grad():
loss_history = []
acc_history = []
# 切换到测试模式
cnn.eval()
for img, target in test_data_loader:
if torch.cuda.is_available():
img = img.to(device)
target = target.to(device)
output = cnn(img)

acc = calculat_acc(output, target)
acc_history.append(float(acc))
print('test_loss: {:.4}|test_acc: {:.4}'.format(
torch.mean(torch.Tensor(loss_history)),
torch.mean(torch.Tensor(acc_history)),
))
print('epoch: {}|time: {:.4f}'.format(epoch, time.time() - start_))
torch.save(cnn.state_dict(), os.path.join(model_path, "model_{}.path".format(epoch)))#每个epoch保存一次参数


if __name__ == "__main__":
train()

模型效果测试

最后我们可以看看我们模型的效果,下面是一些预测的示例。

验证码识别示例1

验证码识别示例2

验证码识别示例3

预测部分的代码如下,至此,整个模型训练部分就完成了,登录网站的部分见下一篇博客。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#predict.py
import torch
from PIL import Image
from model import *
from torchvision.transforms import Compose, ToTensor, Resize,Normalize
import matplotlib.pyplot as plt
import os
import random
from setting import *

#----------参数设定-------------
model_net = CNN_sjtu()
alphabet=alphabet_sjtu
width=width_sjtu
height=height_sjtu
samples_path=samples_path_sjtu+'/4'
numchar=4

# 获取模型
def load_net():
global model_net
# 初始化模型
model_net = CNN_sjtu(num_class=len(alphabet), num_char=int(numchar), width=width, height=height)
# 读取参数模型
if use_gpu:
model_net = model_net.cuda()
model_net.eval()
model_net.load_state_dict(torch.load('./weights/model_sjtu_4.path')) #加载参数模型
else:
model_net.eval()
model_net.load_state_dict(torch.load(model_path, map_location='cpu'))

# 预测验证码
def predict_image(img):
global model_net
with torch.no_grad():
img = img.convert('RGB')
transforms = Compose([Resize((height, width)), ToTensor(),Normalize(0, 1)]) #图像变换
img = transforms(img)

if use_gpu:
img = img.view(1, 3, height, width).cuda()
else:
img = img.view(1, 3, height, width)
output = model_net(img) #推理

output = output.view(-1, len(alphabet))
output = nn.functional.softmax(output, dim=1)
output = torch.argmax(output, dim=1)
output = output.view(-1, numchar)[0]
return ''.join([alphabet[i] for i in output.cpu().detach().numpy()]) #转换成numpy类型需要先从gpu加载到cpu,然后可以获得字母表的索引


if __name__ == "__main__":
load_net()
# 枚举数据所在文件夹
img_names = os.listdir(samples_path)
random.shuffle(img_names)
samples = []
for img_name in img_names:
# 拼接每个数据的路径
img_path = os.path.join(samples_path, img_name)
img = Image.open(img_path)
v_code = predict_image(img)
plt.figure()
plt.title("{}".format(v_code))
plt.imshow(img)
plt.show()

hello world

​ 这里是 Jy 的博客,会follow一些项目的过程,作为学习和记录,以及一些个人的学期总结和年度总结~

GitHub

​ 这里是我的GitHub地址:https://github.com/Chen-Jin-yuan