public void TestNull4() { using (var vw = new VowpalWabbit <Context, ADF>("--cb_adf --rank_all --interact ab")) { var ctx = new Context() { ID = 25, Vector = null, ActionDependentFeatures = new[] { new ADF { ADFID = null } }.ToList() }; var label = new ContextualBanditLabel() { Action = 1, Cost = 1, Probability = 0.2f }; vw.Learn(ctx, ctx.ActionDependentFeatures, 0, label); var result = vw.Predict(ctx, ctx.ActionDependentFeatures); Assert.AreEqual(1, result.Length); ctx.ID = null; vw.Learn(ctx, ctx.ActionDependentFeatures, 0, label); result = vw.Predict(ctx, ctx.ActionDependentFeatures); Assert.AreEqual(1, result.Length); } }
/// <summary> /// Evaluates <paramref name="learnedAction"/> and <paramref name="numActions"/>x constants policies w.r.t. to <paramref name="label"/>. /// </summary> /// <param name="learnedAction">The learned action.</param> /// <param name="numActions">The number constant policies to be evaluated.</param> /// <param name="label">The label.</param> /// <returns></returns> public PoliciesPerformance Evaluate(uint learnedAction, int numActions, ContextualBanditLabel label) { return(new PoliciesPerformance( this.vw.Learn( new LearnedVsConstantPolicy(learnedAction, numActions), label, VowpalWabbitPredictionType.Scalars))); }
private static EvalData Create(ContextualBanditLabel label, string policyName, uint actionTaken) { return(new EvalData { PolicyName = policyName, JSON = JsonConvert.SerializeObject( new { name = policyName, cost = VowpalWabbitContextualBanditUtil.GetUnbiasedCost(label.Action, actionTaken, label.Cost, label.Probability) }) }); }
private void updateModelMaybe() { if (sinceLastUpdate >= ModelUpdateInterval) { // Locking at this level ensures a batch of events is processed completely before // the next batch (finer locking would allow interleaving, violating timeorder lock (this.vwLock) { // Exit gracefully if the object has been disposed if (vwDisposed) { return; } foreach (var dp in log.FlushCompleteEvents()) { uint action = (uint)((int[])dp.InteractData.Value)[0]; var label = new ContextualBanditLabel(action, -dp.Reward, ((GenericTopSlotExplorerState)dp.InteractData.ExplorerState).Probabilities[0]); // String (json) contexts need to be handled specially, since the C# interface // does not currently handle the CB label properly if (typeof(TContext) == typeof(string)) { // Manually insert the CB label fields into the context string labelStr = string.Format(CultureInfo.InvariantCulture, "\"_label_Action\":{0},\"_label_Cost\":{1},\"_label_Probability\":{2},\"_labelIndex\":{3},", label.Action, label.Cost, label.Probability, label.Action - 1); string context = ((string)dp.InteractData.Context).Insert(1, labelStr); using (var vwSerializer = new VowpalWabbitJsonSerializer(vwJson.Native)) using (VowpalWabbitExampleCollection vwExample = vwSerializer.ParseAndCreate(context)) { vwExample.Learn(); } } else { vw.Learn((TContext)dp.InteractData.Context, label, index: (int)label.Action - 1); } } using (MemoryStream currModel = new MemoryStream()) { VowpalWabbit vwNative = (typeof(TContext) == typeof(string)) ? vwJson.Native : vw.Native; vwNative.SaveModel(currModel); currModel.Position = 0; this.UpdateModel(currModel); sinceLastUpdate = 0; } } } }
private void updateModelMaybe() { if (sinceLastUpdate >= ModelUpdateInterval) { foreach (var dp in log.FlushCompleteEvents()) { uint action = (uint)((int[])dp.InteractData.Value)[0]; var label = new ContextualBanditLabel(action, -dp.Reward, ((GenericTopSlotExplorerState)dp.InteractData.ExplorerState).Probabilities[action - 1]); vw.Learn((TContext)dp.InteractData.Context, label, index: (int)label.Action - 1); } using (MemoryStream currModel = new MemoryStream()) { vw.Native.SaveModel(currModel); currModel.Position = 0; this.UpdateModel(currModel); sinceLastUpdate = 0; } } }
public void TestNull3() { using (var vw = new VowpalWabbit <Context, ADF>("--cb_adf --rank_all --interact ac")) { var ctx = new Context() { ID = 25, Vector = new float[] { 3 }, VectorC = new float[] { 2, 2, 3 }, ActionDependentFeatures = new[] { new ADF { ADFID = "23", } }.ToList() }; var label = new ContextualBanditLabel() { Action = 1, Cost = 1, Probability = 0.2f }; vw.Learn(ctx, ctx.ActionDependentFeatures, 0, label); ctx.Vector = null; vw.Learn(ctx, ctx.ActionDependentFeatures, 0, label); ctx.Vector = new float[] { 2 }; ctx.VectorC = null; vw.Learn(ctx, ctx.ActionDependentFeatures, 0, label); ctx.Vector = null; vw.Learn(ctx, ctx.ActionDependentFeatures, 0, label); } }
public void TrainNewVWModelWithMultiActionJsonDirectData() { int numLocations = 2; // user location string[] locations = new string[] { "HealthyTown", "LessHealthyTown" }; int numActions = 3; // food item int numExamplesPerActions = 10000; var recorder = new FoodRecorder(); using (var vw = new VowpalWabbit <FoodContext>( new VowpalWabbitSettings("--cb_explore_adf --epsilon 0.2 --cb_type dr -q ::") { TypeInspector = JsonTypeInspector.Default, EnableStringExampleGeneration = true, EnableStringFloatCompact = true })) { // Learn var rand = new Random(0); for (int iE = 0; iE < numExamplesPerActions * numLocations; iE++) { DateTime timeStamp = DateTime.UtcNow; int iL = rand.Next(0, numLocations); var context = new FoodContext { Actions = new int[] { 1, 2, 3 }, UserLocation = locations[iL] }; string key = "fooditem " + Guid.NewGuid().ToString(); int action = iE % numActions + 1; recorder.Record(null, null, new EpsilonGreedyState { Probability = 1.0f / numActions }, null, key); float cost = 0; var draw = rand.NextDouble(); if (context.UserLocation == "HealthyTown") { // for healthy town, buy burger 1 with probability 0.1, burger 2 with probability 0.15, salad with probability 0.6 if ((action == 1 && draw < 0.1) || (action == 2 && draw < 0.15) || (action == 3 && draw < 0.6)) { cost = -10; } } else { // for unhealthy town, buy burger 1 with probability 0.4, burger 2 with probability 0.6, salad with probability 0.2 if ((action == 1 && draw < 0.4) || (action == 2 && draw < 0.6) || (action == 3 && draw < 0.2)) { cost = -10; } } var label = new ContextualBanditLabel { Action = (uint)action, Cost = cost, Probability = recorder.GetProb(key) }; vw.Learn(context, label, index: (int)label.Action - 1); } var expectedActions = new Dictionary <string, uint>(); expectedActions.Add("HealthyTown", 3); expectedActions.Add("LessHealthyTown", 2); for (int iE = 0; iE < numExamplesPerActions; iE++) { foreach (string location in locations) { DateTime timeStamp = DateTime.UtcNow; var context = new FoodContext { Actions = new int[] { 1, 2, 3 }, UserLocation = location }; ActionScore[] predicts = vw.Predict(context, VowpalWabbitPredictionType.ActionScore); Assert.AreEqual(expectedActions[location], predicts[0].Action + 1); } } } }