【Pytorch基础】处理多维特征的输入

回顾

 到目前为止,我们讨论的都是只有一个实数输入的模型。但实际情况要复杂的多,因此,如何处理多维输入是个非常重要的问题。

关于糖尿病的二分类问题

1. 准备数据集

上述样本的输入为 8 个指标,输出为两个类别(病情未来会加重 1、病情未来不会加重 0)。

1
2
3
4
5
6
7
8
9
import numpy as np
import torch

xy = np.loadtxt('https://project-preview-1257022783.cos.ap-chengdu.myqcloud.com/diabetes.csv.gz',delimiter=',',dtype=np.float32)

# 创建tensor
x_data = torch.from_numpy(xy[:,:-1]) # 所有行,最后一列不要
y_data = torch.from_numpy(xy[:,[-1]]) # 所有行,只要最后一列,-1加[]表示拿出来一个矩阵,而不是向量

多维度输入的逻辑回归模型

 上述数据集的输入不再是一个简单的实数,而是一个8维向量$x^{(i)}$,对于单个样本其模型为:

上述操作得到一个标量 $\hat{y}^{(i)}$, $\sigma$ 为 Logistic 函数。

Mini-Batch(N samples) 情况下

  Mini_Batch 的大小为$N$,即每次更新都根据$N$个样本来计算损失,其模型为:

其中:

即:

构造多层神经网络

先导知识:
对于矩阵运算 $y = A \times x$

其中

那么,上式表示为将$n$维空间映射到$m$维空间的一个线性变换。
因此,可以将矩阵看成一种空间变换的函数。
所以,self.linear = torch.nn.Linear(8,6) 就可以看做将一个 8 维空间经过线性变换映射到一个 6 维空间上。

但是,若我们在每一次线性变换后加入了非线性函数 $\sigma$ ,就可以实现非线性变换,使得模型可以拟合非线性问题。

 多层神经网络,就是通过拼接多次变换得到的:


注意:理论上,隐层数量越多模型的学习能力就越强。但是,太强的学习能力会导致模型连数据中的噪声都学习到了(过拟合)反而适得其反。一个号的模型应该要具有一定的泛化能力,不能去死扣细节而去抓住问题的主要矛盾。因此,层数的多少应该根据实际情况适当尝试调整,而不是一味地求多。

定义多层神经网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class Model(torch.nn.Module):
def __init__(self):
super(Model,self).__init__()
self.linear1 = torch.nn.Linear(8,6) # 输入维度为 8
self.linear2 = torch.nn.Linear(6,4)
self.linear3 = torch.nn.Linear(4,1)
self.activate = torch.nn.Sigmoid() # 激活函数

def forward(self, x):
x = self.activate(self.linear1(x))
x = self.activate(self.linear2(x))
x = self.activate(self.linear3(x))
return x

model = Model()

损失函数和优化器

1
2
criterion = torch.nn.BCELoss(reduction="mean")
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

训练模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
h_list = []
l_list = []
for epoch in range(10000):
# 前馈计算
y_pred = model(x_data)
loss = criterion(y_pred,y_data)
print(epoch, loss.item())
h_list.append(epoch)
l_list.append(loss.item())

# 反向传播
optimizer.zero_grad()
loss.backward()

# 更新参数
optimizer.step()

绘制收敛图

1
2
3
4
5
6
import matplotlib.pyplot as plt

plt.plot(h_list, l_list)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()