线性模型
一般流程
- 准备数据集(训练集,开发集,测试集)
- 选择模型(泛化能力,防止过拟合)
- 训练模型
- 测试模型
例子
学生每周学习时间与期末得分的关系
x(hours) |
y(points) |
1 |
2 |
2 |
4 |
3 |
6 |
4 |
? |
设计模型
观察数据分布可得应采用线性模型:
其中 $\hat y$ 为预测值,不妨简化一下模型为:
我们的目的就是得到一个尽可能好的 $w$ 值。使模型的预测值越接近真实值,因此我们需要一个衡量接近程度的指标 $loss$,可用绝对值或差的平方表示单g个样本预测的损失为(Training Loss):
这里使用差的平方,其中 $y$ 为真实值。
因此,对于多样本预测的平均损失函数为(Mean Square Error):
1 2 3 4 5 6 7 8
| def forward(x): return x * w;
def loss(x, y): y_predict = forward(x) return (y - y_predict) ** 2
|
过程模拟
由于不知道 $w$ 的具体值因此我们给它一个随机初始值,假设 $w = 3$
x(hours) |
y(points) |
y_predict |
loss |
1 |
2 |
3 |
1 |
2 |
4 |
6 |
4 |
3 |
6 |
9 |
9 |
|
|
|
MSE=14/3 |
可知本轮预测平均损失为 14/3
为找到最佳权重,可枚举权重值判断损失,损失最小为最佳
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| w_list = []
mse_list = []
for w in np.arange(0.0, 4.1, 0.1): print("w=", w) loss_sum = 0 for x_val, y_val in zip(x_data, y_data): y_predict_val = forward(x_val) loss_val = loss(x_val, y_val) loss_sum += loss_val print('\t\t',x_val, y_val, format(y_predict_val, '0.2f'),format(loss_val,'0.2f')) print('MSE=',loss_sum / len(x_data)) w_list.append(w) mse_list.append(loss_sum / len(x_data))
|
具体实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
| import numpy as np import matplotlib.pyplot as plt
x_data = [1.0, 2.0, 3.0] y_data = [2.0, 4.0, 6.0]
def forward(x): return x * w;
def loss(x, y): y_predict = forward(x) return (y - y_predict) ** 2
w_list = []
mse_list = []
for w in np.arange(0.0, 4.1, 0.1): print("w=", w) loss_sum = 0 for x_val, y_val in zip(x_data, y_data): y_predict_val = forward(x_val) loss_val = loss(x_val, y_val) loss_sum += loss_val print('\t\t',x_val, y_val, format(y_predict_val, '0.2f'),format(loss_val,'0.2f')) print('MSE=',loss_sum / len(x_data)) w_list.append(w) mse_list.append(loss_sum / len(x_data))
|
得到每轮的预测结果
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
| w= 0.0 1.0 2.0 0.00 4.00 2.0 4.0 0.00 16.00 3.0 6.0 0.00 36.00 MSE= 18.666666666666668 w= 0.1 1.0 2.0 0.10 3.61 2.0 4.0 0.20 14.44 3.0 6.0 0.30 32.49 MSE= 16.846666666666668 w= 0.2 1.0 2.0 0.20 3.24 2.0 4.0 0.40 12.96 3.0 6.0 0.60 29.16 MSE= 15.120000000000003 w= 0.30000000000000004 1.0 2.0 0.30 2.89 2.0 4.0 0.60 11.56 3.0 6.0 0.90 26.01 MSE= 13.486666666666665 w= 0.4 1.0 2.0 0.40 2.56 2.0 4.0 0.80 10.24 3.0 6.0 1.20 23.04 MSE= 11.946666666666667 w= 0.5 1.0 2.0 0.50 2.25 2.0 4.0 1.00 9.00 3.0 6.0 1.50 20.25 MSE= 10.5 w= 0.6000000000000001 1.0 2.0 0.60 1.96 2.0 4.0 1.20 7.84 3.0 6.0 1.80 17.64 MSE= 9.146666666666663 w= 0.7000000000000001 1.0 2.0 0.70 1.69 2.0 4.0 1.40 6.76 3.0 6.0 2.10 15.21 MSE= 7.886666666666666 w= 0.8 1.0 2.0 0.80 1.44 2.0 4.0 1.60 5.76 3.0 6.0 2.40 12.96 MSE= 6.719999999999999 w= 0.9 1.0 2.0 0.90 1.21 2.0 4.0 1.80 4.84 3.0 6.0 2.70 10.89 MSE= 5.646666666666666 w= 1.0 1.0 2.0 1.00 1.00 2.0 4.0 2.00 4.00 3.0 6.0 3.00 9.00 MSE= 4.666666666666667 w= 1.1 1.0 2.0 1.10 0.81 2.0 4.0 2.20 3.24 3.0 6.0 3.30 7.29 MSE= 3.779999999999999 w= 1.2000000000000002 1.0 2.0 1.20 0.64 2.0 4.0 2.40 2.56 3.0 6.0 3.60 5.76 MSE= 2.986666666666665 w= 1.3 1.0 2.0 1.30 0.49 2.0 4.0 2.60 1.96 3.0 6.0 3.90 4.41 MSE= 2.2866666666666657 w= 1.4000000000000001 1.0 2.0 1.40 0.36 2.0 4.0 2.80 1.44 3.0 6.0 4.20 3.24 MSE= 1.6799999999999995 w= 1.5 1.0 2.0 1.50 0.25 2.0 4.0 3.00 1.00 3.0 6.0 4.50 2.25 MSE= 1.1666666666666667 w= 1.6 1.0 2.0 1.60 0.16 2.0 4.0 3.20 0.64 3.0 6.0 4.80 1.44 MSE= 0.746666666666666 w= 1.7000000000000002 1.0 2.0 1.70 0.09 2.0 4.0 3.40 0.36 3.0 6.0 5.10 0.81 MSE= 0.4199999999999995 w= 1.8 1.0 2.0 1.80 0.04 2.0 4.0 3.60 0.16 3.0 6.0 5.40 0.36 MSE= 0.1866666666666665 w= 1.9000000000000001 1.0 2.0 1.90 0.01 2.0 4.0 3.80 0.04 3.0 6.0 5.70 0.09 MSE= 0.046666666666666586 w= 2.0 1.0 2.0 2.00 0.00 2.0 4.0 4.00 0.00 3.0 6.0 6.00 0.00 MSE= 0.0 w= 2.1 1.0 2.0 2.10 0.01 2.0 4.0 4.20 0.04 3.0 6.0 6.30 0.09 MSE= 0.046666666666666835 w= 2.2 1.0 2.0 2.20 0.04 2.0 4.0 4.40 0.16 3.0 6.0 6.60 0.36 MSE= 0.18666666666666698 w= 2.3000000000000003 1.0 2.0 2.30 0.09 2.0 4.0 4.60 0.36 3.0 6.0 6.90 0.81 MSE= 0.42000000000000054 w= 2.4000000000000004 1.0 2.0 2.40 0.16 2.0 4.0 4.80 0.64 3.0 6.0 7.20 1.44 MSE= 0.7466666666666679 w= 2.5 1.0 2.0 2.50 0.25 2.0 4.0 5.00 1.00 3.0 6.0 7.50 2.25 MSE= 1.1666666666666667 w= 2.6 1.0 2.0 2.60 0.36 2.0 4.0 5.20 1.44 3.0 6.0 7.80 3.24 MSE= 1.6800000000000008 w= 2.7 1.0 2.0 2.70 0.49 2.0 4.0 5.40 1.96 3.0 6.0 8.10 4.41 MSE= 2.2866666666666693 w= 2.8000000000000003 1.0 2.0 2.80 0.64 2.0 4.0 5.60 2.56 3.0 6.0 8.40 5.76 MSE= 2.986666666666668 w= 2.9000000000000004 1.0 2.0 2.90 0.81 2.0 4.0 5.80 3.24 3.0 6.0 8.70 7.29 MSE= 3.780000000000003 w= 3.0 1.0 2.0 3.00 1.00 2.0 4.0 6.00 4.00 3.0 6.0 9.00 9.00 MSE= 4.666666666666667 w= 3.1 1.0 2.0 3.10 1.21 2.0 4.0 6.20 4.84 3.0 6.0 9.30 10.89 MSE= 5.646666666666668 w= 3.2 1.0 2.0 3.20 1.44 2.0 4.0 6.40 5.76 3.0 6.0 9.60 12.96 MSE= 6.720000000000003 w= 3.3000000000000003 1.0 2.0 3.30 1.69 2.0 4.0 6.60 6.76 3.0 6.0 9.90 15.21 MSE= 7.886666666666668 w= 3.4000000000000004 1.0 2.0 3.40 1.96 2.0 4.0 6.80 7.84 3.0 6.0 10.20 17.64 MSE= 9.14666666666667 w= 3.5 1.0 2.0 3.50 2.25 2.0 4.0 7.00 9.00 3.0 6.0 10.50 20.25 MSE= 10.5 w= 3.6 1.0 2.0 3.60 2.56 2.0 4.0 7.20 10.24 3.0 6.0 10.80 23.04 MSE= 11.94666666666667 w= 3.7 1.0 2.0 3.70 2.89 2.0 4.0 7.40 11.56 3.0 6.0 11.10 26.01 MSE= 13.486666666666673 w= 3.8000000000000003 1.0 2.0 3.80 3.24 2.0 4.0 7.60 12.96 3.0 6.0 11.40 29.16 MSE= 15.120000000000005 w= 3.9000000000000004 1.0 2.0 3.90 3.61 2.0 4.0 7.80 14.44 3.0 6.0 11.70 32.49 MSE= 16.84666666666667 w= 4.0 1.0 2.0 4.00 4.00 2.0 4.0 8.00 16.00 3.0 6.0 12.00 36.00 MSE= 18.666666666666668
|
画出权重与平均损失的关系图
1 2 3 4 5
| plt.plot(w_list, mse_list) plt.ylabel('Loss') plt.xlabel('W') plt.show()
|
由上图可知,但 $w = 2.0$ 时损失最小,该点也是损失函数图像的最小值。