Exemplo n.º 1
0
        /// <summary>
        /// Trains a model using SAR.
        /// </summary>
        /// <param name="settings">The training settings</param>
        /// <param name="usageEvents">The usage events to use for training</param>
        /// <param name="catalogItems">The catalog items to use for training</param>
        /// <param name="uniqueUsersCount">The number of users in the user id index file.</param>
        /// <param name="uniqueUsageItemsCount">The number of usage items in the item id index file</param>
        /// <param name="cancellationToken">A cancellation token</param>
        public IPredictorModel Train(ITrainingSettings settings,
                                     IList <SarUsageEvent> usageEvents,
                                     IList <SarCatalogItem> catalogItems,
                                     int uniqueUsersCount,
                                     int uniqueUsageItemsCount,
                                     CancellationToken cancellationToken)
        {
            if (settings == null)
            {
                throw new ArgumentNullException(nameof(settings));
            }

            if (usageEvents == null)
            {
                throw new ArgumentNullException(nameof(usageEvents));
            }

            if (settings.EnableColdItemPlacement && catalogItems == null)
            {
                throw new ArgumentNullException(nameof(catalogItems));
            }

            if (uniqueUsersCount < 0)
            {
                var exception = new ArgumentException($"{nameof(uniqueUsersCount)} must be a positive integer");
                _tracer.TraceWarning(exception.ToString());
                throw exception;
            }

            if (uniqueUsageItemsCount < 0)
            {
                var exception = new ArgumentException($"{nameof(uniqueUsageItemsCount)} must be a positive integer");
                _tracer.TraceWarning(exception.ToString());
                throw exception;
            }

            cancellationToken.ThrowIfCancellationRequested();

            using (TlcEnvironment environment = new TlcEnvironment(verbose: true))
            {
                _detectedFeatureWeights = null;
                try
                {
                    environment.AddListener <ChannelMessage>(ChannelMessageListener);
                    IHost environmentHost = environment.Register("SarHost");

                    // bind the cancellation token to SAR cancellation
                    using (cancellationToken.Register(() => { environmentHost.StopExecution(); }))
                    {
                        _tracer.TraceInformation("Starting training model using SAR");
                        return(TrainModel(environmentHost, settings, usageEvents, catalogItems, uniqueUsersCount,
                                          uniqueUsageItemsCount));
                    }
                }
                finally
                {
                    environment.RemoveListener <ChannelMessage>(ChannelMessageListener);
                }
            }
        }
Exemplo n.º 2
0
        /// <summary>
        /// Creates a new instance of <see cref="SarScorer"/> class.
        /// </summary>
        /// <param name="recommender">A trained SAR recommender</param>
        /// <param name="tracer">A message tracer to use for logging</param>
        public SarScorer(IUserHistoryToItemsRecommender recommender, ITracer tracer = null)
        {
            if (recommender == null)
            {
                throw new ArgumentNullException(nameof(recommender));
            }

            _recommender = recommender;
            _tracer      = tracer ?? new DefaultTracer();

            // create the input schema
            _usageDataSchema = SchemaDefinition.Create(typeof(SarUsageEvent));
            _usageDataSchema["user"].ColumnType = _recommender.UserIdType;
            _usageDataSchema["Item"].ColumnType = _recommender.ItemIdType;

            // create an environment and register a message listener
            _environment = new TlcEnvironment(verbose: true);
            _environment.AddListener <ChannelMessage>(_tracer.TraceChannelMessage);
        }
Exemplo n.º 3
0
        private ModelMetrics ComputeMetrics(
            IList <SarUsageEvent> usageEvents,
            IList <SarUsageEvent> evaluationEvents,
            IList <SarScoreResult> scores,
            CancellationToken cancellationToken)
        {
            if (!scores.Any() || !usageEvents.Any() || !evaluationEvents.Any())
            {
                _tracer.TraceWarning(
                    $"Operation '{nameof(ComputeMetrics)}' returning empty results. Scores: '{scores.Count}', Usage Events: '{usageEvents.Count}', Evaluation Events: '{evaluationEvents.Count}'");
                return(new ModelMetrics());
            }

            // convert the usage items to the evaluation format
            List <SarEvaluationUsageEvent> usageEventsFormatted = usageEvents.Select(ToEvaluationUsageEvent).ToList();

            // convert the evaluation usage items to the evaluation format
            List <SarEvaluationUsageEvent> evaluationEventsFormatted =
                evaluationEvents.Select(ToEvaluationUsageEvent).ToList();

            // convert the scores items to the evaluation format
            List <SarEvaluationUsageEvent> scoresFormatted = scores.Select(ToEvaluationUsageEvent).ToList();

            using (TlcEnvironment environment = new TlcEnvironment())
            {
                environment.AddListener <ChannelMessage>(_tracer.TraceChannelMessage);

                // Create a precision evaluator.
                PrecisionAtKEvaluator precisionEvaluator = new PrecisionAtKEvaluator(
                    environment,
                    new PrecisionAtKEvaluator.Arguments
                {
                    k = MaxPrecisionK
                },
                    environment.CreateStreamingDataView(scoresFormatted),
                    environment.CreateStreamingDataView(evaluationEventsFormatted));
                cancellationToken.ThrowIfCancellationRequested();

                // Create a diversity evaluator.
                DiversityAtKEvaluator diversityEvaluator = new DiversityAtKEvaluator(
                    environment,
                    new DiversityAtKEvaluator.Arguments
                {
                    buckets = DiversityBuckets
                },
                    environment.CreateStreamingDataView(scoresFormatted),
                    environment.CreateStreamingDataView(usageEventsFormatted));
                cancellationToken.ThrowIfCancellationRequested();

                // Compute Precision metrics
                IList <PrecisionAtKEvaluator.MetricItem> precisionMetrics =
                    precisionEvaluator.Evaluate()
                    .AsEnumerable <PrecisionAtKEvaluator.MetricItem>(environment, false)
                    .ToList();
                var modelPrecisionMetrics = precisionMetrics.Select(
                    metric => new PrecisionMetric
                {
                    K           = (int)metric.K,
                    Percentage  = Math.Round(metric.PrecisionAtK * 100, 3),
                    UsersInTest = (int?)metric.TotalUsers
                }).ToList();

                cancellationToken.ThrowIfCancellationRequested();

                // Compute Diversity metrics
                IList <DiversityAtKEvaluator.MetricItem> diversityMetrics =
                    diversityEvaluator.Evaluate()
                    .AsEnumerable <DiversityAtKEvaluator.MetricItem>(environment, false)
                    .ToList();

                ModelDiversityMetrics modelDiversityMetrics =
                    new ModelDiversityMetrics
                {
                    PercentileBuckets = diversityMetrics.Select(bucket => new PercentileBucket
                    {
                        Min        = (int)bucket.BucketMin,
                        Max        = (bool)(bucket.BucketLim == 101) ? 100 : (int)bucket.BucketLim,
                        Percentage = Math.Round(bucket.RecommendedItemsFraction * 100, 3)
                    }).ToList(),

                    UniqueItemsRecommended = (int?)diversityMetrics.First().DistinctRecommendations,
                    TotalItemsRecommended  = (int?)diversityMetrics.First().TotalRecommendations,
                    UniqueItemsInTrainSet  = (int?)diversityMetrics.First().TotalItemsEvaluated
                };

                return(new ModelMetrics
                {
                    ModelPrecisionMetrics = modelPrecisionMetrics,
                    ModelDiversityMetrics = modelDiversityMetrics
                });
            }
        }
        /// <summary>
        /// Trains a model using SAR.
        /// </summary>
        /// <param name="settings">The training settings</param>
        /// <param name="usageEvents">The usage events to use for training</param>
        /// <param name="catalogItems">The catalog items to use for training</param>
        /// <param name="featureNames">The names of the catalog items features, in the same order as the feature values in the catalog</param>
        /// <param name="uniqueUsersCount">The number of users in the user id index file.</param>
        /// <param name="uniqueUsageItemsCount">The number of usage items in the item id index file</param>
        /// <param name="catalogFeatureWeights">The computed catalog items features weights (if relevant)</param>
        /// <param name="cancellationToken">A cancellation token</param>
        public IPredictorModel Train(ITrainingSettings settings,
                                     IList <SarUsageEvent> usageEvents,
                                     IList <SarCatalogItem> catalogItems,
                                     string[] featureNames,
                                     int uniqueUsersCount,
                                     int uniqueUsageItemsCount,
                                     out IDictionary <string, double> catalogFeatureWeights,
                                     CancellationToken cancellationToken)
        {
            if (settings == null)
            {
                throw new ArgumentNullException(nameof(settings));
            }

            if (usageEvents == null)
            {
                throw new ArgumentNullException(nameof(usageEvents));
            }

            if (settings.EnableColdItemPlacement && catalogItems == null)
            {
                throw new ArgumentNullException(nameof(catalogItems));
            }

            if (uniqueUsersCount < 0)
            {
                var exception = new ArgumentException($"{nameof(uniqueUsersCount)} must be a positive integer");
                _tracer.TraceWarning(exception.ToString());
                throw exception;
            }

            if (uniqueUsageItemsCount < 0)
            {
                var exception = new ArgumentException($"{nameof(uniqueUsageItemsCount)} must be a positive integer");
                _tracer.TraceWarning(exception.ToString());
                throw exception;
            }

            cancellationToken.ThrowIfCancellationRequested();

            using (TlcEnvironment environment = new TlcEnvironment(verbose: true))
            {
                _detectedFeatureWeights = null;
                try
                {
                    environment.AddListener <ChannelMessage>(ChannelMessageListener);
                    IHost environmentHost = environment.Register("SarHost");

                    // bind the cancellation token to SAR cancellation
                    using (cancellationToken.Register(() => { environmentHost.StopExecution(); }))
                    {
                        _tracer.TraceInformation("Starting training model using SAR");
                        IPredictorModel model = TrainModel(environmentHost, settings, usageEvents, catalogItems, uniqueUsersCount,
                                                           uniqueUsageItemsCount);

                        catalogFeatureWeights = new Dictionary <string, double>();
                        if (_detectedFeatureWeights != null && featureNames != null)
                        {
                            if (_detectedFeatureWeights.Length == featureNames.Length)
                            {
                                for (int i = 0; i < featureNames.Length; i++)
                                {
                                    catalogFeatureWeights[featureNames[i]] = _detectedFeatureWeights[i];
                                }
                            }
                            else
                            {
                                _tracer.TraceWarning(
                                    $"Found a mismatch between number of feature names ({featureNames.Length}) and the number of feature weights ({_detectedFeatureWeights.Length})");
                            }
                        }

                        return(model);
                    }
                }
                finally
                {
                    environment.RemoveListener <ChannelMessage>(ChannelMessageListener);
                }
            }
        }
Exemplo n.º 5
0
        private static bool TryParseFile(IChannel ch, TextLoader.Arguments args, IMultiStreamSource source, bool skipStrictValidation, out ColumnSplitResult result)
        {
            result = default(ColumnSplitResult);
            try
            {
                // No need to provide information from unsuccessful loader, so we create temporary environment and get information from it in case of success
                using (var loaderEnv = new TlcEnvironment(0, true))
                {
                    var messages = new ConcurrentBag <ChannelMessage>();
                    loaderEnv.AddListener <ChannelMessage>(
                        (src, msg) =>
                    {
                        messages.Add(msg);
                    });
                    var  idv          = TextLoader.ReadFile(loaderEnv, args, source).Take(1000);
                    var  columnCounts = new List <int>();
                    int  columnIndex;
                    bool found = idv.Schema.TryGetColumnIndex("C", out columnIndex);
                    ch.Assert(found);

                    using (var cursor = idv.GetRowCursor(x => x == columnIndex))
                    {
                        var getter = cursor.GetGetter <VBuffer <DvText> >(columnIndex);

                        VBuffer <DvText> line = default(VBuffer <DvText>);
                        while (cursor.MoveNext())
                        {
                            getter(ref line);
                            columnCounts.Add(line.Length);
                        }
                    }

                    Contracts.Check(columnCounts.Count > 0);
                    var mostCommon = columnCounts.GroupBy(x => x).OrderByDescending(x => x.Count()).First();
                    if (!skipStrictValidation && mostCommon.Count() < UniformColumnCountThreshold * columnCounts.Count)
                    {
                        return(false);
                    }

                    // If user explicitly specified separator we're allowing "single" column case;
                    // Otherwise user will see message informing that we were not able to detect any columns.
                    if (!skipStrictValidation && mostCommon.Key <= 1)
                    {
                        return(false);
                    }

                    result = new ColumnSplitResult(true, args.Separator, args.AllowQuoting, args.AllowSparse, mostCommon.Key);
                    ch.Trace("Discovered {0} columns using separator '{1}'", mostCommon.Key, args.Separator);
                    foreach (var msg in messages)
                    {
                        ch.Send(msg);
                    }
                    return(true);
                }
            }
            catch (Exception ex)
            {
                if (!ex.IsMarked())
                {
                    throw;
                }
                // For known exceptions, we just continue to the next separator candidate.
            }
            return(false);
        }