public async Task TestAzureTrainer() { using (var trainer = new OnlineTrainerWrapper("--cb_explore_adf --epsilon 0.2 -q ab")) { trainer.Blobs.Cleanup().Wait(); // generate data var data = GenerateData(100).ToList(); var dataMap = data.ToDictionary(d => d.EventId, d => d); // start listening for event hub await trainer.StartAsync(new CountingCheckpointPolicy(data.Count)); // send data to event hub trainer.SendData(data); // wait for trainer to checkpoint await trainer.PollTrainerCheckpoint(blobs => blobs.ModelBlobs.Count > 0 && blobs.ModelTrackbackBlobs.Count > 0 && blobs.StateJsonBlobs.Count > 0); // download & parse trackback file trainer.Blobs.DownloadTrackbacksOrderedByTime(); Assert.AreEqual(1, trainer.Blobs.Trackbacks.Count); var trackback = trainer.Blobs.Trackbacks[0]; Assert.AreEqual(data.Count, trackback.EventIds.Count); Assert.AreEqual(1, trainer.Blobs.ModelBlobs.Count); trainer.TrainOffline("train a model for this set of events", trackback.ModelId, dataMap, trackback.EventIds, trainer.Blobs.ModelBlobs[0].Uri); } }
private async Task <OnlineTrainerWrapper> RunTrainer(string args, IEnumerable <Context> data, Dictionary <string, Context> dataMap, int expectedNumStates, bool cleanBlobs) { var trainer = new OnlineTrainerWrapper("--cb_explore_adf --epsilon 0.1 -q ab -l 0.1"); if (cleanBlobs) { trainer.Blobs.Cleanup().Wait(); } // start listening for event hub await trainer.StartAsync(new CountingCheckpointPolicy(100)); // send data to event hub trainer.SendData(data); await trainer.PollTrainerCheckpoint(blobs => blobs.ModelBlobs.Count == expectedNumStates && blobs.ModelTrackbackBlobs.Count == expectedNumStates && blobs.StateJsonBlobs.Count == expectedNumStates); // download & parse trackback file trainer.Blobs.DownloadTrackbacksOrderedByTime(); foreach (var trackback in trainer.Blobs.Trackbacks) { // due to checkpoint policy = 100 Assert.AreEqual(100, trackback.EventIds.Count, $"{trackback.Blob.Uri} does not contain the expected 100 events. Actual: {trackback.EventIds.Count}"); } return(trainer); }