コード例 #1
0
        public async Task TestDSLocalModelUpdate()
        {
            string vwArgs = "--cb_explore_adf --epsilon 0.2 --cb_type dr -q ::";
            DecisionServiceLocal <FoodContext> dsLocal = new DecisionServiceLocal <FoodContext>(vwArgs, 1, TimeSpan.MaxValue);
            var context = new FoodContext {
                Actions = new int[] { 1, 2, 3 }, UserLocation = "HealthyTown"
            };
            string guid1 = Guid.NewGuid().ToString();
            string guid2 = Guid.NewGuid().ToString();

            byte[] prevModel = null;

            // Generate interactions and ensure the model updates at the right frequency
            // (updates every example initially)
            prevModel = dsLocal.Model;
            await dsLocal.ChooseActionAsync(guid1, context, 1);

            dsLocal.ReportRewardAndComplete((float)1.0, guid1);
            Assert.IsTrue(!dsLocal.Model.SequenceEqual(prevModel));

            // Set the model to update every two examples
            prevModel = dsLocal.Model;
            dsLocal.ModelUpdateInterval = 2;
            await dsLocal.ChooseActionAsync(guid1, context, 1);

            dsLocal.ReportRewardAndComplete((float)1.0, guid1);
            Assert.IsFalse(!dsLocal.Model.SequenceEqual(prevModel));
            await dsLocal.ChooseActionAsync(guid2, context, 1);

            dsLocal.ReportRewardAndComplete((float)2.0, guid1);
            Assert.IsTrue(!dsLocal.Model.SequenceEqual(prevModel));
        }
コード例 #2
0
        public void TrainNewVWModelWithMultiActionJsonDirectData()
        {
            int numLocations = 2; // user location

            string[] locations = new string[] { "HealthyTown", "LessHealthyTown" };

            int numActions            = 3; // food item
            int numExamplesPerActions = 10000;
            var recorder = new FoodRecorder();

            using (var vw = new VowpalWabbit <FoodContext>(
                       new VowpalWabbitSettings("--cb_explore_adf --epsilon 0.2 --cb_type dr -q ::")
            {
                TypeInspector = JsonTypeInspector.Default,
                EnableStringExampleGeneration = true,
                EnableStringFloatCompact = true
            }))
            {
                // Learn
                var rand = new Random(0);
                for (int iE = 0; iE < numExamplesPerActions * numLocations; iE++)
                {
                    DateTime timeStamp = DateTime.UtcNow;

                    int iL = rand.Next(0, numLocations);

                    var context = new FoodContext {
                        Actions = new int[] { 1, 2, 3 }, UserLocation = locations[iL]
                    };
                    string key = "fooditem " + Guid.NewGuid().ToString();

                    int action = iE % numActions + 1;
                    recorder.Record(null, null, new EpsilonGreedyState {
                        Probability = 1.0f / numActions
                    }, null, key);

                    float cost = 0;

                    var draw = rand.NextDouble();
                    if (context.UserLocation == "HealthyTown")
                    {
                        // for healthy town, buy burger 1 with probability 0.1, burger 2 with probability 0.15, salad with probability 0.6
                        if ((action == 1 && draw < 0.1) || (action == 2 && draw < 0.15) || (action == 3 && draw < 0.6))
                        {
                            cost = -10;
                        }
                    }
                    else
                    {
                        // for unhealthy town, buy burger 1 with probability 0.4, burger 2 with probability 0.6, salad with probability 0.2
                        if ((action == 1 && draw < 0.4) || (action == 2 && draw < 0.6) || (action == 3 && draw < 0.2))
                        {
                            cost = -10;
                        }
                    }
                    var label = new ContextualBanditLabel
                    {
                        Action      = (uint)action,
                        Cost        = cost,
                        Probability = recorder.GetProb(key)
                    };
                    vw.Learn(context, label, index: (int)label.Action - 1);
                }
                var expectedActions = new Dictionary <string, uint>();
                expectedActions.Add("HealthyTown", 3);
                expectedActions.Add("LessHealthyTown", 2);
                for (int iE = 0; iE < numExamplesPerActions; iE++)
                {
                    foreach (string location in locations)
                    {
                        DateTime timeStamp = DateTime.UtcNow;

                        var context = new FoodContext {
                            Actions = new int[] { 1, 2, 3 }, UserLocation = location
                        };
                        ActionScore[] predicts = vw.Predict(context, VowpalWabbitPredictionType.ActionScore);
                        Assert.AreEqual(expectedActions[location], predicts[0].Action + 1);
                    }
                }
            }
        }
コード例 #3
0
 public static IReadOnlyCollection <FoodFeature> GetFeaturesFromContext(FoodContext context)
 {
     return(context.ActionDependentFeatures);
 }
コード例 #4
0
        public void TestDSLocalInMemoryLogger()
        {
            // Logger for manually completed events
            var logger1 = new InMemoryLogger <FoodContext, int>(TimeSpan.MaxValue);
            // Logger that completes events automatically after 10ms (experimental unit duration)
            var logger2 = new InMemoryLogger <FoodContext, int>(new TimeSpan(0, 0, 0, 0, 10));
            var context = new FoodContext {
                Actions = new int[] { 1, 2, 3 }, UserLocation = "HealthyTown"
            };
            string guid1 = Guid.NewGuid().ToString();
            string guid2 = Guid.NewGuid().ToString();

            // Ensure manually completed events appear
            logger1.Record(context, 1, null, null, guid1);
            logger1.Record(context, 2, null, null, guid2);
            logger1.ReportRewardAndComplete(guid1, (float)2.0);
            logger1.ReportRewardAndComplete(guid2, (float)2.0);
            var dps1 = logger1.FlushCompleteEvents();

            Assert.IsTrue(dps1.Length == 2);
            string[] guids = { dps1[0].Key, dps1[1].Key };
            Assert.IsTrue(guids.Contains(guid1) && guids.Contains(guid2));

            // Ensure experimental unit duration works
            logger2.Record(context, 1, null, null, guid1);
            // The tick resolution in Windows is typically 15ms, so give some allowance
            Thread.Sleep(20);
            var dps2 = logger2.FlushCompleteEvents();

            Assert.IsTrue(dps2.Length == 1);
            // Since no reward was reported, the reward should be the default value
            Assert.IsTrue((dps2[0].Key == guid1) && (dps2[0].Reward == 0.0));

            // Use experimental unit and manually completed events simultaneously
            logger2.Record(context, 1, null, null, guid1);
            logger2.Record(context, 2, null, null, guid2);
            logger2.ReportRewardAndComplete(guid1, (float)2.0);
            dps2 = logger2.FlushCompleteEvents();
            Assert.IsTrue((dps2.Length == 1) && (dps2[0].Key == guid1));
            Thread.Sleep(50);
            dps2 = logger2.FlushCompleteEvents();
            Assert.IsTrue((dps2.Length == 1) && (dps2[0].Key == guid2));

            // Ensure multithreaded inserts yield correct results
            const int     NumThreads         = 16;
            const int     NumEventsPerThread = 100;
            List <Thread> threads            = new List <Thread>(NumThreads);

            for (int i = 0; i < NumThreads; i++)
            {
                threads.Add(new Thread(() =>
                {
                    for (int j = 0; j < NumEventsPerThread; j++)
                    {
                        string guid = Guid.NewGuid().ToString();
                        // Test manual logger
                        logger1.Record(context, 1, null, null, guid);
                        logger1.ReportRewardAndComplete(guid, (float)3.0);
                        // Test experimental unit logger
                        logger2.Record(context, 1, null, null, guid);
                        logger2.ReportReward(guid, (float)4.0);
                    }
                }));
            }
            foreach (Thread t in threads)
            {
                t.Start();
            }
            foreach (Thread t in threads)
            {
                t.Join();
            }
            dps1 = logger1.FlushCompleteEvents();
            Assert.IsTrue(dps1.Length == NumThreads * NumEventsPerThread);
            Thread.Sleep(50);
            dps2 = logger2.FlushCompleteEvents();
            Assert.IsTrue(dps2.Length == NumThreads * NumEventsPerThread);
            // Ensure the reward information was recorded before the event expired
            foreach (var dp in dps2)
            {
                Assert.IsTrue(dp.Reward == 4.0);
            }
        }