/// <summary> /// Compiles a serializers for the given example user type. /// </summary> /// <typeparam name="TExample">The example user type.</typeparam> /// <param name="settings">The serializer settings.</param> /// <returns>A serializer for the given user example type.</returns> public static VowpalWabbitSerializer <TExample> CreateSerializer <TExample>(VowpalWabbitSettings settings) { var serializerFunc = CreateSerializer <TExample, VowpalWabbitInterfaceVisitor, VowpalWabbitExample>(); #if DEBUG var stringSerializerFunc = CreateSerializer <TExample, VowpalWabbitStringVisitor, string>(); Func <VowpalWabbit, TExample, ILabel, VowpalWabbitExample> wrappedSerializerFunc; if (serializerFunc == null) { // if no features are found, no serializer is generated wrappedSerializerFunc = (a, b, c) => null; } else { wrappedSerializerFunc = (vw, example, label) => new VowpalWabbitDebugExample(serializerFunc(vw, example, label), stringSerializerFunc(vw, example, label)); } return(new VowpalWabbitSerializer <TExample>(wrappedSerializerFunc, settings)); #else if (serializerFunc == null) { // if no features are found, no serializer is generated serializerFunc = (a, b, c) => null; } return(new VowpalWabbitSerializer <TExample>(serializerFunc, settings)); #endif }
private void InitializeVowpalWabbit(VowpalWabbitSettings vwSettings) { if (this.settings.EnableExampleTracing) { vwSettings.EnableStringExampleGeneration = true; vwSettings.EnableStringFloatCompact = true; } vwSettings.EnableThreadSafeExamplePooling = true; vwSettings.MaxExamples = 64 * 1024; this.vw = new VW.VowpalWabbit(vwSettings); this.referenceResolver = new VowpalWabbitJsonReferenceResolver( this.delayedExampleCallback, cacheRequestItemPolicyFactory: key => new CacheItemPolicy() { SlidingExpiration = TimeSpan.FromHours(1), RemovedCallback = this.CacheEntryRemovedCallback }); //this.vwAllReduce = new VowpalWabbitThreadedLearning(vwSettings.ShallowCopy( // maxExampleQueueLengthPerInstance: 4*1024, // parallelOptions: new ParallelOptions // { // MaxDegreeOfParallelism = 2, // }, // exampleDistribution: VowpalWabbitExampleDistribution.RoundRobin, // exampleCountPerRun: 128 * 1024)); }
public static async Task MultiThreadedLearning() { var example = new MyExample { Income = 40, Age = 25 }; var label = new SimpleLabel { Label = 1 }; var settings = new VowpalWabbitSettings { ParallelOptions = new ParallelOptions { MaxDegreeOfParallelism = 16 }, ExampleCountPerRun = 2000, ExampleDistribution = VowpalWabbitExampleDistribution.RoundRobin }; using (var vw = new VowpalWabbitThreadedLearning(settings)) { using (var vwManaged = vw.Create <MyExample>()) { var prediction = await vwManaged.Learn(example, label, VowpalWabbitPredictionType.Scalar); } var saveModelTask = vw.SaveModel("m1.model"); await vw.Complete(); await saveModelTask; } }
public static async Task MultiThreadedLearning() { var example = new MyExample { Income = 40, Age = 25 }; var label = new SimpleLabel { Label = 1 }; var settings = new VowpalWabbitSettings { ParallelOptions = new ParallelOptions { MaxDegreeOfParallelism = 16 }, ExampleCountPerRun = 2000, ExampleDistribution = VowpalWabbitExampleDistribution.RoundRobin }; using (var vw = new VowpalWabbitThreadedLearning(settings)) { using (var vwManaged = vw.Create<MyExample>()) { var prediction = await vwManaged.Learn(example, label, VowpalWabbitPredictionType.Scalar); } var saveModelTask = vw.SaveModel("m1.model"); await vw.Complete(); await saveModelTask; } }
internal VowpalWabbitExampleValidator(VowpalWabbitSettings settings) { this.vw = new VowpalWabbit <TExample>(settings.ShallowCopy(enableStringExampleGeneration: true)); this.serializer = this.vw.Serializer.Func(this.vw.Native); this.vwNative = new VowpalWabbit <TExample>(settings); this.serializerNative = this.vwNative.Serializer.Func(this.vwNative.Native); this.factorySerializer = VowpalWabbitSerializerFactory.CreateSerializer <TExample>(settings.ShallowCopy(enableStringExampleGeneration: true)).Create(this.vw.Native); }
private void InitializeVowpalWabbit(VowpalWabbitSettings vwSettings) { if (this.settings.EnableExampleTracing) { vwSettings.EnableStringExampleGeneration = true; vwSettings.EnableStringFloatCompact = true; } vwSettings.EnableThreadSafeExamplePooling = true; vwSettings.MaxExamples = 64 * 1024; try { this.startDateTime = DateTime.UtcNow; this.vw = new VW.VowpalWabbit(vwSettings); var cmdLine = vw.Arguments.CommandLine; if (!(cmdLine.Contains("--cb_explore") || cmdLine.Contains("--cb_explore_adf"))) { throw new ArgumentException("Only cb_explore and cb_explore_adf are supported"); } } catch (Exception ex) { this.telemetry.TrackException(ex, new Dictionary <string, string> { { "help", "Invalid model. For help go to https://github.com/JohnLangford/vowpal_wabbit/wiki/Azure-Trainer" } }); throw ex; } this.referenceResolver = new VowpalWabbitJsonReferenceResolver( this.delayedExampleCallback, cacheRequestItemPolicyFactory: key => new CacheItemPolicy() { SlidingExpiration = TimeSpan.FromHours(1), RemovedCallback = this.CacheEntryRemovedCallback }); //this.vwAllReduce = new VowpalWabbitThreadedLearning(vwSettings.ShallowCopy( // maxExampleQueueLengthPerInstance: 4*1024, // parallelOptions: new ParallelOptions // { // MaxDegreeOfParallelism = 2, // }, // exampleDistribution: VowpalWabbitExampleDistribution.RoundRobin, // exampleCountPerRun: 128 * 1024)); }
internal VowpalWabbitSerializer(Func <VowpalWabbit, TExample, ILabel, VowpalWabbitExample> serializer, VowpalWabbitSettings settings) { if (serializer == null) { throw new ArgumentNullException("serializer"); } Contract.Ensures(this.settings != null); Contract.EndContractBlock(); this.serializer = serializer; this.settings = settings ?? new VowpalWabbitSettings(); var cacheableAttribute = (CacheableAttribute)typeof(TExample).GetCustomAttributes(typeof(CacheableAttribute), true).FirstOrDefault(); if (cacheableAttribute == null) { return; } if (this.settings.EnableExampleCaching) { if (cacheableAttribute.EqualityComparer == null) { this.exampleCache = new Dictionary <TExample, CacheEntry>(); } else { if (!typeof(IEqualityComparer <TExample>).IsAssignableFrom(cacheableAttribute.EqualityComparer)) { throw new ArgumentException( string.Format( CultureInfo.InvariantCulture, "EqualityComparer ({1}) specified in [Cachable] of {0} must implement IEqualityComparer<{0}>", typeof(TExample), cacheableAttribute.EqualityComparer)); } var comparer = (IEqualityComparer <TExample>)Activator.CreateInstance(cacheableAttribute.EqualityComparer); this.exampleCache = new Dictionary <TExample, CacheEntry>(comparer); } #if DEBUG this.reverseLookup = new Dictionary <VowpalWabbitExample, CacheEntry>(new ReferenceEqualityComparer <VowpalWabbitExample>()); #endif } }
/// <summary> /// Initializes a new instance of the <see cref="VowpalWabbitThreadedPredictionBase{TVowpalWabbit}"/> class. /// </summary> /// <param name="settings">The initial settings to use.</param> protected VowpalWabbitThreadedPredictionBase(VowpalWabbitSettings settings) { this.settings = settings; this.vwPool = new ObjectPool <VowpalWabbitModel, TVowpalWabbit>( ObjectFactory.Create( settings.Model, m => { if (m == null) { return(default(TVowpalWabbit)); } return(CreateVowpalWabbitChild(m)); })); }
public VowpalWabbitMultiExampleSerializerCompilerImpl(VowpalWabbitSettings settings, Schema schema, FeatureExpression multiFeature) { Contract.Requires(settings != null); Contract.Requires(schema != null); Contract.Requires(multiFeature != null); var nonMultiFeatures = schema.Features.Where(fe => fe != multiFeature).ToList(); this.sharedSerializerCompiler = nonMultiFeatures.Count == 0 ? null : new VowpalWabbitSingleExampleSerializerCompiler <TExample>( new Schema { Features = nonMultiFeatures }, settings == null ? null : settings.CustomFeaturizer, !settings.EnableStringExampleGeneration); this.adfSerializerComputer = new VowpalWabbitSingleExampleSerializerCompiler <TActionDependentFeature>( settings.TypeInspector.CreateSchema(settings, typeof(TActionDependentFeature)), settings == null ? null : settings.CustomFeaturizer, !settings.EnableStringExampleGeneration); var exampleParameter = Expression.Parameter(typeof(TExample), "example"); // CODE condition1 && condition2 && condition3 ... var condition = multiFeature.ValueValidExpressionFactories .Skip(1) .Aggregate( multiFeature.ValueValidExpressionFactories.First()(exampleParameter), (cond, factory) => Expression.AndAlso(cond, factory(exampleParameter))); var multiExpression = multiFeature.ValueExpressionFactory(exampleParameter); // CODE example => (IEnumerable<TActionDependentFeature>)(example._multi != null ? example._multi : null) var expr = Expression.Lambda <Func <TExample, IEnumerable <TActionDependentFeature> > >( Expression.Condition( condition, multiExpression, Expression.Constant(null, multiExpression.Type), typeof(IEnumerable <TActionDependentFeature>)), exampleParameter); this.adfAccessor = (Func <TExample, IEnumerable <TActionDependentFeature> >)expr.CompileToFunc(); }
internal VowpalWabbitExampleValidator(VowpalWabbitSettings settings) { this.vw = new VowpalWabbit <TExample>(settings.ShallowCopy(enableStringExampleGeneration: true)); var compiler = this.vw.Serializer as VowpalWabbitSingleExampleSerializerCompiler <TExample>; if (compiler != null) { this.serializer = compiler.Func(this.vw.Native); } this.vwNative = new VowpalWabbit <TExample>(settings); compiler = this.vwNative.Serializer as VowpalWabbitSingleExampleSerializerCompiler <TExample>; if (compiler != null) { this.serializerNative = compiler.Func(this.vwNative.Native); } this.factorySerializer = VowpalWabbitSerializerFactory.CreateSerializer <TExample>(settings.ShallowCopy(enableStringExampleGeneration: true)).Create(this.vw.Native); }
internal void FreshStart(OnlineTrainerState state = null, byte[] model = null) { if (state == null) { state = new OnlineTrainerState(); } this.telemetry.TrackTrace("Fresh Start", SeverityLevel.Information); // start from scratch this.state = state; // save extra state so learning can be resumed later with new data var settings = new VowpalWabbitSettings("--save_resume --preserve_performance_counters " + this.settings.Metadata.TrainArguments); if (model != null) { settings.ModelStream = new MemoryStream(model); } this.InitializeVowpalWabbit(settings); }
internal void TrainOffline(string message, string modelId, Dictionary <string, Context> data, IEnumerable <string> eventOrder, Uri onlineModelUri, string trainArguments = null) { // allow override if (trainArguments == null) { trainArguments = this.trainArguments; } // train model offline using trackback var settings = new VowpalWabbitSettings(trainArguments + $" --id {modelId} --save_resume --preserve_performance_counters -f offline.model"); using (var vw = new VowpalWabbitJson(settings)) { foreach (var id in eventOrder) { var json = data[id].JSON; var progressivePrediction = vw.Learn(json, VowpalWabbitPredictionType.ActionProbabilities); // TODO: validate eval output } } using (var vw = new VowpalWabbit("-i offline.model --save_resume --readable_model offline.model.txt -f offline.reset_perf_counters.model")) { } Blobs.DownloadFile(onlineModelUri, "online.model"); using (var vw = new VowpalWabbit("-i online.model --save_resume --readable_model online.model.txt -f online.reset_perf_counters.model")) { } // validate that the model is the same CollectionAssert.AreEqual( File.ReadAllBytes("offline.reset_perf_counters.model"), File.ReadAllBytes("online.reset_perf_counters.model"), $"{message}. Offline and online model differs. Compare online.model.txt with offline.model.txt to compare"); }
public Schema CreateSchema(VowpalWabbitSettings settings, Type type) { return(TypeInspector.CreateSchema(type, featurePropertyPredicate: (_, attr) => attr != null, labelPropertyPredicate: (_, attr) => attr != null)); }
public Schema CreateSchema(VowpalWabbitSettings settings, Type type) { return(TypeInspector.CreateSchema(type, featurePropertyPredicate: (_, __) => true, labelPropertyPredicate: (_, __) => true)); }
/// <summary> /// Initializes a new instance of the <see cref="VowpalWabbit{TExample}"/> class. /// </summary> /// <param name="settings">Arguments passed to native instance.</param> public VowpalWabbit(VowpalWabbitSettings settings) : this(new VowpalWabbit(settings)) { }
public async Task TestAzureTrainer() { var storageConnectionString = GetConfiguration("storageConnectionString"); var inputEventHubConnectionString = GetConfiguration("inputEventHubConnectionString"); var evalEventHubConnectionString = GetConfiguration("evalEventHubConnectionString"); var trainArguments = "--cb_explore_adf --epsilon 0.2 -q ab"; // register with AppInsights to collect exceptions var exceptions = RegisterAppInsightExceptionHook(); // cleanup blobs var blobs = new ModelBlobs(storageConnectionString); await blobs.Cleanup(); var data = GenerateData(100).ToDictionary(d => d.EventId, d => d); // start listening for event hub using (var trainProcesserHost = new LearnEventProcessorHost()) { await trainProcesserHost.StartAsync(new OnlineTrainerSettingsInternal { CheckpointPolicy = new CountingCheckpointPolicy(data.Count), JoinedEventHubConnectionString = inputEventHubConnectionString, EvalEventHubConnectionString = evalEventHubConnectionString, StorageConnectionString = storageConnectionString, Metadata = new OnlineTrainerSettings { ApplicationID = "vwunittest", TrainArguments = trainArguments }, EnableExampleTracing = true, EventHubStartDateTimeUtc = DateTime.UtcNow // ignore any events that arrived before this time }); // send events to event hub var eventHubInputClient = EventHubClient.CreateFromConnectionString(inputEventHubConnectionString); data.Values.ForEach(c => eventHubInputClient.Send(new EventData(c.JSONAsBytes) { PartitionKey = c.Index.ToString() })); // wait for trainer to checkpoint await blobs.PollTrainerCheckpoint(exceptions); // download & parse trackback file var trackback = blobs.DownloadTrackback(); Assert.AreEqual(data.Count, trackback.EventIds.Count); // train model offline using trackback var settings = new VowpalWabbitSettings(trainArguments + $" --id {trackback.ModelId} --save_resume --readable_model offline.json.model.txt -f offline.json.model"); using (var vw = new VowpalWabbitJson(settings)) { foreach (var id in trackback.EventIds) { var json = data[id].JSON; var progressivePrediction = vw.Learn(json, VowpalWabbitPredictionType.ActionProbabilities); // TODO: validate eval output } vw.Native.SaveModel("offline.json.2.model"); } // download online model new CloudBlob(blobs.ModelBlob.Uri, blobs.BlobClient.Credentials).DownloadToFile("online.model", FileMode.Create); // validate that the model is the same CollectionAssert.AreEqual( File.ReadAllBytes("offline.json.model"), File.ReadAllBytes("online.model"), "Offline and online model differs. Run to 'vw -i online.model --readable_model online.model.txt' to compare"); } }
public Schema CreateSchema(VowpalWabbitSettings settings, Type type) { return JsonTypeInspector.CreateSchema(type, settings.PropertyConfiguration); }
internal async Task TestAllReduceInternal() { var data = Enumerable.Range(1, 1000).Select(_ => Generator.GenerateShared(10)).ToList(); var stringSerializerCompiler = (VowpalWabbitSingleExampleSerializerCompiler<CbAdfShared>) VowpalWabbitSerializerFactory.CreateSerializer<CbAdfShared>(new VowpalWabbitSettings { EnableStringExampleGeneration = true }); var stringSerializerAdfCompiler = (VowpalWabbitSingleExampleSerializerCompiler<CbAdfAction>) VowpalWabbitSerializerFactory.CreateSerializer<CbAdfAction>(new VowpalWabbitSettings { EnableStringExampleGeneration = true }); var stringData = new List<List<string>>(); VowpalWabbitPerformanceStatistics statsExpected; using (var spanningTree = new SpanningTreeClr()) { spanningTree.Start(); using (var vw1 = new VowpalWabbit(new VowpalWabbitSettings(@"--total 2 --node 1 --unique_id 0 --span_server localhost --cb_adf --rank_all --interact xy") { EnableStringExampleGeneration = true })) using (var vw2 = new VowpalWabbit(new VowpalWabbitSettings(@"--total 2 --node 0 --unique_id 0 --span_server localhost --cb_adf --rank_all --interact xy") { EnableStringExampleGeneration = true } )) { var stringSerializer = stringSerializerCompiler.Func(vw1); var stringSerializerAdf = stringSerializerAdfCompiler.Func(vw1); // serialize foreach (var d in data) { var block = new List<string>(); using (var context = new VowpalWabbitMarshalContext(vw1)) { stringSerializer(context, d.Item1, SharedLabel.Instance); block.Add(context.ToString()); } block.AddRange(d.Item2.Select((a, i) => { using (var context = new VowpalWabbitMarshalContext(vw1)) { stringSerializerAdf(context, a, i == d.Item3.Action ? d.Item3 : null); return context.ToString(); } })); stringData.Add(block); } await Task.WhenAll( Task.Factory.StartNew(() => Ingest(vw1, stringData.Take(500))), Task.Factory.StartNew(() => Ingest(vw2, stringData.Skip(500)))); vw1.SaveModel("expected.1.model"); vw2.SaveModel("expected.2.model"); statsExpected = vw1.PerformanceStatistics; } } // skip header var expected1Model = File.ReadAllBytes("expected.1.model").Skip(0x15).ToList(); var expected2Model = File.ReadAllBytes("expected.2.model").Skip(0x15).ToList(); var settings = new VowpalWabbitSettings("--cb_adf --rank_all --interact xy") { ParallelOptions = new ParallelOptions { MaxDegreeOfParallelism = 2 }, ExampleCountPerRun = 2000, ExampleDistribution = VowpalWabbitExampleDistribution.RoundRobin }; using (var vw = new VowpalWabbitThreadedLearning(settings)) { await Task.WhenAll( Task.Factory.StartNew(() => Ingest(vw, stringData.Take(500))), Task.Factory.StartNew(() => Ingest(vw, stringData.Skip(500)))); // important to enqueue the request before Complete() is called var statsTask = vw.PerformanceStatistics; var modelSave = vw.SaveModel("actual.model"); await vw.Complete(); var statsActual = await statsTask; VWTestHelper.AssertEqual(statsExpected, statsActual); await modelSave; // skip header var actualModel = File.ReadAllBytes("actual.model").Skip(0x15).ToList(); CollectionAssert.AreEqual(expected1Model, actualModel); CollectionAssert.AreEqual(expected2Model, actualModel); } using (var vw = new VowpalWabbitThreadedLearning(settings)) { var vwManaged = vw.Create<CbAdfShared, CbAdfAction>(); await Task.WhenAll( Task.Factory.StartNew(() => Ingest(vwManaged, data.Take(500))), Task.Factory.StartNew(() => Ingest(vwManaged, data.Skip(500)))); // important to enqueue the request before Complete() is called var statsTask = vw.PerformanceStatistics; var modelSave = vw.SaveModel("actual.managed.model"); await vw.Complete(); var statsActual = await statsTask; VWTestHelper.AssertEqual(statsExpected, statsActual); await modelSave; // skip header var actualModel = File.ReadAllBytes("actual.managed.model").Skip(0x15).ToList(); CollectionAssert.AreEqual(expected1Model, actualModel); CollectionAssert.AreEqual(expected2Model, actualModel); } }
public Schema CreateSchema(VowpalWabbitSettings settings, Type type) { return TypeInspector.CreateSchema(type, featurePropertyPredicate: (_, __) => true, labelPropertyPredicate: (_, __) => true); }
internal VowpalWabbitExampleJsonValidator(VowpalWabbitSettings settings) { this.vw = new VowpalWabbit(settings.ShallowCopy(enableStringExampleGeneration: true)); this.jsonSerializer = new VowpalWabbitJsonSerializer(this.vw); }
private async Task <bool> TryLoadModel() { // find the model blob if (string.IsNullOrEmpty(this.state.ModelName)) { this.telemetry.TrackTrace("Model not specified"); return(false); } var container = this.blobClient.GetContainerReference(this.settings.StorageContainerName); if (!await container.ExistsAsync()) { this.telemetry.TrackTrace($"Storage container missing '{this.settings.StorageContainerName}'"); return(false); } var modelBlob = container.GetBlockBlobReference(this.state.ModelName); if (!await modelBlob.ExistsAsync()) { this.telemetry.TrackTrace($"Model blob '{this.state.ModelName}' is missing"); return(false); } // load the model using (var modelStream = await modelBlob.OpenReadAsync()) { this.InitializeVowpalWabbit(new VowpalWabbitSettings { ModelStream = modelStream }); this.telemetry.TrackTrace($"Model loaded {this.state.ModelName}", SeverityLevel.Verbose); // validate that the loaded VW model has the same settings as requested by C&C var newSettings = new VowpalWabbitSettings(this.settings.Metadata.TrainArguments); using (var newVW = new VW.VowpalWabbit(newSettings)) { newVW.ID = this.vw.ID; // save the VW instance to a model and load again to get fully expanded parameters. string newVWarguments; using (var tempModel = new MemoryStream()) { newVW.SaveModel(tempModel); tempModel.Position = 0; using (var tempVW = new VW.VowpalWabbit(new VowpalWabbitSettings { ModelStream = tempModel })) { newVWarguments = CleanVowpalWabbitArguments(tempVW.Arguments.CommandLine); } } var oldVWarguments = CleanVowpalWabbitArguments(this.vw.Arguments.CommandLine); // this is the expanded command line if (newVWarguments != oldVWarguments) { this.telemetry.TrackTrace("New VowpalWabbit settings found. Discarding existing model", SeverityLevel.Information, new Dictionary <string, string> { { "TrainArguments", newVW.Arguments.CommandLine }, { "NewExpandedArguments", newVWarguments }, { "OldExpandedArgumentsCleaned", oldVWarguments }, { "OldExpandedArguments", this.vw.Arguments.CommandLine }, }); // discard old, use fresh this.vw.Dispose(); this.vw = null; this.InitializeVowpalWabbit(newSettings); } } } // store the initial model this.settings.InitialVowpalWabbitModel = this.state.ModelName; return(true); }
private void InitializeVowpalWabbit(VowpalWabbitSettings vwSettings) { if (this.settings.EnableExampleTracing) { vwSettings.EnableStringExampleGeneration = true; vwSettings.EnableStringFloatCompact = true; } vwSettings.EnableThreadSafeExamplePooling = true; vwSettings.MaxExamples = 64 * 1024; try { this.vw = new VW.VowpalWabbit(vwSettings); var cmdLine = vw.Arguments.CommandLine; if (!(cmdLine.Contains("--cb_explore") || cmdLine.Contains("--cb_explore_adf"))) throw new ArgumentException("Only cb_explore and cb_explore_adf are supported"); } catch (Exception ex) { this.telemetry.TrackException(ex, new Dictionary<string, string> { { "help", "Invalid model. For help go to https://github.com/JohnLangford/vowpal_wabbit/wiki/Azure-Trainer" } }); throw ex; } this.referenceResolver = new VowpalWabbitJsonReferenceResolver( this.delayedExampleCallback, cacheRequestItemPolicyFactory: key => new CacheItemPolicy() { SlidingExpiration = TimeSpan.FromHours(1), RemovedCallback = this.CacheEntryRemovedCallback }); //this.vwAllReduce = new VowpalWabbitThreadedLearning(vwSettings.ShallowCopy( // maxExampleQueueLengthPerInstance: 4*1024, // parallelOptions: new ParallelOptions // { // MaxDegreeOfParallelism = 2, // }, // exampleDistribution: VowpalWabbitExampleDistribution.RoundRobin, // exampleCountPerRun: 128 * 1024)); }
/// <summary> /// Creates a serializer for the given type and settings. /// </summary> /// <typeparam name="TExample">The user type to serialize.</typeparam> /// <param name="settings"></param> /// <returns></returns> public static IVowpalWabbitSerializerCompiler <TExample> CreateSerializer <TExample>(VowpalWabbitSettings settings = null) { Schema schema = null; Type cacheKey = null; if (settings != null && settings.Schema != null) { schema = settings.Schema; } else { // only cache non-string generating serializer if (!settings.EnableStringExampleGeneration) { cacheKey = typeof(TExample); object serializer; if (SerializerCache.TryGetValue(cacheKey, out serializer)) { return((IVowpalWabbitSerializerCompiler <TExample>)serializer); } } ITypeInspector typeInspector = settings.TypeInspector; if (typeInspector == null) { typeInspector = TypeInspector.Default; } // TODO: enhance caching based on feature list & featurizer set // if no feature mapping is provided, use [Feature] annotation on provided type. schema = typeInspector.CreateSchema(settings, typeof(TExample)); var multiExampleSerializerCompiler = VowpalWabbitMultiExampleSerializerCompiler.TryCreate <TExample>(settings, schema); if (multiExampleSerializerCompiler != null) { return(multiExampleSerializerCompiler); } } // need at least a single feature to do something sensible if (schema == null || schema.Features.Count == 0) { return(null); } var newSerializer = new VowpalWabbitSingleExampleSerializerCompiler <TExample>( schema, settings == null ? null : settings.CustomFeaturizer, !settings.EnableStringExampleGeneration); if (cacheKey != null) { SerializerCache[cacheKey] = newSerializer; } return(newSerializer); }
/// <summary> /// Initializes a new instance of the <see cref="VowpalWabbitThreadedLearning"/> class. /// </summary> /// <param name="settings">Common settings used for vw instances.</param> public VowpalWabbitThreadedLearning(VowpalWabbitSettings settings) { if (settings == null) { throw new ArgumentNullException("settings"); } if (settings.ParallelOptions == null) { throw new ArgumentNullException("settings.ParallelOptions must be set"); } Contract.EndContractBlock(); this.Settings = settings; if (this.Settings.ParallelOptions.CancellationToken == null) { this.Settings.ParallelOptions.CancellationToken = new CancellationToken(); } switch (this.Settings.ExampleDistribution) { case VowpalWabbitExampleDistribution.UniformRandom: this.exampleDistributor = _ => this.random.Next(this.observers.Length); break; case VowpalWabbitExampleDistribution.RoundRobin: this.exampleDistributor = localExampleCount => (int)(localExampleCount % this.observers.Length); break; } this.exampleCount = 0; this.syncActions = new ConcurrentList <Action <VowpalWabbit> >(); this.vws = new VowpalWabbit[settings.ParallelOptions.MaxDegreeOfParallelism]; this.actionBlocks = new ActionBlock <Action <VowpalWabbit> > [settings.ParallelOptions.MaxDegreeOfParallelism]; this.observers = new IObserver <Action <VowpalWabbit> > [settings.ParallelOptions.MaxDegreeOfParallelism]; // setup AllReduce chain // root closure { var nodeSettings = (VowpalWabbitSettings)settings.Clone(); nodeSettings.Node = 0; var vw = this.vws[0] = new VowpalWabbit(nodeSettings); var actionBlock = this.actionBlocks[0] = new ActionBlock <Action <VowpalWabbit> >( action => action(vw), new ExecutionDataflowBlockOptions { MaxDegreeOfParallelism = 1, TaskScheduler = settings.ParallelOptions.TaskScheduler, CancellationToken = settings.ParallelOptions.CancellationToken, BoundedCapacity = (int)settings.MaxExampleQueueLengthPerInstance }); } for (int i = 1; i < settings.ParallelOptions.MaxDegreeOfParallelism; i++) { // closure vars var nodeSettings = (VowpalWabbitSettings)settings.Clone(); nodeSettings.Root = this.vws[0]; nodeSettings.Node = (uint)i; var vw = this.vws[i] = new VowpalWabbit(nodeSettings); var actionBlock = this.actionBlocks[i] = new ActionBlock <Action <VowpalWabbit> >( action => action(vw), new ExecutionDataflowBlockOptions { MaxDegreeOfParallelism = 1, TaskScheduler = settings.ParallelOptions.TaskScheduler, CancellationToken = settings.ParallelOptions.CancellationToken, BoundedCapacity = (int)settings.MaxExampleQueueLengthPerInstance }); } // get observers to allow for blocking calls this.observers = this.actionBlocks.Select(ab => ab.AsObserver()).ToArray(); this.completionTasks = new Task[settings.ParallelOptions.MaxDegreeOfParallelism]; // root closure { var vw = this.vws[0]; this.completionTasks[0] = this.actionBlocks[0].Completion .ContinueWith(_ => { // perform final AllReduce vw.EndOfPass(); // execute synchronization actions foreach (var syncAction in this.syncActions.RemoveAll()) { syncAction(vw); } }); } for (int i = 1; i < this.vws.Length; i++) { // perform final AllReduce var vw = this.vws[i]; this.completionTasks[i] = this.actionBlocks[i].Completion .ContinueWith(_ => vw.EndOfPass(), this.Settings.ParallelOptions.CancellationToken); } }
/// <summary> /// Initializes a new instance of the <see cref="VowpalWabbitDynamic"/> class. /// </summary> /// <param name="settings">Arguments passed to native instance.</param> public VowpalWabbitDynamic(VowpalWabbitSettings settings) { this.vw = new VowpalWabbit(settings); this.serializers = new Dictionary <Type, IDisposable>(); this.serializeMethods = new Dictionary <Type, MethodInfo>(); }
public static IVowpalWabbitSerializerCompiler <TExample> TryCreate <TExample>(VowpalWabbitSettings settings, Schema schema) { // check for _multi var multiFeature = schema.Features.FirstOrDefault(fe => fe.Name == settings.PropertyConfiguration.MultiProperty); if (multiFeature == null) { return(null); } // multi example path // IEnumerable<> or Array var adfType = InspectionHelper.GetEnumerableElementType(multiFeature.FeatureType); if (adfType == null) { throw new ArgumentException(settings.PropertyConfiguration.MultiProperty + " property must be array or IEnumerable<>. Actual type: " + multiFeature.FeatureType); } var compilerType = typeof(VowpalWabbitMultiExampleSerializerCompilerImpl <,>).MakeGenericType(typeof(TExample), adfType); return((IVowpalWabbitSerializerCompiler <TExample>)Activator.CreateInstance(compilerType, settings, schema, multiFeature)); }
protected override VowpalWabbitThreadedPredictionBase <VowpalWabbit <TContext> > CreatePool(VowpalWabbitSettings settings) { return(new VowpalWabbitThreadedPrediction <TContext>(settings)); }
internal VowpalWabbitExampleJsonValidator(VowpalWabbitSettings settings) { this.vw = new VowpalWabbit(settings.ShallowCopy(enableStringExampleGeneration: true)); }
internal async Task TestAllReduceInternal() { var data = Enumerable.Range(1, 1000).Select(_ => Generator.GenerateShared(10)).ToList(); var stringSerializerCompiled = VowpalWabbitSerializerFactory.CreateSerializer <CbAdfShared>(new VowpalWabbitSettings(enableStringExampleGeneration: true)); var stringSerializerAdfCompiled = VowpalWabbitSerializerFactory.CreateSerializer <CbAdfAction>(new VowpalWabbitSettings(enableStringExampleGeneration: true)); var stringData = new List <List <string> >(); VowpalWabbitPerformanceStatistics statsExpected; using (var spanningTree = new SpanningTreeClr()) { spanningTree.Start(); using (var vw1 = new VowpalWabbit(new VowpalWabbitSettings(@"--total 2 --node 1 --unique_id 0 --span_server localhost --cb_adf --rank_all --interact xy", enableStringExampleGeneration: true))) using (var vw2 = new VowpalWabbit(new VowpalWabbitSettings(@"--total 2 --node 0 --unique_id 0 --span_server localhost --cb_adf --rank_all --interact xy", enableStringExampleGeneration: true))) { var stringSerializer = stringSerializerCompiled.Func(vw1); var stringSerializerAdf = stringSerializerAdfCompiled.Func(vw1); // serialize foreach (var d in data) { var block = new List <string>(); using (var context = new VowpalWabbitMarshalContext(vw1)) { stringSerializer(context, d.Item1, SharedLabel.Instance); block.Add(context.StringExample.ToString()); } block.AddRange(d.Item2.Select((a, i) => { using (var context = new VowpalWabbitMarshalContext(vw1)) { stringSerializerAdf(context, a, i == d.Item3.Action ? d.Item3 : null); return(context.StringExample.ToString()); } })); stringData.Add(block); } await Task.WhenAll( Task.Factory.StartNew(() => Ingest(vw1, stringData.Take(500))), Task.Factory.StartNew(() => Ingest(vw2, stringData.Skip(500)))); vw1.SaveModel("expected.1.model"); vw2.SaveModel("expected.2.model"); statsExpected = vw1.PerformanceStatistics; } } // skip header var expected1Model = File.ReadAllBytes("expected.1.model").Skip(0x15).ToList(); var expected2Model = File.ReadAllBytes("expected.2.model").Skip(0x15).ToList(); var settings = new VowpalWabbitSettings("--cb_adf --rank_all --interact xy", parallelOptions: new ParallelOptions { MaxDegreeOfParallelism = 2 }, exampleCountPerRun: 2000, exampleDistribution: VowpalWabbitExampleDistribution.RoundRobin); using (var vw = new VowpalWabbitThreadedLearning(settings)) { await Task.WhenAll( Task.Factory.StartNew(() => Ingest(vw, stringData.Take(500))), Task.Factory.StartNew(() => Ingest(vw, stringData.Skip(500)))); // important to enqueue the request before Complete() is called var statsTask = vw.PerformanceStatistics; var modelSave = vw.SaveModel("actual.model"); await vw.Complete(); var statsActual = await statsTask; VWTestHelper.AssertEqual(statsExpected, statsActual); await modelSave; // skip header var actualModel = File.ReadAllBytes("actual.model").Skip(0x15).ToList(); CollectionAssert.AreEqual(expected1Model, actualModel); CollectionAssert.AreEqual(expected2Model, actualModel); } using (var vw = new VowpalWabbitThreadedLearning(settings)) { var vwManaged = vw.Create <CbAdfShared, CbAdfAction>(); await Task.WhenAll( Task.Factory.StartNew(() => Ingest(vwManaged, data.Take(500))), Task.Factory.StartNew(() => Ingest(vwManaged, data.Skip(500)))); // important to enqueue the request before Complete() is called var statsTask = vw.PerformanceStatistics; var modelSave = vw.SaveModel("actual.managed.model"); await vw.Complete(); var statsActual = await statsTask; VWTestHelper.AssertEqual(statsExpected, statsActual); await modelSave; // skip header var actualModel = File.ReadAllBytes("actual.managed.model").Skip(0x15).ToList(); CollectionAssert.AreEqual(expected1Model, actualModel); CollectionAssert.AreEqual(expected2Model, actualModel); } }
private async Task<bool> TryLoadModel() { // find the model blob if (string.IsNullOrEmpty(this.state.ModelName)) { this.telemetry.TrackTrace("Model not specified"); return false; } var container = this.blobClient.GetContainerReference(this.settings.StorageContainerName); if (!await container.ExistsAsync()) { this.telemetry.TrackTrace($"Storage container missing '{this.settings.StorageContainerName}'"); return false; } var modelBlob = container.GetBlockBlobReference(this.state.ModelName); if (!await modelBlob.ExistsAsync()) { this.telemetry.TrackTrace($"Model blob '{this.state.ModelName}' is missing"); return false; } // load the model using (var modelStream = await modelBlob.OpenReadAsync()) { this.InitializeVowpalWabbit(new VowpalWabbitSettings { ModelStream = modelStream }); this.telemetry.TrackTrace($"Model loaded {this.state.ModelName}", SeverityLevel.Verbose); // validate that the loaded VW model has the same settings as requested by C&C var newSettings = new VowpalWabbitSettings(this.settings.Metadata.TrainArguments); using (var newVW = new VW.VowpalWabbit(newSettings)) { newVW.ID = this.vw.ID; // save the VW instance to a model and load again to get fully expanded parameters. string newVWarguments; using (var tempModel = new MemoryStream()) { newVW.SaveModel(tempModel); tempModel.Position = 0; using (var tempVW = new VW.VowpalWabbit(new VowpalWabbitSettings { ModelStream = tempModel })) { newVWarguments = CleanVowpalWabbitArguments(tempVW.Arguments.CommandLine); } } var oldVWarguments = CleanVowpalWabbitArguments(this.vw.Arguments.CommandLine); // this is the expanded command line if (newVWarguments != oldVWarguments) { this.telemetry.TrackTrace("New VowpalWabbit settings found. Discarding existing model", SeverityLevel.Information, new Dictionary<string, string> { { "TrainArguments", newVW.Arguments.CommandLine }, { "NewExpandedArguments", newVWarguments }, { "OldExpandedArgumentsCleaned", oldVWarguments }, { "OldExpandedArguments", this.vw.Arguments.CommandLine }, }); // discard old, use fresh this.vw.Dispose(); this.vw = null; this.InitializeVowpalWabbit(newSettings); } } } // store the initial model this.settings.InitialVowpalWabbitModel = this.state.ModelName; return true; }
internal VowpalWabbitExampleJsonValidator(VowpalWabbitSettings settings) { settings = (VowpalWabbitSettings)settings.Clone(); settings.EnableStringExampleGeneration = true; this.vw = new VowpalWabbit(settings); }
public Schema CreateSchema(VowpalWabbitSettings settings, Type type) { return TypeInspector.CreateSchema(type, featurePropertyPredicate: (_, attr) => attr != null, labelPropertyPredicate: (_, attr) => attr != null); }
/// <summary> /// Creates a serializer for the given type and settings. /// </summary> /// <typeparam name="TExample">The user type to serialize.</typeparam> /// <param name="settings"></param> /// <returns></returns> public static VowpalWabbitSerializerCompiled <TExample> CreateSerializer <TExample>(VowpalWabbitSettings settings = null) { List <FeatureExpression> allFeatures = null; Type cacheKey = null; if (settings != null && settings.AllFeatures != null) { allFeatures = settings.AllFeatures; } else { // only cache non-string generating serializer if (!settings.EnableStringExampleGeneration) { cacheKey = typeof(TExample); object serializer; if (SerializerCache.TryGetValue(cacheKey, out serializer)) { return((VowpalWabbitSerializerCompiled <TExample>)serializer); } } // TOOD: enhance caching based on feature list & featurizer set // if no feature mapping is provided, use [Feature] annotation on provided type. Func <PropertyInfo, FeatureAttribute, bool> propertyPredicate = null; switch (settings.FeatureDiscovery) { case VowpalWabbitFeatureDiscovery.Default: propertyPredicate = (_, attr) => attr != null; break; case VowpalWabbitFeatureDiscovery.All: propertyPredicate = (_, __) => true; break; } allFeatures = AnnotationInspector.ExtractFeatures(typeof(TExample), propertyPredicate).ToList(); } // need at least a single feature to do something sensible if (allFeatures == null || allFeatures.Count == 0) { return(null); } var newSerializer = new VowpalWabbitSerializerCompiled <TExample>( allFeatures, settings == null ? null : settings.CustomFeaturizer, !settings.EnableStringExampleGeneration); if (cacheKey != null) { SerializerCache[cacheKey] = newSerializer; } return(newSerializer); }
/// <summary> /// Initializes a new instance of the <see cref="VowpalWabbitDynamic"/> class. /// </summary> /// <param name="settings">Arguments passed to native instance.</param> public VowpalWabbitDynamic(VowpalWabbitSettings settings) { this.vw = new VowpalWabbit(settings); }
public Schema CreateSchema(VowpalWabbitSettings settings, Type type) { return(JsonTypeInspector.CreateSchema(type, settings.PropertyConfiguration)); }
/// <summary> /// Sub classes must override and create a new VW pool. /// </summary> protected abstract VowpalWabbitThreadedPredictionBase <TVowpalWabbit> CreatePool(VowpalWabbitSettings settings);
internal static IVowpalWabbitSerializerCompiler <TExample> TryCreate <TExample>(VowpalWabbitSettings settings, List <FeatureExpression> allFeatures) { // check for _multi var multiFeature = allFeatures.FirstOrDefault(fe => fe.Name == VowpalWabbitConstants.MultiProperty); if (multiFeature == null) { return(null); } // multi example path // IEnumerable<> or Array var adfType = InspectionHelper.GetEnumerableElementType(multiFeature.FeatureType); if (adfType == null) { throw new ArgumentException("_multi property must be array or IEnumerable<>. Actual type: " + multiFeature.FeatureType); } var compilerType = typeof(VowpalWabbitMultiExampleSerializerCompilerImpl <,>).MakeGenericType(typeof(TExample), adfType); return((IVowpalWabbitSerializerCompiler <TExample>)Activator.CreateInstance(compilerType, settings, allFeatures, multiFeature)); }
/// <summary> /// Creates a serializer for the given type and settings. /// </summary> /// <typeparam name="TExample">The user type to serialize.</typeparam> /// <param name="settings"></param> /// <returns></returns> public static IVowpalWabbitSerializerCompiler <TExample> CreateSerializer <TExample>(VowpalWabbitSettings settings = null) { Schema schema = null; Type cacheKey = null; if (settings != null && settings.Schema != null) { schema = settings.Schema; } else { // only cache non-string generating serializer if (!settings.EnableStringExampleGeneration) { cacheKey = typeof(TExample); object serializer; if (SerializerCache.TryGetValue(cacheKey, out serializer)) { return((IVowpalWabbitSerializerCompiler <TExample>)serializer); } } if (settings.FeatureDiscovery == VowpalWabbitFeatureDiscovery.Json) { schema = AnnotationJsonInspector.CreateSchema(typeof(TExample), settings.PropertyConfiguration); var multiExampleSerializerCompiler = VowpalWabbitMultiExampleSerializerCompiler.TryCreate <TExample>(settings, schema); if (multiExampleSerializerCompiler != null) { return(multiExampleSerializerCompiler); } } else { // TODO: enhance caching based on feature list & featurizer set // if no feature mapping is provided, use [Feature] annotation on provided type. Func <PropertyInfo, FeatureAttribute, bool> propertyPredicate = null; Func <PropertyInfo, LabelAttribute, bool> labelPredicate = null; switch (settings.FeatureDiscovery) { case VowpalWabbitFeatureDiscovery.Default: propertyPredicate = (_, attr) => attr != null; labelPredicate = (_, attr) => attr != null; break; case VowpalWabbitFeatureDiscovery.All: propertyPredicate = (_, __) => true; labelPredicate = (_, __) => true; break; } schema = AnnotationInspector.CreateSchema(typeof(TExample), propertyPredicate, labelPredicate); } } // need at least a single feature to do something sensible if (schema == null || schema.Features.Count == 0) { return(null); } var newSerializer = new VowpalWabbitSingleExampleSerializerCompiler <TExample>( schema, settings == null ? null : settings.CustomFeaturizer, !settings.EnableStringExampleGeneration); if (cacheKey != null) { SerializerCache[cacheKey] = newSerializer; } return(newSerializer); }