public void Validate(string[] lines, JsonReader jsonReader, IVowpalWabbitLabelComparator labelComparator = null, ILabel label = null, int?index = null, VowpalWabbitJsonExtension extension = null) { VowpalWabbitExample[] strExamples = new VowpalWabbitExample[lines.Count()]; try { for (int i = 0; i < lines.Length; i++) { strExamples[i] = this.vw.ParseLine(lines[i]); } using (var jsonSerializer = new VowpalWabbitJsonSerializer(this.vw)) { if (extension != null) { jsonSerializer.RegisterExtension(extension); // extension are not supported with native JSON parsing } using (var jsonExample = (VowpalWabbitMultiLineExampleCollection)jsonSerializer.ParseAndCreate(jsonReader, label, index)) { var jsonExamples = new List <VowpalWabbitExample>(); if (jsonExample.SharedExample != null) { jsonExamples.Add(jsonExample.SharedExample); } jsonExamples.AddRange(jsonExample.Examples); Assert.AreEqual(strExamples.Length, jsonExamples.Count); for (int i = 0; i < strExamples.Length; i++) { using (var strJsonExample = this.vw.ParseLine(jsonExamples[i].VowpalWabbitString)) { var diff = strExamples[i].Diff(this.vw, jsonExamples[i], labelComparator); Assert.IsNull(diff, diff + " generated string: '" + jsonExamples[i].VowpalWabbitString + "'"); diff = strExamples[i].Diff(this.vw, strJsonExample, labelComparator); Assert.IsNull(diff, diff); } } } } } finally { foreach (var ex in strExamples) { if (ex != null) { ex.Dispose(); } } } }
public void TestJsonLabelExtraction() { using (var vw = new VowpalWabbit("--cb_adf --rank_all")) { using (var jsonSerializer = new VowpalWabbitJsonSerializer(vw)) { string eventId = null; jsonSerializer.RegisterExtension((state, property) => { Assert.AreEqual(property, "_eventid"); Assert.IsTrue(state.Reader.Read()); eventId = (string)state.Reader.Value; return(true); }); jsonSerializer.Parse("{\"_eventid\":\"abc123\",\"a\":1,\"_label_cost\":-1,\"_label_probability\":0.3}"); Assert.AreEqual("abc123", eventId); using (var examples = jsonSerializer.CreateExamples()) { var single = examples as VowpalWabbitSingleLineExampleCollection; Assert.IsNotNull(single); var label = single.Example.Label as ContextualBanditLabel; Assert.IsNotNull(label); Assert.AreEqual(-1, label.Cost); Assert.AreEqual(0.3, label.Probability, 0.0001); } } using (var jsonSerializer = new VowpalWabbitJsonSerializer(vw)) { jsonSerializer.Parse("{\"_multi\":[{\"_text\":\"w1 w2\", \"a\":{\"x\":1}}, {\"_text\":\"w2 w3\"}], \"_labelindex\":1, \"_label_cost\":-1, \"_label_probability\":0.3}"); using (var examples = jsonSerializer.CreateExamples()) { var multi = examples as VowpalWabbitMultiLineExampleCollection; Assert.IsNotNull(multi); Assert.AreEqual(2, multi.Examples.Length); var label = multi.Examples[0].Label as ContextualBanditLabel; Assert.AreEqual(0, label.Cost); Assert.AreEqual(0, label.Probability); label = multi.Examples[1].Label as ContextualBanditLabel; Assert.IsNotNull(label); Assert.AreEqual(-1, label.Cost); Assert.AreEqual(0.3, label.Probability, 0.0001); } } } }
private IEnumerable <PipelineData> Stage1_Deserialize(PipelineData data) { try { using (var jsonReader = new JsonTextReader(new StringReader(data.JSON))) { //jsonReader.FloatParser = Util.ReadDoubleString; // jsonReader.ArrayPool = pool; VowpalWabbitJsonSerializer vwJsonSerializer = null; try { vwJsonSerializer = new VowpalWabbitJsonSerializer(this.trainer.VowpalWabbit, this.trainer.ReferenceResolver); vwJsonSerializer.RegisterExtension((state, property) => { if (TryExtractProperty(state, property, "_eventid", JsonToken.String, reader => data.EventId = (string)reader.Value)) { return(true); } else if (TryExtractProperty(state, property, "_timestamp", JsonToken.Date, reader => data.Timestamp = (DateTime)reader.Value)) { return(true); } else if (TryExtractProperty(state, property, "_ProbabilityOfDrop", JsonToken.Float, reader => data.ProbabilityOfDrop = (float)(reader.Value ?? 0f))) { return(true); } else if (TryExtractArrayProperty <float>(state, property, "_p", arr => data.Probabilities = arr)) { return(true); } else if (TryExtractArrayProperty <int>(state, property, "_a", arr => data.Actions = arr)) { return(true); } return(false); }); data.Example = vwJsonSerializer.ParseAndCreate(jsonReader); if (data.Probabilities == null) { throw new ArgumentNullException("Missing probabilities (_p)"); } if (data.Actions == null) { throw new ArgumentNullException("Missing actions (_a)"); } if (data.Example == null) { // unable to create example due to missing data // will be trigger later vwJsonSerializer.UserContext = data.Example; // make sure the serialize is not deallocated vwJsonSerializer = null; } } finally { if (vwJsonSerializer != null) { vwJsonSerializer.Dispose(); } } performanceCounters.Stage1_JSON_DeserializePerSec.Increment(); // delayed if (data.Example == null) { this.performanceCounters.Feature_Requests_Pending.Increment(); yield break; } } } catch (Exception ex) { this.telemetry.TrackException(ex, new Dictionary <string, string> { { "JSON", data.JSON } }); this.performanceCounters.Stage2_Faulty_Examples_Total.Increment(); this.performanceCounters.Stage2_Faulty_ExamplesPerSec.Increment(); yield break; } yield return(data); }
//private class Event //{ // internal VowpalWabbitExampleCollection Example; // internal string Line; // internal int LineNr; // internal ActionScore[] Prediction; //} /// <summary> /// Train VW on offline data. /// </summary> /// <param name="arguments">Base arguments.</param> /// <param name="inputFile">Path to input file.</param> /// <param name="predictionFile">Name of the output prediction file.</param> /// <param name="reloadInterval">The TimeSpan interval to reload model.</param> /// <param name="learningRate"> /// Learning rate must be specified here otherwise on Reload it will be reset. /// </param> /// <param name="cacheFilePrefix"> /// The prefix of the cache file name to use. For example: prefix = "test" => "test.vw.cache" /// If none or null, the input file name is used, e.g. "input.dataset" => "input.vw.cache" /// !!! IMPORTANT !!!: Always use a new cache name if a different dataset or reload interval is used. /// </param> /// <remarks> /// Both learning rates and cache file are added to initial training arguments as well as Reload arguments. /// </remarks> public static void Train(string arguments, string inputFile, string predictionFile = null, TimeSpan?reloadInterval = null, float?learningRate = null, string cacheFilePrefix = null) { var learningArgs = learningRate == null ? string.Empty : $" -l {learningRate}"; int cacheIndex = 0; var cacheArgs = (Func <int, string>)(i => $" --cache_file {cacheFilePrefix ?? Path.GetFileNameWithoutExtension(inputFile)}-{i}.vw.cache"); using (var reader = new StreamReader(inputFile)) using (var prediction = new StreamWriter(predictionFile ?? inputFile + ".prediction")) using (var vw = new VowpalWabbit(new VowpalWabbitSettings(arguments + learningArgs + cacheArgs(cacheIndex++)) { Verbose = true })) { string line; int lineNr = 0; int invalidExamples = 0; DateTime?lastTimestamp = null; while ((line = reader.ReadLine()) != null) { try { bool reload = false; using (var jsonSerializer = new VowpalWabbitJsonSerializer(vw)) { if (reloadInterval != null) { jsonSerializer.RegisterExtension((state, property) => { if (property.Equals("_timestamp", StringComparison.Ordinal)) { var eventTimestamp = state.Reader.ReadAsDateTime(); if (lastTimestamp == null) { lastTimestamp = eventTimestamp; } else if (lastTimestamp + reloadInterval < eventTimestamp) { reload = true; lastTimestamp = eventTimestamp; } return(true); } return(false); }); } // var pred = vw.Learn(line, VowpalWabbitPredictionType.ActionScore); using (var example = jsonSerializer.ParseAndCreate(line)) { var pred = example.Learn(VowpalWabbitPredictionType.ActionScore); prediction.WriteLine(JsonConvert.SerializeObject( new { nr = lineNr, @as = pred.Select(x => x.Action), p = pred.Select(x => x.Score) })); } if (reload) { vw.Reload(learningArgs + cacheArgs(cacheIndex++)); } } } catch (Exception) { invalidExamples++; } lineNr++; } } // memory leak and not much gain below... //using (var vw = new VowpalWabbit(new VowpalWabbitSettings(arguments) //{ // Verbose = true, // EnableThreadSafeExamplePooling = true, // MaxExamples = 1024 //})) //using (var reader = new StreamReader(inputFile)) //using (var prediction = new StreamWriter(inputFile + ".prediction")) //{ // int invalidExamples = 0; // var deserializeBlock = new TransformBlock<Event, Event>( // evt => // { // try // { // using (var vwJsonSerializer = new VowpalWabbitJsonSerializer(vw)) // { // evt.Example = vwJsonSerializer.ParseAndCreate(evt.Line); // } // // reclaim memory // evt.Line = null; // return evt; // } // catch (Exception) // { // Interlocked.Increment(ref invalidExamples); // return null; // } // }, // new ExecutionDataflowBlockOptions // { // BoundedCapacity = 16, // MaxDegreeOfParallelism = 8 // TODO: parameterize // }); // var learnBlock = new TransformBlock<Event, Event>( // evt => // { // evt.Prediction = evt.Example.Learn(VowpalWabbitPredictionType.ActionScore); // evt.Example.Dispose(); // return evt; // }, // new ExecutionDataflowBlockOptions // { // BoundedCapacity = 64, // MaxDegreeOfParallelism = 1 // }); // var predictionBlock = new ActionBlock<Event>( // evt => prediction.WriteLine(evt.LineNr + " " + string.Join(",", evt.Prediction.Select(a_s => $"{a_s.Action}:{a_s.Score}"))), // new ExecutionDataflowBlockOptions // { // BoundedCapacity = 16, // MaxDegreeOfParallelism = 1 // }); // var input = deserializeBlock.AsObserver(); // deserializeBlock.LinkTo(learnBlock, new DataflowLinkOptions { PropagateCompletion = true }, evt => evt != null); // deserializeBlock.LinkTo(DataflowBlock.NullTarget<object>()); // learnBlock.LinkTo(predictionBlock, new DataflowLinkOptions { PropagateCompletion = true }); // string line; // int lineNr = 0; // while ((line = reader.ReadLine()) != null) // input.OnNext(new Event { Line = line, LineNr = lineNr++ }); // input.OnCompleted(); // predictionBlock.Completion.Wait(); //Console.WriteLine($"Examples {lineNr}. Invalid: {invalidExamples}"); //} }
private IEnumerable <PipelineData> Stage1_Deserialize(PipelineData data) { try { using (var jsonReader = new JsonTextReader(new StringReader(data.JSON))) { //jsonReader.FloatParser = Util.ReadDoubleString; // jsonReader.ArrayPool = pool; VowpalWabbitJsonSerializer vwJsonSerializer = null; try { vwJsonSerializer = new VowpalWabbitJsonSerializer(this.trainer.VowpalWabbit, this.trainer.ReferenceResolver); vwJsonSerializer.RegisterExtension((state, property) => { if (property.Equals("_eventid", StringComparison.OrdinalIgnoreCase)) { if (!state.Reader.Read() && state.Reader.TokenType != JsonToken.String) { throw new VowpalWabbitJsonException(state.Reader, "Expected string"); } data.EventId = (string)state.Reader.Value; return(true); } else if (property.Equals("_timestamp", StringComparison.OrdinalIgnoreCase)) { if (!state.Reader.Read() && state.Reader.TokenType != JsonToken.Date) { throw new VowpalWabbitJsonException(state.Reader, "Expected date"); } data.Timestamp = (DateTime)state.Reader.Value; } return(false); }); data.Example = vwJsonSerializer.ParseAndCreate(jsonReader); if (data.Example == null) { // unable to create example due to missing data // will be trigger later vwJsonSerializer.UserContext = data.Example; // make sure the serialize is not deallocated vwJsonSerializer = null; } } finally { if (vwJsonSerializer != null) { vwJsonSerializer.Dispose(); } } performanceCounters.Stage1_JSON_DeserializePerSec.Increment(); // delayed if (data.Example == null) { this.performanceCounters.Feature_Requests_Pending.Increment(); yield break; } } } catch (Exception ex) { this.telemetry.TrackException(ex, new Dictionary <string, string> { { "JSON", data.JSON } }); this.performanceCounters.Stage2_Faulty_Examples_Total.Increment(); this.performanceCounters.Stage2_Faulty_ExamplesPerSec.Increment(); yield break; } yield return(data); }