Esempio n. 1
0
        public async Task TestAzureTrainer()
        {
            using (var trainer = new OnlineTrainerWrapper("--cb_explore_adf --epsilon 0.2 -q ab"))
            {
                trainer.Blobs.Cleanup().Wait();

                // generate data
                var data    = GenerateData(100).ToList();
                var dataMap = data.ToDictionary(d => d.EventId, d => d);

                // start listening for event hub
                await trainer.StartAsync(new CountingCheckpointPolicy(data.Count));

                // send data to event hub
                trainer.SendData(data);

                // wait for trainer to checkpoint
                await trainer.PollTrainerCheckpoint(blobs => blobs.ModelBlobs.Count > 0 && blobs.ModelTrackbackBlobs.Count > 0 && blobs.StateJsonBlobs.Count > 0);

                // download & parse trackback file
                trainer.Blobs.DownloadTrackbacksOrderedByTime();
                Assert.AreEqual(1, trainer.Blobs.Trackbacks.Count);

                var trackback = trainer.Blobs.Trackbacks[0];
                Assert.AreEqual(data.Count, trackback.EventIds.Count);
                Assert.AreEqual(1, trainer.Blobs.ModelBlobs.Count);

                trainer.TrainOffline("train a model for this set of events", trackback.ModelId, dataMap, trackback.EventIds, trainer.Blobs.ModelBlobs[0].Uri);
            }
        }
Esempio n. 2
0
        private async Task <OnlineTrainerWrapper> RunTrainer(string args, IEnumerable <Context> data, Dictionary <string, Context> dataMap, int expectedNumStates, bool cleanBlobs)
        {
            var trainer = new OnlineTrainerWrapper("--cb_explore_adf --epsilon 0.1 -q ab -l 0.1");

            if (cleanBlobs)
            {
                trainer.Blobs.Cleanup().Wait();
            }

            // start listening for event hub
            await trainer.StartAsync(new CountingCheckpointPolicy(100));

            // send data to event hub
            trainer.SendData(data);

            await trainer.PollTrainerCheckpoint(blobs =>
                                                blobs.ModelBlobs.Count == expectedNumStates &&
                                                blobs.ModelTrackbackBlobs.Count == expectedNumStates &&
                                                blobs.StateJsonBlobs.Count == expectedNumStates);

            // download & parse trackback file
            trainer.Blobs.DownloadTrackbacksOrderedByTime();

            foreach (var trackback in trainer.Blobs.Trackbacks)
            {
                // due to checkpoint policy = 100
                Assert.AreEqual(100, trackback.EventIds.Count, $"{trackback.Blob.Uri} does not contain the expected 100 events. Actual: {trackback.EventIds.Count}");
            }

            return(trainer);
        }