internal Training(NeuralNetwork network, BufferAllocator allocator, TrainingMode mode) : base(network, DetermineIterationRepeat(network), allocator) { Contract.Requires(network != null); Contract.Requires(allocator != null); Mode = mode; if (Mode == TrainingMode.Streamed && (GCAlgo != GradientComputingAlgorithm.RTLR || GCAlgo != GradientComputingAlgorithm.None)) { throw new InvalidOperationException("Only RTLR allowed for Streamed training. You have to use Recurrent NN with RTLR Algorithm in RecurrentOptions."); } if ((network.StructuralElementFlags & NNStructuralElement.GradientInformation) != 0) { if (network.IsRecurrent) { GCAlgo = network.RecurrentOptions.Algorithm == RLAlgorithm.BPTT ? GradientComputingAlgorithm.BPTT : GradientComputingAlgorithm.RTLR; } else { GCAlgo = GradientComputingAlgorithm.BP; } } else { GCAlgo = GradientComputingAlgorithm.None; } if (GCAlgo == GradientComputingAlgorithm.BPTT) { savedErrorVectors = new ErrorVectorStack(network, allocator); } }
private void EnsureInitialized() { if (!isInitialized) { // Call ago init first: Network.Reset(NeuralNetworkResetTarget.Algorithms); if (GCAlgo == GradientComputingAlgorithm.BPTT) { savedErrorVectors = new ErrorVectorStack(Network.RecurrentOptions.MaxIterations, Network.OutputInterfaceLength); } isInitialized = true; } }