コード例 #1
0
        private async Task ResetInternalAsync(OnlineTrainerState state = null, byte[] model = null)
        {
            if (this.trainer == null)
            {
                this.telemetry.TrackTrace("Online Trainer resetting skipped as trainer hasn't started yet.", SeverityLevel.Information);
                return;
            }

            var  msg = "Online Trainer resetting";
            bool updateClientModel = false;

            if (state != null)
            {
                msg += "; state supplied";
                updateClientModel = true;
            }
            if (model != null)
            {
                msg += $"; model of size {model.Length} supplied.";
                updateClientModel = true;
            }

            this.telemetry.TrackTrace(msg, SeverityLevel.Information);

            var settings = this.trainer.Settings;

            await this.StopInternalAsync();

            settings.ForceFreshStart = true;
            settings.CheckpointPolicy.Reset();

            await this.StartInternalAsync(settings, state, model);

            // make sure we store this fresh model, in case we die we don't loose the reset
            await this.trainProcessorFactory.LearnBlock.SendAsync(new CheckpointTriggerEvent { UpdateClientModel = updateClientModel });

            if (!updateClientModel)
            {
                // delete the currently deployed model, so the clients don't use the hold one
                var latestModel = await this.trainer.GetLatestModelBlob();

                this.telemetry.TrackTrace($"Resetting client visible model: {latestModel.Uri}", SeverityLevel.Information);
                await latestModel.UploadFromByteArrayAsync(new byte[0], 0, 0);
            }
        }
コード例 #2
0
        internal void FreshStart(OnlineTrainerState state = null, byte[] model = null)
        {
            if (state == null)
            {
                state = new OnlineTrainerState();
            }

            this.telemetry.TrackTrace("Fresh Start", SeverityLevel.Information);

            // start from scratch
            this.state = state;

            // save extra state so learning can be resumed later with new data
            var settings = new VowpalWabbitSettings("--save_resume --preserve_performance_counters " + this.settings.Metadata.TrainArguments);

            if (model != null)
            {
                settings.ModelStream = new MemoryStream(model);
            }

            this.InitializeVowpalWabbit(settings);
        }
コード例 #3
0
        internal void FreshStart(OnlineTrainerState state = null, byte[] model = null)
        {
            if (state == null)
            {
                state = new OnlineTrainerState();
            }

            this.telemetry.TrackTrace("Fresh Start", SeverityLevel.Information);

            // start from scratch
            this.state = state;

            // save extra state so learning can be resumed later with new data
            var baseArguments = "--save_resume";

            var settings = model == null ?
                           new VowpalWabbitSettings(baseArguments + " " + this.settings.Metadata.TrainArguments) :
                           new VowpalWabbitSettings(baseArguments)
            {
                ModelStream = new MemoryStream(model)
            };

            this.InitializeVowpalWabbit(settings);
        }
コード例 #4
0
 public async Task ResetModelAsync(OnlineTrainerState state = null, byte[] model = null)
 {
     await this.SafeExecute(async() => await this.ResetInternalAsync(state, model));
 }
コード例 #5
0
        private async Task StartInternalAsync(OnlineTrainerSettingsInternal settings, OnlineTrainerState state = null, byte[] model = null)
        {
            this.LastStartDateTimeUtc = DateTime.UtcNow;
            this.perfCounters         = new PerformanceCounters(settings.Metadata.ApplicationID);

            // setup trainer
            this.trainer = new Learner(settings, this.DelayedExampleCallback, this.perfCounters);

            if (settings.ForceFreshStart || model != null)
            {
                this.trainer.FreshStart(state, model);
            }
            else
            {
                await this.trainer.FindAndResumeFromState();
            }

            // setup factory
            this.trainProcessorFactory = new TrainEventProcessorFactory(settings, this.trainer, this.perfCounters);

            // setup host
            var serviceBusConnectionStringBuilder = new ServiceBusConnectionStringBuilder(settings.JoinedEventHubConnectionString);
            var joinedEventhubName = serviceBusConnectionStringBuilder.EntityPath;

            serviceBusConnectionStringBuilder.EntityPath = string.Empty;

            this.eventProcessorHost = new EventProcessorHost(settings.Metadata.ApplicationID, joinedEventhubName,
                                                             EventHubConsumerGroup.DefaultGroupName, serviceBusConnectionStringBuilder.ToString(), settings.StorageConnectionString);

            // used by this.InitialOffsetProvider if no checkpointed state is found
            this.eventHubStartDateTimeUtc = settings.EventHubStartDateTimeUtc;

            await this.eventProcessorHost.RegisterEventProcessorFactoryAsync(
                this.trainProcessorFactory,
                new EventProcessorOptions { InitialOffsetProvider = this.InitialOffsetProvider });

            // don't perform too often
            this.perfUpdater = new SafeTimer(
                TimeSpan.FromMilliseconds(500),
                this.UpdatePerformanceCounters);

            this.telemetry.TrackTrace(
                "OnlineTrainer started",
                SeverityLevel.Information,
                new Dictionary <string, string>
            {
                { "CheckpointPolicy", settings.CheckpointPolicy.ToString() },
                { "VowpalWabbit", settings.Metadata.TrainArguments },
                { "ExampleTracing", settings.EnableExampleTracing.ToString() }
            });
        }