public void TestCbProgressiveValidation() { int numExamples = 1024; foreach (var cbType in new[] { "ips", "dr", "mtr" }) { var trainArguments = $"--cb_explore_adf --epsilon 0.1 --bag 3 -q ab --power_t 0 -l 0.1 --cb_type {cbType} --random_seed 50"; int[] topActionCounts = new int[3]; using (var vw1 = new VowpalWabbitJson(trainArguments)) using (var vw2 = new VowpalWabbitJson(trainArguments)) { foreach (var ex in GenerateData(numExamples)) { var json = ex.JSON; var pred1_a = vw1.Predict(json, VowpalWabbitPredictionType.ActionProbabilities); var pred1_b = vw1.Learn(json, VowpalWabbitPredictionType.ActionProbabilities); var pred2 = vw2.Learn(json, VowpalWabbitPredictionType.ActionProbabilities); AreEqual(pred1_a, pred2, cbType); AreEqual(pred1_b, pred2, cbType); topActionCounts[pred2[0].Action]++; //Debug.WriteLine(json); //Debug.WriteLine("Prob1.pred: " + string.Join(",", pred1_a.Select(a=>$"{a.Action}:{a.Score}"))); //Debug.WriteLine("Prob1.learn: " + string.Join(",", pred1_b.Select(a=>$"{a.Action}:{a.Score}"))); //Debug.WriteLine("Prob2.learn: " + string.Join(",", pred2.Select(a=>$"{a.Action}:{a.Score}"))); //Debug.WriteLine(""); } } foreach (var count in topActionCounts) { Assert.IsTrue(count < numExamples * 0.8, $"Unexpected action distribution: {count}"); } Debug.WriteLine($"cb_types: {cbType} " + string.Join(",", topActionCounts.Select((count, i) => $"{i}:{count}"))); } }
internal void TrainOffline(string message, string modelId, Dictionary <string, Context> data, IEnumerable <string> eventOrder, Uri onlineModelUri, string trainArguments = null) { // allow override if (trainArguments == null) { trainArguments = this.trainArguments; } // train model offline using trackback var settings = new VowpalWabbitSettings(trainArguments + $" --id {modelId} --save_resume --preserve_performance_counters -f offline.model"); using (var vw = new VowpalWabbitJson(settings)) { foreach (var id in eventOrder) { var json = data[id].JSON; var progressivePrediction = vw.Learn(json, VowpalWabbitPredictionType.ActionProbabilities); // TODO: validate eval output } } using (var vw = new VowpalWabbit("-i offline.model --save_resume --readable_model offline.model.txt -f offline.reset_perf_counters.model")) { } Blobs.DownloadFile(onlineModelUri, "online.model"); using (var vw = new VowpalWabbit("-i online.model --save_resume --readable_model online.model.txt -f online.reset_perf_counters.model")) { } // validate that the model is the same CollectionAssert.AreEqual( File.ReadAllBytes("offline.reset_perf_counters.model"), File.ReadAllBytes("online.reset_perf_counters.model"), $"{message}. Offline and online model differs. Compare online.model.txt with offline.model.txt to compare"); }
public void TestCbAdfExplore() { var json = JsonConvert.SerializeObject(new { U = new { age = "18" }, _multi = new[] { new { G = new { _text = "this rocks" }, K = new { constant = 1, doc = "1" } }, new { G = new { _text = "something NYC" }, K = new { constant = 1, doc = "2" } }, }, _label_Action = 2, _label_Probability = 0.1, _label_Cost = -1, _labelIndex = 1 }); using (var vw = new VowpalWabbitJson("--cb_explore_adf --bag 4 --epsilon 0.0001 --cb_type mtr --marginal K -q UG -b 26 --power_t 0 --l1 1e-9 -l 4e-3")) { for (int i = 0; i < 50; i++) { var pred = vw.Learn(json, VowpalWabbitPredictionType.ActionProbabilities); Assert.AreEqual(2, pred.Length); if (i > 40) { Assert.AreEqual(1, (int)pred[0].Action); Assert.IsTrue(pred[0].Score > .9); Assert.AreEqual(0, (int)pred[1].Action); Assert.IsTrue(pred[1].Score < .1); } } vw.Native.SaveModel("cbadfexplore.model"); } using (var vw = new VowpalWabbitJson(new VowpalWabbitSettings { Arguments = "-t", ModelStream = File.Open("cbadfexplore.model", FileMode.Open) })) { var predObj = vw.Predict(json, VowpalWabbitPredictionType.Dynamic); Assert.IsInstanceOfType(predObj, typeof(ActionScore[])); var pred = (ActionScore[])predObj; Assert.AreEqual(1, (int)pred[0].Action); Assert.IsTrue(pred[0].Score > .9); Assert.AreEqual(0, (int)pred[1].Action); Assert.IsTrue(pred[1].Score < .1); } using (var vwModel = new VowpalWabbitModel(new VowpalWabbitSettings { ModelStream = File.Open("cbadfexplore.model", FileMode.Open) })) using (var vwSeeded = new VowpalWabbitJson(new VowpalWabbitSettings { Model = vwModel })) { var pred = vwSeeded.Predict(json, VowpalWabbitPredictionType.ActionProbabilities); Assert.AreEqual(1, (int)pred[0].Action); Assert.IsTrue(pred[0].Score > .9); Assert.AreEqual(0, (int)pred[1].Action); Assert.IsTrue(pred[1].Score < .1); } using (var vwModel = new VowpalWabbitModel(new VowpalWabbitSettings { ModelStream = File.Open("cbadfexplore.model", FileMode.Open) })) { using (var vwPool = new VowpalWabbitJsonThreadedPrediction(vwModel)) using (var vw = vwPool.GetOrCreate()) { var predObj = vw.Value.Predict(json, VowpalWabbitPredictionType.Dynamic); Assert.IsInstanceOfType(predObj, typeof(ActionScore[])); var pred = (ActionScore[])predObj; Assert.AreEqual(1, (int)pred[0].Action); Assert.IsTrue(pred[0].Score > .9); Assert.AreEqual(0, (int)pred[1].Action); Assert.IsTrue(pred[1].Score < .1); } } }
public static void Main(string[] args) { // first argument needs to end with .json if (args.Length == 0) { Console.Error.WriteLine( "Usage: {0} <input.json> <vw arg1> <vw arg2> ...", Path.GetFileName(Assembly.GetExecutingAssembly().Location)); return; } try { var json = args[0]; var vwArguments = string.Join(" ", args.Skip(1)); var fileMode = DetectedFileMode(json); using (var vw = new VowpalWabbitJson(vwArguments)) { switch (fileMode) { case FileMode.JsonArray: using (var reader = new JsonTextReader(new StreamReader(json))) { if (!reader.Read()) { return; } if (reader.TokenType != JsonToken.StartArray) { return; } while (reader.Read()) { switch (reader.TokenType) { case JsonToken.StartObject: vw.Learn(reader); break; case JsonToken.EndObject: // skip break; case JsonToken.EndArray: // end reading return; } } } break; case FileMode.JsonNewLine: using (var reader = new StreamReader(json)) { string line; while ((line = reader.ReadLine()) != null) { if (string.IsNullOrWhiteSpace(line)) { continue; } vw.Learn(line); } } break; } } } catch (Exception e) { Console.Error.WriteLine("Exception: {0}.\n{1}", e.Message, e.StackTrace); } }
public async Task TestAzureTrainer() { var storageConnectionString = GetConfiguration("storageConnectionString"); var inputEventHubConnectionString = GetConfiguration("inputEventHubConnectionString"); var evalEventHubConnectionString = GetConfiguration("evalEventHubConnectionString"); var trainArguments = "--cb_explore_adf --epsilon 0.2 -q ab"; // register with AppInsights to collect exceptions var exceptions = RegisterAppInsightExceptionHook(); // cleanup blobs var blobs = new ModelBlobs(storageConnectionString); await blobs.Cleanup(); var data = GenerateData(100).ToDictionary(d => d.EventId, d => d); // start listening for event hub using (var trainProcesserHost = new LearnEventProcessorHost()) { await trainProcesserHost.StartAsync(new OnlineTrainerSettingsInternal { CheckpointPolicy = new CountingCheckpointPolicy(data.Count), JoinedEventHubConnectionString = inputEventHubConnectionString, EvalEventHubConnectionString = evalEventHubConnectionString, StorageConnectionString = storageConnectionString, Metadata = new OnlineTrainerSettings { ApplicationID = "vwunittest", TrainArguments = trainArguments }, EnableExampleTracing = true, EventHubStartDateTimeUtc = DateTime.UtcNow // ignore any events that arrived before this time }); // send events to event hub var eventHubInputClient = EventHubClient.CreateFromConnectionString(inputEventHubConnectionString); data.Values.ForEach(c => eventHubInputClient.Send(new EventData(c.JSONAsBytes) { PartitionKey = c.Index.ToString() })); // wait for trainer to checkpoint await blobs.PollTrainerCheckpoint(exceptions); // download & parse trackback file var trackback = blobs.DownloadTrackback(); Assert.AreEqual(data.Count, trackback.EventIds.Count); // train model offline using trackback var settings = new VowpalWabbitSettings(trainArguments + $" --id {trackback.ModelId} --save_resume --readable_model offline.json.model.txt -f offline.json.model"); using (var vw = new VowpalWabbitJson(settings)) { foreach (var id in trackback.EventIds) { var json = data[id].JSON; var progressivePrediction = vw.Learn(json, VowpalWabbitPredictionType.ActionProbabilities); // TODO: validate eval output } vw.Native.SaveModel("offline.json.2.model"); } // download online model new CloudBlob(blobs.ModelBlob.Uri, blobs.BlobClient.Credentials).DownloadToFile("online.model", FileMode.Create); // validate that the model is the same CollectionAssert.AreEqual( File.ReadAllBytes("offline.json.model"), File.ReadAllBytes("online.model"), "Offline and online model differs. Run to 'vw -i online.model --readable_model online.model.txt' to compare"); } }
public static void Main(string[] args) { // first argument needs to end with .json if (args.Length == 0) { Console.Error.WriteLine( "Usage: {0} <input.json> <vw arg1> <vw arg2> ...", Path.GetFileName(Assembly.GetExecutingAssembly().Location)); return; } try { var json = args[0]; var vwArguments = string.Join(" ", args.Skip(1)); var fileMode = DetectedFileMode(json); using (var vw = new VowpalWabbitJson(vwArguments)) { switch (fileMode) { case FileMode.JsonArray: using (var reader = new JsonTextReader(new StreamReader(json))) { if (!reader.Read()) return; if (reader.TokenType != JsonToken.StartArray) return; while (reader.Read()) { switch (reader.TokenType) { case JsonToken.StartObject: vw.Learn(reader); break; case JsonToken.EndObject: // skip break; case JsonToken.EndArray: // end reading return; } } } break; case FileMode.JsonNewLine: using (var reader = new StreamReader(json)) { string line; while ((line = reader.ReadLine()) != null) { if (string.IsNullOrWhiteSpace(line)) continue; vw.Learn(line); } } break; } } } catch (Exception e) { Console.Error.WriteLine("Exception: {0}.\n{1}", e.Message, e.StackTrace); } }