コード例 #1
0
        private async Task StartInternalAsync(OnlineTrainerSettingsInternal settings, OnlineTrainerState state = null, byte[] model = null)
        {
            this.LastStartDateTimeUtc = DateTime.UtcNow;
            this.perfCounters = new PerformanceCounters(settings.Metadata.ApplicationID);

            // setup trainer
            this.trainer = new Learner(settings, this.DelayedExampleCallback, this.perfCounters);

            if (settings.ForceFreshStart || model != null)
                this.trainer.FreshStart(state, model);
            else
                await this.trainer.FindAndResumeFromState();

            // setup factory
            this.trainProcessorFactory = new TrainEventProcessorFactory(settings, this.trainer, this.perfCounters);

            // setup host
            var serviceBusConnectionStringBuilder = new ServiceBusConnectionStringBuilder(settings.JoinedEventHubConnectionString);
            var joinedEventhubName = serviceBusConnectionStringBuilder.EntityPath;
            serviceBusConnectionStringBuilder.EntityPath = string.Empty;

            this.eventProcessorHost = new EventProcessorHost(settings.Metadata.ApplicationID, joinedEventhubName,
                EventHubConsumerGroup.DefaultGroupName, serviceBusConnectionStringBuilder.ToString(), settings.StorageConnectionString);

            await this.eventProcessorHost.RegisterEventProcessorFactoryAsync(
                this.trainProcessorFactory,
                new EventProcessorOptions { InitialOffsetProvider = this.InitialOffsetProvider });

            // don't perform too often
            this.perfUpdater = new SafeTimer(
                TimeSpan.FromMilliseconds(500),
                this.UpdatePerformanceCounters);

            this.telemetry.TrackTrace(
                "OnlineTrainer started",
                SeverityLevel.Information,
                new Dictionary<string, string>
                {
                { "CheckpointPolicy", settings.CheckpointPolicy.ToString() },
                { "VowpalWabbit", settings.Metadata.TrainArguments },
                { "ExampleTracing", settings.EnableExampleTracing.ToString() }
                });
        }
コード例 #2
0
        private async Task StopInternalAsync()
        {
            this.telemetry.TrackTrace("OnlineTrainer stopping", SeverityLevel.Verbose);

            if (this.perfUpdater != null)
            {
                this.perfUpdater.Stop(TimeSpan.FromMinutes(1));
                this.perfUpdater = null;
            }

            if (this.eventProcessorHost != null)
            {
                try
                {
                    await this.eventProcessorHost.UnregisterEventProcessorAsync();
                }
                catch (Exception ex)
                {
                    this.telemetry.TrackException(ex);
                }

                this.eventProcessorHost = null;
            }

            if (this.trainProcessorFactory != null)
            {
                // flushes the pipeline
                this.trainProcessorFactory.Dispose();
                this.trainProcessorFactory = null;
            }

            if (this.trainer != null)
            {
                this.trainer.Dispose();
                this.trainer = null;
            }

            if (this.perfCounters != null)
            {
                this.perfCounters.Dispose();
                this.perfCounters = null;
            }

            this.telemetry.TrackTrace("OnlineTrainer stopped", SeverityLevel.Verbose);
        }
コード例 #3
0
        internal TrainEventProcessorFactory(OnlineTrainerSettingsInternal settings, Learner trainer, PerformanceCounters performanceCounters)
        {
            if (settings == null)
                throw new ArgumentNullException(nameof(settings));

            if (trainer == null)
                throw new ArgumentNullException(nameof(trainer));

            if (performanceCounters == null)
                throw new ArgumentNullException(nameof(performanceCounters));

            this.trainer = trainer;
            this.performanceCounters = performanceCounters;

            this.telemetry = new TelemetryClient();
            this.telemetry.Context.Component.Version = GetType().Assembly.GetName().Version.ToString();

            this.evalOperation = new EvalOperation(settings);
            this.latencyOperation = new LatencyOperation();

            this.deserializeBlock = new TransformManyBlock<PipelineData, PipelineData>(
                (Func<PipelineData, IEnumerable<PipelineData>>)this.Stage1_Deserialize,
                new ExecutionDataflowBlockOptions
                {
                    MaxDegreeOfParallelism = 4, // Math.Max(2, Environment.ProcessorCount - 1),
                    BoundedCapacity = 1024
                });
            this.deserializeBlock.Completion.Trace(this.telemetry, "Stage 1 - Deserialization");

            this.learnBlock = new TransformManyBlock<object, object>(
                (Func<object, IEnumerable<object>>)this.Stage2_ProcessEvent,
                new ExecutionDataflowBlockOptions
                {
                    MaxDegreeOfParallelism = 1,
                    BoundedCapacity = 1024
                });
            this.learnBlock.Completion.Trace(this.telemetry, "Stage 2 - Learning");

            // trigger checkpoint checking every second
            this.checkpointTrigger = Observable.Interval(TimeSpan.FromSeconds(1))
                .Select(_ => new CheckpointEvaluateTriggerEvent())
                .Subscribe(this.learnBlock.AsObserver());

            this.checkpointBlock = new ActionBlock<object>(
                this.trainer.Checkpoint,
                new ExecutionDataflowBlockOptions
                {
                    MaxDegreeOfParallelism = 1,
                    BoundedCapacity = 4
                });
            this.learnBlock.Completion.Trace(this.telemetry, "Stage 3 - CheckPointing");

            // setup pipeline
            this.deserializeBlock.LinkTo(
                this.learnBlock,
                new DataflowLinkOptions { PropagateCompletion = true });

            this.learnBlock.LinkTo(
                this.evalOperation.TargetBlock,
                new DataflowLinkOptions { PropagateCompletion = true },
                obj => obj is TrainerResult);

            this.learnBlock.LinkTo(
                this.checkpointBlock,
                new DataflowLinkOptions { PropagateCompletion = true },
                obj => obj is CheckpointData);

            // consume all unmatched
            this.learnBlock.LinkTo(DataflowBlock.NullTarget<object>());
        }