示例#1
0
        public void EpsilonGreedy()
        {
            uint   numActions = 10;
            float  epsilon    = 0f;
            string uniqueKey  = "ManagedTestId";

            TestRecorder <TestContext> recorder = new TestRecorder <TestContext>();
            TestPolicy policy = new TestPolicy();
            MwtExplorer <TestContext> mwtt = new MwtExplorer <TestContext>("mwt", recorder);
            TestContext testContext        = new TestContext();

            testContext.Id = 100;

            var explorer = new EpsilonGreedyExplorer <TestContext>(policy, epsilon, numActions);

            uint expectedAction = policy.ChooseAction(testContext);

            uint chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext);

            Assert.AreEqual(expectedAction, chosenAction);

            chosenAction = mwtt.ChooseAction(explorer, uniqueKey, testContext);
            Assert.AreEqual(expectedAction, chosenAction);

            var interactions = recorder.GetAllInteractions();

            Assert.AreEqual(2, interactions.Count);

            Assert.AreEqual(testContext.Id, interactions[0].Context.Id);
        }
示例#2
0
        public static void Clock()
        {
            float  epsilon         = .2f;
            string uniqueKey       = "clock";
            int    numFeatures     = 1000;
            int    numIter         = 1000;
            int    numWarmup       = 100;
            int    numInteractions = 1;
            uint   numActions      = 10;

            double timeInit = 0, timeChoose = 0, timeSerializedLog = 0;

            System.Diagnostics.Stopwatch watch = new System.Diagnostics.Stopwatch();
            for (int iter = 0; iter < numIter + numWarmup; iter++)
            {
                watch.Restart();

                StringRecorder <SimpleContext> recorder = new StringRecorder <SimpleContext>();
                StringPolicy policy = new StringPolicy();
                MwtExplorer <SimpleContext>           mwt      = new MwtExplorer <SimpleContext>("mwt", recorder);
                EpsilonGreedyExplorer <SimpleContext> explorer = new EpsilonGreedyExplorer <SimpleContext>(policy, epsilon, numActions);

                timeInit += (iter < numWarmup) ? 0 : watch.Elapsed.TotalMilliseconds;

                Feature[] f = new Feature[numFeatures];
                for (int i = 0; i < numFeatures; i++)
                {
                    f[i].Id    = (uint)i + 1;
                    f[i].Value = 0.5f;
                }

                watch.Restart();

                SimpleContext context = new SimpleContext(f);

                for (int i = 0; i < numInteractions; i++)
                {
                    mwt.ChooseAction(explorer, uniqueKey, context);
                }

                timeChoose += (iter < numWarmup) ? 0 : watch.Elapsed.TotalMilliseconds;

                watch.Restart();

                string interactions = recorder.GetRecording();

                timeSerializedLog += (iter < numWarmup) ? 0 : watch.Elapsed.TotalMilliseconds;

                for (int i = 0; i < numInteractions; i++)
                {
                    mwt.ChooseAction(explorer, uniqueKey, context);
                }
            }
            Console.WriteLine("--- PER ITERATION ---");
            Console.WriteLine("# iterations: {0}, # interactions: {1}, # context features {2}", numIter, numInteractions, numFeatures);
            Console.WriteLine("Init: {0} micro", timeInit * 1000 / numIter);
            Console.WriteLine("Choose Action: {0} micro", timeChoose * 1000 / (numIter * numInteractions));
            Console.WriteLine("Get Serialized Log: {0} micro", timeSerializedLog * 1000 / numIter);
            Console.WriteLine("--- TOTAL TIME: {0} micro", (timeInit + timeChoose + timeSerializedLog) * 1000);
        }
示例#3
0
        public void IndexAsync()
        {
            AsyncManager.OutstandingOperations.Increment();

            var explorer = new EpsilonGreedyExplorer <string>(new MartPolicy(), epsilon: .2f, numActions: 10);

            var serviceConfig = new DecisionServiceConfiguration <string>(
                authorizationToken: "c01ff675-5710-4814-a961-d03d2d6bce65",
                explorer: explorer);

            var service = new DecisionService <string>(serviceConfig);

            var rand = new Random();

            for (int i = 0; i < 10; i++)
            {
                int    context   = rand.Next(100);
                string uniqueKey = i.ToString();
                service.ChooseAction(uniqueKey, context.ToString());
                service.ReportReward((float)(context % 2), uniqueKey);
            }
            service.Flush();

            AsyncManager.OutstandingOperations.Decrement();
        }
示例#4
0
        public static void Create(uint numActions, string modelOutputDir, int policyAction)
        {
            if (Explorer == null)
            {
                Explorer = new EpsilonGreedyExplorer <TContext>(new MartPolicy <TContext>(policyAction), LoadSettings().Epsilon, numActions);
            }

            if (Configuration == null)
            {
                Configuration = new DecisionServiceConfiguration <TContext>(appToken, Explorer)
                {
                    BlobOutputDir = modelOutputDir,
                    BatchConfig   = new BatchingConfiguration
                    {
                        MaxDuration            = TimeSpan.FromSeconds(2),
                        MaxBufferSizeInBytes   = 1024,
                        MaxEventCount          = 100,
                        MaxUploadQueueCapacity = 4,
                        UploadRetryPolicy      = BatchUploadRetryPolicy.Retry
                    }
                };
            }

            if (Service == null)
            {
                Service = new DecisionService <TContext>(Configuration);
            }

            if (!File.Exists(settingsFile))
            {
                File.WriteAllText(settingsFile, JsonConvert.SerializeObject(new DecisionServiceSettings()));
            }
        }
示例#5
0
        public void UsageBadVariableActionContext()
        {
            int numExceptionsCaught   = 0;
            int numExceptionsExpected = 5;

            var tryCatchArgumentException = (Action <Action>)((action) => {
                try
                {
                    action();
                }
                catch (ArgumentException ex)
                {
                    if (ex.ParamName.ToLower() == "ctx")
                    {
                        numExceptionsCaught++;
                    }
                }
            });

            tryCatchArgumentException(() => {
                var mwt      = new MwtExplorer <TestContext>("test", new TestRecorder <TestContext>());
                var policy   = new TestPolicy <TestContext>();
                var explorer = new EpsilonGreedyExplorer <TestContext>(policy, 0.2f);
                mwt.ChooseAction(explorer, "key", new TestContext());
            });
            tryCatchArgumentException(() =>
            {
                var mwt      = new MwtExplorer <TestContext>("test", new TestRecorder <TestContext>());
                var policy   = new TestPolicy <TestContext>();
                var explorer = new TauFirstExplorer <TestContext>(policy, 10);
                mwt.ChooseAction(explorer, "key", new TestContext());
            });
            tryCatchArgumentException(() =>
            {
                var mwt      = new MwtExplorer <TestContext>("test", new TestRecorder <TestContext>());
                var policies = new TestPolicy <TestContext> [2];
                for (int i = 0; i < 2; i++)
                {
                    policies[i] = new TestPolicy <TestContext>(i * 2);
                }
                var explorer = new BootstrapExplorer <TestContext>(policies);
                mwt.ChooseAction(explorer, "key", new TestContext());
            });
            tryCatchArgumentException(() =>
            {
                var mwt      = new MwtExplorer <TestContext>("test", new TestRecorder <TestContext>());
                var scorer   = new TestScorer <TestContext>(10);
                var explorer = new SoftmaxExplorer <TestContext>(scorer, 0.5f);
                mwt.ChooseAction(explorer, "key", new TestContext());
            });
            tryCatchArgumentException(() =>
            {
                var mwt      = new MwtExplorer <TestContext>("test", new TestRecorder <TestContext>());
                var scorer   = new TestScorer <TestContext>(10);
                var explorer = new GenericExplorer <TestContext>(scorer);
                mwt.ChooseAction(explorer, "key", new TestContext());
            });

            Assert.AreEqual(numExceptionsExpected, numExceptionsCaught);
        }
示例#6
0
        public void EpsilonGreedyFixedActionUsingVariableActionInterface()
        {
            int   numActions  = 10;
            float epsilon     = 0f;
            var   policy      = new TestPolicy <VariableActionTestContext>();
            var   testContext = new VariableActionTestContext(numActions);
            var   explorer    = new EpsilonGreedyExplorer(epsilon);

            EpsilonGreedyWithContext(numActions, testContext, policy, explorer);
        }
示例#7
0
        public void EpsilonGreedy()
        {
            int   numActions  = 10;
            float epsilon     = 0f;
            var   policy      = new TestPolicy <RegularTestContext>();
            var   testContext = new RegularTestContext();
            var   explorer    = new EpsilonGreedyExplorer(epsilon);

            EpsilonGreedyWithContext(numActions, testContext, policy, explorer);
        }
示例#8
0
        public void EpsilonGreedy()
        {
            uint  numActions  = 10;
            float epsilon     = 0f;
            var   policy      = new TestPolicy <TestContext>();
            var   testContext = new TestContext();
            var   explorer    = new EpsilonGreedyExplorer <TestContext>(policy, epsilon, numActions);

            EpsilonGreedyWithContext(numActions, testContext, policy, explorer);
        }
示例#9
0
        public void EndToEndEpsilonGreedy()
        {
            uint  numActions = 10;
            float epsilon    = 0.5f;

            TestRecorder <SimpleContext> recorder = new TestRecorder <SimpleContext>();
            MwtExplorer <SimpleContext>  mwtt     = new MwtExplorer <SimpleContext>("mwt", recorder);

            TestSimplePolicy policy = new TestSimplePolicy();
            var explorer            = new EpsilonGreedyExplorer <SimpleContext>(policy, epsilon, numActions);

            EndToEnd(mwtt, explorer, recorder);
        }
示例#10
0
        // TODO: refactor
        //static void TestBootstrap(JObject config)
        //{
        //    var outputFile = config["OutputFile"].Value<string>();
        //    var appId = config["AppId"].Value<string>();
        //    var numActions = config["NumberOfActions"].Value<uint>();
        //    var experimentalUnitIdList = config["ExperimentalUnitIdList"].ToObject<string[]>();
        //    var configPolicies = (JArray)config["PolicyConfigurations"];

        //    switch (config["ContextType"].Value<int>())
        //    {
        //        case 0: // fixed action context
        //        {
        //            var contextList = Enumerable
        //                .Range(0, experimentalUnitIdList.Length)
        //                .Select(i => new RegularTestContext { Id = i })
        //                .ToArray();

        //            ExploreBootstrap<RegularTestContext>(appId, configPolicies,
        //                numActions, experimentalUnitIdList, contextList, outputFile);

        //            break;
        //        }
        //        case 1: // variable action context
        //        {
        //            var contextList = Enumerable
        //                .Range(0, experimentalUnitIdList.Length)
        //                .Select(i => new VariableActionTestContext(numActions) { Id = i })
        //                .ToArray();

        //            ExploreBootstrap<VariableActionTestContext>(appId, configPolicies,
        //                numActions, experimentalUnitIdList, contextList, outputFile);

        //            break;
        //        }
        //    }
        //}

        static void ExploreEpsilonGreedy <TContext>
        (
            string appId,
            int policyType,
            JToken configPolicy,
            float epsilon,
            int numActions,
            string[] experimentalUnitIdList,
            TContext[] contextList,
            string outputFile
        )
        {
            var recorder = new StringRecorder <TContext>();

            bool isVariableActionContext = typeof(IVariableActionContext).IsAssignableFrom(typeof(TContext));

            switch (policyType)
            {
            case 0:     // fixed policy
            {
                var policyAction = configPolicy["Action"].Value <uint>();

                var policy = new TestPolicy <TContext> {
                    ActionToChoose = policyAction
                };

                var explorer = new EpsilonGreedyExplorer(epsilon);

                var mwt = isVariableActionContext ?
                          MwtExplorer.Create(appId, new VariableActionProvider <TContext>(), recorder, explorer, policy) :
                          MwtExplorer.Create(appId, numActions, recorder, explorer, policy);

                for (int i = 0; i < experimentalUnitIdList.Length; i++)
                {
                    mwt.ChooseAction(experimentalUnitIdList[i], contextList[i]);
                }

                File.AppendAllText(outputFile, recorder.GetRecording());

                break;
            }
            }
        }
        public static void Run()
        {
            string exploration_type = "greedy";

            if (exploration_type == "greedy")
            {
                // Initialize Epsilon-Greedy explore algorithm using built-in StringRecorder and SimpleContext types

                // Creates a recorder of built-in StringRecorder type for string serialization
                StringRecorder <SimpleContext> recorder = new StringRecorder <SimpleContext>();

                int   numActions = 10;
                float epsilon    = 0.2f;
                // Creates an Epsilon-Greedy explorer using the specified settings
                var explorer = new EpsilonGreedyExplorer(epsilon);

                // Creates an MwtExplorer instance using the recorder above
                // Creates a policy that interacts with SimpleContext type
                var mwtt = MwtExplorer.Create("mwt", numActions, recorder, explorer, new StringPolicy());

                // Creates a context of built-in SimpleContext type
                SimpleContext context = new SimpleContext(new float[] { .5f, 1.3f, -.5f });

                // Performs exploration by passing an instance of the Epsilon-Greedy exploration algorithm into MwtExplorer
                // using a sample string to uniquely identify this event
                string uniqueKey = "eventid";
                int    action    = mwtt.ChooseAction(uniqueKey, context);

                Console.WriteLine(recorder.GetRecording());

                return;
            }
            else if (exploration_type == "tau-first")
            {
                // Initialize Tau-First explore algorithm using custom Recorder, Policy & Context types
                MyRecorder recorder = new MyRecorder();

                int numActions = 10;
                int tau        = 0;

                //MwtExplorer<MyContext> mwtt = new MwtExplorer<MyContext>("mwt", recorder);
                var mwtt = MwtExplorer.Create("mwt", numActions, recorder, new TauFirstExplorer(tau), new MyPolicy());

                int action = mwtt.ChooseAction("key", new MyContext());
                Console.WriteLine(String.Join(",", recorder.GetAllInteractions().Select(it => it.Action)));
                return;
            }
            else if (exploration_type == "bootstrap")
            {
                // TODO: add support for bootstrap
                //// Initialize Bootstrap explore algorithm using custom Recorder, Policy & Context types
                //MyRecorder recorder = new MyRecorder();
                ////MwtExplorer<MyContext> mwtt = new MwtExplorer<MyContext>("mwt", recorder);

                //uint numActions = 10;
                //uint numbags = 2;
                //MyPolicy[] policies = new MyPolicy[numbags];
                //for (int i = 0; i < numbags; i++)
                //{
                //    policies[i] = new MyPolicy(i * 2);
                //}
                //var mwtt = MwtExplorer.Create("mwt", recorder, new BootstrapExplorer(numActions));
                //uint action = mwtt.ChooseAction(new BootstrapExplorer<MyContext>(policies, numActions), "key", new MyContext());
                //Console.WriteLine(String.Join(",", recorder.GetAllInteractions().Select(it => it.Action)));
                return;
            }
            else if (exploration_type == "softmax")
            {
                // TODO: add support for softmax
                //// Initialize Softmax explore algorithm using custom Recorder, Scorer & Context types
                //MyRecorder recorder = new MyRecorder();
                //MwtExplorer<MyContext> mwtt = new MwtExplorer<MyContext>("mwt", recorder);

                //uint numActions = 10;
                //float lambda = 0.5f;
                //MyScorer scorer = new MyScorer(numActions);
                //uint action = mwtt.ChooseAction(new SoftmaxExplorer<MyContext>(scorer, lambda, numActions), "key", new MyContext());

                //Console.WriteLine(String.Join(",", recorder.GetAllInteractions().Select(it => it.Action)));
                return;
            }
            else if (exploration_type == "generic")
            {
                // TODO: add support for generic
                //// Initialize Generic explore algorithm using custom Recorder, Scorer & Context types
                //MyRecorder recorder = new MyRecorder();
                //MwtExplorer<MyContext> mwtt = new MwtExplorer<MyContext>("mwt", recorder);

                //uint numActions = 10;
                //MyScorer scorer = new MyScorer(numActions);
                //uint action = mwtt.ChooseAction(new GenericExplorer<MyContext>(scorer, numActions), "key", new MyContext());

                //Console.WriteLine(String.Join(",", recorder.GetAllInteractions().Select(it => it.Action)));
                return;
            }
            else
            {  //add error here
            }
        }
示例#12
0
        public static void Run()
        {
            string exploration_type = "greedy";

            if (exploration_type == "greedy")
            {
                // Initialize Epsilon-Greedy explore algorithm using built-in StringRecorder and SimpleContext types

                // Creates a recorder of built-in StringRecorder type for string serialization
                StringRecorder <SimpleContext> recorder = new StringRecorder <SimpleContext>();

                // Creates an MwtExplorer instance using the recorder above
                MwtExplorer <SimpleContext> mwtt = new MwtExplorer <SimpleContext>("mwt", recorder);

                // Creates a policy that interacts with SimpleContext type
                StringPolicy policy = new StringPolicy();

                uint  numActions = 10;
                float epsilon    = 0.2f;
                // Creates an Epsilon-Greedy explorer using the specified settings
                EpsilonGreedyExplorer <SimpleContext> explorer = new EpsilonGreedyExplorer <SimpleContext>(policy, epsilon, numActions);

                // Creates a context of built-in SimpleContext type
                SimpleContext context = new SimpleContext(new Feature[] {
                    new Feature()
                    {
                        Id = 1, Value = 0.5f
                    },
                    new Feature()
                    {
                        Id = 4, Value = 1.3f
                    },
                    new Feature()
                    {
                        Id = 9, Value = -0.5f
                    },
                });

                // Performs exploration by passing an instance of the Epsilon-Greedy exploration algorithm into MwtExplorer
                // using a sample string to uniquely identify this event
                string uniqueKey = "eventid";
                uint   action    = mwtt.ChooseAction(explorer, uniqueKey, context);

                Console.WriteLine(recorder.GetRecording());

                return;
            }
            else if (exploration_type == "tau-first")
            {
                // Initialize Tau-First explore algorithm using custom Recorder, Policy & Context types
                MyRecorder recorder          = new MyRecorder();
                MwtExplorer <MyContext> mwtt = new MwtExplorer <MyContext>("mwt", recorder);

                uint     numActions = 10;
                uint     tau        = 0;
                MyPolicy policy     = new MyPolicy();
                uint     action     = mwtt.ChooseAction(new TauFirstExplorer <MyContext>(policy, tau, numActions), "key", new MyContext());
                Console.WriteLine(String.Join(",", recorder.GetAllInteractions().Select(it => it.Action)));
                return;
            }
            else if (exploration_type == "bootstrap")
            {
                // Initialize Bootstrap explore algorithm using custom Recorder, Policy & Context types
                MyRecorder recorder          = new MyRecorder();
                MwtExplorer <MyContext> mwtt = new MwtExplorer <MyContext>("mwt", recorder);

                uint       numActions = 10;
                uint       numbags    = 2;
                MyPolicy[] policies   = new MyPolicy[numbags];
                for (int i = 0; i < numbags; i++)
                {
                    policies[i] = new MyPolicy(i * 2);
                }
                uint action = mwtt.ChooseAction(new BootstrapExplorer <MyContext>(policies, numActions), "key", new MyContext());
                Console.WriteLine(String.Join(",", recorder.GetAllInteractions().Select(it => it.Action)));
                return;
            }
            else if (exploration_type == "softmax")
            {
                // Initialize Softmax explore algorithm using custom Recorder, Scorer & Context types
                MyRecorder recorder          = new MyRecorder();
                MwtExplorer <MyContext> mwtt = new MwtExplorer <MyContext>("mwt", recorder);

                uint     numActions = 10;
                float    lambda     = 0.5f;
                MyScorer scorer     = new MyScorer(numActions);
                uint     action     = mwtt.ChooseAction(new SoftmaxExplorer <MyContext>(scorer, lambda, numActions), "key", new MyContext());

                Console.WriteLine(String.Join(",", recorder.GetAllInteractions().Select(it => it.Action)));
                return;
            }
            else if (exploration_type == "generic")
            {
                // Initialize Generic explore algorithm using custom Recorder, Scorer & Context types
                MyRecorder recorder          = new MyRecorder();
                MwtExplorer <MyContext> mwtt = new MwtExplorer <MyContext>("mwt", recorder);

                uint     numActions = 10;
                MyScorer scorer     = new MyScorer(numActions);
                uint     action     = mwtt.ChooseAction(new GenericExplorer <MyContext>(scorer, numActions), "key", new MyContext());

                Console.WriteLine(String.Join(",", recorder.GetAllInteractions().Select(it => it.Action)));
                return;
            }
            else
            {  //add error here
            }
        }