public virtual MyExecutionBlock CreateCustomExecutionPlan(MyExecutionBlock defaultPlan) { switch (LoopType) { case MyLoopOperation.All: { var res = new MyLoopBlock( x => x < Iterations, new MyExecutionBlock( m_oneShotTasks, defaultPlan ) ); return(res); } case MyLoopOperation.Normal: default: { return(new MyLoopBlock( x => x < Iterations, defaultPlan )); } } }
public virtual MyExecutionBlock CreateCustomExecutionPlan(MyExecutionBlock defaultPlan) { List <IMyExecutable> selected = new List <IMyExecutable>(); List <IMyExecutable> newPlan = new List <IMyExecutable>(); List <IMyExecutable> BPTTSingleStep = new List <IMyExecutable>(); List <IMyExecutable> BPTTAllSteps = new List <IMyExecutable>(); // copy default plan content to new plan content foreach (IMyExecutable groupTask in defaultPlan.Children) { if (groupTask.GetType() == typeof(MyExecutionBlock)) { foreach (IMyExecutable nodeTask in (groupTask as MyExecutionBlock).Children) { newPlan.Add(nodeTask); // add individual node tasks } } else { newPlan.Add(groupTask); // add group tasks } } // remove group backprop tasks (they should be called from the individual layers) // DO NOT remove RBM tasks // DO NOT remove the currently selected backprop task (it handles batch learning) selected = newPlan.Where(task => task is MyAbstractBackpropTask && !(task.Enabled) && !(task is MyRBMLearningTask || task is MyRBMReconstructionTask)).ToList(); newPlan.RemoveAll(selected.Contains); // bbpt single step BPTTSingleStep.AddRange(newPlan.Where(task => task is IMyDeltaTask).ToList().Reverse <IMyExecutable>()); BPTTSingleStep.AddRange(newPlan.Where(task => task is MyLSTMPartialDerivativesTask).ToList()); BPTTSingleStep.AddRange(newPlan.Where(task => task is MyGradientCheckTask).ToList()); BPTTSingleStep.Add(DecrementTimeStep); // backprop until unfolded (timestep=0) MyExecutionBlock BPTTLoop = new MyLoopBlock(i => TimeStep != -1, BPTTSingleStep.ToArray() ); // if learning is globally disabled, removed update weights tasks MyExecutionBlock UpdateWeightsIfNotDisabled = new MyIfBlock(() => GetActiveBackpropTask() != null && GetActiveBackpropTask().DisableLearning == false, newPlan.Where(task => task is IMyUpdateWeightsTask).ToArray() ); if (GetActiveBackpropTask() != null && GetActiveBackpropTask().DisableLearning) { MyLog.WARNING.WriteLine("Learning is globally disabled for the network " + this.Name + " in the " + GetActiveBackpropTask().Name + " backprop task."); } // bptt architecture BPTTAllSteps.Add(BPTTLoop); BPTTAllSteps.Add(IncrementTimeStep); BPTTAllSteps.Add(RunTemporalBlocksMode); BPTTAllSteps.Add(UpdateWeightsIfNotDisabled); BPTTAllSteps.Add(DecrementTimeStep); // if current time is time for bbp, do it MyExecutionBlock BPTTExecuteBPTTIfTimeCountReachedSequenceLength = new MyIfBlock(() => TimeStep == SequenceLength - 1, BPTTAllSteps.ToArray() ); // remove group backprop tasks (they should be called from the individual layers) newPlan.RemoveAll(newPlan.Where(task => task is MyAbstractBackpropTask && !(task is MyRBMLearningTask || task is MyRBMReconstructionTask)).ToList().Contains); //TODO - include dropout in the new version of planner newPlan.RemoveAll(newPlan.Where(task => task is MyCreateDropoutMaskTask).ToList().Contains); //newPlan.RemoveAll(newPlan.Where(task => task is IMyOutputDeltaTask).ToList().Contains); newPlan.RemoveAll(newPlan.Where(task => task is IMyDeltaTask).ToList().Contains); newPlan.RemoveAll(newPlan.Where(task => task is MyGradientCheckTask).ToList().Contains); newPlan.RemoveAll(newPlan.Where(task => task is IMyUpdateWeightsTask).ToList().Contains); newPlan.RemoveAll(newPlan.Where(task => task is MyLSTMPartialDerivativesTask).ToList().Contains); newPlan.RemoveAll(newPlan.Where(task => task is MyIncrementTimeStepTask).ToList().Contains); newPlan.RemoveAll(newPlan.Where(task => task is MyDecrementTimeStepTask).ToList().Contains); newPlan.RemoveAll(newPlan.Where(task => task is MyRunTemporalBlocksModeTask).ToList().Contains); //selected = newPlan.Where(task => task is IMyOutputDeltaTask).ToList(); //newPlan.RemoveAll(selected.Contains); // after FF add deltaoutput and bptt if needed, then increment one step :) newPlan.Insert(0, IncrementTimeStep); // Move output delta tasks after all forward tasks. selected = newPlan.Where(task => task is IMyOutputDeltaTask).ToList(); newPlan.RemoveAll(selected.Contains); newPlan.InsertRange(newPlan.IndexOf(newPlan.FindLast(task => task is IMyForwardTask)) + 1, selected.Reverse <IMyExecutable>()); // Move Q-learning tasks between forward tasks and output delta tasks. selected = newPlan.Where(task => task is MyQLearningTask || task is MyQLearningBatchTask).ToList(); newPlan.RemoveAll(selected.Contains); newPlan.InsertRange(newPlan.IndexOf(newPlan.FindLast(task => task is IMyForwardTask)) + 1, selected.Reverse <IMyExecutable>()); newPlan.Add(BPTTExecuteBPTTIfTimeCountReachedSequenceLength); // return new plan as MyExecutionBlock return(new MyExecutionBlock(newPlan.ToArray())); }