/// <summary> /// /// </summary> /// <returns>ID of selected action neuron.</returns> public int Execute() { NeuralNetworkContext ctx = stackedRuntimeContext[0]; neuralNetwork.Execute(ctx); //randomly select action from output result Utils.Normalize(ctx.outputData); int action = Utils.RandomChoice(ctx.outputData); //save action to session Utils.IntToStream(action, learningSessionStream); Utils.FloatArrayToStream(ctx.inputData, learningSessionStream); currentSession++; return(action); }
/// <summary> /// Run single iteration of learning, either 1 forward or backward propagation. /// </summary> public void Learn() { if (!running) { return; } if (resetState) { resetState = false; newLoss = 0.0f; lossSampleCount = 0; derivatives.Reset(); for (int i = 0; i < stackedRuntimeContext.Length; i++) { stackedRuntimeContext[i].Reset(true); stackedDerivativeMemory[i].Reset(); } if (hasRecurring && stochasticSkipping) { if (targetData.Length < maxUnrollLength) { skipN = 0; } else { skipN = (int)(Utils.NextInt(0, (targetData[dataIndex].Length % maxUnrollLength) + 1)); } } } //run forwards for maxUnrollLength and then run backwards for maxUnrollLength backpropagating through recurring if (skipN > 0) { //skip random # at beginning to apply a 'shuffle' if (hasRecurring) { Array.Copy(inputData[dataIndex], stackedRuntimeContext[0].inputData, stackedRuntimeContext[0].inputData.Length); neuralNetwork.Execute(stackedRuntimeContext[0]); } skipN--; } else { int unrollIndex = unrollCount; if (!hasRecurring) { unrollIndex = 0; } Array.Copy(inputData[dataIndex], stackedRuntimeContext[unrollIndex].inputData, stackedRuntimeContext[unrollIndex].inputData.Length); neuralNetwork.Execute_FullContext(stackedRuntimeContext[unrollIndex], stackedFullContext[unrollIndex]); unrollCount++; if (hasRecurring) { if (unrollCount >= maxUnrollLength || dataIndex + 1 >= targetData.Length) { //back propagate through stacked float nextLoss = 0.0f; int tdatIndex = dataIndex, nunroll = unrollCount; while (unrollCount-- > 0) { neuralNetwork.ExecuteBackwards(targetData[tdatIndex], stackedRuntimeContext[unrollCount], stackedFullContext[unrollCount], stackedDerivativeMemory[unrollCount], lossType, (lossType == LOSS_TYPE_CROSSENTROPY ? crossEntropyLossTargets[tdatIndex] : -1)); if (lossType == LOSS_TYPE_AVERAGE) { nextLoss += stackedDerivativeMemory[unrollCount].loss; } else { if (stackedDerivativeMemory[unrollCount].loss > nextLoss) { nextLoss = stackedDerivativeMemory[unrollCount].loss; } } tdatIndex--; } if (lossType == LOSS_TYPE_AVERAGE) { newLoss += nextLoss / (float)nunroll; lossSampleCount++; } else { if (nextLoss > newLoss) { newLoss = nextLoss; } } //learn adagradMemory.Apply(stackedDerivativeMemory[0]); derivatives.Reset(); unrollCount = 0; //copy recurring state over CopyRecurringState(stackedRuntimeContext[maxUnrollLength - 1], stackedRuntimeContext[0]); } else { //copy recurring state into next CopyRecurringState(stackedRuntimeContext[unrollCount - 1], stackedRuntimeContext[unrollCount]); } } else { neuralNetwork.ExecuteBackwards(targetData[dataIndex], stackedRuntimeContext[unrollIndex], stackedFullContext[unrollIndex], stackedDerivativeMemory[unrollIndex], lossType, (lossType == LOSS_TYPE_CROSSENTROPY ? crossEntropyLossTargets[dataIndex] : -1)); if (lossType == LOSS_TYPE_AVERAGE) { newLoss += stackedDerivativeMemory[unrollIndex].loss; lossSampleCount++; } else { if (stackedDerivativeMemory[unrollIndex].loss > newLoss) { newLoss = stackedDerivativeMemory[unrollIndex].loss; } } if (unrollCount >= maxUnrollLength || dataIndex + 1 >= targetData.Length) { //learn adagradMemory.Apply(stackedDerivativeMemory[0]); derivatives.Reset(); unrollCount = 0; } } } //advance index dataIndex++; if (dataIndex >= targetData.Length) { iterations++; dataIndex = 0; if (lossType == LOSS_TYPE_AVERAGE) { newLoss /= (float)lossSampleCount; } if (newLoss < bestLoss) { bestLoss = newLoss; } if (newLoss <= desiredLoss) { //hit goal, stop if (onReachedGoal != null) { onReachedGoal(); } running = false; return; } float lsl = smoothLoss; smoothLoss = smoothLoss * lossSmoothing + newLoss * (1.0f - lossSmoothing); lossDelta = lossDelta * lossSmoothing + (lsl - smoothLoss) * (1.0f - lossSmoothing); lossSampleCount = 0; newLoss = 0.0f; //stream new data if (onStreamNextData != null) { resetState = onStreamNextData(ref inputData, ref targetData); if (lossType == LOSS_TYPE_CROSSENTROPY) { crossEntropyLossTargets = new int[targetData.Length]; for (int i = 0; i < targetData.Length; i++) { int r = Utils.Largest(targetData[i], 0, targetData[i].Length); if (targetData[i][r] > 0.0f) { crossEntropyLossTargets[i] = r; } else { crossEntropyLossTargets[i] = -1; } } } } else { resetState = true; } if (shuffleChance > 0.0f && Utils.NextFloat01() < shuffleChance) { Utils.Shuffle(inputData, targetData); } } }