Example #1
0
        public void Validate(string[] lines, JsonReader jsonReader, IVowpalWabbitLabelComparator labelComparator = null, ILabel label = null, int?index = null, VowpalWabbitJsonExtension extension = null)
        {
            VowpalWabbitExample[] strExamples = new VowpalWabbitExample[lines.Count()];

            try
            {
                for (int i = 0; i < lines.Length; i++)
                {
                    strExamples[i] = this.vw.ParseLine(lines[i]);
                }

                using (var jsonSerializer = new VowpalWabbitJsonSerializer(this.vw))
                {
                    if (extension != null)
                    {
                        jsonSerializer.RegisterExtension(extension);
                        // extension are not supported with native JSON parsing
                    }

                    using (var jsonExample = (VowpalWabbitMultiLineExampleCollection)jsonSerializer.ParseAndCreate(jsonReader, label, index))
                    {
                        var jsonExamples = new List <VowpalWabbitExample>();

                        if (jsonExample.SharedExample != null)
                        {
                            jsonExamples.Add(jsonExample.SharedExample);
                        }

                        jsonExamples.AddRange(jsonExample.Examples);

                        Assert.AreEqual(strExamples.Length, jsonExamples.Count);


                        for (int i = 0; i < strExamples.Length; i++)
                        {
                            using (var strJsonExample = this.vw.ParseLine(jsonExamples[i].VowpalWabbitString))
                            {
                                var diff = strExamples[i].Diff(this.vw, jsonExamples[i], labelComparator);
                                Assert.IsNull(diff, diff + " generated string: '" + jsonExamples[i].VowpalWabbitString + "'");

                                diff = strExamples[i].Diff(this.vw, strJsonExample, labelComparator);
                                Assert.IsNull(diff, diff);
                            }
                        }
                    }
                }
            }
            finally
            {
                foreach (var ex in strExamples)
                {
                    if (ex != null)
                    {
                        ex.Dispose();
                    }
                }
            }
        }
Example #2
0
        public void TestJsonLabelExtraction()
        {
            using (var vw = new VowpalWabbit("--cb_adf --rank_all"))
            {
                using (var jsonSerializer = new VowpalWabbitJsonSerializer(vw))
                {
                    string eventId = null;
                    jsonSerializer.RegisterExtension((state, property) =>
                    {
                        Assert.AreEqual(property, "_eventid");
                        Assert.IsTrue(state.Reader.Read());

                        eventId = (string)state.Reader.Value;
                        return(true);
                    });

                    jsonSerializer.Parse("{\"_eventid\":\"abc123\",\"a\":1,\"_label_cost\":-1,\"_label_probability\":0.3}");

                    Assert.AreEqual("abc123", eventId);

                    using (var examples = jsonSerializer.CreateExamples())
                    {
                        var single = examples as VowpalWabbitSingleLineExampleCollection;
                        Assert.IsNotNull(single);

                        var label = single.Example.Label as ContextualBanditLabel;
                        Assert.IsNotNull(label);

                        Assert.AreEqual(-1, label.Cost);
                        Assert.AreEqual(0.3, label.Probability, 0.0001);
                    }
                }

                using (var jsonSerializer = new VowpalWabbitJsonSerializer(vw))
                {
                    jsonSerializer.Parse("{\"_multi\":[{\"_text\":\"w1 w2\", \"a\":{\"x\":1}}, {\"_text\":\"w2 w3\"}], \"_labelindex\":1, \"_label_cost\":-1, \"_label_probability\":0.3}");

                    using (var examples = jsonSerializer.CreateExamples())
                    {
                        var multi = examples as VowpalWabbitMultiLineExampleCollection;
                        Assert.IsNotNull(multi);

                        Assert.AreEqual(2, multi.Examples.Length);
                        var label = multi.Examples[0].Label as ContextualBanditLabel;
                        Assert.AreEqual(0, label.Cost);
                        Assert.AreEqual(0, label.Probability);

                        label = multi.Examples[1].Label as ContextualBanditLabel;
                        Assert.IsNotNull(label);

                        Assert.AreEqual(-1, label.Cost);
                        Assert.AreEqual(0.3, label.Probability, 0.0001);
                    }
                }
            }
        }
Example #3
0
        private IEnumerable <PipelineData> Stage1_Deserialize(PipelineData data)
        {
            try
            {
                using (var jsonReader = new JsonTextReader(new StringReader(data.JSON)))
                {
                    //jsonReader.FloatParser = Util.ReadDoubleString;
                    // jsonReader.ArrayPool = pool;

                    VowpalWabbitJsonSerializer vwJsonSerializer = null;
                    try
                    {
                        vwJsonSerializer = new VowpalWabbitJsonSerializer(this.trainer.VowpalWabbit, this.trainer.ReferenceResolver);

                        vwJsonSerializer.RegisterExtension((state, property) =>
                        {
                            if (TryExtractProperty(state, property, "_eventid", JsonToken.String, reader => data.EventId = (string)reader.Value))
                            {
                                return(true);
                            }
                            else if (TryExtractProperty(state, property, "_timestamp", JsonToken.Date, reader => data.Timestamp = (DateTime)reader.Value))
                            {
                                return(true);
                            }
                            else if (TryExtractProperty(state, property, "_ProbabilityOfDrop", JsonToken.Float, reader => data.ProbabilityOfDrop = (float)(reader.Value ?? 0f)))
                            {
                                return(true);
                            }
                            else if (TryExtractArrayProperty <float>(state, property, "_p", arr => data.Probabilities = arr))
                            {
                                return(true);
                            }
                            else if (TryExtractArrayProperty <int>(state, property, "_a", arr => data.Actions = arr))
                            {
                                return(true);
                            }

                            return(false);
                        });

                        data.Example = vwJsonSerializer.ParseAndCreate(jsonReader);

                        if (data.Probabilities == null)
                        {
                            throw new ArgumentNullException("Missing probabilities (_p)");
                        }
                        if (data.Actions == null)
                        {
                            throw new ArgumentNullException("Missing actions (_a)");
                        }

                        if (data.Example == null)
                        {
                            // unable to create example due to missing data
                            // will be trigger later
                            vwJsonSerializer.UserContext = data.Example;
                            // make sure the serialize is not deallocated
                            vwJsonSerializer = null;
                        }
                    }
                    finally
                    {
                        if (vwJsonSerializer != null)
                        {
                            vwJsonSerializer.Dispose();
                        }
                    }

                    performanceCounters.Stage1_JSON_DeserializePerSec.Increment();

                    // delayed
                    if (data.Example == null)
                    {
                        this.performanceCounters.Feature_Requests_Pending.Increment();
                        yield break;
                    }
                }
            }
            catch (Exception ex)
            {
                this.telemetry.TrackException(ex, new Dictionary <string, string> {
                    { "JSON", data.JSON }
                });

                this.performanceCounters.Stage2_Faulty_Examples_Total.Increment();
                this.performanceCounters.Stage2_Faulty_ExamplesPerSec.Increment();

                yield break;
            }

            yield return(data);
        }
Example #4
0
        //private class Event
        //{
        //    internal VowpalWabbitExampleCollection Example;

        //    internal string Line;

        //    internal int LineNr;

        //    internal ActionScore[] Prediction;
        //}

        /// <summary>
        /// Train VW on offline data.
        /// </summary>
        /// <param name="arguments">Base arguments.</param>
        /// <param name="inputFile">Path to input file.</param>
        /// <param name="predictionFile">Name of the output prediction file.</param>
        /// <param name="reloadInterval">The TimeSpan interval to reload model.</param>
        /// <param name="learningRate">
        /// Learning rate must be specified here otherwise on Reload it will be reset.
        /// </param>
        /// <param name="cacheFilePrefix">
        /// The prefix of the cache file name to use. For example: prefix = "test" => "test.vw.cache"
        /// If none or null, the input file name is used, e.g. "input.dataset" => "input.vw.cache"
        /// !!! IMPORTANT !!!: Always use a new cache name if a different dataset or reload interval is used.
        /// </param>
        /// <remarks>
        /// Both learning rates and cache file are added to initial training arguments as well as Reload arguments.
        /// </remarks>
        public static void Train(string arguments, string inputFile, string predictionFile = null, TimeSpan?reloadInterval = null, float?learningRate = null, string cacheFilePrefix = null)
        {
            var learningArgs = learningRate == null ? string.Empty : $" -l {learningRate}";

            int cacheIndex = 0;
            var cacheArgs  = (Func <int, string>)(i => $" --cache_file {cacheFilePrefix ?? Path.GetFileNameWithoutExtension(inputFile)}-{i}.vw.cache");

            using (var reader = new StreamReader(inputFile))
                using (var prediction = new StreamWriter(predictionFile ?? inputFile + ".prediction"))
                    using (var vw = new VowpalWabbit(new VowpalWabbitSettings(arguments + learningArgs + cacheArgs(cacheIndex++))
                    {
                        Verbose = true
                    }))
                    {
                        string   line;
                        int      lineNr          = 0;
                        int      invalidExamples = 0;
                        DateTime?lastTimestamp   = null;

                        while ((line = reader.ReadLine()) != null)
                        {
                            try
                            {
                                bool reload = false;
                                using (var jsonSerializer = new VowpalWabbitJsonSerializer(vw))
                                {
                                    if (reloadInterval != null)
                                    {
                                        jsonSerializer.RegisterExtension((state, property) =>
                                        {
                                            if (property.Equals("_timestamp", StringComparison.Ordinal))
                                            {
                                                var eventTimestamp = state.Reader.ReadAsDateTime();
                                                if (lastTimestamp == null)
                                                {
                                                    lastTimestamp = eventTimestamp;
                                                }
                                                else if (lastTimestamp + reloadInterval < eventTimestamp)
                                                {
                                                    reload        = true;
                                                    lastTimestamp = eventTimestamp;
                                                }

                                                return(true);
                                            }

                                            return(false);
                                        });
                                    }

                                    // var pred = vw.Learn(line, VowpalWabbitPredictionType.ActionScore);
                                    using (var example = jsonSerializer.ParseAndCreate(line))
                                    {
                                        var pred = example.Learn(VowpalWabbitPredictionType.ActionScore);

                                        prediction.WriteLine(JsonConvert.SerializeObject(
                                                                 new
                                        {
                                            nr  = lineNr,
                                            @as = pred.Select(x => x.Action),
                                            p   = pred.Select(x => x.Score)
                                        }));
                                    }

                                    if (reload)
                                    {
                                        vw.Reload(learningArgs + cacheArgs(cacheIndex++));
                                    }
                                }
                            }
                            catch (Exception)
                            {
                                invalidExamples++;
                            }

                            lineNr++;
                        }
                    }

            // memory leak and not much gain below...
            //using (var vw = new VowpalWabbit(new VowpalWabbitSettings(arguments)
            //{
            //    Verbose = true,
            //    EnableThreadSafeExamplePooling = true,
            //    MaxExamples = 1024
            //}))
            //using (var reader = new StreamReader(inputFile))
            //using (var prediction = new StreamWriter(inputFile + ".prediction"))
            //{
            //    int invalidExamples = 0;

            //    var deserializeBlock = new TransformBlock<Event, Event>(
            //        evt =>
            //        {
            //            try
            //            {
            //                using (var vwJsonSerializer = new VowpalWabbitJsonSerializer(vw))
            //                {
            //                    evt.Example = vwJsonSerializer.ParseAndCreate(evt.Line);
            //                }
            //                // reclaim memory
            //                evt.Line = null;

            //                return evt;
            //            }
            //            catch (Exception)
            //            {
            //                Interlocked.Increment(ref invalidExamples);
            //                return null;
            //            }
            //        },
            //        new ExecutionDataflowBlockOptions
            //        {
            //            BoundedCapacity = 16,
            //            MaxDegreeOfParallelism = 8 // TODO: parameterize
            //        });

            //    var learnBlock = new TransformBlock<Event, Event>(
            //        evt =>
            //        {
            //            evt.Prediction = evt.Example.Learn(VowpalWabbitPredictionType.ActionScore);
            //            evt.Example.Dispose();
            //            return evt;
            //        },
            //        new ExecutionDataflowBlockOptions
            //        {
            //            BoundedCapacity = 64,
            //            MaxDegreeOfParallelism = 1
            //        });

            //    var predictionBlock = new ActionBlock<Event>(
            //        evt => prediction.WriteLine(evt.LineNr + " " + string.Join(",", evt.Prediction.Select(a_s => $"{a_s.Action}:{a_s.Score}"))),
            //        new ExecutionDataflowBlockOptions
            //        {
            //            BoundedCapacity = 16,
            //            MaxDegreeOfParallelism = 1
            //        });

            //    var input = deserializeBlock.AsObserver();

            //    deserializeBlock.LinkTo(learnBlock, new DataflowLinkOptions { PropagateCompletion = true }, evt => evt != null);
            //    deserializeBlock.LinkTo(DataflowBlock.NullTarget<object>());

            //    learnBlock.LinkTo(predictionBlock, new DataflowLinkOptions { PropagateCompletion = true });

            //    string line;
            //    int lineNr = 0;

            //    while ((line = reader.ReadLine()) != null)
            //        input.OnNext(new Event { Line = line, LineNr = lineNr++ });
            //    input.OnCompleted();

            //    predictionBlock.Completion.Wait();

            //Console.WriteLine($"Examples {lineNr}. Invalid: {invalidExamples}");
            //}
        }
Example #5
0
        private IEnumerable <PipelineData> Stage1_Deserialize(PipelineData data)
        {
            try
            {
                using (var jsonReader = new JsonTextReader(new StringReader(data.JSON)))
                {
                    //jsonReader.FloatParser = Util.ReadDoubleString;
                    // jsonReader.ArrayPool = pool;

                    VowpalWabbitJsonSerializer vwJsonSerializer = null;
                    try
                    {
                        vwJsonSerializer = new VowpalWabbitJsonSerializer(this.trainer.VowpalWabbit, this.trainer.ReferenceResolver);

                        vwJsonSerializer.RegisterExtension((state, property) =>
                        {
                            if (property.Equals("_eventid", StringComparison.OrdinalIgnoreCase))
                            {
                                if (!state.Reader.Read() && state.Reader.TokenType != JsonToken.String)
                                {
                                    throw new VowpalWabbitJsonException(state.Reader, "Expected string");
                                }
                                data.EventId = (string)state.Reader.Value;

                                return(true);
                            }
                            else if (property.Equals("_timestamp", StringComparison.OrdinalIgnoreCase))
                            {
                                if (!state.Reader.Read() && state.Reader.TokenType != JsonToken.Date)
                                {
                                    throw new VowpalWabbitJsonException(state.Reader, "Expected date");
                                }
                                data.Timestamp = (DateTime)state.Reader.Value;
                            }

                            return(false);
                        });

                        data.Example = vwJsonSerializer.ParseAndCreate(jsonReader);

                        if (data.Example == null)
                        {
                            // unable to create example due to missing data
                            // will be trigger later
                            vwJsonSerializer.UserContext = data.Example;
                            // make sure the serialize is not deallocated
                            vwJsonSerializer = null;
                        }
                    }
                    finally
                    {
                        if (vwJsonSerializer != null)
                        {
                            vwJsonSerializer.Dispose();
                        }
                    }

                    performanceCounters.Stage1_JSON_DeserializePerSec.Increment();

                    // delayed
                    if (data.Example == null)
                    {
                        this.performanceCounters.Feature_Requests_Pending.Increment();
                        yield break;
                    }
                }
            }
            catch (Exception ex)
            {
                this.telemetry.TrackException(ex, new Dictionary <string, string> {
                    { "JSON", data.JSON }
                });

                this.performanceCounters.Stage2_Faulty_Examples_Total.Increment();
                this.performanceCounters.Stage2_Faulty_ExamplesPerSec.Increment();

                yield break;
            }

            yield return(data);
        }