Exemplo n.º 1
0
        public void TestNull4()
        {
            using (var vw = new VowpalWabbit <Context, ADF>("--cb_adf --rank_all --interact ab"))
            {
                var ctx = new Context()
                {
                    ID     = 25,
                    Vector = null,
                    ActionDependentFeatures = new[] {
                        new ADF {
                            ADFID = null
                        }
                    }.ToList()
                };

                var label = new ContextualBanditLabel()
                {
                    Action      = 1,
                    Cost        = 1,
                    Probability = 0.2f
                };

                vw.Learn(ctx, ctx.ActionDependentFeatures, 0, label);
                var result = vw.Predict(ctx, ctx.ActionDependentFeatures);
                Assert.AreEqual(1, result.Length);

                ctx.ID = null;

                vw.Learn(ctx, ctx.ActionDependentFeatures, 0, label);
                result = vw.Predict(ctx, ctx.ActionDependentFeatures);
                Assert.AreEqual(1, result.Length);
            }
        }
Exemplo n.º 2
0
 /// <summary>
 /// Evaluates <paramref name="learnedAction"/> and <paramref name="numActions"/>x constants policies w.r.t. to <paramref name="label"/>.
 /// </summary>
 /// <param name="learnedAction">The learned action.</param>
 /// <param name="numActions">The number constant policies to be evaluated.</param>
 /// <param name="label">The label.</param>
 /// <returns></returns>
 public PoliciesPerformance Evaluate(uint learnedAction, int numActions, ContextualBanditLabel label)
 {
     return(new PoliciesPerformance(
                this.vw.Learn(
                    new LearnedVsConstantPolicy(learnedAction, numActions),
                    label,
                    VowpalWabbitPredictionType.Scalars)));
 }
Exemplo n.º 3
0
 private static EvalData Create(ContextualBanditLabel label, string policyName, uint actionTaken)
 {
     return(new EvalData
     {
         PolicyName = policyName,
         JSON = JsonConvert.SerializeObject(
             new
         {
             name = policyName,
             cost = VowpalWabbitContextualBanditUtil.GetUnbiasedCost(label.Action, actionTaken, label.Cost, label.Probability)
         })
     });
 }
Exemplo n.º 4
0
 private void updateModelMaybe()
 {
     if (sinceLastUpdate >= ModelUpdateInterval)
     {
         // Locking at this level ensures a batch of events is processed completely before
         // the next batch (finer locking would allow interleaving, violating timeorder
         lock (this.vwLock)
         {
             // Exit gracefully if the object has been disposed
             if (vwDisposed)
             {
                 return;
             }
             foreach (var dp in log.FlushCompleteEvents())
             {
                 uint action = (uint)((int[])dp.InteractData.Value)[0];
                 var  label  = new ContextualBanditLabel(action, -dp.Reward, ((GenericTopSlotExplorerState)dp.InteractData.ExplorerState).Probabilities[0]);
                 // String (json) contexts need to be handled specially, since the C# interface
                 // does not currently handle the CB label properly
                 if (typeof(TContext) == typeof(string))
                 {
                     // Manually insert the CB label fields into the context
                     string labelStr = string.Format(CultureInfo.InvariantCulture, "\"_label_Action\":{0},\"_label_Cost\":{1},\"_label_Probability\":{2},\"_labelIndex\":{3},",
                                                     label.Action, label.Cost, label.Probability, label.Action - 1);
                     string context = ((string)dp.InteractData.Context).Insert(1, labelStr);
                     using (var vwSerializer = new VowpalWabbitJsonSerializer(vwJson.Native))
                         using (VowpalWabbitExampleCollection vwExample = vwSerializer.ParseAndCreate(context))
                         {
                             vwExample.Learn();
                         }
                 }
                 else
                 {
                     vw.Learn((TContext)dp.InteractData.Context, label, index: (int)label.Action - 1);
                 }
             }
             using (MemoryStream currModel = new MemoryStream())
             {
                 VowpalWabbit vwNative = (typeof(TContext) == typeof(string)) ? vwJson.Native : vw.Native;
                 vwNative.SaveModel(currModel);
                 currModel.Position = 0;
                 this.UpdateModel(currModel);
                 sinceLastUpdate = 0;
             }
         }
     }
 }
Exemplo n.º 5
0
 private void updateModelMaybe()
 {
     if (sinceLastUpdate >= ModelUpdateInterval)
     {
         foreach (var dp in log.FlushCompleteEvents())
         {
             uint action = (uint)((int[])dp.InteractData.Value)[0];
             var  label  = new ContextualBanditLabel(action, -dp.Reward, ((GenericTopSlotExplorerState)dp.InteractData.ExplorerState).Probabilities[action - 1]);
             vw.Learn((TContext)dp.InteractData.Context, label, index: (int)label.Action - 1);
         }
         using (MemoryStream currModel = new MemoryStream())
         {
             vw.Native.SaveModel(currModel);
             currModel.Position = 0;
             this.UpdateModel(currModel);
             sinceLastUpdate = 0;
         }
     }
 }
Exemplo n.º 6
0
        public void TestNull3()
        {
            using (var vw = new VowpalWabbit <Context, ADF>("--cb_adf --rank_all --interact ac"))
            {
                var ctx = new Context()
                {
                    ID      = 25,
                    Vector  = new float[] { 3 },
                    VectorC = new float[] { 2, 2, 3 },
                    ActionDependentFeatures = new[] {
                        new ADF {
                            ADFID = "23",
                        }
                    }.ToList()
                };

                var label = new ContextualBanditLabel()
                {
                    Action      = 1,
                    Cost        = 1,
                    Probability = 0.2f
                };

                vw.Learn(ctx, ctx.ActionDependentFeatures, 0, label);

                ctx.Vector = null;
                vw.Learn(ctx, ctx.ActionDependentFeatures, 0, label);

                ctx.Vector  = new float[] { 2 };
                ctx.VectorC = null;
                vw.Learn(ctx, ctx.ActionDependentFeatures, 0, label);

                ctx.Vector = null;
                vw.Learn(ctx, ctx.ActionDependentFeatures, 0, label);
            }
        }
Exemplo n.º 7
0
        public void TrainNewVWModelWithMultiActionJsonDirectData()
        {
            int numLocations = 2; // user location

            string[] locations = new string[] { "HealthyTown", "LessHealthyTown" };

            int numActions            = 3; // food item
            int numExamplesPerActions = 10000;
            var recorder = new FoodRecorder();

            using (var vw = new VowpalWabbit <FoodContext>(
                       new VowpalWabbitSettings("--cb_explore_adf --epsilon 0.2 --cb_type dr -q ::")
            {
                TypeInspector = JsonTypeInspector.Default,
                EnableStringExampleGeneration = true,
                EnableStringFloatCompact = true
            }))
            {
                // Learn
                var rand = new Random(0);
                for (int iE = 0; iE < numExamplesPerActions * numLocations; iE++)
                {
                    DateTime timeStamp = DateTime.UtcNow;

                    int iL = rand.Next(0, numLocations);

                    var context = new FoodContext {
                        Actions = new int[] { 1, 2, 3 }, UserLocation = locations[iL]
                    };
                    string key = "fooditem " + Guid.NewGuid().ToString();

                    int action = iE % numActions + 1;
                    recorder.Record(null, null, new EpsilonGreedyState {
                        Probability = 1.0f / numActions
                    }, null, key);

                    float cost = 0;

                    var draw = rand.NextDouble();
                    if (context.UserLocation == "HealthyTown")
                    {
                        // for healthy town, buy burger 1 with probability 0.1, burger 2 with probability 0.15, salad with probability 0.6
                        if ((action == 1 && draw < 0.1) || (action == 2 && draw < 0.15) || (action == 3 && draw < 0.6))
                        {
                            cost = -10;
                        }
                    }
                    else
                    {
                        // for unhealthy town, buy burger 1 with probability 0.4, burger 2 with probability 0.6, salad with probability 0.2
                        if ((action == 1 && draw < 0.4) || (action == 2 && draw < 0.6) || (action == 3 && draw < 0.2))
                        {
                            cost = -10;
                        }
                    }
                    var label = new ContextualBanditLabel
                    {
                        Action      = (uint)action,
                        Cost        = cost,
                        Probability = recorder.GetProb(key)
                    };
                    vw.Learn(context, label, index: (int)label.Action - 1);
                }
                var expectedActions = new Dictionary <string, uint>();
                expectedActions.Add("HealthyTown", 3);
                expectedActions.Add("LessHealthyTown", 2);
                for (int iE = 0; iE < numExamplesPerActions; iE++)
                {
                    foreach (string location in locations)
                    {
                        DateTime timeStamp = DateTime.UtcNow;

                        var context = new FoodContext {
                            Actions = new int[] { 1, 2, 3 }, UserLocation = location
                        };
                        ActionScore[] predicts = vw.Predict(context, VowpalWabbitPredictionType.ActionScore);
                        Assert.AreEqual(expectedActions[location], predicts[0].Action + 1);
                    }
                }
            }
        }