MCTS_TreeNode uct()
    {
        MCTS_TreeNode selected = null;
        //Choose best action using UCT
        float bestValue = -float.MaxValue;

        foreach (MCTS_TreeNode child in children)
        {
            float uctValue = child.qValue / child.nVisits;
            uctValue = Normalize(uctValue, bounds[0], bounds[1]);
            uctValue = uctValue + sqrt2 * Mathf.Sqrt(Mathf.Log(nVisits + 1) / child.nVisits);
            uctValue = AddNoise(uctValue);
            if (uctValue > bestValue)
            {
                selected  = child;
                bestValue = uctValue;
            }
        }
        if (selected == null)
        {
            Debug.Log("Warning! returning null...");
        }

        //Roll the state:
        currentState.SimulateForward((Agent.Action)(selected.index));
        birdPosList.Add(currentState.birdPos);

        return(selected);
    }
Example #2
0
    public override Action GetAction()
    {
        mctsRootNode = new MCTS_TreeNode();
        mctsRootNode.mctsSearch(numIterations);
        int actionIdx = mctsRootNode.bestAction();

        return((Action)(actionIdx));
    }
 public void mctsSearch(int numIterations)
 {
     MyLineRenderer.Init();
     for (int i = 0; i < numIterations; i++)
     {
         currentState.Save();
         birdPosList = new List <Vector3>();
         birdPosList.Add(currentState.birdPos);
         MCTS_TreeNode selected = treePolicy();
         float         reward   = selected.rollOut();
         backUp(selected, reward);
         MyLineRenderer.lineStrips.Add(birdPosList);
     }
 }
    MCTS_TreeNode treePolicy()
    {
        MCTS_TreeNode cur = this;

        while (!finishRollout(cur.depth))
        {
            if (cur.notFullyExpanded())
            {
                return(cur.expand());
            }
            else
            {
                cur = cur.uct();
            }
        }
        return(cur);
    }
 void backUp(MCTS_TreeNode node, float reward)
 {
     while (node != null)
     {
         node.nVisits++;
         node.qValue += reward;
         if (reward < node.bounds[0])
         {
             node.bounds[0] = reward;
         }
         if (reward > node.bounds[1])
         {
             node.bounds[1] = reward;
         }
         node    = node.parent;
         reward *= Agent_MCTS._.lambda; //The effect of reward decreases exponentially as we get closer to the root
     }
 }
 public MCTS_TreeNode(MCTS_TreeNode parent = null, int index = -1)
 {
     children    = new MCTS_TreeNode[numActions];
     qValue      = 0;
     nVisits     = 0;
     depth       = 0;
     this.parent = parent;
     this.index  = index;
     if (parent != null)
     {
         depth = parent.depth + 1;
     }
     if (currentState == null)
     {
         currentState = new WorldState();
     }
     if (birdPosList == null)
     {
         birdPosList = new List <Vector3>();
     }
 }
    MCTS_TreeNode expand()
    {
        int   action   = 0;
        float bestRand = -1;

        for (int i = 0; i < numActions; i++)
        {
            float x = Random.Range(0.0f, 1.0f);
            if (x > bestRand && children[i] == null)
            {
                bestRand = x;
                action   = i;
            }
        }

        currentState.SimulateForward((Agent.Action)(action));
        birdPosList.Add(currentState.birdPos);

        MCTS_TreeNode newChild = new MCTS_TreeNode(this, action);

        children[action] = newChild;

        return(newChild);
    }