private void EndToEnd(MwtExplorer <SimpleContext> mwtt, IExplorer <SimpleContext> explorer, TestRecorder <SimpleContext> recorder) { uint numActions = 10; Random rand = new Random(); List <float> rewards = new List <float>(); for (int i = 0; i < 1000; i++) { Feature[] f = new Feature[rand.Next(800, 1201)]; for (int j = 0; j < f.Length; j++) { f[j].Id = (uint)(j + 1); f[j].Value = (float)rand.NextDouble(); } SimpleContext c = new SimpleContext(f); mwtt.ChooseAction(explorer, i.ToString(), c); rewards.Add((float)rand.NextDouble()); } var testInteractions = recorder.GetAllInteractions(); Interaction[] partialInteractions = new Interaction[testInteractions.Count]; for (int i = 0; i < testInteractions.Count; i++) { partialInteractions[i] = new Interaction() { ApplicationContext = new OldSimpleContext(testInteractions[i].Context.GetFeatures(), null), ChosenAction = testInteractions[i].Action, Probability = testInteractions[i].Probability, Id = testInteractions[i].UniqueKey }; } MwtRewardReporter mrr = new MwtRewardReporter(partialInteractions); for (int i = 0; i < partialInteractions.Length; i++) { Assert.AreEqual(true, mrr.ReportReward(partialInteractions[i].GetId(), rewards[i])); } Interaction[] completeInteractions = mrr.GetAllInteractions(); MwtOptimizer mop = new MwtOptimizer(completeInteractions, numActions); string modelFile = "model"; mop.OptimizePolicyVWCSOAA(modelFile); Assert.IsTrue(System.IO.File.Exists(modelFile)); float evaluatedValue = mop.EvaluatePolicyVWCSOAA(modelFile); Assert.IsFalse(float.IsNaN(evaluatedValue)); System.IO.File.Delete(modelFile); }
public static void Run() { string interactionFile = "serialized.txt"; MwtLogger logger = new MwtLogger(interactionFile); MwtExplorer mwt = new MwtExplorer("test", logger); uint numActions = 10; float epsilon = 0.2f; uint tau = 0; uint bags = 2; float lambda = 0.5f; int policyParams = 1003; CustomParams customParams = new CustomParams() { Value1 = policyParams, Value2 = policyParams + 1 }; /*** Initialize Epsilon-Greedy explore algorithm using a default policy function that accepts parameters ***/ mwt.InitializeEpsilonGreedy <int>(epsilon, new StatefulPolicyDelegate <int>(SampleStatefulPolicyFunc), policyParams, numActions); /*** Initialize Epsilon-Greedy explore algorithm using a stateless default policy function ***/ //mwt.InitializeEpsilonGreedy(epsilon, new StatelessPolicyDelegate(SampleStatelessPolicyFunc), numActions); /*** Initialize Tau-First explore algorithm using a default policy function that accepts parameters ***/ //mwt.InitializeTauFirst<CustomParams>(tau, new StatefulPolicyDelegate<CustomParams>(SampleStatefulPolicyFunc), customParams, numActions); /*** Initialize Tau-First explore algorithm using a stateless default policy function ***/ //mwt.InitializeTauFirst(tau, new StatelessPolicyDelegate(SampleStatelessPolicyFunc), numActions); /*** Initialize Bagging explore algorithm using a default policy function that accepts parameters ***/ //StatefulPolicyDelegate<int>[] funcs = //{ // new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc), // new StatefulPolicyDelegate<int>(SampleStatefulPolicyFunc2) //}; //int[] parameters = { policyParams, policyParams }; //mwt.InitializeBagging<int>(bags, funcs, parameters, numActions); /*** Initialize Bagging explore algorithm using a stateless default policy function ***/ //StatelessPolicyDelegate[] funcs = //{ // new StatelessPolicyDelegate(SampleStatelessPolicyFunc), // new StatelessPolicyDelegate(SampleStatelessPolicyFunc2) //}; //mwt.InitializeBagging(bags, funcs, numActions); /*** Initialize Softmax explore algorithm using a default policy function that accepts parameters ***/ //mwt.InitializeSoftmax<int>(lambda, new StatefulScorerDelegate<int>(SampleStatefulScorerFunc), policyParams, numActions); /*** Initialize Softmax explore algorithm using a stateless default policy function ***/ //mwt.InitializeSoftmax(lambda, new StatelessScorerDelegate(SampleStatelessScorerFunc), numActions); FEATURE[] f = new FEATURE[2]; f[0].X = 0.5f; f[0].Index = 1; f[1].X = 0.9f; f[1].Index = 2; string otherContext = "Some other context data that might be helpful to log"; CONTEXT context = new CONTEXT(f, otherContext); UInt32 chosenAction = mwt.ChooseAction(context, "myId"); INTERACTION[] interactions = mwt.GetAllInteractions(); mwt.Unintialize(); MwtRewardReporter mrr = new MwtRewardReporter(interactions); string joinKey = "myId"; float reward = 0.5f; if (!mrr.ReportReward(joinKey, reward)) { throw new Exception(); } MwtOptimizer mot = new MwtOptimizer(interactions, numActions); float eval1 = mot.EvaluatePolicy(new StatefulPolicyDelegate <int>(SampleStatefulPolicyFunc), policyParams); mot.OptimizePolicyVWCSOAA("model_file"); float eval2 = mot.EvaluatePolicyVWCSOAA("model_file"); Console.WriteLine(chosenAction); Console.WriteLine(interactions); logger.Flush(); // Create a new logger to read back interaction data logger = new MwtLogger(interactionFile); INTERACTION[] inters = logger.GetAllInteractions(); // Load and save reward data to file string rewardFile = "rewards.txt"; RewardStore rewardStore = new RewardStore(rewardFile); rewardStore.Add(new float[2] { 1.0f, 0.4f }); rewardStore.Flush(); // Read back reward data rewardStore = new RewardStore(rewardFile); float[] rewards = rewardStore.GetAllRewards(); }