示例#1
0
        public static MwtExplorer <TContext, TAction, TPolicyValue> Create <TContext, TAction, TPolicyValue>(
            string appId,
            int numActions,
            IRecorder <TContext, TAction> recorder,
            IExplorer <TAction, TPolicyValue> explorer,
            IContextMapper <TContext, TPolicyValue> policy = null,
            IFullExplorer <TAction> initialExplorer        = null)
        {
            var mwt = new MwtExplorer <TContext, TAction, TPolicyValue>(appId, new ConstantActionsProvider <TContext>(numActions), recorder, explorer, initialExplorer);

            mwt.Policy = policy;
            return(mwt);
        }
示例#2
0
 ExploitUntilModelReady <TContext>
     (this DecisionServiceClient <TContext> that, IContextMapper <TContext, ActionProbability[]> initialPolicy)
 {
     that.InitialPolicy = initialPolicy;
     return(that);
 }
示例#3
0
        private static void GenericWithContext <TContext>(int numActions, TContext testContext, IExplorer <int, float[]> explorer, IContextMapper <TContext, float[]> scorer)
            where TContext : RegularTestContext
        {
            string uniqueKey = "ManagedTestId";
            var    recorder  = new TestRecorder <TContext>();

            //var mwtt = new MwtExplorer<TContext>("mwt", recorder);
            var mwtt = MwtExplorer.Create("mwt", numActions, recorder, explorer, scorer);

            int chosenAction = mwtt.ChooseAction(uniqueKey, testContext);

            var interactions = recorder.GetAllInteractions();

            Assert.AreEqual(1, interactions.Count);
            Assert.AreEqual(testContext.Id, interactions[0].Context.Id);
        }
示例#4
0
        private static void SoftmaxWithContext <TContext>(int numActions, IExplorer <int, float[]> explorer, IContextMapper <TContext, float[]> scorer, TContext[] contexts)
            where TContext : RegularTestContext
        {
            var recorder = new TestRecorder <TContext>();
            //var mwtt = new MwtExplorer<TContext>("mwt", recorder);
            var mwtt = MwtExplorer.Create("mwt", numActions, recorder, explorer, scorer);

            uint[] actions = new uint[numActions];

            Random rand = new Random();

            for (uint i = 0; i < contexts.Length; i++)
            {
                var uniqueId     = rand.NextDouble().ToString();
                int chosenAction = mwtt.ChooseAction(uniqueId, contexts[i]);
                actions[chosenAction - 1]++; // action id is one-based
            }

            for (uint i = 0; i < numActions; i++)
            {
                Assert.IsTrue(actions[i] > 0);
            }

            var interactions = recorder.GetAllInteractions();

            Assert.AreEqual(contexts.Length, interactions.Count);

            for (int i = 0; i < contexts.Length; i++)
            {
                Assert.AreEqual(i, interactions[i].Context.Id);
            }
        }
示例#5
0
        public DecisionServiceClient(
            DecisionServiceConfiguration config,
            ApplicationClientMetadata metaData,
            IContextMapper <TContext, ActionProbability[]> internalPolicy,
            IContextMapper <TContext, ActionProbability[]> initialPolicy = null,
            IFullExplorer <int[]> initialFullExplorer = null,
            IInitialExplorer <ActionProbability[], int[]> initialExplorer = null)
        {
            if (config == null)
            {
                throw new ArgumentNullException("config");
            }

            if (config.InteractionUploadConfiguration == null)
            {
                config.InteractionUploadConfiguration = new JoinUploader.BatchingConfiguration(config.DevelopmentMode);
            }

            if (config.ObservationUploadConfiguration == null)
            {
                config.ObservationUploadConfiguration = new JoinUploader.BatchingConfiguration(config.DevelopmentMode);
            }

            this.config = config;
            string appId = string.Empty;

            this.metaData = metaData;

            if (config.OfflineMode)
            {
                this.recorder = new OfflineRecorder();
                if (config.OfflineApplicationID == null)
                {
                    throw new ArgumentNullException("OfflineApplicationID", "Offline Application ID must be set explicitly in offline mode.");
                }
                appId = config.OfflineApplicationID;
            }
            else
            {
                if (metaData == null)
                {
                    throw new Exception("Unable to locate a registered MWT application.");
                }

                if (this.recorder == null)
                {
                    var joinServerLogger = new JoinServiceLogger <TContext, int[]>(metaData.ApplicationID, config.DevelopmentMode); // TODO: check token remove
                    switch (config.JoinServerType)
                    {
                    case JoinServerType.CustomSolution:
                        joinServerLogger.InitializeWithCustomAzureJoinServer(
                            config.LoggingServiceAddress,
                            config.InteractionUploadConfiguration);
                        break;

                    case JoinServerType.AzureStreamAnalytics:
                    default:
                        joinServerLogger.InitializeWithAzureStreamAnalyticsJoinServer(
                            metaData.EventHubInteractionConnectionString,
                            metaData.EventHubObservationConnectionString,
                            config.InteractionUploadConfiguration,
                            config.ObservationUploadConfiguration);
                        break;
                    }
                    this.recorder = (IRecorder <TContext, int[]>)joinServerLogger;
                }

                var settingsBlobPollDelay = config.PollingForSettingsPeriod == TimeSpan.Zero ? DecisionServiceConstants.PollDelay : config.PollingForSettingsPeriod;
                if (settingsBlobPollDelay != TimeSpan.MinValue)
                {
                    this.settingsDownloader             = new AzureBlobBackgroundDownloader(config.SettingsBlobUri, settingsBlobPollDelay, downloadImmediately: false, storageConnectionString: config.AzureStorageConnectionString);
                    this.settingsDownloader.Downloaded += this.UpdateSettings;
                    this.settingsDownloader.Failed     += settingsDownloader_Failed;
                }

                var modelBlobPollDelay = config.PollingForModelPeriod == TimeSpan.Zero ? DecisionServiceConstants.PollDelay : config.PollingForModelPeriod;
                if (modelBlobPollDelay != TimeSpan.MinValue)
                {
                    this.modelDownloader             = new AzureBlobBackgroundDownloader(metaData.ModelBlobUri, modelBlobPollDelay, downloadImmediately: true, storageConnectionString: config.AzureStorageConnectionString);
                    this.modelDownloader.Downloaded += this.UpdateContextMapper;
                    this.modelDownloader.Failed     += modelDownloader_Failed;
                }

                appId = metaData.ApplicationID;
            }

            this.logger         = this.recorder as ILogger;
            this.internalPolicy = internalPolicy;
            this.initialPolicy  = initialPolicy;

            if (initialExplorer != null && initialPolicy != null)
            {
                throw new Exception("Initial Explorer and Default Policy are both specified but only one can be used.");
            }

            var explorer = new GenericTopSlotExplorer();

            // explorer used if model not ready and defaultAction provided
            if (initialExplorer == null)
            {
                initialExplorer = new EpsilonGreedyInitialExplorer(this.metaData.InitialExplorationEpsilon);
            }

            // explorer used if model not ready and no default action provided
            if (initialFullExplorer == null)
            {
                initialFullExplorer = new PermutationExplorer(1);
            }

            var match = Regex.Match(metaData.TrainArguments ?? string.Empty, @"--cb_explore\s+(?<numActions>\d+)");

            if (match.Success)
            {
                var numActions = int.Parse(match.Groups["numActions"].Value);
                this.numActionsProvider = new ConstantNumActionsProvider(numActions);

                this.mwtExplorer = MwtExplorer.Create(appId,
                                                      numActions, this.recorder, explorer, initialPolicy, initialFullExplorer, initialExplorer);
            }
            else
            {
                if (initialExplorer != null || metaData.InitialExplorationEpsilon == 1f) // only needed when full exploration
                {
                    numActionsProvider = internalPolicy as INumberOfActionsProvider <TContext>;
                    if (numActionsProvider == null)
                    {
                        numActionsProvider = explorer as INumberOfActionsProvider <TContext>;
                    }

                    if (numActionsProvider == null)
                    {
                        throw new ArgumentException("Explorer must implement INumberOfActionsProvider interface");
                    }
                }

                this.mwtExplorer = MwtExplorer.Create(appId,
                                                      numActionsProvider, this.recorder, explorer, initialPolicy, initialFullExplorer, initialExplorer);
            }
        }
示例#6
0
 public static DecisionServiceClient <TContext> Create <TContext>(DecisionServiceConfiguration config, IContextMapper <TContext, ActionProbability[]> contextMapper, ApplicationClientMetadata metaData = null)
 {
     return(new DecisionServiceClient <TContext>(
                config,
                DownloadMetadata(config, metaData),
                contextMapper,
                // TODO: cleanup. this means that the context mapper passed in needs to be able to score from the beginning
                contextMapper));
 }