예제 #1
0
 public static DecisionServiceClient <string> CreateJson(DecisionServiceConfiguration config, ApplicationClientMetadata metaData = null)
 {
     return(new DecisionServiceClient <string>(
                config,
                DownloadMetadata(config, metaData),
                new VWJsonExplorer(config.ModelStream, config.DevelopmentMode)));
 }
예제 #2
0
 public static DecisionServiceClient <TContext> Create <TContext>(DecisionServiceConfiguration config, ITypeInspector typeInspector = null, ApplicationClientMetadata metaData = null)
 {
     return(new DecisionServiceClient <TContext>(
                config,
                DownloadMetadata(config, metaData),
                new VWExplorer <TContext>(config.ModelStream, typeInspector, config.DevelopmentMode)));
 }
예제 #3
0
 public static DecisionServiceClient <TContext> Create <TContext>(DecisionServiceConfiguration config, IContextMapper <TContext, ActionProbability[]> contextMapper, ApplicationClientMetadata metaData = null)
 {
     return(new DecisionServiceClient <TContext>(
                config,
                DownloadMetadata(config, metaData),
                contextMapper,
                // TODO: cleanup. this means that the context mapper passed in needs to be able to score from the beginning
                contextMapper));
 }
예제 #4
0
        private static ApplicationClientMetadata DownloadMetadata(DecisionServiceConfiguration config, ApplicationClientMetadata metaData)
        {
            if (!config.OfflineMode || metaData == null)
            {
                metaData = ApplicationMetadataUtil.DownloadMetadata <ApplicationClientMetadata>(config.SettingsBlobUri);
                if (config.LogAppInsights)
                {
                    Trace.Listeners.Add(new ApplicationInsights.TraceListener.ApplicationInsightsTraceListener(metaData.AppInsightsKey));
                }
            }

            return(metaData);
        }
예제 #5
0
        public DecisionServiceClient(
            DecisionServiceConfiguration config,
            ApplicationClientMetadata metaData,
            IContextMapper <TContext, ActionProbability[]> internalPolicy,
            IContextMapper <TContext, ActionProbability[]> initialPolicy = null,
            IFullExplorer <int[]> initialFullExplorer = null,
            IInitialExplorer <ActionProbability[], int[]> initialExplorer = null)
        {
            if (config == null)
            {
                throw new ArgumentNullException("config");
            }

            if (config.InteractionUploadConfiguration == null)
            {
                config.InteractionUploadConfiguration = new JoinUploader.BatchingConfiguration(config.DevelopmentMode);
            }

            if (config.ObservationUploadConfiguration == null)
            {
                config.ObservationUploadConfiguration = new JoinUploader.BatchingConfiguration(config.DevelopmentMode);
            }

            this.config = config;
            string appId = string.Empty;

            this.metaData = metaData;

            if (config.OfflineMode)
            {
                this.recorder = new OfflineRecorder();
                if (config.OfflineApplicationID == null)
                {
                    throw new ArgumentNullException("OfflineApplicationID", "Offline Application ID must be set explicitly in offline mode.");
                }
                appId = config.OfflineApplicationID;
            }
            else
            {
                if (metaData == null)
                {
                    throw new Exception("Unable to locate a registered MWT application.");
                }

                if (this.recorder == null)
                {
                    var joinServerLogger = new JoinServiceLogger <TContext, int[]>(metaData.ApplicationID, config.DevelopmentMode); // TODO: check token remove
                    switch (config.JoinServerType)
                    {
                    case JoinServerType.CustomSolution:
                        joinServerLogger.InitializeWithCustomAzureJoinServer(
                            config.LoggingServiceAddress,
                            config.InteractionUploadConfiguration);
                        break;

                    case JoinServerType.AzureStreamAnalytics:
                    default:
                        joinServerLogger.InitializeWithAzureStreamAnalyticsJoinServer(
                            metaData.EventHubInteractionConnectionString,
                            metaData.EventHubObservationConnectionString,
                            config.InteractionUploadConfiguration,
                            config.ObservationUploadConfiguration);
                        break;
                    }
                    this.recorder = (IRecorder <TContext, int[]>)joinServerLogger;
                }

                var settingsBlobPollDelay = config.PollingForSettingsPeriod == TimeSpan.Zero ? DecisionServiceConstants.PollDelay : config.PollingForSettingsPeriod;
                if (settingsBlobPollDelay != TimeSpan.MinValue)
                {
                    this.settingsDownloader             = new AzureBlobBackgroundDownloader(config.SettingsBlobUri, settingsBlobPollDelay, downloadImmediately: false, storageConnectionString: config.AzureStorageConnectionString);
                    this.settingsDownloader.Downloaded += this.UpdateSettings;
                    this.settingsDownloader.Failed     += settingsDownloader_Failed;
                }

                var modelBlobPollDelay = config.PollingForModelPeriod == TimeSpan.Zero ? DecisionServiceConstants.PollDelay : config.PollingForModelPeriod;
                if (modelBlobPollDelay != TimeSpan.MinValue)
                {
                    this.modelDownloader             = new AzureBlobBackgroundDownloader(metaData.ModelBlobUri, modelBlobPollDelay, downloadImmediately: true, storageConnectionString: config.AzureStorageConnectionString);
                    this.modelDownloader.Downloaded += this.UpdateContextMapper;
                    this.modelDownloader.Failed     += modelDownloader_Failed;
                }

                appId = metaData.ApplicationID;
            }

            this.logger         = this.recorder as ILogger;
            this.internalPolicy = internalPolicy;
            this.initialPolicy  = initialPolicy;

            if (initialExplorer != null && initialPolicy != null)
            {
                throw new Exception("Initial Explorer and Default Policy are both specified but only one can be used.");
            }

            var explorer = new GenericTopSlotExplorer();

            // explorer used if model not ready and defaultAction provided
            if (initialExplorer == null)
            {
                initialExplorer = new EpsilonGreedyInitialExplorer(this.metaData.InitialExplorationEpsilon);
            }

            // explorer used if model not ready and no default action provided
            if (initialFullExplorer == null)
            {
                initialFullExplorer = new PermutationExplorer(1);
            }

            var match = Regex.Match(metaData.TrainArguments ?? string.Empty, @"--cb_explore\s+(?<numActions>\d+)");

            if (match.Success)
            {
                var numActions = int.Parse(match.Groups["numActions"].Value);
                this.numActionsProvider = new ConstantNumActionsProvider(numActions);

                this.mwtExplorer = MwtExplorer.Create(appId,
                                                      numActions, this.recorder, explorer, initialPolicy, initialFullExplorer, initialExplorer);
            }
            else
            {
                if (initialExplorer != null || metaData.InitialExplorationEpsilon == 1f) // only needed when full exploration
                {
                    numActionsProvider = internalPolicy as INumberOfActionsProvider <TContext>;
                    if (numActionsProvider == null)
                    {
                        numActionsProvider = explorer as INumberOfActionsProvider <TContext>;
                    }

                    if (numActionsProvider == null)
                    {
                        throw new ArgumentException("Explorer must implement INumberOfActionsProvider interface");
                    }
                }

                this.mwtExplorer = MwtExplorer.Create(appId,
                                                      numActionsProvider, this.recorder, explorer, initialPolicy, initialFullExplorer, initialExplorer);
            }
        }