Ejemplo n.º 1
0
        /// <summary>
        /// Compiles a serializers for the given example user type.
        /// </summary>
        /// <typeparam name="TExample">The example user type.</typeparam>
        /// <param name="settings">The serializer settings.</param>
        /// <returns>A serializer for the given user example type.</returns>
        public static VowpalWabbitSerializer <TExample> CreateSerializer <TExample>(VowpalWabbitSettings settings)
        {
            var serializerFunc = CreateSerializer <TExample, VowpalWabbitInterfaceVisitor, VowpalWabbitExample>();

#if DEBUG
            var stringSerializerFunc = CreateSerializer <TExample, VowpalWabbitStringVisitor, string>();

            Func <VowpalWabbit, TExample, ILabel, VowpalWabbitExample> wrappedSerializerFunc;
            if (serializerFunc == null)
            {
                // if no features are found, no serializer is generated
                wrappedSerializerFunc = (a, b, c) => null;
            }
            else
            {
                wrappedSerializerFunc = (vw, example, label) => new VowpalWabbitDebugExample(serializerFunc(vw, example, label), stringSerializerFunc(vw, example, label));
            }

            return(new VowpalWabbitSerializer <TExample>(wrappedSerializerFunc, settings));
#else
            if (serializerFunc == null)
            {
                // if no features are found, no serializer is generated
                serializerFunc = (a, b, c) => null;
            }

            return(new VowpalWabbitSerializer <TExample>(serializerFunc, settings));
#endif
        }
Ejemplo n.º 2
0
        private void InitializeVowpalWabbit(VowpalWabbitSettings vwSettings)
        {
            if (this.settings.EnableExampleTracing)
            {
                vwSettings.EnableStringExampleGeneration = true;
                vwSettings.EnableStringFloatCompact      = true;
            }

            vwSettings.EnableThreadSafeExamplePooling = true;
            vwSettings.MaxExamples = 64 * 1024;

            this.vw = new VW.VowpalWabbit(vwSettings);
            this.referenceResolver = new VowpalWabbitJsonReferenceResolver(
                this.delayedExampleCallback,
                cacheRequestItemPolicyFactory:
                key => new CacheItemPolicy()
            {
                SlidingExpiration = TimeSpan.FromHours(1),
                RemovedCallback   = this.CacheEntryRemovedCallback
            });

            //this.vwAllReduce = new VowpalWabbitThreadedLearning(vwSettings.ShallowCopy(
            //    maxExampleQueueLengthPerInstance: 4*1024,
            //    parallelOptions: new ParallelOptions
            //    {
            //        MaxDegreeOfParallelism = 2,
            //    },
            //    exampleDistribution: VowpalWabbitExampleDistribution.RoundRobin,
            //    exampleCountPerRun: 128 * 1024));
        }
Ejemplo n.º 3
0
        public static async Task MultiThreadedLearning()
        {
            var example = new MyExample {
                Income = 40, Age = 25
            };
            var label = new SimpleLabel {
                Label = 1
            };

            var settings = new VowpalWabbitSettings
            {
                ParallelOptions = new ParallelOptions
                {
                    MaxDegreeOfParallelism = 16
                },
                ExampleCountPerRun  = 2000,
                ExampleDistribution = VowpalWabbitExampleDistribution.RoundRobin
            };

            using (var vw = new VowpalWabbitThreadedLearning(settings))
            {
                using (var vwManaged = vw.Create <MyExample>())
                {
                    var prediction = await vwManaged.Learn(example, label, VowpalWabbitPredictionType.Scalar);
                }

                var saveModelTask = vw.SaveModel("m1.model");

                await vw.Complete();

                await saveModelTask;
            }
        }
Ejemplo n.º 4
0
        public static async Task MultiThreadedLearning()
        {
            var example = new MyExample { Income = 40, Age = 25 };
            var label = new SimpleLabel { Label = 1 };

            var settings = new VowpalWabbitSettings
            {
                ParallelOptions = new ParallelOptions
                {
                    MaxDegreeOfParallelism = 16
                },
                ExampleCountPerRun = 2000,
                ExampleDistribution = VowpalWabbitExampleDistribution.RoundRobin
            };

            using (var vw = new VowpalWabbitThreadedLearning(settings))
            {
                using (var vwManaged = vw.Create<MyExample>())
                {
                    var prediction = await vwManaged.Learn(example, label, VowpalWabbitPredictionType.Scalar);
                }

                var saveModelTask = vw.SaveModel("m1.model");

                await vw.Complete();
                await saveModelTask;
            }
        }
Ejemplo n.º 5
0
        internal VowpalWabbitExampleValidator(VowpalWabbitSettings settings)
        {
            this.vw         = new VowpalWabbit <TExample>(settings.ShallowCopy(enableStringExampleGeneration: true));
            this.serializer = this.vw.Serializer.Func(this.vw.Native);

            this.vwNative         = new VowpalWabbit <TExample>(settings);
            this.serializerNative = this.vwNative.Serializer.Func(this.vwNative.Native);

            this.factorySerializer = VowpalWabbitSerializerFactory.CreateSerializer <TExample>(settings.ShallowCopy(enableStringExampleGeneration: true)).Create(this.vw.Native);
        }
        private void InitializeVowpalWabbit(VowpalWabbitSettings vwSettings)
        {
            if (this.settings.EnableExampleTracing)
            {
                vwSettings.EnableStringExampleGeneration = true;
                vwSettings.EnableStringFloatCompact      = true;
            }

            vwSettings.EnableThreadSafeExamplePooling = true;
            vwSettings.MaxExamples = 64 * 1024;

            try
            {
                this.startDateTime = DateTime.UtcNow;
                this.vw            = new VW.VowpalWabbit(vwSettings);
                var cmdLine = vw.Arguments.CommandLine;

                if (!(cmdLine.Contains("--cb_explore") || cmdLine.Contains("--cb_explore_adf")))
                {
                    throw new ArgumentException("Only cb_explore and cb_explore_adf are supported");
                }
            }
            catch (Exception ex)
            {
                this.telemetry.TrackException(ex, new Dictionary <string, string>
                {
                    { "help", "Invalid model. For help go to https://github.com/JohnLangford/vowpal_wabbit/wiki/Azure-Trainer" }
                });

                throw ex;
            }

            this.referenceResolver = new VowpalWabbitJsonReferenceResolver(
                this.delayedExampleCallback,
                cacheRequestItemPolicyFactory:
                key => new CacheItemPolicy()
            {
                SlidingExpiration = TimeSpan.FromHours(1),
                RemovedCallback   = this.CacheEntryRemovedCallback
            });

            //this.vwAllReduce = new VowpalWabbitThreadedLearning(vwSettings.ShallowCopy(
            //    maxExampleQueueLengthPerInstance: 4*1024,
            //    parallelOptions: new ParallelOptions
            //    {
            //        MaxDegreeOfParallelism = 2,
            //    },
            //    exampleDistribution: VowpalWabbitExampleDistribution.RoundRobin,
            //    exampleCountPerRun: 128 * 1024));
        }
Ejemplo n.º 7
0
        internal VowpalWabbitSerializer(Func <VowpalWabbit, TExample, ILabel, VowpalWabbitExample> serializer, VowpalWabbitSettings settings)
        {
            if (serializer == null)
            {
                throw new ArgumentNullException("serializer");
            }
            Contract.Ensures(this.settings != null);
            Contract.EndContractBlock();

            this.serializer = serializer;
            this.settings   = settings ?? new VowpalWabbitSettings();

            var cacheableAttribute = (CacheableAttribute)typeof(TExample).GetCustomAttributes(typeof(CacheableAttribute), true).FirstOrDefault();

            if (cacheableAttribute == null)
            {
                return;
            }

            if (this.settings.EnableExampleCaching)
            {
                if (cacheableAttribute.EqualityComparer == null)
                {
                    this.exampleCache = new Dictionary <TExample, CacheEntry>();
                }
                else
                {
                    if (!typeof(IEqualityComparer <TExample>).IsAssignableFrom(cacheableAttribute.EqualityComparer))
                    {
                        throw new ArgumentException(
                                  string.Format(
                                      CultureInfo.InvariantCulture,
                                      "EqualityComparer ({1}) specified in [Cachable] of {0} must implement IEqualityComparer<{0}>",
                                      typeof(TExample),
                                      cacheableAttribute.EqualityComparer));
                    }

                    var comparer = (IEqualityComparer <TExample>)Activator.CreateInstance(cacheableAttribute.EqualityComparer);
                    this.exampleCache = new Dictionary <TExample, CacheEntry>(comparer);
                }

#if DEBUG
                this.reverseLookup = new Dictionary <VowpalWabbitExample, CacheEntry>(new ReferenceEqualityComparer <VowpalWabbitExample>());
#endif
            }
        }
        /// <summary>
        /// Initializes a new instance of the <see cref="VowpalWabbitThreadedPredictionBase{TVowpalWabbit}"/> class.
        /// </summary>
        /// <param name="settings">The initial settings to use.</param>
        protected VowpalWabbitThreadedPredictionBase(VowpalWabbitSettings settings)
        {
            this.settings = settings;

            this.vwPool = new ObjectPool <VowpalWabbitModel, TVowpalWabbit>(
                ObjectFactory.Create(
                    settings.Model,
                    m =>
            {
                if (m == null)
                {
                    return(default(TVowpalWabbit));
                }

                return(CreateVowpalWabbitChild(m));
            }));
        }
Ejemplo n.º 9
0
            public VowpalWabbitMultiExampleSerializerCompilerImpl(VowpalWabbitSettings settings, Schema schema, FeatureExpression multiFeature)
            {
                Contract.Requires(settings != null);
                Contract.Requires(schema != null);
                Contract.Requires(multiFeature != null);

                var nonMultiFeatures = schema.Features.Where(fe => fe != multiFeature).ToList();

                this.sharedSerializerCompiler = nonMultiFeatures.Count == 0 ? null :
                                                new VowpalWabbitSingleExampleSerializerCompiler <TExample>(
                    new Schema {
                    Features = nonMultiFeatures
                },
                    settings == null ? null : settings.CustomFeaturizer,
                    !settings.EnableStringExampleGeneration);

                this.adfSerializerComputer = new VowpalWabbitSingleExampleSerializerCompiler <TActionDependentFeature>(
                    settings.TypeInspector.CreateSchema(settings, typeof(TActionDependentFeature)),
                    settings == null ? null : settings.CustomFeaturizer,
                    !settings.EnableStringExampleGeneration);

                var exampleParameter = Expression.Parameter(typeof(TExample), "example");

                // CODE condition1 && condition2 && condition3 ...
                var condition = multiFeature.ValueValidExpressionFactories
                                .Skip(1)
                                .Aggregate(
                    multiFeature.ValueValidExpressionFactories.First()(exampleParameter),
                    (cond, factory) => Expression.AndAlso(cond, factory(exampleParameter)));

                var multiExpression = multiFeature.ValueExpressionFactory(exampleParameter);

                // CODE example => (IEnumerable<TActionDependentFeature>)(example._multi != null ? example._multi : null)
                var expr = Expression.Lambda <Func <TExample, IEnumerable <TActionDependentFeature> > >(
                    Expression.Condition(
                        condition,
                        multiExpression,
                        Expression.Constant(null, multiExpression.Type),
                        typeof(IEnumerable <TActionDependentFeature>)),
                    exampleParameter);

                this.adfAccessor = (Func <TExample, IEnumerable <TActionDependentFeature> >)expr.CompileToFunc();
            }
        internal VowpalWabbitExampleValidator(VowpalWabbitSettings settings)
        {
            this.vw = new VowpalWabbit <TExample>(settings.ShallowCopy(enableStringExampleGeneration: true));

            var compiler = this.vw.Serializer as VowpalWabbitSingleExampleSerializerCompiler <TExample>;

            if (compiler != null)
            {
                this.serializer = compiler.Func(this.vw.Native);
            }

            this.vwNative = new VowpalWabbit <TExample>(settings);

            compiler = this.vwNative.Serializer as VowpalWabbitSingleExampleSerializerCompiler <TExample>;
            if (compiler != null)
            {
                this.serializerNative = compiler.Func(this.vwNative.Native);
            }

            this.factorySerializer = VowpalWabbitSerializerFactory.CreateSerializer <TExample>(settings.ShallowCopy(enableStringExampleGeneration: true)).Create(this.vw.Native);
        }
        internal void FreshStart(OnlineTrainerState state = null, byte[] model = null)
        {
            if (state == null)
            {
                state = new OnlineTrainerState();
            }

            this.telemetry.TrackTrace("Fresh Start", SeverityLevel.Information);

            // start from scratch
            this.state = state;

            // save extra state so learning can be resumed later with new data
            var settings = new VowpalWabbitSettings("--save_resume --preserve_performance_counters " + this.settings.Metadata.TrainArguments);

            if (model != null)
            {
                settings.ModelStream = new MemoryStream(model);
            }

            this.InitializeVowpalWabbit(settings);
        }
Ejemplo n.º 12
0
            internal void TrainOffline(string message, string modelId, Dictionary <string, Context> data, IEnumerable <string> eventOrder, Uri onlineModelUri, string trainArguments = null)
            {
                // allow override
                if (trainArguments == null)
                {
                    trainArguments = this.trainArguments;
                }

                // train model offline using trackback
                var settings = new VowpalWabbitSettings(trainArguments + $" --id {modelId} --save_resume --preserve_performance_counters -f offline.model");

                using (var vw = new VowpalWabbitJson(settings))
                {
                    foreach (var id in eventOrder)
                    {
                        var json = data[id].JSON;

                        var progressivePrediction = vw.Learn(json, VowpalWabbitPredictionType.ActionProbabilities);
                        // TODO: validate eval output
                    }
                }

                using (var vw = new VowpalWabbit("-i offline.model --save_resume --readable_model offline.model.txt -f offline.reset_perf_counters.model"))
                { }

                Blobs.DownloadFile(onlineModelUri, "online.model");
                using (var vw = new VowpalWabbit("-i online.model --save_resume --readable_model online.model.txt -f online.reset_perf_counters.model"))
                { }


                // validate that the model is the same
                CollectionAssert.AreEqual(
                    File.ReadAllBytes("offline.reset_perf_counters.model"),
                    File.ReadAllBytes("online.reset_perf_counters.model"),
                    $"{message}. Offline and online model differs. Compare online.model.txt with offline.model.txt to compare");
            }
Ejemplo n.º 13
0
 public Schema CreateSchema(VowpalWabbitSettings settings, Type type)
 {
     return(TypeInspector.CreateSchema(type,
                                       featurePropertyPredicate: (_, attr) => attr != null,
                                       labelPropertyPredicate: (_, attr) => attr != null));
 }
Ejemplo n.º 14
0
 public Schema CreateSchema(VowpalWabbitSettings settings, Type type)
 {
     return(TypeInspector.CreateSchema(type,
                                       featurePropertyPredicate: (_, __) => true,
                                       labelPropertyPredicate: (_, __) => true));
 }
Ejemplo n.º 15
0
 /// <summary>
 /// Initializes a new instance of the <see cref="VowpalWabbit{TExample}"/> class.
 /// </summary>
 /// <param name="settings">Arguments passed to native instance.</param>
 public VowpalWabbit(VowpalWabbitSettings settings)
     : this(new VowpalWabbit(settings))
 {
 }
Ejemplo n.º 16
0
        public async Task TestAzureTrainer()
        {
            var storageConnectionString       = GetConfiguration("storageConnectionString");
            var inputEventHubConnectionString = GetConfiguration("inputEventHubConnectionString");
            var evalEventHubConnectionString  = GetConfiguration("evalEventHubConnectionString");

            var trainArguments = "--cb_explore_adf --epsilon 0.2 -q ab";

            // register with AppInsights to collect exceptions
            var exceptions = RegisterAppInsightExceptionHook();

            // cleanup blobs
            var blobs = new ModelBlobs(storageConnectionString);
            await blobs.Cleanup();

            var data = GenerateData(100).ToDictionary(d => d.EventId, d => d);

            // start listening for event hub
            using (var trainProcesserHost = new LearnEventProcessorHost())
            {
                await trainProcesserHost.StartAsync(new OnlineTrainerSettingsInternal
                {
                    CheckpointPolicy = new CountingCheckpointPolicy(data.Count),
                    JoinedEventHubConnectionString = inputEventHubConnectionString,
                    EvalEventHubConnectionString   = evalEventHubConnectionString,
                    StorageConnectionString        = storageConnectionString,
                    Metadata = new OnlineTrainerSettings
                    {
                        ApplicationID  = "vwunittest",
                        TrainArguments = trainArguments
                    },
                    EnableExampleTracing     = true,
                    EventHubStartDateTimeUtc = DateTime.UtcNow // ignore any events that arrived before this time
                });

                // send events to event hub
                var eventHubInputClient = EventHubClient.CreateFromConnectionString(inputEventHubConnectionString);
                data.Values.ForEach(c => eventHubInputClient.Send(new EventData(c.JSONAsBytes)
                {
                    PartitionKey = c.Index.ToString()
                }));

                // wait for trainer to checkpoint
                await blobs.PollTrainerCheckpoint(exceptions);

                // download & parse trackback file
                var trackback = blobs.DownloadTrackback();
                Assert.AreEqual(data.Count, trackback.EventIds.Count);

                // train model offline using trackback
                var settings = new VowpalWabbitSettings(trainArguments + $" --id {trackback.ModelId} --save_resume --readable_model offline.json.model.txt -f offline.json.model");
                using (var vw = new VowpalWabbitJson(settings))
                {
                    foreach (var id in trackback.EventIds)
                    {
                        var json = data[id].JSON;

                        var progressivePrediction = vw.Learn(json, VowpalWabbitPredictionType.ActionProbabilities);
                        // TODO: validate eval output
                    }

                    vw.Native.SaveModel("offline.json.2.model");
                }

                // download online model
                new CloudBlob(blobs.ModelBlob.Uri, blobs.BlobClient.Credentials).DownloadToFile("online.model", FileMode.Create);

                // validate that the model is the same
                CollectionAssert.AreEqual(
                    File.ReadAllBytes("offline.json.model"),
                    File.ReadAllBytes("online.model"),
                    "Offline and online model differs. Run to 'vw -i online.model --readable_model online.model.txt' to compare");
            }
        }
Ejemplo n.º 17
0
 public Schema CreateSchema(VowpalWabbitSettings settings, Type type)
 {
     return JsonTypeInspector.CreateSchema(type, settings.PropertyConfiguration);
 }
Ejemplo n.º 18
0
        internal async Task TestAllReduceInternal()
        {
            var data = Enumerable.Range(1, 1000).Select(_ => Generator.GenerateShared(10)).ToList();

            var stringSerializerCompiler = (VowpalWabbitSingleExampleSerializerCompiler<CbAdfShared>)
                VowpalWabbitSerializerFactory.CreateSerializer<CbAdfShared>(new VowpalWabbitSettings { EnableStringExampleGeneration = true });
            var stringSerializerAdfCompiler = (VowpalWabbitSingleExampleSerializerCompiler<CbAdfAction>)
                VowpalWabbitSerializerFactory.CreateSerializer<CbAdfAction>(new VowpalWabbitSettings { EnableStringExampleGeneration = true });

            var stringData = new List<List<string>>();

            VowpalWabbitPerformanceStatistics statsExpected;
            using (var spanningTree = new SpanningTreeClr())
            {
                spanningTree.Start();

                using (var vw1 = new VowpalWabbit(new VowpalWabbitSettings(@"--total 2 --node 1 --unique_id 0 --span_server localhost --cb_adf --rank_all --interact xy") { EnableStringExampleGeneration = true }))
                using (var vw2 = new VowpalWabbit(new VowpalWabbitSettings(@"--total 2 --node 0 --unique_id 0 --span_server localhost --cb_adf --rank_all --interact xy") { EnableStringExampleGeneration = true } ))
                {
                    var stringSerializer = stringSerializerCompiler.Func(vw1);
                    var stringSerializerAdf = stringSerializerAdfCompiler.Func(vw1);

                    // serialize
                    foreach (var d in data)
                    {
                        var block = new List<string>();

                        using (var context = new VowpalWabbitMarshalContext(vw1))
                        {
                            stringSerializer(context, d.Item1, SharedLabel.Instance);
                            block.Add(context.ToString());
                        }

                        block.AddRange(d.Item2.Select((a, i) =>
                            {
                                using (var context = new VowpalWabbitMarshalContext(vw1))
                                {
                                    stringSerializerAdf(context, a, i == d.Item3.Action ? d.Item3 : null);
                                    return context.ToString();
                                }
                            }));

                        stringData.Add(block);
                    }

                    await Task.WhenAll(
                        Task.Factory.StartNew(() => Ingest(vw1, stringData.Take(500))),
                        Task.Factory.StartNew(() => Ingest(vw2, stringData.Skip(500))));

                    vw1.SaveModel("expected.1.model");
                    vw2.SaveModel("expected.2.model");

                    statsExpected = vw1.PerformanceStatistics;
                }
            }

            // skip header
            var expected1Model = File.ReadAllBytes("expected.1.model").Skip(0x15).ToList();
            var expected2Model = File.ReadAllBytes("expected.2.model").Skip(0x15).ToList();

            var settings = new VowpalWabbitSettings("--cb_adf --rank_all --interact xy")
            {
                ParallelOptions = new ParallelOptions
                {
                    MaxDegreeOfParallelism = 2
                },
                ExampleCountPerRun = 2000,
                ExampleDistribution = VowpalWabbitExampleDistribution.RoundRobin
            };

            using (var vw = new VowpalWabbitThreadedLearning(settings))
            {
                await Task.WhenAll(
                    Task.Factory.StartNew(() => Ingest(vw, stringData.Take(500))),
                    Task.Factory.StartNew(() => Ingest(vw, stringData.Skip(500))));

                // important to enqueue the request before Complete() is called
                var statsTask = vw.PerformanceStatistics;
                var modelSave = vw.SaveModel("actual.model");

                await vw.Complete();

                var statsActual = await statsTask;
                VWTestHelper.AssertEqual(statsExpected, statsActual);

                await modelSave;

                // skip header
                var actualModel = File.ReadAllBytes("actual.model").Skip(0x15).ToList();

                CollectionAssert.AreEqual(expected1Model, actualModel);
                CollectionAssert.AreEqual(expected2Model, actualModel);
            }

            using (var vw = new VowpalWabbitThreadedLearning(settings))
            {
                var vwManaged = vw.Create<CbAdfShared, CbAdfAction>();

                await Task.WhenAll(
                    Task.Factory.StartNew(() => Ingest(vwManaged, data.Take(500))),
                    Task.Factory.StartNew(() => Ingest(vwManaged, data.Skip(500))));

                // important to enqueue the request before Complete() is called
                var statsTask = vw.PerformanceStatistics;
                var modelSave = vw.SaveModel("actual.managed.model");

                await vw.Complete();

                var statsActual = await statsTask;
                VWTestHelper.AssertEqual(statsExpected, statsActual);

                await modelSave;

                // skip header
                var actualModel = File.ReadAllBytes("actual.managed.model").Skip(0x15).ToList();

                CollectionAssert.AreEqual(expected1Model, actualModel);
                CollectionAssert.AreEqual(expected2Model, actualModel);
            }
        }
Ejemplo n.º 19
0
 public Schema CreateSchema(VowpalWabbitSettings settings, Type type)
 {
     return TypeInspector.CreateSchema(type,
         featurePropertyPredicate: (_, __) => true,
         labelPropertyPredicate: (_, __) => true);
 }
 internal VowpalWabbitExampleJsonValidator(VowpalWabbitSettings settings)
 {
     this.vw = new VowpalWabbit(settings.ShallowCopy(enableStringExampleGeneration: true));
     this.jsonSerializer = new VowpalWabbitJsonSerializer(this.vw);
 }
Ejemplo n.º 21
0
        private async Task <bool> TryLoadModel()
        {
            // find the model blob
            if (string.IsNullOrEmpty(this.state.ModelName))
            {
                this.telemetry.TrackTrace("Model not specified");
                return(false);
            }

            var container = this.blobClient.GetContainerReference(this.settings.StorageContainerName);

            if (!await container.ExistsAsync())
            {
                this.telemetry.TrackTrace($"Storage container missing '{this.settings.StorageContainerName}'");
                return(false);
            }

            var modelBlob = container.GetBlockBlobReference(this.state.ModelName);

            if (!await modelBlob.ExistsAsync())
            {
                this.telemetry.TrackTrace($"Model blob '{this.state.ModelName}' is missing");
                return(false);
            }

            // load the model
            using (var modelStream = await modelBlob.OpenReadAsync())
            {
                this.InitializeVowpalWabbit(new VowpalWabbitSettings {
                    ModelStream = modelStream
                });
                this.telemetry.TrackTrace($"Model loaded {this.state.ModelName}", SeverityLevel.Verbose);

                // validate that the loaded VW model has the same settings as requested by C&C
                var newSettings = new VowpalWabbitSettings(this.settings.Metadata.TrainArguments);
                using (var newVW = new VW.VowpalWabbit(newSettings))
                {
                    newVW.ID = this.vw.ID;

                    // save the VW instance to a model and load again to get fully expanded parameters.
                    string newVWarguments;
                    using (var tempModel = new MemoryStream())
                    {
                        newVW.SaveModel(tempModel);
                        tempModel.Position = 0;

                        using (var tempVW = new VW.VowpalWabbit(new VowpalWabbitSettings {
                            ModelStream = tempModel
                        }))
                        {
                            newVWarguments = CleanVowpalWabbitArguments(tempVW.Arguments.CommandLine);
                        }
                    }

                    var oldVWarguments = CleanVowpalWabbitArguments(this.vw.Arguments.CommandLine);

                    // this is the expanded command line
                    if (newVWarguments != oldVWarguments)
                    {
                        this.telemetry.TrackTrace("New VowpalWabbit settings found. Discarding existing model",
                                                  SeverityLevel.Information,
                                                  new Dictionary <string, string>
                        {
                            { "TrainArguments", newVW.Arguments.CommandLine },
                            { "NewExpandedArguments", newVWarguments },
                            { "OldExpandedArgumentsCleaned", oldVWarguments },
                            { "OldExpandedArguments", this.vw.Arguments.CommandLine },
                        });

                        // discard old, use fresh
                        this.vw.Dispose();
                        this.vw = null;

                        this.InitializeVowpalWabbit(newSettings);
                    }
                }
            }

            // store the initial model
            this.settings.InitialVowpalWabbitModel = this.state.ModelName;

            return(true);
        }
Ejemplo n.º 22
0
        private void InitializeVowpalWabbit(VowpalWabbitSettings vwSettings)
        {
            if (this.settings.EnableExampleTracing)
            {
                vwSettings.EnableStringExampleGeneration = true;
                vwSettings.EnableStringFloatCompact = true;
            }

            vwSettings.EnableThreadSafeExamplePooling = true;
            vwSettings.MaxExamples = 64 * 1024;

            try
            {
                this.vw = new VW.VowpalWabbit(vwSettings);
                var cmdLine = vw.Arguments.CommandLine;

                if (!(cmdLine.Contains("--cb_explore") || cmdLine.Contains("--cb_explore_adf")))
                    throw new ArgumentException("Only cb_explore and cb_explore_adf are supported");
            }
            catch (Exception ex)
            {
                this.telemetry.TrackException(ex, new Dictionary<string, string>
                {
                    { "help", "Invalid model. For help go to https://github.com/JohnLangford/vowpal_wabbit/wiki/Azure-Trainer" }
                });

                throw ex;
            }

            this.referenceResolver = new VowpalWabbitJsonReferenceResolver(
                this.delayedExampleCallback,
                cacheRequestItemPolicyFactory:
                    key => new CacheItemPolicy()
                    {
                        SlidingExpiration = TimeSpan.FromHours(1),
                        RemovedCallback = this.CacheEntryRemovedCallback
                    });

            //this.vwAllReduce = new VowpalWabbitThreadedLearning(vwSettings.ShallowCopy(
            //    maxExampleQueueLengthPerInstance: 4*1024,
            //    parallelOptions: new ParallelOptions
            //    {
            //        MaxDegreeOfParallelism = 2,
            //    },
            //    exampleDistribution: VowpalWabbitExampleDistribution.RoundRobin,
            //    exampleCountPerRun: 128 * 1024));
        }
        /// <summary>
        /// Creates a serializer for the given type and settings.
        /// </summary>
        /// <typeparam name="TExample">The user type to serialize.</typeparam>
        /// <param name="settings"></param>
        /// <returns></returns>
        public static IVowpalWabbitSerializerCompiler <TExample> CreateSerializer <TExample>(VowpalWabbitSettings settings = null)
        {
            Schema schema = null;

            Type cacheKey = null;

            if (settings != null && settings.Schema != null)
            {
                schema = settings.Schema;
            }
            else
            {
                // only cache non-string generating serializer
                if (!settings.EnableStringExampleGeneration)
                {
                    cacheKey = typeof(TExample);
                    object serializer;

                    if (SerializerCache.TryGetValue(cacheKey, out serializer))
                    {
                        return((IVowpalWabbitSerializerCompiler <TExample>)serializer);
                    }
                }

                ITypeInspector typeInspector = settings.TypeInspector;
                if (typeInspector == null)
                {
                    typeInspector = TypeInspector.Default;
                }

                // TODO: enhance caching based on feature list & featurizer set
                // if no feature mapping is provided, use [Feature] annotation on provided type.
                schema = typeInspector.CreateSchema(settings, typeof(TExample));

                var multiExampleSerializerCompiler = VowpalWabbitMultiExampleSerializerCompiler.TryCreate <TExample>(settings, schema);
                if (multiExampleSerializerCompiler != null)
                {
                    return(multiExampleSerializerCompiler);
                }
            }

            // need at least a single feature to do something sensible
            if (schema == null || schema.Features.Count == 0)
            {
                return(null);
            }

            var newSerializer = new VowpalWabbitSingleExampleSerializerCompiler <TExample>(
                schema,
                settings == null ? null : settings.CustomFeaturizer,
                !settings.EnableStringExampleGeneration);

            if (cacheKey != null)
            {
                SerializerCache[cacheKey] = newSerializer;
            }

            return(newSerializer);
        }
Ejemplo n.º 24
0
        /// <summary>
        /// Initializes a new instance of the <see cref="VowpalWabbitThreadedLearning"/> class.
        /// </summary>
        /// <param name="settings">Common settings used for vw instances.</param>
        public VowpalWabbitThreadedLearning(VowpalWabbitSettings settings)
        {
            if (settings == null)
            {
                throw new ArgumentNullException("settings");
            }

            if (settings.ParallelOptions == null)
            {
                throw new ArgumentNullException("settings.ParallelOptions must be set");
            }
            Contract.EndContractBlock();

            this.Settings = settings;

            if (this.Settings.ParallelOptions.CancellationToken == null)
            {
                this.Settings.ParallelOptions.CancellationToken = new CancellationToken();
            }

            switch (this.Settings.ExampleDistribution)
            {
            case VowpalWabbitExampleDistribution.UniformRandom:
                this.exampleDistributor = _ => this.random.Next(this.observers.Length);
                break;

            case VowpalWabbitExampleDistribution.RoundRobin:
                this.exampleDistributor = localExampleCount => (int)(localExampleCount % this.observers.Length);
                break;
            }

            this.exampleCount = 0;
            this.syncActions  = new ConcurrentList <Action <VowpalWabbit> >();

            this.vws          = new VowpalWabbit[settings.ParallelOptions.MaxDegreeOfParallelism];
            this.actionBlocks = new ActionBlock <Action <VowpalWabbit> > [settings.ParallelOptions.MaxDegreeOfParallelism];
            this.observers    = new IObserver <Action <VowpalWabbit> > [settings.ParallelOptions.MaxDegreeOfParallelism];

            // setup AllReduce chain
            // root closure
            {
                var nodeSettings = (VowpalWabbitSettings)settings.Clone();
                nodeSettings.Node = 0;
                var vw = this.vws[0] = new VowpalWabbit(nodeSettings);

                var actionBlock = this.actionBlocks[0] = new ActionBlock <Action <VowpalWabbit> >(
                    action => action(vw),
                    new ExecutionDataflowBlockOptions
                {
                    MaxDegreeOfParallelism = 1,
                    TaskScheduler          = settings.ParallelOptions.TaskScheduler,
                    CancellationToken      = settings.ParallelOptions.CancellationToken,
                    BoundedCapacity        = (int)settings.MaxExampleQueueLengthPerInstance
                });
            }

            for (int i = 1; i < settings.ParallelOptions.MaxDegreeOfParallelism; i++)
            {
                // closure vars
                var nodeSettings = (VowpalWabbitSettings)settings.Clone();
                nodeSettings.Root = this.vws[0];
                nodeSettings.Node = (uint)i;
                var vw = this.vws[i] = new VowpalWabbit(nodeSettings);

                var actionBlock = this.actionBlocks[i] = new ActionBlock <Action <VowpalWabbit> >(
                    action => action(vw),
                    new ExecutionDataflowBlockOptions
                {
                    MaxDegreeOfParallelism = 1,
                    TaskScheduler          = settings.ParallelOptions.TaskScheduler,
                    CancellationToken      = settings.ParallelOptions.CancellationToken,
                    BoundedCapacity        = (int)settings.MaxExampleQueueLengthPerInstance
                });
            }

            // get observers to allow for blocking calls
            this.observers = this.actionBlocks.Select(ab => ab.AsObserver()).ToArray();

            this.completionTasks = new Task[settings.ParallelOptions.MaxDegreeOfParallelism];
            // root closure
            {
                var vw = this.vws[0];
                this.completionTasks[0] = this.actionBlocks[0].Completion
                                          .ContinueWith(_ =>
                {
                    // perform final AllReduce
                    vw.EndOfPass();

                    // execute synchronization actions
                    foreach (var syncAction in this.syncActions.RemoveAll())
                    {
                        syncAction(vw);
                    }
                });
            }

            for (int i = 1; i < this.vws.Length; i++)
            {
                // perform final AllReduce
                var vw = this.vws[i];
                this.completionTasks[i] = this.actionBlocks[i].Completion
                                          .ContinueWith(_ => vw.EndOfPass(), this.Settings.ParallelOptions.CancellationToken);
            }
        }
Ejemplo n.º 25
0
 /// <summary>
 /// Initializes a new instance of the <see cref="VowpalWabbitDynamic"/> class.
 /// </summary>
 /// <param name="settings">Arguments passed to native instance.</param>
 public VowpalWabbitDynamic(VowpalWabbitSettings settings)
 {
     this.vw               = new VowpalWabbit(settings);
     this.serializers      = new Dictionary <Type, IDisposable>();
     this.serializeMethods = new Dictionary <Type, MethodInfo>();
 }
Ejemplo n.º 26
0
        public static IVowpalWabbitSerializerCompiler <TExample> TryCreate <TExample>(VowpalWabbitSettings settings, Schema schema)
        {
            // check for _multi
            var multiFeature = schema.Features.FirstOrDefault(fe => fe.Name == settings.PropertyConfiguration.MultiProperty);

            if (multiFeature == null)
            {
                return(null);
            }

            // multi example path
            // IEnumerable<> or Array
            var adfType = InspectionHelper.GetEnumerableElementType(multiFeature.FeatureType);

            if (adfType == null)
            {
                throw new ArgumentException(settings.PropertyConfiguration.MultiProperty + " property must be array or IEnumerable<>. Actual type: " + multiFeature.FeatureType);
            }

            var compilerType = typeof(VowpalWabbitMultiExampleSerializerCompilerImpl <,>).MakeGenericType(typeof(TExample), adfType);

            return((IVowpalWabbitSerializerCompiler <TExample>)Activator.CreateInstance(compilerType, settings, schema, multiFeature));
        }
Ejemplo n.º 27
0
 protected override VowpalWabbitThreadedPredictionBase <VowpalWabbit <TContext> > CreatePool(VowpalWabbitSettings settings)
 {
     return(new VowpalWabbitThreadedPrediction <TContext>(settings));
 }
 internal VowpalWabbitExampleJsonValidator(VowpalWabbitSettings settings)
 {
     this.vw = new VowpalWabbit(settings.ShallowCopy(enableStringExampleGeneration: true));
 }
Ejemplo n.º 29
0
        internal async Task TestAllReduceInternal()
        {
            var data = Enumerable.Range(1, 1000).Select(_ => Generator.GenerateShared(10)).ToList();

            var stringSerializerCompiled    = VowpalWabbitSerializerFactory.CreateSerializer <CbAdfShared>(new VowpalWabbitSettings(enableStringExampleGeneration: true));
            var stringSerializerAdfCompiled = VowpalWabbitSerializerFactory.CreateSerializer <CbAdfAction>(new VowpalWabbitSettings(enableStringExampleGeneration: true));

            var stringData = new List <List <string> >();

            VowpalWabbitPerformanceStatistics statsExpected;

            using (var spanningTree = new SpanningTreeClr())
            {
                spanningTree.Start();

                using (var vw1 = new VowpalWabbit(new VowpalWabbitSettings(@"--total 2 --node 1 --unique_id 0 --span_server localhost --cb_adf --rank_all --interact xy", enableStringExampleGeneration: true)))
                    using (var vw2 = new VowpalWabbit(new VowpalWabbitSettings(@"--total 2 --node 0 --unique_id 0 --span_server localhost --cb_adf --rank_all --interact xy", enableStringExampleGeneration: true)))
                    {
                        var stringSerializer    = stringSerializerCompiled.Func(vw1);
                        var stringSerializerAdf = stringSerializerAdfCompiled.Func(vw1);

                        // serialize
                        foreach (var d in data)
                        {
                            var block = new List <string>();

                            using (var context = new VowpalWabbitMarshalContext(vw1))
                            {
                                stringSerializer(context, d.Item1, SharedLabel.Instance);
                                block.Add(context.StringExample.ToString());
                            }

                            block.AddRange(d.Item2.Select((a, i) =>
                            {
                                using (var context = new VowpalWabbitMarshalContext(vw1))
                                {
                                    stringSerializerAdf(context, a, i == d.Item3.Action ? d.Item3 : null);
                                    return(context.StringExample.ToString());
                                }
                            }));

                            stringData.Add(block);
                        }

                        await Task.WhenAll(
                            Task.Factory.StartNew(() => Ingest(vw1, stringData.Take(500))),
                            Task.Factory.StartNew(() => Ingest(vw2, stringData.Skip(500))));

                        vw1.SaveModel("expected.1.model");
                        vw2.SaveModel("expected.2.model");

                        statsExpected = vw1.PerformanceStatistics;
                    }
            }

            // skip header
            var expected1Model = File.ReadAllBytes("expected.1.model").Skip(0x15).ToList();
            var expected2Model = File.ReadAllBytes("expected.2.model").Skip(0x15).ToList();

            var settings = new VowpalWabbitSettings("--cb_adf --rank_all --interact xy",
                                                    parallelOptions: new ParallelOptions
            {
                MaxDegreeOfParallelism = 2
            },
                                                    exampleCountPerRun: 2000,
                                                    exampleDistribution: VowpalWabbitExampleDistribution.RoundRobin);

            using (var vw = new VowpalWabbitThreadedLearning(settings))
            {
                await Task.WhenAll(
                    Task.Factory.StartNew(() => Ingest(vw, stringData.Take(500))),
                    Task.Factory.StartNew(() => Ingest(vw, stringData.Skip(500))));

                // important to enqueue the request before Complete() is called
                var statsTask = vw.PerformanceStatistics;
                var modelSave = vw.SaveModel("actual.model");

                await vw.Complete();

                var statsActual = await statsTask;
                VWTestHelper.AssertEqual(statsExpected, statsActual);

                await modelSave;

                // skip header
                var actualModel = File.ReadAllBytes("actual.model").Skip(0x15).ToList();

                CollectionAssert.AreEqual(expected1Model, actualModel);
                CollectionAssert.AreEqual(expected2Model, actualModel);
            }

            using (var vw = new VowpalWabbitThreadedLearning(settings))
            {
                var vwManaged = vw.Create <CbAdfShared, CbAdfAction>();

                await Task.WhenAll(
                    Task.Factory.StartNew(() => Ingest(vwManaged, data.Take(500))),
                    Task.Factory.StartNew(() => Ingest(vwManaged, data.Skip(500))));

                // important to enqueue the request before Complete() is called
                var statsTask = vw.PerformanceStatistics;
                var modelSave = vw.SaveModel("actual.managed.model");

                await vw.Complete();

                var statsActual = await statsTask;
                VWTestHelper.AssertEqual(statsExpected, statsActual);

                await modelSave;

                // skip header
                var actualModel = File.ReadAllBytes("actual.managed.model").Skip(0x15).ToList();

                CollectionAssert.AreEqual(expected1Model, actualModel);
                CollectionAssert.AreEqual(expected2Model, actualModel);
            }
        }
Ejemplo n.º 30
0
        private async Task<bool> TryLoadModel()
        {
            // find the model blob
            if (string.IsNullOrEmpty(this.state.ModelName))
            {
                this.telemetry.TrackTrace("Model not specified");
                return false;
            }

            var container = this.blobClient.GetContainerReference(this.settings.StorageContainerName);
            if (!await container.ExistsAsync())
            {
                this.telemetry.TrackTrace($"Storage container missing '{this.settings.StorageContainerName}'");
                return false;
            }

            var modelBlob = container.GetBlockBlobReference(this.state.ModelName);
            if (!await modelBlob.ExistsAsync())
            {
                this.telemetry.TrackTrace($"Model blob '{this.state.ModelName}' is missing");
                return false;
            }

            // load the model
            using (var modelStream = await modelBlob.OpenReadAsync())
            {
                this.InitializeVowpalWabbit(new VowpalWabbitSettings { ModelStream = modelStream });
                this.telemetry.TrackTrace($"Model loaded {this.state.ModelName}", SeverityLevel.Verbose);

                // validate that the loaded VW model has the same settings as requested by C&C
                var newSettings = new VowpalWabbitSettings(this.settings.Metadata.TrainArguments);
                using (var newVW = new VW.VowpalWabbit(newSettings))
                {
                    newVW.ID = this.vw.ID;

                    // save the VW instance to a model and load again to get fully expanded parameters.
                    string newVWarguments;
                    using (var tempModel = new MemoryStream())
                    {
                        newVW.SaveModel(tempModel);
                        tempModel.Position = 0;

                        using (var tempVW = new VW.VowpalWabbit(new VowpalWabbitSettings { ModelStream = tempModel }))
                        {
                            newVWarguments = CleanVowpalWabbitArguments(tempVW.Arguments.CommandLine);
                        }
                    }

                    var oldVWarguments = CleanVowpalWabbitArguments(this.vw.Arguments.CommandLine);

                    // this is the expanded command line
                    if (newVWarguments != oldVWarguments)
                    {
                        this.telemetry.TrackTrace("New VowpalWabbit settings found. Discarding existing model",
                            SeverityLevel.Information,
                            new Dictionary<string, string>
                            {
                            { "TrainArguments", newVW.Arguments.CommandLine },
                            { "NewExpandedArguments", newVWarguments },
                            { "OldExpandedArgumentsCleaned", oldVWarguments },
                            { "OldExpandedArguments", this.vw.Arguments.CommandLine },
                            });

                        // discard old, use fresh
                        this.vw.Dispose();
                        this.vw = null;

                        this.InitializeVowpalWabbit(newSettings);
                    }
                }
            }

            // store the initial model
            this.settings.InitialVowpalWabbitModel = this.state.ModelName;

            return true;
        }
 internal VowpalWabbitExampleJsonValidator(VowpalWabbitSettings settings)
 {
     settings = (VowpalWabbitSettings)settings.Clone();
     settings.EnableStringExampleGeneration = true;
     this.vw = new VowpalWabbit(settings);
 }
Ejemplo n.º 32
0
 public Schema CreateSchema(VowpalWabbitSettings settings, Type type)
 {
     return TypeInspector.CreateSchema(type,
         featurePropertyPredicate: (_, attr) => attr != null,
         labelPropertyPredicate: (_, attr) => attr != null);
 }
Ejemplo n.º 33
0
        private void InitializeVowpalWabbit(VowpalWabbitSettings vwSettings)
        {
            if (this.settings.EnableExampleTracing)
            {
                vwSettings.EnableStringExampleGeneration = true;
                vwSettings.EnableStringFloatCompact = true;
            }

            vwSettings.EnableThreadSafeExamplePooling = true;
            vwSettings.MaxExamples = 64 * 1024;

            this.vw = new VW.VowpalWabbit(vwSettings);
            this.referenceResolver = new VowpalWabbitJsonReferenceResolver(
                this.delayedExampleCallback,
                cacheRequestItemPolicyFactory:
                    key => new CacheItemPolicy()
                    {
                        SlidingExpiration = TimeSpan.FromHours(1),
                        RemovedCallback = this.CacheEntryRemovedCallback
                    });
                
            //this.vwAllReduce = new VowpalWabbitThreadedLearning(vwSettings.ShallowCopy(
            //    maxExampleQueueLengthPerInstance: 4*1024,
            //    parallelOptions: new ParallelOptions
            //    {
            //        MaxDegreeOfParallelism = 2,
            //    },
            //    exampleDistribution: VowpalWabbitExampleDistribution.RoundRobin,
            //    exampleCountPerRun: 128 * 1024));
        }
Ejemplo n.º 34
0
        /// <summary>
        /// Creates a serializer for the given type and settings.
        /// </summary>
        /// <typeparam name="TExample">The user type to serialize.</typeparam>
        /// <param name="settings"></param>
        /// <returns></returns>
        public static VowpalWabbitSerializerCompiled <TExample> CreateSerializer <TExample>(VowpalWabbitSettings settings = null)
        {
            List <FeatureExpression> allFeatures = null;

            Type cacheKey = null;

            if (settings != null && settings.AllFeatures != null)
            {
                allFeatures = settings.AllFeatures;
            }
            else
            {
                // only cache non-string generating serializer
                if (!settings.EnableStringExampleGeneration)
                {
                    cacheKey = typeof(TExample);
                    object serializer;

                    if (SerializerCache.TryGetValue(cacheKey, out serializer))
                    {
                        return((VowpalWabbitSerializerCompiled <TExample>)serializer);
                    }
                }

                // TOOD: enhance caching based on feature list & featurizer set
                // if no feature mapping is provided, use [Feature] annotation on provided type.

                Func <PropertyInfo, FeatureAttribute, bool> propertyPredicate = null;
                switch (settings.FeatureDiscovery)
                {
                case VowpalWabbitFeatureDiscovery.Default:
                    propertyPredicate = (_, attr) => attr != null;
                    break;

                case VowpalWabbitFeatureDiscovery.All:
                    propertyPredicate = (_, __) => true;
                    break;
                }

                allFeatures = AnnotationInspector.ExtractFeatures(typeof(TExample), propertyPredicate).ToList();
            }

            // need at least a single feature to do something sensible
            if (allFeatures == null || allFeatures.Count == 0)
            {
                return(null);
            }

            var newSerializer = new VowpalWabbitSerializerCompiled <TExample>(
                allFeatures,
                settings == null ? null : settings.CustomFeaturizer,
                !settings.EnableStringExampleGeneration);

            if (cacheKey != null)
            {
                SerializerCache[cacheKey] = newSerializer;
            }

            return(newSerializer);
        }
Ejemplo n.º 35
0
 /// <summary>
 /// Initializes a new instance of the <see cref="VowpalWabbitDynamic"/> class.
 /// </summary>
 /// <param name="settings">Arguments passed to native instance.</param>
 public VowpalWabbitDynamic(VowpalWabbitSettings settings)
 {
     this.vw = new VowpalWabbit(settings);
 }
Ejemplo n.º 36
0
 public Schema CreateSchema(VowpalWabbitSettings settings, Type type)
 {
     return(JsonTypeInspector.CreateSchema(type, settings.PropertyConfiguration));
 }
Ejemplo n.º 37
0
 /// <summary>
 /// Sub classes must override and create a new VW pool.
 /// </summary>
 protected abstract VowpalWabbitThreadedPredictionBase <TVowpalWabbit> CreatePool(VowpalWabbitSettings settings);
Ejemplo n.º 38
0
 /// <summary>
 /// Initializes a new instance of the <see cref="VowpalWabbitDynamic"/> class.
 /// </summary>
 /// <param name="settings">Arguments passed to native instance.</param>
 public VowpalWabbitDynamic(VowpalWabbitSettings settings)
 {
     this.vw = new VowpalWabbit(settings);
 }
Ejemplo n.º 39
0
 internal VowpalWabbitExampleJsonValidator(VowpalWabbitSettings settings)
 {
     settings = (VowpalWabbitSettings)settings.Clone();
     settings.EnableStringExampleGeneration = true;
     this.vw = new VowpalWabbit(settings);
 }
        internal static IVowpalWabbitSerializerCompiler <TExample> TryCreate <TExample>(VowpalWabbitSettings settings, List <FeatureExpression> allFeatures)
        {
            // check for _multi
            var multiFeature = allFeatures.FirstOrDefault(fe => fe.Name == VowpalWabbitConstants.MultiProperty);

            if (multiFeature == null)
            {
                return(null);
            }

            // multi example path
            // IEnumerable<> or Array
            var adfType = InspectionHelper.GetEnumerableElementType(multiFeature.FeatureType);

            if (adfType == null)
            {
                throw new ArgumentException("_multi property must be array or IEnumerable<>. Actual type: " + multiFeature.FeatureType);
            }

            var compilerType = typeof(VowpalWabbitMultiExampleSerializerCompilerImpl <,>).MakeGenericType(typeof(TExample), adfType);

            return((IVowpalWabbitSerializerCompiler <TExample>)Activator.CreateInstance(compilerType, settings, allFeatures, multiFeature));
        }
Ejemplo n.º 41
0
        /// <summary>
        /// Creates a serializer for the given type and settings.
        /// </summary>
        /// <typeparam name="TExample">The user type to serialize.</typeparam>
        /// <param name="settings"></param>
        /// <returns></returns>
        public static IVowpalWabbitSerializerCompiler <TExample> CreateSerializer <TExample>(VowpalWabbitSettings settings = null)
        {
            Schema schema = null;

            Type cacheKey = null;

            if (settings != null && settings.Schema != null)
            {
                schema = settings.Schema;
            }
            else
            {
                // only cache non-string generating serializer
                if (!settings.EnableStringExampleGeneration)
                {
                    cacheKey = typeof(TExample);
                    object serializer;

                    if (SerializerCache.TryGetValue(cacheKey, out serializer))
                    {
                        return((IVowpalWabbitSerializerCompiler <TExample>)serializer);
                    }
                }

                if (settings.FeatureDiscovery == VowpalWabbitFeatureDiscovery.Json)
                {
                    schema = AnnotationJsonInspector.CreateSchema(typeof(TExample), settings.PropertyConfiguration);

                    var multiExampleSerializerCompiler = VowpalWabbitMultiExampleSerializerCompiler.TryCreate <TExample>(settings, schema);
                    if (multiExampleSerializerCompiler != null)
                    {
                        return(multiExampleSerializerCompiler);
                    }
                }
                else
                {
                    // TODO: enhance caching based on feature list & featurizer set
                    // if no feature mapping is provided, use [Feature] annotation on provided type.

                    Func <PropertyInfo, FeatureAttribute, bool> propertyPredicate = null;
                    Func <PropertyInfo, LabelAttribute, bool>   labelPredicate    = null;
                    switch (settings.FeatureDiscovery)
                    {
                    case VowpalWabbitFeatureDiscovery.Default:
                        propertyPredicate = (_, attr) => attr != null;
                        labelPredicate    = (_, attr) => attr != null;
                        break;

                    case VowpalWabbitFeatureDiscovery.All:
                        propertyPredicate = (_, __) => true;
                        labelPredicate    = (_, __) => true;
                        break;
                    }

                    schema = AnnotationInspector.CreateSchema(typeof(TExample), propertyPredicate, labelPredicate);
                }
            }

            // need at least a single feature to do something sensible
            if (schema == null || schema.Features.Count == 0)
            {
                return(null);
            }

            var newSerializer = new VowpalWabbitSingleExampleSerializerCompiler <TExample>(
                schema,
                settings == null ? null : settings.CustomFeaturizer,
                !settings.EnableStringExampleGeneration);

            if (cacheKey != null)
            {
                SerializerCache[cacheKey] = newSerializer;
            }

            return(newSerializer);
        }