前言:
反向传播的目的是计算成本函数C对网络中任意w或b的偏导数。一旦我们有了这些偏导数,我们将通过一些常数 α的乘积和该数量相对于成本函数的偏导数来更服务器之家络中的权重和偏差。这是流行的梯度下降算法。而偏导数给出了最大上升的方向。因此,关于反向传播算法,我们继续查看下文。
我们向相反的方向迈出了一小步——最大下降的方向,也就是将我们带到成本函数的局部最小值的方向
如题:
意思是利用这个二次模型来预测数据,减小损失函数(MSE)的值。
代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
|
import torch import matplotlib.pyplot as plt import os os.environ[ "KMP_DUPLICATE_LIB_OK" ] = "TRUE" # 数据集 x_data = [ 1.0 , 2.0 , 3.0 ] y_data = [ 2.0 , 4.0 , 6.0 ] # 权重参数初始值均为1 w = torch.tensor([ 1.0 , 1.0 , 1.0 ]) w.requires_grad = True # 需要计算梯度 # 前向传播 def forward(x): return w[ 0 ] * (x * * 2 ) + w[ 1 ] * x + w[ 2 ] # 计算损失 def loss(x,y): y_pred = forward(x) return (y_pred - y) * * 2 # 训练模块 print ( 'predict (before tranining) ' , 4 , forward( 4 ).item()) epoch_list = [] w_list = [] loss_list = [] for epoch in range ( 1000 ): for x,y in zip (x_data,y_data): l = loss(x,y) l.backward() # 后向传播 print ( '\tgrad: ' ,x,y,w.grad.data) w.data = w.data - 0.01 * w.grad.data # 梯度下降 w.grad.data.zero_() # 梯度清零操作 print ( 'progress: ' ,epoch,l.item()) epoch_list.append(epoch) w_list.append(w.data) loss_list.append(l.item()) print ( 'predict (after tranining) ' , 4 , forward( 4 ).item()) # 绘图 plt.plot(epoch_list,loss_list, 'b' ) plt.xlabel( 'Epoch' ) plt.ylabel( 'Loss' ) plt.grid() plt.show() |
结果如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
|
predict (before tranining) 4 21.0 grad: 1.0 2.0 tensor([2., 2., 2.]) grad: 2.0 4.0 tensor([22.8800, 11.4400, 5.7200]) grad: 3.0 6.0 tensor([77.0472, 25.6824, 8.5608]) progress: 0 18.321826934814453 grad: 1.0 2.0 tensor([-1.1466, -1.1466, -1.1466]) grad: 2.0 4.0 tensor([-15.5367, -7.7683, -3.8842]) grad: 3.0 6.0 tensor([-30.4322, -10.1441, -3.3814]) progress: 1 2.858394145965576 grad: 1.0 2.0 tensor([0.3451, 0.3451, 0.3451]) grad: 2.0 4.0 tensor([2.4273, 1.2137, 0.6068]) grad: 3.0 6.0 tensor([19.4499, 6.4833, 2.1611]) progress: 2 1.1675907373428345 grad: 1.0 2.0 tensor([-0.3224, -0.3224, -0.3224]) grad: 2.0 4.0 tensor([-5.8458, -2.9229, -1.4614]) grad: 3.0 6.0 tensor([-3.8829, -1.2943, -0.4314]) progress: 3 0.04653334245085716 grad: 1.0 2.0 tensor([0.0137, 0.0137, 0.0137]) grad: 2.0 4.0 tensor([-1.9141, -0.9570, -0.4785]) grad: 3.0 6.0 tensor([6.8557, 2.2852, 0.7617]) progress: 4 0.14506366848945618 grad: 1.0 2.0 tensor([-0.1182, -0.1182, -0.1182]) grad: 2.0 4.0 tensor([-3.6644, -1.8322, -0.9161]) grad: 3.0 6.0 tensor([1.7455, 0.5818, 0.1939]) progress: 5 0.009403289295732975 grad: 1.0 2.0 tensor([-0.0333, -0.0333, -0.0333]) grad: 2.0 4.0 tensor([-2.7739, -1.3869, -0.6935]) grad: 3.0 6.0 tensor([4.0140, 1.3380, 0.4460]) progress: 6 0.04972923547029495 grad: 1.0 2.0 tensor([-0.0501, -0.0501, -0.0501]) grad: 2.0 4.0 tensor([-3.1150, -1.5575, -0.7788]) grad: 3.0 6.0 tensor([2.8534, 0.9511, 0.3170]) progress: 7 0.025129113346338272 grad: 1.0 2.0 tensor([-0.0205, -0.0205, -0.0205]) grad: 2.0 4.0 tensor([-2.8858, -1.4429, -0.7215]) grad: 3.0 6.0 tensor([3.2924, 1.0975, 0.3658]) progress: 8 0.03345605731010437 grad: 1.0 2.0 tensor([-0.0134, -0.0134, -0.0134]) grad: 2.0 4.0 tensor([-2.9247, -1.4623, -0.7312]) grad: 3.0 6.0 tensor([2.9909, 0.9970, 0.3323]) progress: 9 0.027609655633568764 grad: 1.0 2.0 tensor([0.0033, 0.0033, 0.0033]) grad: 2.0 4.0 tensor([-2.8414, -1.4207, -0.7103]) grad: 3.0 6.0 tensor([3.0377, 1.0126, 0.3375]) progress: 10 0.02848036028444767 grad: 1.0 2.0 tensor([0.0148, 0.0148, 0.0148]) grad: 2.0 4.0 tensor([-2.8174, -1.4087, -0.7043]) grad: 3.0 6.0 tensor([2.9260, 0.9753, 0.3251]) progress: 11 0.02642466314136982 grad: 1.0 2.0 tensor([0.0280, 0.0280, 0.0280]) grad: 2.0 4.0 tensor([-2.7682, -1.3841, -0.6920]) grad: 3.0 6.0 tensor([2.8915, 0.9638, 0.3213]) progress: 12 0.025804826989769936 grad: 1.0 2.0 tensor([0.0397, 0.0397, 0.0397]) grad: 2.0 4.0 tensor([-2.7330, -1.3665, -0.6832]) grad: 3.0 6.0 tensor([2.8243, 0.9414, 0.3138]) progress: 13 0.02462013065814972 grad: 1.0 2.0 tensor([0.0514, 0.0514, 0.0514]) grad: 2.0 4.0 tensor([-2.6934, -1.3467, -0.6734]) grad: 3.0 6.0 tensor([2.7756, 0.9252, 0.3084]) progress: 14 0.023777369409799576 grad: 1.0 2.0 tensor([0.0624, 0.0624, 0.0624]) grad: 2.0 4.0 tensor([-2.6580, -1.3290, -0.6645]) grad: 3.0 6.0 tensor([2.7213, 0.9071, 0.3024]) progress: 15 0.0228563379496336 grad: 1.0 2.0 tensor([0.0731, 0.0731, 0.0731]) grad: 2.0 4.0 tensor([-2.6227, -1.3113, -0.6557]) grad: 3.0 6.0 tensor([2.6725, 0.8908, 0.2969]) progress: 16 0.022044027224183083 grad: 1.0 2.0 tensor([0.0833, 0.0833, 0.0833]) grad: 2.0 4.0 tensor([-2.5893, -1.2946, -0.6473]) grad: 3.0 6.0 tensor([2.6240, 0.8747, 0.2916]) progress: 17 0.02125072106719017 grad: 1.0 2.0 tensor([0.0931, 0.0931, 0.0931]) grad: 2.0 4.0 tensor([-2.5568, -1.2784, -0.6392]) grad: 3.0 6.0 tensor([2.5780, 0.8593, 0.2864]) progress: 18 0.020513182505965233 grad: 1.0 2.0 tensor([0.1025, 0.1025, 0.1025]) grad: 2.0 4.0 tensor([-2.5258, -1.2629, -0.6314]) grad: 3.0 6.0 tensor([2.5335, 0.8445, 0.2815]) progress: 19 0.019810274243354797 grad: 1.0 2.0 tensor([0.1116, 0.1116, 0.1116]) grad: 2.0 4.0 tensor([-2.4958, -1.2479, -0.6239]) grad: 3.0 6.0 tensor([2.4908, 0.8303, 0.2768]) progress: 20 0.019148115068674088 grad: 1.0 2.0 tensor([0.1203, 0.1203, 0.1203]) grad: 2.0 4.0 tensor([-2.4669, -1.2335, -0.6167]) grad: 3.0 6.0 tensor([2.4496, 0.8165, 0.2722]) progress: 21 0.018520694226026535 grad: 1.0 2.0 tensor([0.1286, 0.1286, 0.1286]) grad: 2.0 4.0 tensor([-2.4392, -1.2196, -0.6098]) grad: 3.0 6.0 tensor([2.4101, 0.8034, 0.2678]) progress: 22 0.017927465960383415 grad: 1.0 2.0 tensor([0.1367, 0.1367, 0.1367]) grad: 2.0 4.0 tensor([-2.4124, -1.2062, -0.6031]) grad: 3.0 6.0 tensor([2.3720, 0.7907, 0.2636]) progress: 23 0.01736525259912014 grad: 1.0 2.0 tensor([0.1444, 0.1444, 0.1444]) grad: 2.0 4.0 tensor([-2.3867, -1.1933, -0.5967]) grad: 3.0 6.0 tensor([2.3354, 0.7785, 0.2595]) progress: 24 0.016833148896694183 grad: 1.0 2.0 tensor([0.1518, 0.1518, 0.1518]) grad: 2.0 4.0 tensor([-2.3619, -1.1810, -0.5905]) grad: 3.0 6.0 tensor([2.3001, 0.7667, 0.2556]) progress: 25 0.01632905937731266 grad: 1.0 2.0 tensor([0.1589, 0.1589, 0.1589]) grad: 2.0 4.0 tensor([-2.3380, -1.1690, -0.5845]) grad: 3.0 6.0 tensor([2.2662, 0.7554, 0.2518]) progress: 26 0.01585075818002224 grad: 1.0 2.0 tensor([0.1657, 0.1657, 0.1657]) grad: 2.0 4.0 tensor([-2.3151, -1.1575, -0.5788]) grad: 3.0 6.0 tensor([2.2336, 0.7445, 0.2482]) progress: 27 0.015397666022181511 grad: 1.0 2.0 tensor([0.1723, 0.1723, 0.1723]) grad: 2.0 4.0 tensor([-2.2929, -1.1465, -0.5732]) grad: 3.0 6.0 tensor([2.2022, 0.7341, 0.2447]) progress: 28 0.014967591501772404 grad: 1.0 2.0 tensor([0.1786, 0.1786, 0.1786]) grad: 2.0 4.0 tensor([-2.2716, -1.1358, -0.5679]) grad: 3.0 6.0 tensor([2.1719, 0.7240, 0.2413]) progress: 29 0.014559715054929256 grad: 1.0 2.0 tensor([0.1846, 0.1846, 0.1846]) grad: 2.0 4.0 tensor([-2.2511, -1.1255, -0.5628]) grad: 3.0 6.0 tensor([2.1429, 0.7143, 0.2381]) progress: 30 0.014172340743243694 grad: 1.0 2.0 tensor([0.1904, 0.1904, 0.1904]) grad: 2.0 4.0 tensor([-2.2313, -1.1157, -0.5578]) grad: 3.0 6.0 tensor([2.1149, 0.7050, 0.2350]) progress: 31 0.013804304413497448 grad: 1.0 2.0 tensor([0.1960, 0.1960, 0.1960]) grad: 2.0 4.0 tensor([-2.2123, -1.1061, -0.5531]) grad: 3.0 6.0 tensor([2.0879, 0.6960, 0.2320]) progress: 32 0.013455045409500599 grad: 1.0 2.0 tensor([0.2014, 0.2014, 0.2014]) grad: 2.0 4.0 tensor([-2.1939, -1.0970, -0.5485]) grad: 3.0 6.0 tensor([2.0620, 0.6873, 0.2291]) progress: 33 0.013122711330652237 grad: 1.0 2.0 tensor([0.2065, 0.2065, 0.2065]) grad: 2.0 4.0 tensor([-2.1763, -1.0881, -0.5441]) grad: 3.0 6.0 tensor([2.0370, 0.6790, 0.2263]) progress: 34 0.01280694268643856 grad: 1.0 2.0 tensor([0.2114, 0.2114, 0.2114]) grad: 2.0 4.0 tensor([-2.1592, -1.0796, -0.5398]) grad: 3.0 6.0 tensor([2.0130, 0.6710, 0.2237]) progress: 35 0.012506747618317604 grad: 1.0 2.0 tensor([0.2162, 0.2162, 0.2162]) grad: 2.0 4.0 tensor([-2.1428, -1.0714, -0.5357]) grad: 3.0 6.0 tensor([1.9899, 0.6633, 0.2211]) progress: 36 0.012220758944749832 grad: 1.0 2.0 tensor([0.2207, 0.2207, 0.2207]) grad: 2.0 4.0 tensor([-2.1270, -1.0635, -0.5317]) grad: 3.0 6.0 tensor([1.9676, 0.6559, 0.2186]) progress: 37 0.01194891706109047 grad: 1.0 2.0 tensor([0.2251, 0.2251, 0.2251]) grad: 2.0 4.0 tensor([-2.1118, -1.0559, -0.5279]) grad: 3.0 6.0 tensor([1.9462, 0.6487, 0.2162]) progress: 38 0.011689926497638226 grad: 1.0 2.0 tensor([0.2292, 0.2292, 0.2292]) grad: 2.0 4.0 tensor([-2.0971, -1.0485, -0.5243]) grad: 3.0 6.0 tensor([1.9255, 0.6418, 0.2139]) progress: 39 0.01144315768033266 grad: 1.0 2.0 tensor([0.2333, 0.2333, 0.2333]) grad: 2.0 4.0 tensor([-2.0829, -1.0414, -0.5207]) grad: 3.0 6.0 tensor([1.9057, 0.6352, 0.2117]) progress: 40 0.011208509095013142 grad: 1.0 2.0 tensor([0.2371, 0.2371, 0.2371]) grad: 2.0 4.0 tensor([-2.0693, -1.0346, -0.5173]) grad: 3.0 6.0 tensor([1.8865, 0.6288, 0.2096]) progress: 41 0.0109840864315629 grad: 1.0 2.0 tensor([0.2408, 0.2408, 0.2408]) grad: 2.0 4.0 tensor([-2.0561, -1.0280, -0.5140]) grad: 3.0 6.0 tensor([1.8681, 0.6227, 0.2076]) progress: 42 0.010770938359200954 grad: 1.0 2.0 tensor([0.2444, 0.2444, 0.2444]) grad: 2.0 4.0 tensor([-2.0434, -1.0217, -0.5108]) grad: 3.0 6.0 tensor([1.8503, 0.6168, 0.2056]) progress: 43 0.010566935874521732 grad: 1.0 2.0 tensor([0.2478, 0.2478, 0.2478]) grad: 2.0 4.0 tensor([-2.0312, -1.0156, -0.5078]) grad: 3.0 6.0 tensor([1.8332, 0.6111, 0.2037]) progress: 44 0.010372749529778957 grad: 1.0 2.0 tensor([0.2510, 0.2510, 0.2510]) grad: 2.0 4.0 tensor([-2.0194, -1.0097, -0.5048]) grad: 3.0 6.0 tensor([1.8168, 0.6056, 0.2019]) progress: 45 0.010187389329075813 grad: 1.0 2.0 tensor([0.2542, 0.2542, 0.2542]) grad: 2.0 4.0 tensor([-2.0080, -1.0040, -0.5020]) grad: 3.0 6.0 tensor([1.8009, 0.6003, 0.2001]) progress: 46 0.010010283440351486 grad: 1.0 2.0 tensor([0.2572, 0.2572, 0.2572]) grad: 2.0 4.0 tensor([-1.9970, -0.9985, -0.4992]) grad: 3.0 6.0 tensor([1.7856, 0.5952, 0.1984]) progress: 47 0.00984097272157669 grad: 1.0 2.0 tensor([0.2600, 0.2600, 0.2600]) grad: 2.0 4.0 tensor([-1.9864, -0.9932, -0.4966]) grad: 3.0 6.0 tensor([1.7709, 0.5903, 0.1968]) progress: 48 0.009679674170911312 grad: 1.0 2.0 tensor([0.2628, 0.2628, 0.2628]) grad: 2.0 4.0 tensor([-1.9762, -0.9881, -0.4940]) grad: 3.0 6.0 tensor([1.7568, 0.5856, 0.1952]) progress: 49 0.009525291621685028 grad: 1.0 2.0 tensor([0.2655, 0.2655, 0.2655]) grad: 2.0 4.0 tensor([-1.9663, -0.9832, -0.4916]) grad: 3.0 6.0 tensor([1.7431, 0.5810, 0.1937]) progress: 50 0.00937769003212452 grad: 1.0 2.0 tensor([0.2680, 0.2680, 0.2680]) grad: 2.0 4.0 tensor([-1.9568, -0.9784, -0.4892]) grad: 3.0 6.0 tensor([1.7299, 0.5766, 0.1922]) progress: 51 0.009236648678779602 grad: 1.0 2.0 tensor([0.2704, 0.2704, 0.2704]) grad: 2.0 4.0 tensor([-1.9476, -0.9738, -0.4869]) grad: 3.0 6.0 tensor([1.7172, 0.5724, 0.1908]) progress: 52 0.00910158734768629 grad: 1.0 2.0 tensor([0.2728, 0.2728, 0.2728]) grad: 2.0 4.0 tensor([-1.9387, -0.9694, -0.4847]) grad: 3.0 6.0 tensor([1.7050, 0.5683, 0.1894]) progress: 53 0.00897257961332798 grad: 1.0 2.0 tensor([0.2750, 0.2750, 0.2750]) grad: 2.0 4.0 tensor([-1.9301, -0.9651, -0.4825]) grad: 3.0 6.0 tensor([1.6932, 0.5644, 0.1881]) progress: 54 0.008848887868225574 grad: 1.0 2.0 tensor([0.2771, 0.2771, 0.2771]) grad: 2.0 4.0 tensor([-1.9219, -0.9609, -0.4805]) grad: 3.0 6.0 tensor([1.6819, 0.5606, 0.1869]) progress: 55 0.008730598725378513 grad: 1.0 2.0 tensor([0.2792, 0.2792, 0.2792]) grad: 2.0 4.0 tensor([-1.9139, -0.9569, -0.4785]) grad: 3.0 6.0 tensor([1.6709, 0.5570, 0.1857]) progress: 56 0.00861735362559557 grad: 1.0 2.0 tensor([0.2811, 0.2811, 0.2811]) grad: 2.0 4.0 tensor([-1.9062, -0.9531, -0.4765]) grad: 3.0 6.0 tensor([1.6604, 0.5535, 0.1845]) progress: 57 0.008508718572556973 grad: 1.0 2.0 tensor([0.2830, 0.2830, 0.2830]) grad: 2.0 4.0 tensor([-1.8987, -0.9493, -0.4747]) grad: 3.0 6.0 tensor([1.6502, 0.5501, 0.1834]) progress: 58 0.008404706604778767 grad: 1.0 2.0 tensor([0.2848, 0.2848, 0.2848]) grad: 2.0 4.0 tensor([-1.8915, -0.9457, -0.4729]) grad: 3.0 6.0 tensor([1.6404, 0.5468, 0.1823]) progress: 59 0.008305158466100693 grad: 1.0 2.0 tensor([0.2865, 0.2865, 0.2865]) grad: 2.0 4.0 tensor([-1.8845, -0.9423, -0.4711]) grad: 3.0 6.0 tensor([1.6309, 0.5436, 0.1812]) progress: 60 0.00820931326597929 grad: 1.0 2.0 tensor([0.2882, 0.2882, 0.2882]) grad: 2.0 4.0 tensor([-1.8778, -0.9389, -0.4694]) grad: 3.0 6.0 tensor([1.6218, 0.5406, 0.1802]) progress: 61 0.008117804303765297 grad: 1.0 2.0 tensor([0.2898, 0.2898, 0.2898]) grad: 2.0 4.0 tensor([-1.8713, -0.9356, -0.4678]) grad: 3.0 6.0 tensor([1.6130, 0.5377, 0.1792]) progress: 62 0.008029798977077007 grad: 1.0 2.0 tensor([0.2913, 0.2913, 0.2913]) grad: 2.0 4.0 tensor([-1.8650, -0.9325, -0.4662]) grad: 3.0 6.0 tensor([1.6045, 0.5348, 0.1783]) progress: 63 0.007945418357849121 grad: 1.0 2.0 tensor([0.2927, 0.2927, 0.2927]) grad: 2.0 4.0 tensor([-1.8589, -0.9294, -0.4647]) grad: 3.0 6.0 tensor([1.5962, 0.5321, 0.1774]) progress: 64 0.007864190265536308 grad: 1.0 2.0 tensor([0.2941, 0.2941, 0.2941]) grad: 2.0 4.0 tensor([-1.8530, -0.9265, -0.4632]) grad: 3.0 6.0 tensor([1.5884, 0.5295, 0.1765]) progress: 65 0.007786744274199009 grad: 1.0 2.0 tensor([0.2954, 0.2954, 0.2954]) grad: 2.0 4.0 tensor([-1.8473, -0.9236, -0.4618]) grad: 3.0 6.0 tensor([1.5807, 0.5269, 0.1756]) progress: 66 0.007711691781878471 grad: 1.0 2.0 tensor([0.2967, 0.2967, 0.2967]) grad: 2.0 4.0 tensor([-1.8417, -0.9209, -0.4604]) grad: 3.0 6.0 tensor([1.5733, 0.5244, 0.1748]) progress: 67 0.007640169933438301 grad: 1.0 2.0 tensor([0.2979, 0.2979, 0.2979]) grad: 2.0 4.0 tensor([-1.8364, -0.9182, -0.4591]) grad: 3.0 6.0 tensor([1.5662, 0.5221, 0.1740]) progress: 68 0.007570972666144371 grad: 1.0 2.0 tensor([0.2991, 0.2991, 0.2991]) grad: 2.0 4.0 tensor([-1.8312, -0.9156, -0.4578]) grad: 3.0 6.0 tensor([1.5593, 0.5198, 0.1733]) progress: 69 0.007504733745008707 grad: 1.0 2.0 tensor([0.3002, 0.3002, 0.3002]) grad: 2.0 4.0 tensor([-1.8262, -0.9131, -0.4566]) grad: 3.0 6.0 tensor([1.5527, 0.5176, 0.1725]) progress: 70 0.007440924644470215 grad: 1.0 2.0 tensor([0.3012, 0.3012, 0.3012]) grad: 2.0 4.0 tensor([-1.8214, -0.9107, -0.4553]) grad: 3.0 6.0 tensor([1.5463, 0.5154, 0.1718]) progress: 71 0.007379599846899509 grad: 1.0 2.0 tensor([0.3022, 0.3022, 0.3022]) grad: 2.0 4.0 tensor([-1.8167, -0.9083, -0.4542]) grad: 3.0 6.0 tensor([1.5401, 0.5134, 0.1711]) progress: 72 0.007320486940443516 grad: 1.0 2.0 tensor([0.3032, 0.3032, 0.3032]) grad: 2.0 4.0 tensor([-1.8121, -0.9060, -0.4530]) grad: 3.0 6.0 tensor([1.5341, 0.5114, 0.1705]) progress: 73 0.007263725157827139 grad: 1.0 2.0 tensor([0.3041, 0.3041, 0.3041]) grad: 2.0 4.0 tensor([-1.8077, -0.9038, -0.4519]) grad: 3.0 6.0 tensor([1.5283, 0.5094, 0.1698]) progress: 74 0.007209045812487602 grad: 1.0 2.0 tensor([0.3050, 0.3050, 0.3050]) grad: 2.0 4.0 tensor([-1.8034, -0.9017, -0.4508]) grad: 3.0 6.0 tensor([1.5227, 0.5076, 0.1692]) progress: 75 0.007156429346650839 grad: 1.0 2.0 tensor([0.3058, 0.3058, 0.3058]) grad: 2.0 4.0 tensor([-1.7992, -0.8996, -0.4498]) grad: 3.0 6.0 tensor([1.5173, 0.5058, 0.1686]) progress: 76 0.007105532102286816 grad: 1.0 2.0 tensor([0.3066, 0.3066, 0.3066]) grad: 2.0 4.0 tensor([-1.7952, -0.8976, -0.4488]) grad: 3.0 6.0 tensor([1.5121, 0.5040, 0.1680]) progress: 77 0.00705681974068284 grad: 1.0 2.0 tensor([0.3073, 0.3073, 0.3073]) grad: 2.0 4.0 tensor([-1.7913, -0.8956, -0.4478]) grad: 3.0 6.0 tensor([1.5070, 0.5023, 0.1674]) progress: 78 0.007009552326053381 grad: 1.0 2.0 tensor([0.3081, 0.3081, 0.3081]) grad: 2.0 4.0 tensor([-1.7875, -0.8937, -0.4469]) grad: 3.0 6.0 tensor([1.5021, 0.5007, 0.1669]) progress: 79 0.006964194122701883 grad: 1.0 2.0 tensor([0.3087, 0.3087, 0.3087]) grad: 2.0 4.0 tensor([-1.7838, -0.8919, -0.4459]) grad: 3.0 6.0 tensor([1.4974, 0.4991, 0.1664]) progress: 80 0.006920332089066505 grad: 1.0 2.0 tensor([0.3094, 0.3094, 0.3094]) grad: 2.0 4.0 tensor([-1.7802, -0.8901, -0.4450]) grad: 3.0 6.0 tensor([1.4928, 0.4976, 0.1659]) progress: 81 0.006878111511468887 grad: 1.0 2.0 tensor([0.3100, 0.3100, 0.3100]) grad: 2.0 4.0 tensor([-1.7767, -0.8883, -0.4442]) grad: 3.0 6.0 tensor([1.4884, 0.4961, 0.1654]) progress: 82 0.006837360095232725 grad: 1.0 2.0 tensor([0.3106, 0.3106, 0.3106]) grad: 2.0 4.0 tensor([-1.7733, -0.8867, -0.4433]) grad: 3.0 6.0 tensor([1.4841, 0.4947, 0.1649]) progress: 83 0.006797831039875746 grad: 1.0 2.0 tensor([0.3111, 0.3111, 0.3111]) grad: 2.0 4.0 tensor([-1.7700, -0.8850, -0.4425]) grad: 3.0 6.0 tensor([1.4800, 0.4933, 0.1644]) progress: 84 0.006760062649846077 grad: 1.0 2.0 tensor([0.3117, 0.3117, 0.3117]) grad: 2.0 4.0 tensor([-1.7668, -0.8834, -0.4417]) grad: 3.0 6.0 tensor([1.4759, 0.4920, 0.1640]) progress: 85 0.006723103579133749 grad: 1.0 2.0 tensor([0.3122, 0.3122, 0.3122]) grad: 2.0 4.0 tensor([-1.7637, -0.8818, -0.4409]) grad: 3.0 6.0 tensor([1.4720, 0.4907, 0.1636]) progress: 86 0.00668772729113698 grad: 1.0 2.0 tensor([0.3127, 0.3127, 0.3127]) grad: 2.0 4.0 tensor([-1.7607, -0.8803, -0.4402]) grad: 3.0 6.0 tensor([1.4682, 0.4894, 0.1631]) progress: 87 0.006653300020843744 grad: 1.0 2.0 tensor([0.3131, 0.3131, 0.3131]) grad: 2.0 4.0 tensor([-1.7577, -0.8789, -0.4394]) grad: 3.0 6.0 tensor([1.4646, 0.4882, 0.1627]) progress: 88 0.0066203586757183075 grad: 1.0 2.0 tensor([0.3135, 0.3135, 0.3135]) grad: 2.0 4.0 tensor([-1.7548, -0.8774, -0.4387]) grad: 3.0 6.0 tensor([1.4610, 0.4870, 0.1623]) progress: 89 0.0065881176851689816 grad: 1.0 2.0 tensor([0.3139, 0.3139, 0.3139]) grad: 2.0 4.0 tensor([-1.7520, -0.8760, -0.4380]) grad: 3.0 6.0 tensor([1.4576, 0.4859, 0.1620]) progress: 90 0.0065572685562074184 grad: 1.0 2.0 tensor([0.3143, 0.3143, 0.3143]) grad: 2.0 4.0 tensor([-1.7493, -0.8747, -0.4373]) grad: 3.0 6.0 tensor([1.4542, 0.4847, 0.1616]) progress: 91 0.0065271081402897835 grad: 1.0 2.0 tensor([0.3147, 0.3147, 0.3147]) grad: 2.0 4.0 tensor([-1.7466, -0.8733, -0.4367]) grad: 3.0 6.0 tensor([1.4510, 0.4837, 0.1612]) progress: 92 0.00649801641702652 grad: 1.0 2.0 tensor([0.3150, 0.3150, 0.3150]) grad: 2.0 4.0 tensor([-1.7441, -0.8720, -0.4360]) grad: 3.0 6.0 tensor([1.4478, 0.4826, 0.1609]) progress: 93 0.0064699104987084866 grad: 1.0 2.0 tensor([0.3153, 0.3153, 0.3153]) grad: 2.0 4.0 tensor([-1.7415, -0.8708, -0.4354]) grad: 3.0 6.0 tensor([1.4448, 0.4816, 0.1605]) progress: 94 0.006442630663514137 grad: 1.0 2.0 tensor([0.3156, 0.3156, 0.3156]) grad: 2.0 4.0 tensor([-1.7391, -0.8695, -0.4348]) grad: 3.0 6.0 tensor([1.4418, 0.4806, 0.1602]) progress: 95 0.006416172254830599 grad: 1.0 2.0 tensor([0.3159, 0.3159, 0.3159]) grad: 2.0 4.0 tensor([-1.7366, -0.8683, -0.4342]) grad: 3.0 6.0 tensor([1.4389, 0.4796, 0.1599]) progress: 96 0.006390606984496117 grad: 1.0 2.0 tensor([0.3161, 0.3161, 0.3161]) grad: 2.0 4.0 tensor([-1.7343, -0.8671, -0.4336]) grad: 3.0 6.0 tensor([1.4361, 0.4787, 0.1596]) progress: 97 0.0063657015562057495 grad: 1.0 2.0 tensor([0.3164, 0.3164, 0.3164]) grad: 2.0 4.0 tensor([-1.7320, -0.8660, -0.4330]) grad: 3.0 6.0 tensor([1.4334, 0.4778, 0.1593]) progress: 98 0.0063416799530386925 grad: 1.0 2.0 tensor([0.3166, 0.3166, 0.3166]) grad: 2.0 4.0 tensor([-1.7297, -0.8649, -0.4324]) grad: 3.0 6.0 tensor([1.4308, 0.4769, 0.1590]) progress: 99 0.00631808303296566 predict (after tranining) 4 8.544171333312988 |
损失值随着迭代次数的增加呈递减趋势,如下图所示:
可以看出:x=4时的预测值约为8.5,与真实值8有所差距,可通过提高迭代次数或者调整学习率、初始参数等方法来减小差距。
参考文献:
到此这篇关于PyTorch反向传播的文章就介绍到这了,更多相关PyTorch反向传播内容请搜索服务器之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持服务器之家!
原文链接:https://blog.csdn.net/weixin_43821559/article/details/123296140