Example #1
0
        public void E2ERankerStochasticRewards(DecisionServiceDeployment deployment)
        {
            // Create configuration for the decision service
            float initialEpsilon = .5f;

            deployment.ConfigureDecisionService(trainArguments: "--cb_explore_adf --cb_type dr -q :: --epsilon 0.2", initialExplorationEpsilon: initialEpsilon);

            string settingsBlobUri = deployment.SettingsUrl;

            deployment.OnlineTrainerWaitForStartup();

            float percentCorrect = UploadFoodContextData(deployment, settingsBlobUri, firstPass: true);

            Assert.IsTrue(percentCorrect < initialEpsilon);

            percentCorrect = UploadFoodContextData(deployment, settingsBlobUri, firstPass: false);
            Assert.IsTrue(percentCorrect > .8f);
        }
Example #2
0
        public async Task SimplePolicyTest(DecisionServiceDeployment deployment)
        {
            deployment.OnlineTrainerWaitForStartup();

            deployment.ConfigureDecisionService("--cb_explore 4 --epsilon 0", initialExplorationEpsilon: 1, isExplorationEnabled: true);

            // 4 Actions
            // why does this need to be different from default?
            var config = new DecisionServiceConfiguration(deployment.SettingsUrl)
            {
                InteractionUploadConfiguration = new BatchingConfiguration
                {
                    MaxEventCount = 64
                },
                ObservationUploadConfiguration = new BatchingConfiguration
                {
                    MaxEventCount = 64
                },
                PollingForModelPeriod = TimeSpan.FromMinutes(5)
            };

            config.InteractionUploadConfiguration.ErrorHandler   += JoinServiceBatchConfiguration_ErrorHandler;
            config.InteractionUploadConfiguration.SuccessHandler += JoinServiceBatchConfiguration_SuccessHandler;
            this.features = new string[] { "a", "b", "c", "d" };
            this.freq     = new Dictionary <string, int>();
            this.rnd      = new Random(123);

            deployment.OnlineTrainerReset();

            {
                var expectedEvents = 0;
                using (var client = Microsoft.Research.MultiWorldTesting.ClientLibrary.DecisionService.Create <MyContext>(config))
                {
                    for (int i = 0; i < 50; i++)
                    {
                        expectedEvents += SendEvents(client, 128);
                        // Thread.Sleep(500);
                    }
                }

                // TODO: flush doesn't work
                // Assert.AreEqual(expectedEvents, this.eventCount);
            }

            // 4 actions times 4 feature values
            Assert.AreEqual(4 * 4, freq.Keys.Count);

            Console.WriteLine("Exploration");
            var total = freq.Values.Sum();

            foreach (var k in freq.Keys.OrderBy(k => k))
            {
                var f = freq[k] / (float)total;
                Assert.IsTrue(f < 0.08);
                Console.WriteLine("{0} | {1}", k, f);
            }

            freq.Clear();

            await Task.Delay(TimeSpan.FromMinutes(2));

            // TODO: update eps: 0
            using (var client = Microsoft.Research.MultiWorldTesting.ClientLibrary.DecisionService.Create <MyContext>(config))
            {
                int i;
                for (i = 0; i < 120; i++)
                {
                    try
                    {
                        client.DownloadModelAndUpdate(new System.Threading.CancellationToken()).Wait();
                        break;
                    }
                    catch (Exception)
                    {
                        await Task.Delay(TimeSpan.FromSeconds(1));
                    }
                }

                Assert.IsTrue(i < 30, "Unable to download model");

                for (i = 0; i < 1024; i++)
                {
                    var key = Guid.NewGuid().ToString();

                    var featureIndex = i % features.Length;

                    var action = client.ChooseAction(key, new MyContext {
                        Feature = features[featureIndex]
                    });

                    var stat = string.Format("'{0}' '{1}' ", features[featureIndex], action);
                    int count;
                    if (freq.TryGetValue(stat, out count))
                    {
                        freq[stat]++;
                    }
                    else
                    {
                        freq.Add(stat, count);
                    }
                }
            }

            Console.WriteLine("Exploitation");
            total = freq.Values.Sum();
            foreach (var k in freq.Keys.OrderBy(k => k))
            {
                var f = freq[k] / (float)total;
                Assert.AreEqual(0.25f, f, 0.1);
                Console.WriteLine("{0} | {1}", k, f);
            }
        }