コード例 #1
0
ファイル: Boxer.cs プロジェクト: jacattelona/PunchOut
    /// <summary>
    /// Take an action
    /// </summary>
    /// <param name="vectorAction">The action to take</param>
    /// <param name="textAction">The name of the action</param>
    public override void AgentAction(float[] vectorAction, string textAction)
    {
        base.AgentAction(vectorAction, textAction);
        lastActions = vectorAction;

        lastLoss = MLActionFactory.GetLoss(vectorAction);

        if (!isFighting)
        {
            return;
        }

        // Only perform confident moves
        var confidence = MLActionFactory.GetProbabilityFromVector(MLActionFactory.GetAction(vectorAction), vectorAction);

        if (!Mathf.Approximately(confidence, 0) && confidence < minConfidence)
        {
            return;
        }

        // Try to perform action
        TryToTakeAction(MLActionFactory.GetAction(vectorAction));

        // Interrupt dodges early
        if (endDodgeEarly && MLActionFactory.IsDodge(currentAction))
        {
            if (MLActionFactory.IsPunch(lastEnemyState) && !MLActionFactory.IsPunch(opponent.GetCurrentAction()))
            {
                // Opponent finished punching, so stop dodging
                dodgeAction.Interrupt();
            }
        }

        lastEnemyState = opponent.GetCurrentAction();
    }
コード例 #2
0
ファイル: Boxer.cs プロジェクト: jacattelona/PunchOut
    private void TryToTakeAction(MLAction action)
    {
        if (IsPerformingMove())
        {
            return;
        }
        if (isAnim)
        {
            return;
        }
        if (MLActionFactory.IsPunch(action))
        {
            if (punchAction.IsOnCooldown())
            {
                return;
            }
            punchAction.Run(action == MLAction.PUNCH_LEFT ? Direction.LEFT : Direction.RIGHT);
        }

        if (MLActionFactory.IsDodge(action))
        {
            if (dodgeAction.IsOnCooldown())
            {
                return;
            }
            dodgeAction.Run(action == MLAction.DODGE_LEFT ? Direction.LEFT : Direction.RIGHT);
        }
    }
コード例 #3
0
    public float[] ExpertDecide(List <float> vectorObs)
    {
        System.Random r = new System.Random();

        if (r.NextDouble() < noise)
        {
            return(new float[] { r.Next(0, 5) });
        }

        MLInput input = new MLInput(vectorObs.ToArray());

        if (input.GetOpponentAction() == MLAction.DODGE_LEFT)
        {
            return(MLActionFactory.GetVectorAction(MLAction.PUNCH_RIGHT));
        }

        if (input.GetOpponentAction() == MLAction.DODGE_RIGHT)
        {
            return(MLActionFactory.GetVectorAction(MLAction.PUNCH_LEFT));
        }

        if (input.GetOpponentAction() == MLAction.PUNCH_LEFT)
        {
            return(MLActionFactory.GetVectorAction(MLAction.DODGE_LEFT));
        }

        if (input.GetOpponentAction() == MLAction.PUNCH_RIGHT)
        {
            return(MLActionFactory.GetVectorAction(MLAction.DODGE_RIGHT));
        }

        return(new float[] { 0f });
    }
コード例 #4
0
ファイル: Boxer.cs プロジェクト: jacattelona/PunchOut
    private bool IsGoodMove()
    {
        if (opponent == null)
        {
            return(false);
        }
        var opponentMove = opponent.GetCurrentAction();
        var myMove       = MLActionFactory.GetAction(lastActions);

        if (myMove == MLAction.NOTHING)
        {
            return(true);
        }

        if (MLActionFactory.IsPunch(opponentMove) && !MLActionFactory.IsDodge(myMove))
        {
            return(false);
        }

        if (!MLActionFactory.IsPunch(opponentMove) && MLActionFactory.IsDodge(myMove))
        {
            return(false);
        }

        return(true);
    }
コード例 #5
0
    private MLAction runMoves(List <float> vectorObs, MLAction[] moves)
    {
        MLInput input = new MLInput(vectorObs.ToArray());

        var currentMove = moves[moveIdx];

        if (currentMove == MLAction.NOTHING && Time.time - nothingStartTime >= nothingDuration)
        {
            // Doing nothing ended
            return(runNextMove(moves, input));
        }
        else if (didAction && MLActionFactory.IsPunch(currentMove) && input.GetMyMove() == MLAction.NOTHING)
        {
            // Punch ended
            return(runNextMove(moves, input));
        }
        else if (didAction && MLActionFactory.IsDodge(currentMove) && input.GetMyMove() == MLAction.NOTHING)
        {
            // Dodge ended
            return(runNextMove(moves, input));
        }

        if (!didAction && MLActionFactory.IsPunch(currentMove))
        {
            didAction = input.IsPunchReady();
        }
        else if (!didAction && MLActionFactory.IsDodge(currentMove))
        {
            didAction = input.IsDodgeReady();
        }

        return(moves[moveIdx]);
    }
コード例 #6
0
    public override float[] Decide(List <float> vectorObs, List <Texture2D> visualObs, float reward, bool done, List <float> memory)
    {
        if (done)
        {
            moveIdx = 0;
            return(NOTHING);
        }

        MLInput input = new MLInput(vectorObs.ToArray());

        if (input.GetMyMove() == MLAction.NOTHING && input.IsDodgeReady()) // Can dodge
        {
            if (input.GetOpponentAction() == MLAction.PUNCH_LEFT)
            {
                int rand = Random.Range(0, 100);
                if (rand <= 30)
                {
                    return(LEFT_DODGE);
                }
                else if (rand <= 50)
                {
                    return(RIGHT_DODGE);
                }
                else
                {
                    // Run the normal moves
                }
            }
            if (input.GetOpponentAction() == MLAction.PUNCH_RIGHT)
            {
                int rand = Random.Range(0, 100);
                if (rand <= 30)
                {
                    return(RIGHT_DODGE);
                }
                else if (rand <= 50)
                {
                    return(LEFT_DODGE);
                }
                else
                {
                    // Run the normal moves
                }
            }
        }

        return(MLActionFactory.GetVectorAction(runMoves(vectorObs, moves)));
    }
コード例 #7
0
    public override float[] Decide(List <float> vectorObs, List <Texture2D> visualObs, float reward, bool done, List <float> memory)
    {
        if (!useOldVersion)
        {
            if (Mathf.Approximately(seqStartTime, 0))
            {
                Reset();
            }

            if (Time.time - seqStartTime > 15 && seqIdx == 0)
            {
                seqIdx           = 1;
                moveIdx          = 0;
                didAction        = false;
                nothingStartTime = Time.time;
            }
            else if (Time.time - seqStartTime > 30 && seqIdx == 1)
            {
                seqIdx           = 2;
                moveIdx          = 0;
                didAction        = false;
                nothingStartTime = Time.time;
            }

            MLAction nextMove = runMoves(vectorObs, moveSequences[seqIdx]);
            return(MLActionFactory.GetVectorAction(nextMove));
        }
        else
        {
            if (done)
            {
                moveIdx = 0;
                return(NOTHING);
            }


            MLInput input = new MLInput(vectorObs.ToArray());

            if (input.IsPunchReady() && input.IsDodgeReady()) // Can punch / dodge
            {
                float[] move = moves[moveIdx];
                moveIdx = (moveIdx + 1) % moves.Length;
                return(move);
            }

            return(NOTHING);
        }
    }
コード例 #8
0
    public override float[] Decide(List <float> vectorObs, List <Texture2D> visualObs, float reward, bool done, List <float> memory)
    {
        MLInput input = new MLInput(vectorObs.ToArray());


        MLAction currentMove = input.GetMyMove();

        //if (MLActionFactory.IsDodge(currentMove))
        //{
        //    return MLActionFactory.GetVectorAction(currentMove);
        //}

        if (currentMove != MLAction.NOTHING)
        {
            var maxMoveCount = MLActionFactory.IsDodge(currentMove) ? maxDodgeCount : maxPunchCount;
            if (moveCount < maxMoveCount)
            {
                moveCount++;
                return(MLActionFactory.GetVectorAction(currentMove));
            }
            return(MLActionFactory.GetVectorAction(MLAction.NOTHING));
        }
        else
        {
            moveCount = 0;
        }

        if (Input.GetKey(KeyCode.F))
        {
            return(MLActionFactory.GetVectorAction(MLAction.PUNCH_LEFT));
        }
        else if (Input.GetKey(KeyCode.J))
        {
            return(MLActionFactory.GetVectorAction(MLAction.PUNCH_RIGHT));
        }
        else if (Input.GetKey(KeyCode.D))
        {
            return(MLActionFactory.GetVectorAction(MLAction.DODGE_LEFT));
        }
        else if (Input.GetKey(KeyCode.K))
        {
            return(MLActionFactory.GetVectorAction(MLAction.DODGE_RIGHT));
        }

        return(MLActionFactory.GetVectorAction(MLAction.NOTHING));
    }
コード例 #9
0
    private MLAction runNextMove(MLAction[] moves, MLInput input)
    {
        didAction        = false;
        nothingStartTime = Time.time;
        moveIdx          = (moveIdx + 1) % moves.Length;
        MLAction nextMove = moves[moveIdx];

        if (MLActionFactory.IsPunch(nextMove))
        {
            didAction = input.IsPunchReady();
        }
        else if (MLActionFactory.IsDodge(nextMove))
        {
            didAction = input.IsDodgeReady();
        }

        return(nextMove);
    }
コード例 #10
0
ファイル: Evaluator.cs プロジェクト: jacattelona/PunchOut
    private void Update()
    {
        var match = trainee.currentAction == coach.currentAction;//MLActionFactory.GetAction(trainee.lastActions) == MLActionFactory.GetAction(coach.lastActions);

        if (match)
        {
            matchingEvent.Invoke();
        }
        var desiredAction = MLActionFactory.GetAction(coach.lastActions);
        var probability   = MLActionFactory.GetProbabilityFromVector(desiredAction, trainee.lastActions);

        crossEntropy += MathUtils.CrossEntropy(probability);
        //correctness += probability;
        var alpha = 0.995f;

        runningAverageCorrectness = alpha * runningAverageCorrectness + (1 - alpha) * probability;
        AddSample(match);
    }
コード例 #11
0
    public float[] OnlyDoAndDodge(List <float> vectorObs, MLAction action)
    {
        MLInput input = new MLInput(vectorObs.ToArray());

        if (lastAction != input.GetOpponentAction() && input.GetOpponentAction() == MLAction.PUNCH_LEFT)
        {
            lastAction = input.GetOpponentAction();
            return(MLActionFactory.GetVectorAction(MLAction.DODGE_RIGHT));
        }

        if (lastAction != input.GetOpponentAction() && input.GetOpponentAction() == MLAction.PUNCH_RIGHT)
        {
            lastAction = input.GetOpponentAction();
            return(MLActionFactory.GetVectorAction(MLAction.DODGE_LEFT));
        }

        lastAction = input.GetOpponentAction();

        return(MLActionFactory.GetVectorAction(action));
    }
コード例 #12
0
ファイル: Boxer.cs プロジェクト: jacattelona/PunchOut
    /// <summary>
    /// Collect the observations (internal punch)
    /// </summary>
    public override void CollectObservations()
    {
        AddVectorObs(!punchAction.IsOnCooldown() && !punchAction.IsRunning());
        AddVectorObs(!dodgeAction.IsOnCooldown() && !dodgeAction.IsRunning());

        bool[] move;
        int    opponentComboState = 0;

        if (opponent != null)
        {
            move = new bool[] {
                opponent.currentAction == MLAction.PUNCH_RIGHT,
                opponent.currentAction == MLAction.PUNCH_LEFT,
                opponent.currentAction == MLAction.DODGE_LEFT,
                opponent.currentAction == MLAction.DODGE_RIGHT
            };

            opponentComboState = opponent.comboTracker.GetState();
        }
        else
        {
            move = new bool[] { false, false, false, false };
            opponentComboState = 0;
        }

        AddVectorObs(Encoder.encodeInt(comboTracker.GetState(), 0, comboTracker.GetTotalStates()));
        foreach (var m in move)
        {
            AddVectorObs(m);
        }
        AddVectorObs(currentAction == MLAction.PUNCH_LEFT);
        AddVectorObs(currentAction == MLAction.PUNCH_RIGHT);
        AddVectorObs(currentAction == MLAction.DODGE_LEFT);
        AddVectorObs(currentAction == MLAction.DODGE_RIGHT);

        if (bufferSize >= maxBufferSize)
        {
            SetTextObs((isTeacher && isFighting && MLActionFactory.GetAction(lastActions) != MLAction.NOTHING) + "," + true);
            bufferSize = 0;
        }
        else
        {
            if (MLActionFactory.IsPunch(MLActionFactory.GetAction(lastActions)))
            {
                punchCount++;
            }
            else if (MLActionFactory.IsDodge(MLActionFactory.GetAction(lastActions)))
            {
                dodgeCount++;
            }

            if (isTeacher)
            {
                //Debug.Log(punchCount + ", " + dodgeCount);
            }

            var training = isTeacher && isFighting && (MLActionFactory.GetAction(lastActions) != MLAction.NOTHING || nothingBuffer < nothingBufferSize) && IsGoodMove();
            if (training)
            {
                bufferSize++;
            }
            if (MLActionFactory.GetAction(lastActions) == MLAction.NOTHING)
            {
                nothingBuffer++;
            }
            SetTextObs(training + "," + false);
        }
    }
コード例 #13
0
 public float[] OnlyDo(List <float> vectorObs, MLAction action)
 {
     return(MLActionFactory.GetVectorAction(action));
 }
コード例 #14
0
    public float[] RepeatActions(List <float> vectorObs, List <MLAction> actions, bool shouldDodge)
    {
        MLInput input = new MLInput(vectorObs.ToArray());

        // Sequence
        if (input.IsDodgeReady()) // Can punch / dodge
        {
            // Dodging
            if (shouldDodge && lastAction != input.GetOpponentAction() && input.GetOpponentAction() == MLAction.PUNCH_LEFT)
            {
                lastAction = input.GetOpponentAction();
                actionIdx  = 0;
                return(MLActionFactory.GetVectorAction(MLAction.DODGE_RIGHT));
            }

            if (shouldDodge && lastAction != input.GetOpponentAction() && input.GetOpponentAction() == MLAction.PUNCH_RIGHT)
            {
                lastAction = input.GetOpponentAction();
                actionIdx  = 0;
                return(MLActionFactory.GetVectorAction(MLAction.DODGE_LEFT));
            }

            lastAction = input.GetOpponentAction();
        }

        if (input.IsPunchReady())
        {
            int myComboState = input.GetMyComboState();
            if (actions.Count == 2)
            {
                switch (myComboState)
                {
                case 0: return(MLActionFactory.GetVectorAction(actions[0]));

                case 1: return(MLActionFactory.GetVectorAction(actions[1]));

                case 2: return(MLActionFactory.GetVectorAction(actions[1]));

                default: return(MLActionFactory.GetVectorAction(MLAction.NOTHING));
                }
            }
            if (actions.Count == 3)
            {
                switch (myComboState)
                {
                case 0: return(MLActionFactory.GetVectorAction(actions[0]));

                case 1: return(MLActionFactory.GetVectorAction(actions[1]));

                case 2: return(MLActionFactory.GetVectorAction(actions[1]));

                case 3: return(MLActionFactory.GetVectorAction(actions[2]));

                case 4: return(MLActionFactory.GetVectorAction(actions[2]));

                default: return(MLActionFactory.GetVectorAction(MLAction.NOTHING));
                }
            }
            return(MLActionFactory.GetVectorAction(actions[0]));
        }

        return(MLActionFactory.GetVectorAction(MLAction.NOTHING));
    }