public MNISTNetwork(int inputSize, int hiddenSize, int outputSize) : base() { this.inputSize = inputSize; this.hiddenSize = hiddenSize; this.outputSize = outputSize; affine1 = new AffineLayer(inputSize, hiddenSize); relu = new ReLULayer(); affine2 = new AffineLayer(hiddenSize, outputSize); softmax = new SoftmaxLayer(); optimizer = new MomentumOptimizer(0.9f); }
public MNISTBatchNormalizationNetwork(int inputSize, int hiddenSize, int outputSize) : base() { this.inputSize = inputSize; this.hiddenSize = hiddenSize; this.outputSize = outputSize; affine1 = new AffineLayer(inputSize, hiddenSize, Mathf.Sqrt(2.0f / inputSize)); bn1 = new BatchNormalizationLayer(hiddenSize, hiddenSize); relu1 = new ReLULayer(); affine2 = new AffineLayer(hiddenSize, hiddenSize, Mathf.Sqrt(2.0f / hiddenSize)); bn2 = new BatchNormalizationLayer(hiddenSize, hiddenSize); relu2 = new ReLULayer(); affine3 = new AffineLayer(hiddenSize, outputSize, Mathf.Sqrt(2.0f / hiddenSize)); softmax = new SoftmaxLayer(); optimizer = new MomentumOptimizer(0.9f); }
public void OverShootingMakeWeightsSmaller() { const int targetValue = 0; const int passCount = 10; //Model model = InitSimpleModel(_rnd.Next(1, 15), CostType.Square); Model model = InitSimpleModel(10, CostType.Square); // model.InitWithRandomWeights(_rnd, 5, 10); model.InitWithConstWeights(1); for (int i = 0; i < passCount; i++) { model.ForwardPass(new Matrix(0)); double curValue = model.FirstOutputValue; var weightsBefore = new List <Matrix>(); //var modelBefore = Utils.DeepCopy(model); for (int k = 0; k < model.LayersCount - 1; k++) { weightsBefore.Add(Utils.DeepCopy(model[k].Weights.Primal)); } Matrix target = new Matrix(targetValue); model.BackwardPass(target); for (int k = 0; k < model.LayersCount - 1; k++) { AffineLayer modelLayer = model[k]; Matrix prevModelWeights = weightsBefore[k]; for (int j = 0; j < modelLayer.Weights.Primal.Rows; j++) { double current = modelLayer.Weights.Primal[j, 0]; double prev = prevModelWeights[j, 0]; Assert.IsTrue(current < prev, "current = {0} prev = {1}", current, prev); } } } }