public void LinkのParameterがoptimizerで更新される() { var optimizer = new optimizers.SGD(lr: 0.001f); var link = new SimpleLink(); optimizer.Setup(link); var loss = MeanSquaredError.ForwardStatic( link.Forward(new Variable(Matrix <float> .Build.DenseOfArray(new float[, ] { { 1, 1, 1 } }).Transpose())), new Variable(Matrix <float> .Build.DenseOfArray(new float[, ] { { 1, 2, 3 } }).Transpose()) ); var before = link.constParam.Value; optimizer.ZeroGrads(); loss.Backward(); optimizer.Update(); var after = link.constParam.Value; Helper.AssertMatrixNotAlmostEqual(before, after, delta: 0); }
public void Iterationを回すと最適値になる() { var optimizer = new optimizers.SGD(lr: 0.05f); var link = new SimpleLink(); optimizer.Setup(link); var loss = MeanSquaredError.ForwardStatic( link.Forward(new Variable(Matrix <float> .Build.DenseOfArray(new float[, ] { { 1, 1, 1 } }).Transpose())), new Variable(Matrix <float> .Build.DenseOfArray(new float[, ] { { 1, 2, 3 } }).Transpose()) ); Assert.Greater(loss.Value[0, 0], 0.1f); var converge = false; for (int i = 0; i < 100; i++) { var lossEach = MeanSquaredError.ForwardStatic( link.Forward(new Variable(Matrix <float> .Build.DenseOfArray(new float[, ] { { 1, 1, 1 } }).Transpose())), new Variable(Matrix <float> .Build.DenseOfArray(new float[, ] { { 1, 2, 3 } }).Transpose()) ); if (lossEach.Value[0, 0] < 0.1f) { converge = true; break; } optimizer.ZeroGrads(); lossEach.Backward(); optimizer.Update(); } Assert.True(converge); }
public void chainer_pythonと同じ値になる() { var chain = new VerySmallChain(); var optimizer = new chainer.optimizers.Adam(); var input = new Variable(builder.DenseOfArray(new float[, ] { { 4, 3, 2 } })); var target = new Variable(builder.DenseOfArray(new float[, ] { { 100 } })); optimizer.Setup(chain); Helper.AssertMatrixAlmostEqual(chain.fc._Params["W"].Value, builder.DenseOfArray(new float[, ] { { -1, 0, 1 } })); Helper.AssertMatrixAlmostEqual(chain.fc._Params["b"].Value, builder.DenseOfArray(new float[, ] { { 1 } })); var loss = MeanSquaredError.ForwardStatic( chain.Forward(input), target ); Helper.AssertMatrixAlmostEqual( loss.Value, builder.DenseOfArray(new float[, ] { { 10201 } }), delta: 0.01f ); optimizer.ZeroGrads(); loss.Backward(); optimizer.Update(); loss = MeanSquaredError.ForwardStatic( chain.Forward(input), target ); Helper.AssertMatrixAlmostEqual( loss.Value, builder.DenseOfArray(new float[, ] { { 10198.9794921875f } }), delta: 0.01f ); optimizer.ZeroGrads(); loss.Backward(); optimizer.Update(); loss = MeanSquaredError.ForwardStatic( chain.Forward(input), target ); Helper.AssertMatrixAlmostEqual( loss.Value, builder.DenseOfArray(new float[, ] { { 10196.9609375f } }), delta: 0.01f ); for (int i = 0; i < 100; i++) { loss = MeanSquaredError.ForwardStatic( chain.Forward(input), target ); optimizer.ZeroGrads(); loss.Backward(); optimizer.Update(); } loss = MeanSquaredError.ForwardStatic( chain.Forward(input), target ); Helper.AssertMatrixAlmostEqual( loss.Value, builder.DenseOfArray(new float[, ] { { 9996.3515625f } }), delta: 0.01f ); }