示例#1
0
        public byte[] GetCBModelBlobContent(int numExamples, int numFeatures, int numActions, string vwArgs)
        {
            Random rg = new Random(numExamples + numFeatures);

            string localOutputDir = "test";
            string vwFileName     = Path.Combine(localOutputDir, string.Format("test_vw_{0}.model", numExamples));

            using (var vw = new VowpalWabbit <TestRcv1Context>(vwArgs))
            {
                // Create examples
                for (int ie = 0; ie < numExamples; ie++)
                {
                    // Create features
                    var context = TestRcv1Context.CreateRandom(numActions, numFeatures, rg);
                    vw.Learn(context, context.Label);
                }

                vw.Native.SaveModel(vwFileName);
            }

            byte[] vwModelBytes = File.ReadAllBytes(vwFileName);

            Directory.Delete(localOutputDir, recursive: true);

            return(vwModelBytes);
        }
示例#2
0
        public void TestRcv1ModelUpdateFromStream()
        {
            joinServer.Reset();

            int numActions  = 10;
            int numFeatures = 1024;

            string vwArgs = "--cb_explore 10 --epsilon 0.5";

            commandCenter.CreateBlobs(createSettingsBlob: true, createModelBlob: false, vwArgs: vwArgs);

            var dsConfig = new DecisionServiceConfiguration(MockCommandCenter.SettingsBlobUri)
                           //explorer: new EpsilonGreedyExplorer<TestRcv1Context>(new ConstantPolicy<TestRcv1Context>(ctx => ctx.ActionDependentFeatures.Count), epsilon: 0.5f, numActions: (int)numActions))
            {
                JoinServerType           = JoinServerType.CustomSolution,
                LoggingServiceAddress    = MockJoinServer.MockJoinServerAddress,
                PollingForModelPeriod    = TimeSpan.MinValue,
                PollingForSettingsPeriod = TimeSpan.MinValue
            };

            using (var ds = DecisionService
                            .Create <TestRcv1Context>(dsConfig, TypeInspector.Default)
                            // TODOD: .WithEpsilonGreedy(.5f)
                            .ExploitUntilModelReady(new ConstantPolicy <TestRcv1Context>(ctx => ctx.Features.Count)))
            {
                string uniqueKey = "eventid";

                for (int i = 1; i <= 100; i++)
                {
                    Random rg = new Random(i);

                    if (i % 50 == 1)
                    {
                        int    modelIndex   = i / 50;
                        byte[] modelContent = commandCenter.GetCBModelBlobContent(numExamples: 3 + modelIndex, numFeatures: numFeatures, numActions: numActions, vwArgs: vwArgs);
                        using (var modelStream = new MemoryStream(modelContent))
                        {
                            ds.UpdateModel(modelStream);
                        }
                    }

                    var context = TestRcv1Context.CreateRandom(numActions, numFeatures, rg);

                    DateTime timeStamp = DateTime.UtcNow;

                    int action = ds.ChooseAction(uniqueKey, context);

                    // verify the actions are in the expected range
                    Assert.IsTrue(action >= 1 && action <= numActions);

                    ds.ReportReward(i / 100f, uniqueKey);
                }
            }

            Assert.AreEqual(200, joinServer.EventBatchList.Sum(b => b.ExperimentalUnitFragments.Count));
        }