/// <summary> /// This method gets updated as long as this worker is not paused or stopped. Maybe only once. /// </summary> protected virtual void DoWork() { if (_epochBlockYield == null) { throw new InvalidOperationException($"Unable to do work in worker {this}, worker was not initialised successfully (epoch yield is null)."); } // no more blocks in this yield, therefore epoch is done if (!_epochBlockYield.MoveNext()) { InvokeTimeScaleEvent(TimeScale.Epoch); Logger.Debug($"Completed epoch {LocalEpochNumber} at iteration {LocalIterationNumber} in worker {this}."); LocalEpochNumber++; LocalIterationNumber = 0; _epochBlockYield = LocalTrainingDataIterator.Yield(Operator.Handler, Operator.Sigma).GetEnumerator(); _epochBlockYield.MoveNext(); } if (_epochBlockYield.Current == null) { throw new InvalidOperationException($"Unable to do work in worker {this} because current epoch block yield is null."); } Operator.PullProgress(this); bool useSessions = Operator.UseSessions; if (useSessions) { Operator.Handler.BeginSession(); } Operator.Trainer.ProvideExternalInputData(LocalNetwork, _epochBlockYield.Current); Operator.Trainer.RunTrainingIteration(LocalNetwork, LocalOptimiser, GetPopulatedBufferRegistry(), Operator.Handler); Operator.Trainer.ProvideExternalOutputData(LocalNetwork, _epochBlockYield.Current); if (useSessions) { Operator.Handler.EndSession(); } InvokeTimeScaleEvent(TimeScale.Iteration); //_logger.Debug($"Worker {this} done with iteration {LocalIterationNumber} in epoch {LocalEpochNumber} at cost:\t{LocalOptimiser.Registry.Get<INumber>("total_cost").Value}"); LocalIterationNumber++; // push progress for this iteration Operator.PushProgress(this); }
/// <summary> /// This method will be called every time the worker will start from a full stop. /// </summary> protected virtual void Initialise() { Logger.Debug($"Initialising worker {this}..."); IEnumerable <IDictionary <string, INDArray> > blockYieldEnumerable = LocalTrainingDataIterator?.Yield(Operator.Handler, Operator.Sigma); if (blockYieldEnumerable == null) { _logger.Warn($"Unable to yield block enumerable from local training iterator {LocalTrainingDataIterator} in worker {this}"); return; } _epochBlockYield = blockYieldEnumerable.GetEnumerator(); Logger.Debug($"Done initialising worker {this}."); }