Example #1
0
        public void TestCbProgressiveValidation()
        {
            int numExamples = 1024;

            foreach (var cbType in new[] { "ips", "dr", "mtr" })
            {
                var   trainArguments  = $"--cb_explore_adf --epsilon 0.1 --bag 3 -q ab --power_t 0 -l 0.1 --cb_type {cbType} --random_seed 50";
                int[] topActionCounts = new int[3];
                using (var vw1 = new VowpalWabbitJson(trainArguments))
                    using (var vw2 = new VowpalWabbitJson(trainArguments))
                    {
                        foreach (var ex in GenerateData(numExamples))
                        {
                            var json = ex.JSON;

                            var pred1_a = vw1.Predict(json, VowpalWabbitPredictionType.ActionProbabilities);
                            var pred1_b = vw1.Learn(json, VowpalWabbitPredictionType.ActionProbabilities);

                            var pred2 = vw2.Learn(json, VowpalWabbitPredictionType.ActionProbabilities);

                            AreEqual(pred1_a, pred2, cbType);
                            AreEqual(pred1_b, pred2, cbType);

                            topActionCounts[pred2[0].Action]++;

                            //Debug.WriteLine(json);
                            //Debug.WriteLine("Prob1.pred:  " + string.Join(",", pred1_a.Select(a=>$"{a.Action}:{a.Score}")));
                            //Debug.WriteLine("Prob1.learn: " + string.Join(",", pred1_b.Select(a=>$"{a.Action}:{a.Score}")));
                            //Debug.WriteLine("Prob2.learn: " + string.Join(",", pred2.Select(a=>$"{a.Action}:{a.Score}")));
                            //Debug.WriteLine("");
                        }
                    }

                foreach (var count in topActionCounts)
                {
                    Assert.IsTrue(count < numExamples * 0.8, $"Unexpected action distribution: {count}");
                }
                Debug.WriteLine($"cb_types: {cbType} " + string.Join(",", topActionCounts.Select((count, i) => $"{i}:{count}")));
            }
        }
Example #2
0
            internal void TrainOffline(string message, string modelId, Dictionary <string, Context> data, IEnumerable <string> eventOrder, Uri onlineModelUri, string trainArguments = null)
            {
                // allow override
                if (trainArguments == null)
                {
                    trainArguments = this.trainArguments;
                }

                // train model offline using trackback
                var settings = new VowpalWabbitSettings(trainArguments + $" --id {modelId} --save_resume --preserve_performance_counters -f offline.model");

                using (var vw = new VowpalWabbitJson(settings))
                {
                    foreach (var id in eventOrder)
                    {
                        var json = data[id].JSON;

                        var progressivePrediction = vw.Learn(json, VowpalWabbitPredictionType.ActionProbabilities);
                        // TODO: validate eval output
                    }
                }

                using (var vw = new VowpalWabbit("-i offline.model --save_resume --readable_model offline.model.txt -f offline.reset_perf_counters.model"))
                { }

                Blobs.DownloadFile(onlineModelUri, "online.model");
                using (var vw = new VowpalWabbit("-i online.model --save_resume --readable_model online.model.txt -f online.reset_perf_counters.model"))
                { }


                // validate that the model is the same
                CollectionAssert.AreEqual(
                    File.ReadAllBytes("offline.reset_perf_counters.model"),
                    File.ReadAllBytes("online.reset_perf_counters.model"),
                    $"{message}. Offline and online model differs. Compare online.model.txt with offline.model.txt to compare");
            }
Example #3
0
        public void TestCbAdfExplore()
        {
            var json = JsonConvert.SerializeObject(new
            {
                U      = new { age = "18" },
                _multi = new[]
                {
                    new
                    {
                        G = new { _text = "this rocks" },
                        K = new { constant = 1, doc = "1" }
                    },
                    new
                    {
                        G = new { _text = "something NYC" },
                        K = new { constant = 1, doc = "2" }
                    },
                },
                _label_Action      = 2,
                _label_Probability = 0.1,
                _label_Cost        = -1,
                _labelIndex        = 1
            });

            using (var vw = new VowpalWabbitJson("--cb_explore_adf --bag 4 --epsilon 0.0001 --cb_type mtr --marginal K -q UG -b 26 --power_t 0 --l1 1e-9 -l 4e-3"))
            {
                for (int i = 0; i < 50; i++)
                {
                    var pred = vw.Learn(json, VowpalWabbitPredictionType.ActionProbabilities);
                    Assert.AreEqual(2, pred.Length);

                    if (i > 40)
                    {
                        Assert.AreEqual(1, (int)pred[0].Action);
                        Assert.IsTrue(pred[0].Score > .9);

                        Assert.AreEqual(0, (int)pred[1].Action);
                        Assert.IsTrue(pred[1].Score < .1);
                    }
                }

                vw.Native.SaveModel("cbadfexplore.model");
            }

            using (var vw = new VowpalWabbitJson(new VowpalWabbitSettings {
                Arguments = "-t", ModelStream = File.Open("cbadfexplore.model", FileMode.Open)
            }))
            {
                var predObj = vw.Predict(json, VowpalWabbitPredictionType.Dynamic);
                Assert.IsInstanceOfType(predObj, typeof(ActionScore[]));

                var pred = (ActionScore[])predObj;
                Assert.AreEqual(1, (int)pred[0].Action);
                Assert.IsTrue(pred[0].Score > .9);

                Assert.AreEqual(0, (int)pred[1].Action);
                Assert.IsTrue(pred[1].Score < .1);
            }

            using (var vwModel = new VowpalWabbitModel(new VowpalWabbitSettings {
                ModelStream = File.Open("cbadfexplore.model", FileMode.Open)
            }))
                using (var vwSeeded = new VowpalWabbitJson(new VowpalWabbitSettings {
                    Model = vwModel
                }))
                {
                    var pred = vwSeeded.Predict(json, VowpalWabbitPredictionType.ActionProbabilities);
                    Assert.AreEqual(1, (int)pred[0].Action);
                    Assert.IsTrue(pred[0].Score > .9);

                    Assert.AreEqual(0, (int)pred[1].Action);
                    Assert.IsTrue(pred[1].Score < .1);
                }

            using (var vwModel = new VowpalWabbitModel(new VowpalWabbitSettings {
                ModelStream = File.Open("cbadfexplore.model", FileMode.Open)
            }))
            {
                using (var vwPool = new VowpalWabbitJsonThreadedPrediction(vwModel))
                    using (var vw = vwPool.GetOrCreate())
                    {
                        var predObj = vw.Value.Predict(json, VowpalWabbitPredictionType.Dynamic);
                        Assert.IsInstanceOfType(predObj, typeof(ActionScore[]));

                        var pred = (ActionScore[])predObj;
                        Assert.AreEqual(1, (int)pred[0].Action);
                        Assert.IsTrue(pred[0].Score > .9);

                        Assert.AreEqual(0, (int)pred[1].Action);
                        Assert.IsTrue(pred[1].Score < .1);
                    }
            }
        }
Example #4
0
        public static void Main(string[] args)
        {
            // first argument needs to end with .json
            if (args.Length == 0)
            {
                Console.Error.WriteLine(
                    "Usage: {0} <input.json> <vw arg1> <vw arg2> ...",
                    Path.GetFileName(Assembly.GetExecutingAssembly().Location));
                return;
            }

            try
            {
                var json        = args[0];
                var vwArguments = string.Join(" ", args.Skip(1));

                var fileMode = DetectedFileMode(json);

                using (var vw = new VowpalWabbitJson(vwArguments))
                {
                    switch (fileMode)
                    {
                    case FileMode.JsonArray:
                        using (var reader = new JsonTextReader(new StreamReader(json)))
                        {
                            if (!reader.Read())
                            {
                                return;
                            }

                            if (reader.TokenType != JsonToken.StartArray)
                            {
                                return;
                            }

                            while (reader.Read())
                            {
                                switch (reader.TokenType)
                                {
                                case JsonToken.StartObject:
                                    vw.Learn(reader);
                                    break;

                                case JsonToken.EndObject:
                                    // skip
                                    break;

                                case JsonToken.EndArray:
                                    // end reading
                                    return;
                                }
                            }
                        }
                        break;

                    case FileMode.JsonNewLine:
                        using (var reader = new StreamReader(json))
                        {
                            string line;

                            while ((line = reader.ReadLine()) != null)
                            {
                                if (string.IsNullOrWhiteSpace(line))
                                {
                                    continue;
                                }

                                vw.Learn(line);
                            }
                        }
                        break;
                    }
                }
            }
            catch (Exception e)
            {
                Console.Error.WriteLine("Exception: {0}.\n{1}", e.Message, e.StackTrace);
            }
        }
Example #5
0
        public async Task TestAzureTrainer()
        {
            var storageConnectionString       = GetConfiguration("storageConnectionString");
            var inputEventHubConnectionString = GetConfiguration("inputEventHubConnectionString");
            var evalEventHubConnectionString  = GetConfiguration("evalEventHubConnectionString");

            var trainArguments = "--cb_explore_adf --epsilon 0.2 -q ab";

            // register with AppInsights to collect exceptions
            var exceptions = RegisterAppInsightExceptionHook();

            // cleanup blobs
            var blobs = new ModelBlobs(storageConnectionString);
            await blobs.Cleanup();

            var data = GenerateData(100).ToDictionary(d => d.EventId, d => d);

            // start listening for event hub
            using (var trainProcesserHost = new LearnEventProcessorHost())
            {
                await trainProcesserHost.StartAsync(new OnlineTrainerSettingsInternal
                {
                    CheckpointPolicy = new CountingCheckpointPolicy(data.Count),
                    JoinedEventHubConnectionString = inputEventHubConnectionString,
                    EvalEventHubConnectionString   = evalEventHubConnectionString,
                    StorageConnectionString        = storageConnectionString,
                    Metadata = new OnlineTrainerSettings
                    {
                        ApplicationID  = "vwunittest",
                        TrainArguments = trainArguments
                    },
                    EnableExampleTracing     = true,
                    EventHubStartDateTimeUtc = DateTime.UtcNow // ignore any events that arrived before this time
                });

                // send events to event hub
                var eventHubInputClient = EventHubClient.CreateFromConnectionString(inputEventHubConnectionString);
                data.Values.ForEach(c => eventHubInputClient.Send(new EventData(c.JSONAsBytes)
                {
                    PartitionKey = c.Index.ToString()
                }));

                // wait for trainer to checkpoint
                await blobs.PollTrainerCheckpoint(exceptions);

                // download & parse trackback file
                var trackback = blobs.DownloadTrackback();
                Assert.AreEqual(data.Count, trackback.EventIds.Count);

                // train model offline using trackback
                var settings = new VowpalWabbitSettings(trainArguments + $" --id {trackback.ModelId} --save_resume --readable_model offline.json.model.txt -f offline.json.model");
                using (var vw = new VowpalWabbitJson(settings))
                {
                    foreach (var id in trackback.EventIds)
                    {
                        var json = data[id].JSON;

                        var progressivePrediction = vw.Learn(json, VowpalWabbitPredictionType.ActionProbabilities);
                        // TODO: validate eval output
                    }

                    vw.Native.SaveModel("offline.json.2.model");
                }

                // download online model
                new CloudBlob(blobs.ModelBlob.Uri, blobs.BlobClient.Credentials).DownloadToFile("online.model", FileMode.Create);

                // validate that the model is the same
                CollectionAssert.AreEqual(
                    File.ReadAllBytes("offline.json.model"),
                    File.ReadAllBytes("online.model"),
                    "Offline and online model differs. Run to 'vw -i online.model --readable_model online.model.txt' to compare");
            }
        }
Example #6
0
        public void TestCbAdfExplore()
        {
            var json = JsonConvert.SerializeObject(new
            {
                U = new { age = "18" },
                _multi = new[]
                {
                            new
                            {
                                G = new { _text = "this rocks" },
                                K = new { constant = 1, doc = "1" }
                            },
                            new
                            {
                                G = new { _text = "something NYC" },
                                K = new { constant = 1, doc = "2" }
                            },
                        },
                _label_Action = 2,
                _label_Probability = 0.1,
                _label_Cost = -1,
                _labelIndex = 1
            });

            using (var vw = new VowpalWabbitJson("--cb_explore_adf --bag 4 --epsilon 0.0001 --cb_type mtr --marginal K -q UG -b 26 --power_t 0 --l1 1e-9 -l 4e-3"))
            {
                for (int i = 0; i < 50; i++)
                {
                    var pred = vw.Learn(json, VowpalWabbitPredictionType.ActionProbabilities);
                    Assert.AreEqual(2, pred.Length);

                    if (i > 40)
                    {
                        Assert.AreEqual(1, (int)pred[0].Action);
                        Assert.IsTrue(pred[0].Score > .9);

                        Assert.AreEqual(0, (int)pred[1].Action);
                        Assert.IsTrue(pred[1].Score < .1);
                    }
                }

                vw.Native.SaveModel("cbadfexplore.model");
            }

            using (var vw = new VowpalWabbitJson(new VowpalWabbitSettings { Arguments = "-t", ModelStream = File.Open("cbadfexplore.model", FileMode.Open) }))
            {
                var predObj = vw.Predict(json, VowpalWabbitPredictionType.Dynamic);
                Assert.IsInstanceOfType(predObj, typeof(ActionScore[]));

                var pred = (ActionScore[])predObj;
                Assert.AreEqual(1, (int)pred[0].Action);
                Assert.IsTrue(pred[0].Score > .9);

                Assert.AreEqual(0, (int)pred[1].Action);
                Assert.IsTrue(pred[1].Score < .1);
            }

            using (var vwModel = new VowpalWabbitModel(new VowpalWabbitSettings { ModelStream = File.Open("cbadfexplore.model", FileMode.Open) }))
            using (var vwSeeded = new VowpalWabbitJson(new VowpalWabbitSettings { Model = vwModel }))
            {
                var pred = vwSeeded.Predict(json, VowpalWabbitPredictionType.ActionProbabilities);
                Assert.AreEqual(1, (int)pred[0].Action);
                Assert.IsTrue(pred[0].Score > .9);

                Assert.AreEqual(0, (int)pred[1].Action);
                Assert.IsTrue(pred[1].Score < .1);
            }

            using (var vwModel = new VowpalWabbitModel(new VowpalWabbitSettings { ModelStream = File.Open("cbadfexplore.model", FileMode.Open) }))
            {
                using (var vwPool = new VowpalWabbitJsonThreadedPrediction(vwModel))
                using (var vw = vwPool.GetOrCreate())
                {
                    var predObj = vw.Value.Predict(json, VowpalWabbitPredictionType.Dynamic);
                    Assert.IsInstanceOfType(predObj, typeof(ActionScore[]));

                    var pred = (ActionScore[])predObj;
                    Assert.AreEqual(1, (int)pred[0].Action);
                    Assert.IsTrue(pred[0].Score > .9);

                    Assert.AreEqual(0, (int)pred[1].Action);
                    Assert.IsTrue(pred[1].Score < .1);
                }
            }
        }
Example #7
0
        public static void Main(string[] args)
        {
            // first argument needs to end with .json
            if (args.Length == 0)
            {
                Console.Error.WriteLine(
                    "Usage: {0} <input.json> <vw arg1> <vw arg2> ...",
                    Path.GetFileName(Assembly.GetExecutingAssembly().Location));
                return;
            }

            try
            {
                var json = args[0];
                var vwArguments = string.Join(" ", args.Skip(1));

                var fileMode = DetectedFileMode(json);

                using (var vw = new VowpalWabbitJson(vwArguments))
                {
                    switch (fileMode)
                    {
                        case FileMode.JsonArray:
                            using (var reader = new JsonTextReader(new StreamReader(json)))
                            {
                                if (!reader.Read())
                                    return;

                                if (reader.TokenType != JsonToken.StartArray)
                                    return;

                                while (reader.Read())
                                {
                                    switch (reader.TokenType)
                                    {
                                        case JsonToken.StartObject:
                                            vw.Learn(reader);
                                            break;
                                        case JsonToken.EndObject:
                                            // skip
                                            break;
                                        case JsonToken.EndArray:
                                            // end reading
                                            return;
                                    }
                                }
                            }
                            break;
                        case FileMode.JsonNewLine:
                            using (var reader = new StreamReader(json))
                            {
                                string line;

                                while ((line = reader.ReadLine()) != null)
                                {
                                    if (string.IsNullOrWhiteSpace(line))
                                        continue;

                                    vw.Learn(line);
                                }
                            }
                            break;
                    }
                }
            }
            catch (Exception e)
            {
                Console.Error.WriteLine("Exception: {0}.\n{1}", e.Message, e.StackTrace);
            }
        }