public byte[] GetCBModelBlobContent(int numExamples, int numFeatures, int numActions, string vwArgs) { Random rg = new Random(numExamples + numFeatures); string localOutputDir = "test"; string vwFileName = Path.Combine(localOutputDir, string.Format("test_vw_{0}.model", numExamples)); using (var vw = new VowpalWabbit <TestRcv1Context>(vwArgs)) { // Create examples for (int ie = 0; ie < numExamples; ie++) { // Create features var context = TestRcv1Context.CreateRandom(numActions, numFeatures, rg); vw.Learn(context, context.Label); } vw.Native.SaveModel(vwFileName); } byte[] vwModelBytes = File.ReadAllBytes(vwFileName); Directory.Delete(localOutputDir, recursive: true); return(vwModelBytes); }
public void TestRcv1ModelUpdateFromStream() { joinServer.Reset(); int numActions = 10; int numFeatures = 1024; string vwArgs = "--cb_explore 10 --epsilon 0.5"; commandCenter.CreateBlobs(createSettingsBlob: true, createModelBlob: false, vwArgs: vwArgs); var dsConfig = new DecisionServiceConfiguration(MockCommandCenter.SettingsBlobUri) //explorer: new EpsilonGreedyExplorer<TestRcv1Context>(new ConstantPolicy<TestRcv1Context>(ctx => ctx.ActionDependentFeatures.Count), epsilon: 0.5f, numActions: (int)numActions)) { JoinServerType = JoinServerType.CustomSolution, LoggingServiceAddress = MockJoinServer.MockJoinServerAddress, PollingForModelPeriod = TimeSpan.MinValue, PollingForSettingsPeriod = TimeSpan.MinValue }; using (var ds = DecisionService .Create <TestRcv1Context>(dsConfig, TypeInspector.Default) // TODOD: .WithEpsilonGreedy(.5f) .ExploitUntilModelReady(new ConstantPolicy <TestRcv1Context>(ctx => ctx.Features.Count))) { string uniqueKey = "eventid"; for (int i = 1; i <= 100; i++) { Random rg = new Random(i); if (i % 50 == 1) { int modelIndex = i / 50; byte[] modelContent = commandCenter.GetCBModelBlobContent(numExamples: 3 + modelIndex, numFeatures: numFeatures, numActions: numActions, vwArgs: vwArgs); using (var modelStream = new MemoryStream(modelContent)) { ds.UpdateModel(modelStream); } } var context = TestRcv1Context.CreateRandom(numActions, numFeatures, rg); DateTime timeStamp = DateTime.UtcNow; int action = ds.ChooseAction(uniqueKey, context); // verify the actions are in the expected range Assert.IsTrue(action >= 1 && action <= numActions); ds.ReportReward(i / 100f, uniqueKey); } } Assert.AreEqual(200, joinServer.EventBatchList.Sum(b => b.ExperimentalUnitFragments.Count)); }