“那是飞机吗?”——基于Pytorch的简易分类器

CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含 10 个类别的 RGB 彩色图 片:飞机( airplane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。图片的尺寸为 32×32 ,数据集中一共有 50000 张训练图片和 10000 张测试图片。

首先准备好数据集。使用torchvision.datasets可以直接下载数据集,使用download参数控制。当此项参数为True,则数据集不存在指定目录下时会自动下载。同时使用transform.Normalize()可以很好地进行归一化。将transforms.ToTensor()与其使用transforms.Compose()结合在一起。

data_path = r"./datasets"
cifar10 = datasets.CIFAR10(data_path, train=True, download=True)
cifar10_val = datasets.CIFAR10(data_path, train=False, download=True)

tensor_cifar10 = datasets.CIFAR10(data_path, train=True, download=False, transform=transforms.ToTensor())
tensor_cifar10_val = datasets.CIFAR10(data_path, train=False, download=False, transform=transforms.ToTensor())

img_t, _ = tensor_cifar10[99]
imgs = torch.stack([img_t for img_t, _ in tensor_cifar10], dim = 3)

'''
Normalize params are from
imgs.view(3, -1).mean(dim=1) and imgs.view(3, -1).std(dim=1)
'''

transformed_cifar10 = datasets.CIFAR10(data_path,
                                       train=True,
                                       download=False,
                                       transform=transforms.Compose([transforms.ToTensor(),
                                                                    transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                                                         (0.2470, 0.2435, 0.2616))]))
transformed_cifar10_val = datasets.CIFAR10(data_path,
                                       train=False,
                                       download=False,
                                       transform=transforms.Compose([transforms.ToTensor(),
                                                                    transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                                                         (0.2470, 0.2435, 0.2616))]))

由于我们只需要分类鸟与飞机,我们只取需要的数据。并用dataloader加载数据集。

label_map = {0: 0, 2: 1}
class_names = {'airplane', 'bird'}
cifar2 = [(img, label_map[label])
          for img, label in transformed_cifar10
          if label in [0, 2]]

cifar2_val = [(img, label_map[label])
              for img, label in transformed_cifar10_val
              if label in [0, 2]]

train_loader = torch.utils.data.DataLoader(cifar2, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(cifar2_val, batch_size=64, shuffle=False)

定义我们的模型:

model = nn.Sequential(nn.Linear(3072,512, ),
                      nn.Tanh(),
                      nn.Linear(512, 2, ),
                      nn.LogSoftmax(dim = 1))

了解其中的Softmax函数是必要的。

Softmax函数定义

优化这个模型最简单粗暴的办法就是增加层数。最后进入训练。

loss_fn = nn.NLLLoss()
learning_rate = 1e-2
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
n_epochs = 100

# Train the model
for epoch in range(n_epochs):
    for imgs, labels in train_loader:
        batch_size = imgs.size(0)
        outputs = model(imgs.view(batch_size, -1))
        loss = loss_fn(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print("Epoch: %d, Loss: %f" % (epoch, float(loss)))

我们可以评估一下通过此训练出的模型。

with torch.no_grad():
    for imgs, labels in val_loader:
        batch_size = imgs.size(0)
        outputs = model(imgs.view(batch_size, -1))
        _, predicted = torch.max(outputs, dim = 1)
        total += labels.size(0)
        correct += int((predicted == labels).sum())

准确度在0.82左右,相对来说还是不错。