【Pytorch基础】线性模型

线性模型

一般流程

  1. 准备数据集(训练集,开发集,测试集)
  2. 选择模型(泛化能力,防止过拟合)
  3. 训练模型
  4. 测试模型

例子

学生每周学习时间与期末得分的关系

x(hours) y(points)
1 2
2 4
3 6
4 ?

设计模型

观察数据分布可得应采用线性模型:

其中 $\hat y$ 为预测值,不妨简化一下模型为:

我们的目的就是得到一个尽可能好的 $w$ 值。使模型的预测值越接近真实值,因此我们需要一个衡量接近程度的指标 $loss$,可用绝对值或差的平方表示单g个样本预测的损失为(Training Loss):

这里使用差的平方,其中 $y$ 为真实值。

因此,对于多样本预测的平均损失函数为(Mean Square Error):

1
2
3
4
5
6
7
8
# 定义模型函数
def forward(x):
return x * w;

# 定义损失函数
def loss(x, y):
y_predict = forward(x)
return (y - y_predict) ** 2

过程模拟

由于不知道 $w$ 的具体值因此我们给它一个随机初始值,假设 $w = 3$

x(hours) y(points) y_predict loss
1 2 3 1
2 4 6 4
3 6 9 9
MSE=14/3

可知本轮预测平均损失为 14/3

为找到最佳权重,可枚举权重值判断损失,损失最小为最佳

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 存放枚举到的权重 w 的取值
w_list = []
# 对应权重的平均误差
mse_list = []

# 枚举权重,步长为 0.1
for w in np.arange(0.0, 4.1, 0.1): # 从 0.0 到 4.1
print("w=", w)
loss_sum = 0 # 损失和
for x_val, y_val in zip(x_data, y_data): # zip 函数传入可迭代对象
y_predict_val = forward(x_val) # 计算预测值
loss_val = loss(x_val, y_val) # 计算单样本损失
loss_sum += loss_val # 更新损失和
print('\t\t',x_val, y_val, format(y_predict_val, '0.2f'),format(loss_val,'0.2f'))
print('MSE=',loss_sum / len(x_data))
w_list.append(w)
mse_list.append(loss_sum / len(x_data))

具体实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import numpy as np
import matplotlib.pyplot as plt

# 准备数据集
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

# 定义模型函数
def forward(x):
return x * w;

# 定义损失函数
def loss(x, y):
y_predict = forward(x)
return (y - y_predict) ** 2

# 权重 w 的取值
w_list = []
# 对应权重的平均误差
mse_list = []

# 枚举权重,步长为 0.1
for w in np.arange(0.0, 4.1, 0.1): # 从 0.0 到 4.1
print("w=", w)
loss_sum = 0 # 损失和
for x_val, y_val in zip(x_data, y_data): # zip 函数传入可迭代对象
y_predict_val = forward(x_val) # 计算预测值
loss_val = loss(x_val, y_val) # 计算单样本损失
loss_sum += loss_val # 更新损失和
print('\t\t',x_val, y_val, format(y_predict_val, '0.2f'),format(loss_val,'0.2f'))
print('MSE=',loss_sum / len(x_data))
w_list.append(w)
mse_list.append(loss_sum / len(x_data))

得到每轮的预测结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
w= 0.0
1.0 2.0 0.00 4.00
2.0 4.0 0.00 16.00
3.0 6.0 0.00 36.00
MSE= 18.666666666666668
w= 0.1
1.0 2.0 0.10 3.61
2.0 4.0 0.20 14.44
3.0 6.0 0.30 32.49
MSE= 16.846666666666668
w= 0.2
1.0 2.0 0.20 3.24
2.0 4.0 0.40 12.96
3.0 6.0 0.60 29.16
MSE= 15.120000000000003
w= 0.30000000000000004
1.0 2.0 0.30 2.89
2.0 4.0 0.60 11.56
3.0 6.0 0.90 26.01
MSE= 13.486666666666665
w= 0.4
1.0 2.0 0.40 2.56
2.0 4.0 0.80 10.24
3.0 6.0 1.20 23.04
MSE= 11.946666666666667
w= 0.5
1.0 2.0 0.50 2.25
2.0 4.0 1.00 9.00
3.0 6.0 1.50 20.25
MSE= 10.5
w= 0.6000000000000001
1.0 2.0 0.60 1.96
2.0 4.0 1.20 7.84
3.0 6.0 1.80 17.64
MSE= 9.146666666666663
w= 0.7000000000000001
1.0 2.0 0.70 1.69
2.0 4.0 1.40 6.76
3.0 6.0 2.10 15.21
MSE= 7.886666666666666
w= 0.8
1.0 2.0 0.80 1.44
2.0 4.0 1.60 5.76
3.0 6.0 2.40 12.96
MSE= 6.719999999999999
w= 0.9
1.0 2.0 0.90 1.21
2.0 4.0 1.80 4.84
3.0 6.0 2.70 10.89
MSE= 5.646666666666666
w= 1.0
1.0 2.0 1.00 1.00
2.0 4.0 2.00 4.00
3.0 6.0 3.00 9.00
MSE= 4.666666666666667
w= 1.1
1.0 2.0 1.10 0.81
2.0 4.0 2.20 3.24
3.0 6.0 3.30 7.29
MSE= 3.779999999999999
w= 1.2000000000000002
1.0 2.0 1.20 0.64
2.0 4.0 2.40 2.56
3.0 6.0 3.60 5.76
MSE= 2.986666666666665
w= 1.3
1.0 2.0 1.30 0.49
2.0 4.0 2.60 1.96
3.0 6.0 3.90 4.41
MSE= 2.2866666666666657
w= 1.4000000000000001
1.0 2.0 1.40 0.36
2.0 4.0 2.80 1.44
3.0 6.0 4.20 3.24
MSE= 1.6799999999999995
w= 1.5
1.0 2.0 1.50 0.25
2.0 4.0 3.00 1.00
3.0 6.0 4.50 2.25
MSE= 1.1666666666666667
w= 1.6
1.0 2.0 1.60 0.16
2.0 4.0 3.20 0.64
3.0 6.0 4.80 1.44
MSE= 0.746666666666666
w= 1.7000000000000002
1.0 2.0 1.70 0.09
2.0 4.0 3.40 0.36
3.0 6.0 5.10 0.81
MSE= 0.4199999999999995
w= 1.8
1.0 2.0 1.80 0.04
2.0 4.0 3.60 0.16
3.0 6.0 5.40 0.36
MSE= 0.1866666666666665
w= 1.9000000000000001
1.0 2.0 1.90 0.01
2.0 4.0 3.80 0.04
3.0 6.0 5.70 0.09
MSE= 0.046666666666666586
w= 2.0
1.0 2.0 2.00 0.00
2.0 4.0 4.00 0.00
3.0 6.0 6.00 0.00
MSE= 0.0
w= 2.1
1.0 2.0 2.10 0.01
2.0 4.0 4.20 0.04
3.0 6.0 6.30 0.09
MSE= 0.046666666666666835
w= 2.2
1.0 2.0 2.20 0.04
2.0 4.0 4.40 0.16
3.0 6.0 6.60 0.36
MSE= 0.18666666666666698
w= 2.3000000000000003
1.0 2.0 2.30 0.09
2.0 4.0 4.60 0.36
3.0 6.0 6.90 0.81
MSE= 0.42000000000000054
w= 2.4000000000000004
1.0 2.0 2.40 0.16
2.0 4.0 4.80 0.64
3.0 6.0 7.20 1.44
MSE= 0.7466666666666679
w= 2.5
1.0 2.0 2.50 0.25
2.0 4.0 5.00 1.00
3.0 6.0 7.50 2.25
MSE= 1.1666666666666667
w= 2.6
1.0 2.0 2.60 0.36
2.0 4.0 5.20 1.44
3.0 6.0 7.80 3.24
MSE= 1.6800000000000008
w= 2.7
1.0 2.0 2.70 0.49
2.0 4.0 5.40 1.96
3.0 6.0 8.10 4.41
MSE= 2.2866666666666693
w= 2.8000000000000003
1.0 2.0 2.80 0.64
2.0 4.0 5.60 2.56
3.0 6.0 8.40 5.76
MSE= 2.986666666666668
w= 2.9000000000000004
1.0 2.0 2.90 0.81
2.0 4.0 5.80 3.24
3.0 6.0 8.70 7.29
MSE= 3.780000000000003
w= 3.0
1.0 2.0 3.00 1.00
2.0 4.0 6.00 4.00
3.0 6.0 9.00 9.00
MSE= 4.666666666666667
w= 3.1
1.0 2.0 3.10 1.21
2.0 4.0 6.20 4.84
3.0 6.0 9.30 10.89
MSE= 5.646666666666668
w= 3.2
1.0 2.0 3.20 1.44
2.0 4.0 6.40 5.76
3.0 6.0 9.60 12.96
MSE= 6.720000000000003
w= 3.3000000000000003
1.0 2.0 3.30 1.69
2.0 4.0 6.60 6.76
3.0 6.0 9.90 15.21
MSE= 7.886666666666668
w= 3.4000000000000004
1.0 2.0 3.40 1.96
2.0 4.0 6.80 7.84
3.0 6.0 10.20 17.64
MSE= 9.14666666666667
w= 3.5
1.0 2.0 3.50 2.25
2.0 4.0 7.00 9.00
3.0 6.0 10.50 20.25
MSE= 10.5
w= 3.6
1.0 2.0 3.60 2.56
2.0 4.0 7.20 10.24
3.0 6.0 10.80 23.04
MSE= 11.94666666666667
w= 3.7
1.0 2.0 3.70 2.89
2.0 4.0 7.40 11.56
3.0 6.0 11.10 26.01
MSE= 13.486666666666673
w= 3.8000000000000003
1.0 2.0 3.80 3.24
2.0 4.0 7.60 12.96
3.0 6.0 11.40 29.16
MSE= 15.120000000000005
w= 3.9000000000000004
1.0 2.0 3.90 3.61
2.0 4.0 7.80 14.44
3.0 6.0 11.70 32.49
MSE= 16.84666666666667
w= 4.0
1.0 2.0 4.00 4.00
2.0 4.0 8.00 16.00
3.0 6.0 12.00 36.00
MSE= 18.666666666666668

画出权重与平均损失的关系图

1
2
3
4
5
# 绘图(权重与平均损失的关系)
plt.plot(w_list, mse_list)
plt.ylabel('Loss')
plt.xlabel('W')
plt.show()

由上图可知,但 $w = 2.0$ 时损失最小,该点也是损失函数图像的最小值。