示例#1
0
        public override double update(StateTransition <int[], int[]> transition)
        {
            stats.cumulativeReward += transition.reward;

            int[] alloOldState = new int[2] {
                transition.oldState[0], transition.oldState[1]
            };
            int[] alloNewState = new int[2] {
                transition.newState[0], transition.newState[1]
            };
            int[] egoOldState = new int[8];
            Array.Copy(transition.oldState, 2, egoOldState, 0, 8);
            int[] egoNewState = new int[8];
            Array.Copy(transition.newState, 2, egoNewState, 0, 8);

            // load the transition into the history
            if (saHistory.Count > 500)
            {
                saHistory.Dequeue();
                sPrimeHistory.Dequeue();
            }
            double[] sa = new double[10];
            Array.Copy(egoOldState, sa, 8);
            sa[8] = transition.action[0];
            sa[9] = transition.action[1];
            Console.WriteLine("sa: " + string.Join(",", sa));
            Console.WriteLine("sprime: " + alloNewState[0] + "," + alloNewState[1]);

            //double[] dummy;
            //if (!inSample(sa, out dummy))
            //{
            saHistory.Enqueue(sa);
            sPrimeHistory.Enqueue(new double[3] {
                alloNewState[0] - alloOldState[0], alloNewState[1] - alloOldState[1], transition.reward
            });
            //}

            // run regression
            if (saHistory.Count > 50 && fullPredictionMode)
            {
                double error;
                for (int epoch = 1; epoch < 2; epoch++)
                {
                    error = teacher.RunEpoch(saHistory.ToArray(), sPrimeHistory.ToArray()) / saHistory.Count;
                }
            }

            // update models with the current transition
            alloModel.update(new StateTransition <int[], int[]>(alloOldState, transition.action, transition.reward, alloNewState));
            egoModel.update(new StateTransition <int[], int[]>(egoOldState, transition.action, transition.reward, egoNewState));

            // transfer info from ego to allo models
            Console.WriteLine("current state: " + alloNewState[0] + "," + alloNewState[1]);
            Console.WriteLine("ego. state: " + string.Join(",", egoNewState));

            foreach (int[] a in availableActions)
            {
                sa = new double[10];
                Array.Copy(egoNewState, sa, 8);
                sa[8] = a[0];
                sa[9] = a[1];
                double[] predicted    = network.Compute(sa);// linearModel.Compute(sa);
                int[]    predictedAlo = { (int)Math.Round(predicted[0]) + alloNewState[0], (int)Math.Round(predicted[1]) + alloNewState[1] };
                double   reward       = predicted[2];

                double handCodedReward; int[] handCodedPredictedAlo;
                handCodedPrediction(egoNewState, a, out handCodedReward, alloNewState, out handCodedPredictedAlo);

                Console.WriteLine("action " + a[0] + "," + a[1] + " -> " + predictedAlo[0] + "," + predictedAlo[1] + " reward: " + reward);

                if (saHistory.Count >= 50)
                {
                    double[] matchingSample;
                    //if (inSample(sa, out matchingSample))
                    //{
                    if (alloModel.value(alloNewState, a) == alloModel.defaultQ)
                    {
                        if (fullPredictionMode)
                        {
                            alloModel.update(new StateTransition <int[], int[]>(alloNewState, a, reward, predictedAlo));
                        }
                        else
                        {
                            alloModel.Qtable[alloNewState][a] = egoModel.value(egoNewState, a);
                        }
                    }
                    //}
                }
            }


            return(0);
        }
        //private bool inSample(double[] sa, out double[] matchingSample)
        //{
        //    matchingSample = null;

        //    if (saHistory.Count == 0)
        //        return false;

        //    foreach (double[] sample in saHistory)
        //    {
        //        bool thisSampleMatches = true;
        //        for (int i=0; i<10; i++)
        //        {
        //            if (sa[i] != sample[i])
        //            {
        //                thisSampleMatches = false;
        //                break;
        //            }
        //        }
        //        if (thisSampleMatches)
        //        {
        //            matchingSample = sample;
        //            return true;
        //        }
        //    }
        //    return false;
        //}

        public override double update(StateTransition <int[], int[]> transition)
        {
            steps++;
            stats.cumulativeReward += transition.reward;

            int[] alloOldState = new int[2] {
                transition.oldState[0], transition.oldState[1]
            };
            int[] alloNewState = new int[2] {
                transition.newState[0], transition.newState[1]
            };
            int[] egoOldState = new int[12];
            Array.Copy(transition.oldState, 2, egoOldState, 0, 12);
            int[] egoNewState = new int[12];
            Array.Copy(transition.newState, 2, egoNewState, 0, 12);


            //double[] sa = new double[10];
            //Array.Copy(egoOldState, sa, 8);
            //sa[8] = transition.action[0];
            //sa[9] = transition.action[1];
            //Console.WriteLine("sa: " + string.Join(",", sa));
            //Console.WriteLine("sprime: " + alloNewState[0] + "," + alloNewState[1]);


            // train ego prediction models
            egoPredictionModels[0].update(new StateTransition <int[], int[]>(egoOldState, transition.action, transition.newState[0] - transition.oldState[0], new int[1] {
                -1
            }));
            egoPredictionModels[1].update(new StateTransition <int[], int[]>(egoOldState, transition.action, transition.newState[1] - transition.oldState[1], new int[1] {
                -1
            }));
            egoPredictionModels[2].update(new StateTransition <int[], int[]>(egoOldState, transition.action, transition.reward, new int[1] {
                -1
            }));


            // update models with the current transition
            alloLearner.update(new StateTransition <int[], int[]>(alloOldState, transition.action, transition.reward, alloNewState));
            egoLearner.update(new StateTransition <int[], int[]>(Array.ConvertAll(egoOldState, x => (int)x), transition.action, transition.reward, Array.ConvertAll(egoNewState, x => (int)x)));

            // transfer info from ego to allo models
            //Console.WriteLine("current state: " + alloNewState[0] + "," + alloNewState[1]);
            //Console.WriteLine("ego. state: " + string.Join(",", egoNewState));


            for (int i = 0; i < availableActions.Count; i++)
            {
                if (!visitedStates[i].ContainsKey(alloNewState))
                {
                    visitedStates[i].Add(alloNewState, 1);
                }
                else
                {
                    visitedStates[i][alloNewState]++;
                }

                if (steps >= 10 && visitedStates[i][alloNewState] <= updateTerminationStepCount)
                {
                    double predictedReward = egoPredictionModels[2].value(egoNewState, availableActions[i]);
                    double d0           = egoPredictionModels[0].value(egoNewState, availableActions[i]);
                    double d1           = egoPredictionModels[1].value(egoNewState, availableActions[i]);
                    int[]  predictedAlo = { (int)Math.Round(d0 + alloNewState[0]), (int)Math.Round(d1 + alloNewState[1]) };


                    //handCodedPrediction(Array.ConvertAll(egoNewState, x => (int)x), availableActions[i], out reward, alloNewState, out predictedAlo, 0.05);

                    Console.WriteLine("action " + availableActions[i][0] + "," + availableActions[i][1] + " -> " + predictedAlo[0] + "," + predictedAlo[1] + " reward: " + predictedReward);

                    //double[] matchingSample;
                    //if (inSample(sa, out matchingSample))
                    //{
                    //if (alloModel.value(alloNewState, availableActions[i]) == alloModel.defaultQ)
                    //{
                    if (fullPredictionMode && egoPredictionModels[0].T.GetStateValueTable(egoNewState, availableActions[i]).Values.Sum() > 1)
                    {
                        alloLearner.update(new StateTransition <int[], int[]>(alloNewState, availableActions[i], predictedReward, predictedAlo));
                    }
                    else if (!fullPredictionMode)
                    {
                        double setQvalue = egoLearner.value(Array.ConvertAll(egoNewState, x => (int)x), availableActions[i]);
                        alloLearner.Qtable[alloNewState][availableActions[i]] = setQvalue;
                    }
                    //}
                    //}
                }
            }


            return(0);
        }