/// <summary> /// Merge the given /// <c>Cost</c> /// data with the data in this /// instance. /// </summary> /// <param name="otherCost"/> public virtual void Merge(Classifier.Cost otherCost) { this.cost += otherCost.GetCost(); this.percentCorrect += otherCost.GetPercentCorrect(); ArrayMath.AddInPlace(this.gradW1, otherCost.GetGradW1()); ArrayMath.PairwiseAddInPlace(this.gradb1, otherCost.GetGradb1()); ArrayMath.AddInPlace(this.gradW2, otherCost.GetGradW2()); ArrayMath.AddInPlace(this.gradE, otherCost.GetGradE()); }
/// <summary> /// Update classifier weights using the given training cost /// information. /// </summary> /// <param name="cost"> /// Cost information as returned by /// <see cref="ComputeCostFunction(int, double, double)"/> /// . /// </param> /// <param name="adaAlpha">Global AdaGrad learning rate</param> /// <param name="adaEps"> /// Epsilon value for numerical stability in AdaGrad's /// division /// </param> public virtual void TakeAdaGradientStep(Classifier.Cost cost, double adaAlpha, double adaEps) { ValidateTraining(); double[][] gradW1 = cost.GetGradW1(); double[][] gradW2 = cost.GetGradW2(); double[][] gradE = cost.GetGradE(); double[] gradb1 = cost.GetGradb1(); for (int i = 0; i < W1.Length; ++i) { for (int j = 0; j < W1[i].Length; ++j) { eg2W1[i][j] += gradW1[i][j] * gradW1[i][j]; W1[i][j] -= adaAlpha * gradW1[i][j] / System.Math.Sqrt(eg2W1[i][j] + adaEps); } } for (int i_1 = 0; i_1 < b1.Length; ++i_1) { eg2b1[i_1] += gradb1[i_1] * gradb1[i_1]; b1[i_1] -= adaAlpha * gradb1[i_1] / System.Math.Sqrt(eg2b1[i_1] + adaEps); } for (int i_2 = 0; i_2 < W2.Length; ++i_2) { for (int j = 0; j < W2[i_2].Length; ++j) { eg2W2[i_2][j] += gradW2[i_2][j] * gradW2[i_2][j]; W2[i_2][j] -= adaAlpha * gradW2[i_2][j] / System.Math.Sqrt(eg2W2[i_2][j] + adaEps); } } if (config.doWordEmbeddingGradUpdate) { for (int i_3 = 0; i_3 < E.Length; ++i_3) { for (int j = 0; j < E[i_3].Length; ++j) { eg2E[i_3][j] += gradE[i_3][j] * gradE[i_3][j]; E[i_3][j] -= adaAlpha * gradE[i_3][j] / System.Math.Sqrt(eg2E[i_3][j] + adaEps); } } } }