Beispiel #1
0
        /// <summary>
        /// Convert the data to suitable type trainer supports
        /// </summary>
        /// <param name="args"> data being converted</param>
        /// <returns></returns>
        private UnorderedMapVariableMinibatchData prepareForTrainerEx(UnorderedMapStreamInformationMinibatchData args, bool cloneData)
        {
            // var d = new Dictionary<Variable, MinibatchData>();
            var d = new UnorderedMapVariableMinibatchData();

            for (int i = 0; i < args.Count; i++)
            {
                var k = args.ElementAt(i);

                var var = InputVariables.Union(OutputVariables).Where(x => x.Name.Equals(k.Key.m_name)).FirstOrDefault();
                if (var == null)
                {
                    throw new Exception("Variable doesn't exist!");
                }

                //clone data first
                var           key    = k.Key;
                MinibatchData mbData = null;
                if (cloneData)
                {
                    mbData = new MinibatchData(k.Value.data.DeepClone(), k.Value.numberOfSequences, k.Value.numberOfSamples, k.Value.sweepEnd);
                }
                else
                {
                    mbData = k.Value;
                }
                d.Add(var, mbData);
            }

            return(d);
        }
Beispiel #2
0
        /// <summary>
        /// The method is called during training process. The method returns the chunk of data specified by Batch size.
        /// </summary>
        /// <param name="minibatchSizeInSamples"></param>
        /// <param name="device"></param>
        /// <returns></returns>
        public UnorderedMapStreamInformationMinibatchData GetNextMinibatch(uint minibatchSizeInSamples, DeviceDescriptor device)
        {
            if (Type == MinibatchType.Default)
            {
                var retVal = defaultmb.GetNextMinibatch(minibatchSizeInSamples, device);
                return(retVal);
            }
            else if (Type == MinibatchType.Custom)
            {
                var retVal = nextBatch(custommb, StreamConfigurations, (int)minibatchSizeInSamples, 1, device);
                var mb     = new UnorderedMapStreamInformationMinibatchData();

                for (int i = 0; i < retVal.Count; i++)
                {
                    var k = retVal.ElementAt(i);
                    //this is fix for 2.6 version since the Value  data is not valid, so it must be clone in order to create MiniBatchData
                    var mbData = new MinibatchData(k.Value.data.DeepClone(true), k.Value.numberOfSequences, k.Value.numberOfSamples, k.Value.sweepEnd);

                    var si = new StreamInformation();
                    si.m_definesMbSize = StreamConfigurations[i].m_definesMbSize;
                    si.m_storageFormat = k.Value.data.StorageFormat;
                    si.m_name          = StreamConfigurations[i].m_streamName;

                    //
                    mb.Add(si, mbData);
                    //mb.Add(si, k.Value);
                }

                return(mb);
            }
            else
            {
                throw new Exception("Unsupported Mini-batch-source type!");
            }
        }
        /// <summary>
        /// The method is called during training process. The method returns the chunk of data specified by Batch size.
        /// </summary>
        /// <param name="minibatchSizeInSamples"></param>
        /// <param name="device"></param>
        /// <returns></returns>
        public UnorderedMapStreamInformationMinibatchData GetNextMinibatch(uint minibatchSizeInSamples, DeviceDescriptor device)
        {
            if (Type == MinibatchType.Default || Type == MinibatchType.Image)
            {
                var retVal = defaultmb.GetNextMinibatch(minibatchSizeInSamples, device);
                return(retVal);
            }
            else if (Type == MinibatchType.Custom)
            {
                var retVal = nextBatch(custommb, StreamConfigurations, (int)minibatchSizeInSamples);
                var mb     = new UnorderedMapStreamInformationMinibatchData();
                var eofs   = custommb.EndOfStream;
                //create minibatch
                foreach (var d in retVal)
                {
                    var v   = Value.CreateBatchOfSequences <float>(new NDShape(1, d.Key.m_dim), d.Value, device);
                    var mbd = new MinibatchData(v, (uint)d.Value.Count(), (uint)d.Value.Sum(x => x.Count), eofs);

                    var si = new StreamInformation();
                    si.m_definesMbSize = d.Key.m_definesMbSize;
                    si.m_storageFormat = StorageFormat.Dense;
                    si.m_name          = d.Key.m_streamName;

                    mb.Add(si, mbd);
                }

                return(mb);
            }
            else
            {
                throw new Exception("Unsupported Mini-batch-source type!");
            }
        }
        static MinibatchData GetMinibatchData()
        {
            var mb = new MinibatchData();

            mb.data = GetValue();

            return(mb);
        }
Beispiel #5
0
        private void OnWorldMovePerformed(World world, PlayAction action)
        {
            // Calculate reward
            state.Reward  = world.Score - previousScore;
            previousScore = world.Score;
            //Trace.WriteLine($"OnWorldMovePerformed => {action} Reward = {state.Reward}");

            states.Insert(0, state);
            if (state.Reward != 0 && states.Count >= batchSize)
            {
                states = states.Take(batchSize).ToList();
                Trace.WriteLine($"Train batch");

                // Calculate reward and expected output
                float reward  = 0;
                var   values  = new float[states.First().Value.Length *states.Count];
                var   actions = new float[World.PLAY_ACTION_COUNT * states.Count];
                int   i       = 0;
                foreach (var state in states)
                {
                    state.Value.CopyTo(values, i * state.Value.Length);

                    reward = decay * reward + state.Reward;

                    Trace.WriteLine($"Train batch - Action: {state.Action} Reward: {reward}");

                    var expectedActions = CNTKHelper.CNTKHelper.SoftMax(CNTKHelper.CNTKHelper.OneHot((int)state.Action, World.PLAY_ACTION_COUNT, reward));
                    expectedActions.CopyTo(actions, i * World.PLAY_ACTION_COUNT);

                    i++;
                }

                // Create Minibatches
                var inputs         = Value.CreateBatch <float>(model.Arguments[0].Shape, values, device);
                var inputMinibatch = new MinibatchData(inputs, (uint)states.Count());

                var outputs         = Value.CreateBatch <float>(model.Output.Shape, actions, device);
                var outputMinibatch = new MinibatchData(outputs, (uint)states.Count());

                // Apply learning

                var arguments = new Dictionary <Variable, MinibatchData>
                {
                    { inputVariable, inputMinibatch },
                    { actionVariable, outputMinibatch }
                };
                int epoc = 5;
                while (epoc > 0)
                {
                    trainer.TrainMinibatch(arguments, device);
                    CNTKHelper.CNTKHelper.PrintTrainingProgress(trainer, epoc);

                    epoc--;
                }
                // Go for next
                states.Clear();
            }
        }
Beispiel #6
0
        /// <summary>
        /// Helper method for retrieving the next batch of data for the custom minibatch source.
        /// </summary>
        /// <param name="stream">Stream object</param>
        /// <param name="m_streamConfig"></param>
        /// <param name="batchSize"></param>
        /// <param name="iteration"></param>
        /// <param name="device"></param>
        /// <returns></returns>
        private static Dictionary <StreamConfiguration, MinibatchData> nextBatch(TextReader stream,
                                                                                 StreamConfiguration[] m_streamConfig, int batchSize, int iteration, DeviceDescriptor device)
        {
            var values      = new Dictionary <StreamConfiguration, List <List <float> > >();
            var retVal      = new Dictionary <StreamConfiguration, MinibatchData>();
            var endOfStream = false;

            //local function for creating a batch of data
            if (((StreamReader)stream).EndOfStream)
            {
                ((StreamReader)stream).BaseStream.Position = 0;
            }

            //in case batchSize is less than 1 retrieve all data set
            var reader = batchSize <= 0 ? ReadLineFromFile((StreamReader)stream) : ReadLineFromFile((StreamReader)stream).Take(batchSize);

            //
            foreach (var batchLine in reader)
            {
                var streams = batchLine.Split(MLFactory.m_cntkSpearator, StringSplitOptions.RemoveEmptyEntries);
                var dics    = processTextLine <float>(streams, m_streamConfig);
                //
                foreach (var d in dics)
                {
                    if (!values.ContainsKey(d.Key))
                    {
                        var l = new List <List <float> >();
                        l.Add(d.Value);
                        values.Add(d.Key, l);
                    }
                    else
                    {
                        values[d.Key].Add(d.Value);
                    }
                }
            }
            //check for end of file
            endOfStream = ((StreamReader)stream).EndOfStream;

            //in case of end batch return null
            //this should never happen
            if (values.Count == 0)
            {
                return(null);
            }

            //create minibatch
            foreach (var d in values)
            {
                var v   = Value.CreateBatchOfSequences <float>(new NDShape(1, d.Key.m_dim), d.Value, device);
                var mbd = new MinibatchData(v, (uint)d.Value.Count, (uint)d.Value.Sum(x => x.Count), endOfStream);
                retVal.Add(d.Key, mbd);
            }
            //
            return(retVal);
        }
Beispiel #7
0
        public void Train()
        {
            uint nbTries = SIZE_IN * (SIZE_IN - 1);
            var  values  = new float[nbTries * SIZE_IN * SIZE_IN_NB_LAYER];
            var  actions = new float[nbTries * SIZE_OUT];

            // Generate training data
            int index = 0;

            for (int p = 0; p < SIZE_IN; p++)
            {
                for (int c = 0; c < SIZE_IN; c++)
                {
                    if (p != c)
                    {
                        var worldValue = WorldToValue(c, p);
                        worldValue.CopyTo(values, index * SIZE_IN * SIZE_IN_NB_LAYER);
                        if (p > c)
                        {
                            actions[index * SIZE_OUT] = 1; // Left
                        }
                        else
                        {
                            actions[index * SIZE_OUT + 1] = 1; // Right
                        }
                        index++;
                    }
                }
            }

            // Create Minibatches
            var inputs         = Value.CreateBatch <float>(model.Arguments[0].Shape, values, device);
            var inputMinibatch = new MinibatchData(inputs, nbTries);

            var outputs         = Value.CreateBatch <float>(model.Output.Shape, actions, device);
            var outputMinibatch = new MinibatchData(outputs, nbTries);


            // Apply learning
            var arguments = new Dictionary <Variable, MinibatchData>
            {
                { inputVariable, inputMinibatch },
                { actionVariable, outputMinibatch }
            };
            int epoc = 25;

            while (epoc > 0)
            {
                trainer.TrainMinibatch(arguments, device);

                CNTKHelper.CNTKHelper.PrintTrainingProgress(trainer, epoc);

                epoc--;
            }
        }
Beispiel #8
0
        public override Minibatch GetNextMinibatch(DeviceDescriptor device = null)
        {
            if (device == null)
            {
                device = DeviceDescriptor.UseDefaultDevice();
            }

            Value value;

            if (InputVariable == null)
            {
                value = FunctionInvoke.Invoke(Expression, new Dictionary <Variable, Value>(), device, false)[0];
            }
            else
            {
                value = FunctionInvoke.Invoke(Expression, new Dictionary <Variable, Value>()
                {
                    { InputVariable, PrevValue }
                }, device, false)[0];
            }

            int sampleCount = 0;
            int rank        = value.Shape.Rank;

            if (rank == 0)
            {
                sampleCount = 1;
            }
            else
            {
                sampleCount = value.Shape[rank - 1];
            }

            ++Iterations;
            var sweepEnd = (Iterations + 1) % IterationsPerEpoch == 0;

            var data      = new MinibatchData(value, (uint)sampleCount, sweepEnd);
            var minibatch = new Minibatch();

            minibatch.Add(Name, data);

            PrevValue = value;

            return(minibatch);
        }
Beispiel #9
0
        public void TestDataIngegrityAfterGarbageCollection()
        {
            for (var i = 0; i < 100; ++i)
            {
                var data = new float[] { 1, 2, 3, 4, 5, 6 };

                MinibatchData m;
                IntPtr sharedPtrAddress;
                IntPtr valueAddress;
                int c1;
                unsafe
                {
                    fixed (float* d = data)
                    {
                        var a = new NDArrayView(new int[] { 3, 2 }, data, DeviceDescriptor.CPUDevice);
                        var a2 = a.DeepClone();
                        var value = new Value(a2);
                        c1 = SwigMethods.GetSharedPtrUseCount(value);
                        sharedPtrAddress = SwigMethods.GetSwigPointerAddress(value);
                        valueAddress = SwigMethods.GetSharedPtrElementPointer(value);

                        m = new MinibatchData(value);
                    }
                }

                var c2 = SwigMethods.GetSharedPtrUseCount(m.data);
                var sharedPtrAddress2 = SwigMethods.GetSwigPointerAddress(m.data);
                var valueAddress2 = SwigMethods.GetSharedPtrElementPointer(m.data);
                var c3 = SwigMethods.GetSharedPtrUseCount(m.data);

                GC.Collect();
                GC.Collect();
                GC.Collect();

                var c4 = SwigMethods.GetSharedPtrUseCount(m.data);
                var sharedPtrAddress3 = SwigMethods.GetSwigPointerAddress(m.data);
                var valueAddress3 = SwigMethods.GetSharedPtrElementPointer(m.data);

                var ds = DataSourceFactory.FromValue(m.data);
                Assert.AreEqual(6, ds.Data.Count);
                CollectionAssert.AreEqual(new int[] { 3, 2 }, ds.Shape.Dimensions);
                CollectionAssert.AreEqual(new float[] { 1, 2, 3, 4, 5, 6 }, ds.TypedData);
            }
        }
Beispiel #10
0
        /// <summary>
        /// Convert minibatch data
        /// </summary>
        /// <param name="args"> data being converted</param>
        /// <returns></returns>
        public static UnorderedMapVariableMinibatchData ToMinibatchData(UnorderedMapStreamInformationMinibatchData args, List <Variable> vars, MinibatchType type)
        {
            var arguments = new UnorderedMapVariableMinibatchData();

            foreach (var mbd in args)
            {
                var v = vars.Where(x => x.Name == mbd.Key.m_name).FirstOrDefault();
                if (v == null)
                {
                    throw new Exception("Stream is invalid!");
                }
                if (type == MinibatchType.Custom)
                {
                    var mbd1 = new MinibatchData(mbd.Value.data.DeepClone(), mbd.Value.numberOfSamples, mbd.Value.sweepEnd);
                    arguments.Add(v, mbd1);
                }

                else
                {
                    arguments.Add(v, mbd.Value);
                }
            }
            return(arguments);
        }
Beispiel #11
0
 internal static global::System.Runtime.InteropServices.HandleRef getCPtr(MinibatchData obj)
 {
     return((obj == null) ? new global::System.Runtime.InteropServices.HandleRef(null, global::System.IntPtr.Zero) : obj.swigCPtr);
 }
Beispiel #12
0
        /// <summary>
        /// Test a simple model which takes a one hot encoded digit as an input and returns the same as an output
        /// </summary>
        private void TrainAndEvaluateTest(Function model, Value inputValue)
        {
            #region Evaluate model before training

            var inputDataMap = new Dictionary <Variable, Value>()
            {
                { model.Arguments[0], inputValue }
            };
            var outputDataMap = new Dictionary <Variable, Value>()
            {
                { model.Output, null }
            };

            model.Evaluate(inputDataMap, outputDataMap, DeviceDescriptor.CPUDevice);

            IList <IList <float> > preTrainingOutput = outputDataMap[model.Output].GetDenseData <float>(model.Output);
            for (int i = 0; i < TEST1_SIZE; i++)
            {
                Trace.WriteLine($"Argmax({i}): {CNTKHelper.ArgMax(preTrainingOutput[i].ToArray())}");
            }
            #endregion

            #region Train Model
            var labels       = CNTKLib.InputVariable(new int[] { TEST1_SIZE }, DataType.Float, "Error Input");
            var trainingLoss = CNTKLib.CrossEntropyWithSoftmax(new Variable(model), labels, "lossFunction");
            var prediction   = CNTKLib.ClassificationError(new Variable(model), labels, "classificationError");

            // Set per sample learning rate
            CNTK.TrainingParameterScheduleDouble learningRatePerSample = new CNTK.TrainingParameterScheduleDouble(0.003125, 1);

            IList <Learner> parameterLearners = new List <Learner>()
            {
                Learner.SGDLearner(model.Parameters(), learningRatePerSample)
            };
            var trainer = Trainer.CreateTrainer(model, trainingLoss, prediction, parameterLearners);

            // Create expected output
            var expectedOutputValue = Value.CreateBatch <float>(new int[] { TEST1_SIZE }, ExpectedOutput(TEST1_SIZE), DeviceDescriptor.CPUDevice);

            var inputMiniBatch  = new MinibatchData(inputValue, TEST1_SIZE);
            var outputMiniBatch = new MinibatchData(expectedOutputValue, TEST1_SIZE);

            var arguments = new Dictionary <Variable, MinibatchData>
            {
                { model.Arguments[0], inputMiniBatch },
                { labels, outputMiniBatch }
            };
            int epochs = 5;
            while (epochs > 0)
            {
                trainer.TrainMinibatch(arguments, device);

                epochs--;
            }
            #endregion

            #region Evaluate Model after training

            outputDataMap = new Dictionary <Variable, Value>()
            {
                { model.Output, null }
            };
            model.Evaluate(inputDataMap, outputDataMap, DeviceDescriptor.CPUDevice);

            IList <IList <float> > postTrainingOutput = outputDataMap[model.Output].GetDenseData <float>(model.Output);
            int nbFail = 0;
            for (int i = 0; i < TEST1_SIZE; i++)
            {
                int prepTrainValue = CNTKHelper.ArgMax(preTrainingOutput[i].ToArray());
                int postTrainValue = CNTKHelper.ArgMax(postTrainingOutput[i].ToArray());
                if (i != postTrainValue)
                {
                    nbFail++;
                }
                Trace.WriteLine($"Argmax({i}): {prepTrainValue} ==>  {postTrainValue}");
            }
            Trace.WriteLine($"Failure rate = ({nbFail}/{TEST1_SIZE})");
            #endregion
        }
Beispiel #13
0
        private void Train()
        {
            uint SIZE_IN_NB_LAYER = 2;
            uint SIZE_IN          = (World.SIZE - 2) * (World.SIZE - 2);
            uint nbTries          = (World.SIZE - 2) * (World.SIZE - 2) * (World.SIZE - 2) * (World.SIZE - 2) - (World.SIZE - 2) * (World.SIZE - 2);

            var values  = new float[nbTries * SIZE_IN * SIZE_IN_NB_LAYER];
            var actions = new float[nbTries * World.PLAY_ACTION_COUNT];

            int index = 0;

            for (int i = 1; i < World.SIZE - 1; i++)     // i => Y
            {
                for (int j = 1; j < World.SIZE - 1; j++) // j => X
                {
                    var coinPosition = new Position(j, i);

                    for (int k = 1; k < World.SIZE - 1; k++)
                    {
                        for (int l = 1; l < World.SIZE - 1; l++)
                        {
                            if (k == i && l == j)
                            {
                                continue;
                            }

                            var playerPosition = new Position(l, k);

                            var worldValue = WorldToValue(coinPosition, playerPosition);
                            worldValue.CopyTo(values, index * SIZE_IN * SIZE_IN_NB_LAYER);

                            var dx = coinPosition.X - playerPosition.X;
                            var dy = coinPosition.Y - playerPosition.Y;

                            if (Math.Abs(dx) > Math.Abs(dy))
                            {
                                if (dx > 0)
                                {
                                    actions[index * World.PLAY_ACTION_COUNT + (int)PlayAction.Right] = 1;
                                }
                                else
                                {
                                    actions[index * World.PLAY_ACTION_COUNT + (int)PlayAction.Left] = 1;
                                }
                            }
                            else if (Math.Abs(dx) < Math.Abs(dy))
                            {
                                if (dy > 0)
                                {
                                    actions[index * World.PLAY_ACTION_COUNT + (int)PlayAction.Down] = 1;
                                }
                                else
                                {
                                    actions[index * World.PLAY_ACTION_COUNT + (int)PlayAction.Up] = 1;
                                }
                            }

                            index++;
                        }
                    }
                }
            }
            Trace.WriteLine($"NbTries={nbTries} Index {index}");

            // Create Minibatches
            var inputs         = Value.CreateBatch <float>(model.Arguments[0].Shape, values, device);
            var inputMinibatch = new MinibatchData(inputs, nbTries);

            var outputs         = Value.CreateBatch <float>(model.Output.Shape, actions, device);
            var outputMinibatch = new MinibatchData(outputs, nbTries);


            // Apply learning
            var arguments = new Dictionary <Variable, MinibatchData>
            {
                { inputVariable, inputMinibatch },
                { actionVariable, outputMinibatch }
            };
            int epoc = 0;

            while (epoc < 50)
            {
                trainer.TrainMinibatch(arguments, device);

                CNTKHelper.CNTKHelper.PrintTrainingProgress(trainer, epoc);

                //float trainLossValue = (float)trainer.PreviousMinibatchLossAverage();
                //if (trainLossValue < 0.005)
                //    break;

                epoc++;

                // Test

                inputDataMap[inputVariable] = inputs;
                outputDataMap[model.Output] = null;

                model.Evaluate(inputDataMap, outputDataMap, DeviceDescriptor.CPUDevice);

                var testOutputs = outputDataMap[model].GetDenseData <float>(model);
                int testId      = 0;
                int success     = 0;
                foreach (var actualValues in testOutputs)
                {
                    var expectedValues = actions.Skip(testId * World.PLAY_ACTION_COUNT).Take(World.PLAY_ACTION_COUNT).ToArray();
                    var expectedAction = (PlayAction)CNTKHelper.CNTKHelper.ArgMax(expectedValues);
                    var actualAction   = (PlayAction)CNTKHelper.CNTKHelper.ArgMax(actualValues);

                    if (actualAction == expectedAction)
                    {
                        success++;
                    }
                }
                Trace.WriteLine($"Success {success}/{testOutputs.Count}");
            }

            trained = true;
        }
Beispiel #14
0
        public void OnWorldMovePerformed(World world, PlayAction action)
        {
            // Calculate reward
            state.Reward  = world.Score - previousScore;
            previousScore = world.Score;
            //Trace.WriteLine($"OnWorldMovePerformed => {action} Reward = {state.Reward}");

            states.Insert(0, state);
            if (state.Reward != 0 && states.Count >= batchSize)
            {
                states = states.Take(batchSize).ToList();

                // Calculate reward and expected output
                float   reward          = 0;
                var     values          = new float[states.First().Value.Length *batchSize];
                var     actions         = new float[World.PLAY_ACTION_COUNT * batchSize];
                int     i               = 0;
                float[] expectedActions = null;
                foreach (var state in states)
                {
                    state.Value.CopyTo(values, i * state.Value.Length);

                    reward = decay * reward + state.Reward;

                    //if (reward > 100)
                    //{
                    //    TraceValues(state.Value, "State Values");
                    //}

                    //Trace.WriteLine($"Train batch - Action: {state.Action} Reward: {reward}");

                    expectedActions = CNTKHelper.CNTKHelper.SoftMax(CNTKHelper.CNTKHelper.OneHot((int)state.Action, World.PLAY_ACTION_COUNT, reward));

                    expectedActions.CopyTo(actions, i * World.PLAY_ACTION_COUNT);

                    i++;
                }

                // Create Minibatches
                var inputs         = Value.CreateBatch <float>(model.Arguments[0].Shape, values, device);
                var inputMinibatch = new MinibatchData(inputs, (uint)states.Count());

                var outputs         = Value.CreateBatch <float>(model.Output.Shape, actions, device);
                var outputMinibatch = new MinibatchData(outputs, (uint)states.Count());

                // Calculate output before learning (for debug)

#if DEBUG
                //inputDataMap[inputVariable] = inputs;
                //outputDataMap[model.Output] = null;

                //model.Evaluate(inputDataMap, outputDataMap, DeviceDescriptor.CPUDevice);

                //var beforeOutput = outputDataMap[model].GetDenseData<float>(model);
#endif

                // Apply learning
                var arguments = new Dictionary <Variable, MinibatchData>
                {
                    { inputVariable, inputMinibatch },
                    { actionVariable, outputMinibatch }
                };
                int epoc = 10;
                while (epoc > 0)
                {
                    trainer.TrainMinibatch(arguments, device);

                    epoc--;
                }
                CNTKHelper.CNTKHelper.PrintTrainingProgress(trainer, epoc);

                //inputDataMap[inputVariable] = inputs;
                //outputDataMap[model.Output] = null;

                //model.Evaluate(inputDataMap, outputDataMap, DeviceDescriptor.CPUDevice);

                //var afterOutput = outputDataMap[model].GetDenseData<float>(model);
                //Trace.WriteLine("Action  \t" + Enum.GetValues(typeof(PlayAction)).OfType<PlayAction>().Select(p => p.ToString()).Aggregate((a, b) => a + "  \t" + b));
                //for (i = 0; i < beforeOutput.Count; i++)
                //{
                //    var ea = actions.Skip(i * World.PLAY_ACTION_COUNT).Take(World.PLAY_ACTION_COUNT).ToArray<float>();
                //    var bo = beforeOutput[i];
                //    var ao = afterOutput[i];
                //    Trace.WriteLine("Expected\t" + ea.Select(p => p.ToString("0.000")).Aggregate((a, b) => a + "\t" + b) + "\t" + (PlayAction)CNTKHelper.CNTKHelper.ArgMax(ea));
                //    Trace.WriteLine("Before  \t" + bo.Select(p => p.ToString("0.000")).Aggregate((a, b) => a + "\t" + b) + "\t" + (PlayAction)CNTKHelper.CNTKHelper.ArgMax(bo));
                //    Trace.WriteLine("After   \t" + ao.Select(p => p.ToString("0.000")).Aggregate((a, b) => a + "\t" + b) + "\t" + (PlayAction)CNTKHelper.CNTKHelper.ArgMax(ao));
                //}

                // Go for next
                states.Clear();
            }
        }