/// <summary> /// データを元に結果を予測する /// </summary> /// <param name="xData">予測する状況</param> /// <returns>予測結果</returns> public NumYArray Predict(NumYArray xData) { //順伝播・フォワードプロパゲーション NumYArray layerZ1 = new NumYArray(NumY.Dot(xData, this.w1) + this.b1); NumYArray layerA1 = new NumYArray(NumY.Sigmoid(layerZ1)); NumYArray layerZ2 = new NumYArray(NumY.Dot(layerA1, this.w2) + this.b2); NumYArray layerA2 = new NumYArray(NumY.Sigmoid(layerZ2)); NumYArray returnArray = new NumYArray(layerA2); return(returnArray); }
/// <summary> /// 入力されたデータを元にトレーニング /// </summary> /// <param name="epochs">トレーニング回数</param> /// <param name="learningRate">学習率</param> /// <returns></returns> public bool Train(int epochs, float learningRate, bool isTrainNow) { if (!isTrainNow) { NumYArray layerZ1 = new NumYArray(); NumYArray layerA1 = new NumYArray(); NumYArray layerZ2 = new NumYArray(); NumYArray layerA2 = new NumYArray(); NumYArray dlayerZ1 = new NumYArray(); NumYArray dlayerZ2 = new NumYArray(); NumYArray dw1 = new NumYArray(); NumYArray db1 = new NumYArray(); NumYArray dw2 = new NumYArray(); NumYArray db2 = new NumYArray(); totalTrainNum = 0; maxTrainNum = Mathf.RoundToInt(100 / xData.Get().Count); if (maxTrainNum <= 0) { maxTrainNum = 1; } learnSpeed = 1; } nowTrainNum = 0; //規定回数学習させ、他処理のために途中で抜ける while (nowTrainNum < maxTrainNum * learnSpeed) { //許容値を超えているか判定する用出力結果 //Array a = Predict(xData); int m = xData.Get()[0].Count; //順伝播・フォワードプロパゲーション layerZ1 = NumY.Dot(xData, this.w1) + this.b1; layerA1 = NumY.Sigmoid(layerZ1); layerZ2 = NumY.Dot(layerA1, this.w2) + this.b2; layerA2 = NumY.Sigmoid(layerZ2); //誤差逆伝播法・バックプロパゲーション dlayerZ2 = (layerA2 - yData) / m; dw2 = NumY.Dot(layerA1.T, dlayerZ2); db2 = NumY.Sum(dlayerZ2, 0); dlayerZ1 = NumY.Dot(dlayerZ2, w2.T) * NumY.SigmoidDerivative(layerZ1); dw1 = NumY.Dot(xData.T, dlayerZ1); db1 = NumY.Sum(dlayerZ1, 0); //パラメータ更新 w2 -= learningRate * dw2; b2 -= learningRate * db2; w1 -= learningRate * dw1; b1 -= learningRate * db1; nowTrainNum++; totalTrainNum++; if (totalTrainNum >= epochs) { this.xData = new NumYArray(); this.yData = new NumYArray(); totalTrainNum = 0; nowTrainNum = 0; return(false); } } nowTrainNum = 0; return(true); }