Esempio n. 1
0
 public virtual MyExecutionBlock CreateCustomExecutionPlan(MyExecutionBlock defaultPlan)
 {
     switch (LoopType)
     {
         case MyLoopOperation.All:
         {
             var res = new MyLoopBlock(
                 x => x < Iterations,
                 new MyExecutionBlock(
                     m_oneShotTasks,
                     defaultPlan
                 )
             );
             return res;
         }
         case MyLoopOperation.Normal:
         default:
         {
             return new MyLoopBlock(
                 x => x < Iterations,
                 defaultPlan
             );
         }
     }
 }
        public virtual MyExecutionBlock CreateCustomExecutionPlan(MyExecutionBlock defaultPlan)
        {
            List<IMyExecutable> selected = new List<IMyExecutable>();
            List<IMyExecutable> newPlan = new List<IMyExecutable>();

            List<IMyExecutable> BPTTSingleStep = new List<IMyExecutable>();
            List<IMyExecutable> BPTTAllSteps = new List<IMyExecutable>();

            // copy default plan content to new plan content
            foreach (IMyExecutable groupTask in defaultPlan.Children)
                if (groupTask.GetType() == typeof(MyExecutionBlock))
                    foreach (IMyExecutable nodeTask in (groupTask as MyExecutionBlock).Children)
                        newPlan.Add(nodeTask); // add individual node tasks
                else
                    newPlan.Add(groupTask); // add group tasks

            // remove group backprop tasks (they should be called from the individual layers)
            // DO NOT remove RBM tasks
            // DO NOT remove the currently selected backprop task (it handles batch learning)
            selected = newPlan.Where(task => task is MyAbstractBackpropTask &&  !(task.Enabled) && !(task is MyRBMLearningTask || task is MyRBMReconstructionTask)).ToList();
            newPlan.RemoveAll(selected.Contains);
            // bbpt single step
            BPTTSingleStep.AddRange(newPlan.Where(task => task is IMyDeltaTask).ToList().Reverse<IMyExecutable>());
            BPTTSingleStep.AddRange(newPlan.Where(task => task is MyLSTMPartialDerivativesTask).ToList());
            BPTTSingleStep.AddRange(newPlan.Where(task => task is MyGradientCheckTask).ToList());
            BPTTSingleStep.Add(DecrementTimeStep);

            // backprop until unfolded (timestep=0)
            MyExecutionBlock BPTTLoop = new MyLoopBlock(i => TimeStep != -1,
                BPTTSingleStep.ToArray()
            );

            // if learning is globally disabled, removed update weights tasks
            MyExecutionBlock UpdateWeightsIfNotDisabled = new MyIfBlock(() => GetActiveBackpropTask() != null && GetActiveBackpropTask().DisableLearning == false,
                newPlan.Where(task => task is IMyUpdateWeightsTask).ToArray()
            );
            if (GetActiveBackpropTask() != null && GetActiveBackpropTask().DisableLearning)
            {
                MyLog.WARNING.WriteLine("Learning is globally disabled for the network " + this.Name + " in the " + GetActiveBackpropTask().Name + " backprop task.");
            }

            // bptt architecture
            BPTTAllSteps.Add(BPTTLoop);
            BPTTAllSteps.Add(IncrementTimeStep);
            BPTTAllSteps.Add(RunTemporalBlocksMode);
            BPTTAllSteps.Add(UpdateWeightsIfNotDisabled);
            BPTTAllSteps.Add(DecrementTimeStep);

            // if current time is time for bbp, do it
            MyExecutionBlock BPTTExecuteBPTTIfTimeCountReachedSequenceLength = new MyIfBlock(() => TimeStep == SequenceLength-1,
                BPTTAllSteps.ToArray()
            );

            // remove group backprop tasks (they should be called from the individual layers)
            newPlan.RemoveAll(newPlan.Where(task => task is MyAbstractBackpropTask && !(task is MyRBMLearningTask || task is MyRBMReconstructionTask)).ToList().Contains);
            //TODO - include dropout in the new version of planner
            newPlan.RemoveAll(newPlan.Where(task => task is MyCreateDropoutMaskTask).ToList().Contains);
            //newPlan.RemoveAll(newPlan.Where(task => task is IMyOutputDeltaTask).ToList().Contains);
            newPlan.RemoveAll(newPlan.Where(task => task is IMyDeltaTask).ToList().Contains);
            newPlan.RemoveAll(newPlan.Where(task => task is MyGradientCheckTask).ToList().Contains);
            newPlan.RemoveAll(newPlan.Where(task => task is IMyUpdateWeightsTask).ToList().Contains);
            newPlan.RemoveAll(newPlan.Where(task => task is MyLSTMPartialDerivativesTask).ToList().Contains);
            newPlan.RemoveAll(newPlan.Where(task => task is MyIncrementTimeStepTask).ToList().Contains);
            newPlan.RemoveAll(newPlan.Where(task => task is MyDecrementTimeStepTask).ToList().Contains);
            newPlan.RemoveAll(newPlan.Where(task => task is MyRunTemporalBlocksModeTask).ToList().Contains);

            //selected = newPlan.Where(task => task is IMyOutputDeltaTask).ToList();
            //newPlan.RemoveAll(selected.Contains);

            // after FF add deltaoutput and bptt if needed, then increment one step :)
            newPlan.Insert(0, IncrementTimeStep);

            // Move output delta tasks after all forward tasks.
            selected = newPlan.Where(task => task is IMyOutputDeltaTask).ToList();
            newPlan.RemoveAll(selected.Contains);
            newPlan.InsertRange(newPlan.IndexOf(newPlan.FindLast(task => task is IMyForwardTask)) + 1, selected.Reverse<IMyExecutable>());

            // Move Q-learning tasks between forward tasks and output delta tasks.
            selected = newPlan.Where(task => task is MyQLearningTask || task is MyQLearningBatchTask).ToList();
            newPlan.RemoveAll(selected.Contains);
            newPlan.InsertRange(newPlan.IndexOf(newPlan.FindLast(task => task is IMyForwardTask)) + 1, selected.Reverse<IMyExecutable>());

            newPlan.Add(BPTTExecuteBPTTIfTimeCountReachedSequenceLength);

            // return new plan as MyExecutionBlock
            return new MyExecutionBlock(newPlan.ToArray());
        }