コード例 #1
0
            private void ProcessPipeline(Sweeper.Algorithms.SweeperProbabilityUtils utils, Stopwatch stopwatch, PipelinePattern candidate, int numOfTrainingRows)
            {
                // Create a randomized numer of rows to do train/test with.
                int randomizedNumberOfRows =
                    (int)Math.Floor(utils.NormalRVs(1, numOfTrainingRows, (double)numOfTrainingRows / 10).First());

                if (randomizedNumberOfRows > numOfTrainingRows)
                {
                    randomizedNumberOfRows = numOfTrainingRows - (randomizedNumberOfRows - numOfTrainingRows);
                }

                // Run pipeline, and time how long it takes
                stopwatch.Restart();
                candidate.RunTrainTestExperiment(_trainData.Take(randomizedNumberOfRows),
                                                 _testData, Metric, TrainerKind, out var testMetricVal, out var trainMetricVal);
                stopwatch.Stop();

                // Handle key collisions on sorted list
                while (_sortedSampledElements.ContainsKey(testMetricVal))
                {
                    testMetricVal += 1e-10;
                }

                // Save performance score
                candidate.PerformanceSummary =
                    new RunSummary(testMetricVal, randomizedNumberOfRows, stopwatch.ElapsedMilliseconds, trainMetricVal);
                _sortedSampledElements.Add(candidate.PerformanceSummary.MetricValue, candidate);
                _history.Add(candidate);
            }
コード例 #2
0
        /// <summary>
        /// Auto-detect purpose for the data view columns.
        /// </summary>
        /// <param name="env">The host environment to use.</param>
        /// <param name="data">The data to use for inference.</param>
        /// <param name="columnIndices">Indices of columns that we're interested in.</param>
        /// <param name="args">Additional arguments to inference.</param>
        /// <returns>The result includes the array of auto-detected column purposes.</returns>
        public static InferenceResult InferPurposes(IHostEnvironment env, IDataView data, IEnumerable <int> columnIndices, Arguments args)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("InferPurposes");

            host.CheckValue(data, nameof(data));
            host.CheckValue(columnIndices, nameof(columnIndices));

            InferenceResult result;

            using (var ch = host.Start("InferPurposes"))
            {
                var takenData = data.Take(args.MaxRowsToRead);
                var cols      = columnIndices.Select(x => new IntermediateColumn(takenData, x)).ToArray();
                data = takenData;

                foreach (var expert in GetExperts())
                {
                    using (var expertChannel = host.Start(expert.GetType().ToString()))
                    {
                        expert.Apply(expertChannel, cols);
                        expertChannel.Done();
                    }
                }

                ch.Check(cols.All(x => x.IsPurposeSuggested), "Purpose inference must be conclusive");

                result = new InferenceResult(cols.Select(x => x.GetColumn()).ToArray());

                ch.Info("Automatic purpose inference complete");
                ch.Done();
            }
            return(result);
        }
コード例 #3
0
        private static IntermediateColumn[] InferPurposes(MLContext context, IDataView data, IEnumerable <int> columnIndices)
        {
            data = data.Take(MaxRowsToRead);
            var cols = columnIndices.Select(x => new IntermediateColumn(data, x)).ToArray();

            foreach (var expert in GetExperts())
            {
                expert.Apply(cols);
            }
            return(cols);
        }
コード例 #4
0
        /// <summary>
        /// Auto-detect purpose for the data view columns.
        /// </summary>
        /// <param name="env">The host environment to use.</param>
        /// <param name="data">The data to use for inference.</param>
        /// <param name="columnIndices">Indices of columns that we're interested in.</param>
        /// <param name="args">Additional arguments to inference.</param>
        /// <param name="dataRoles">(Optional) User defined Role mappings for data.</param>
        /// <returns>The result includes the array of auto-detected column purposes.</returns>
        public static InferenceResult InferPurposes(IHostEnvironment env, IDataView data, IEnumerable <int> columnIndices, Arguments args,
                                                    RoleMappedData dataRoles = null)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("InferPurposes");

            host.CheckValue(data, nameof(data));
            host.CheckValue(columnIndices, nameof(columnIndices));

            InferenceResult result;

            using (var ch = host.Start("InferPurposes"))
            {
                var takenData = data.Take(args.MaxRowsToRead);
                var cols      = columnIndices.Select(x => new IntermediateColumn(takenData, x)).ToList();
                data = takenData;

                if (dataRoles != null)
                {
                    var items = dataRoles.Schema.GetColumnRoles();
                    foreach (var item in items)
                    {
                        Enum.TryParse(item.Key.Value, out ColumnPurpose purpose);
                        var col = cols.Find(x => x.ColumnName == item.Value.Name);
                        col.SuggestedPurpose = purpose;
                    }
                }

                foreach (var expert in GetExperts())
                {
                    using (var expertChannel = host.Start(expert.GetType().ToString()))
                    {
                        expert.Apply(expertChannel, cols.ToArray());
                        expertChannel.Done();
                    }
                }

                ch.Check(cols.All(x => x.IsPurposeSuggested), "Purpose inference must be conclusive");

                result = new InferenceResult(cols.Select(x => x.GetColumn()).ToArray());

                ch.Info("Automatic purpose inference complete");
                ch.Done();
            }
            return(result);
        }
コード例 #5
0
        /// <summary>
        /// Automatically infer transforms for the data view
        /// </summary>
        public static SuggestedTransform[] InferTransforms(MLContext env, IDataView data, PurposeInference.Column[] purposes)
        {
            data     = data.Take(MaxRowsToRead);
            var cols = purposes.Where(x => !data.Schema[x.ColumnIndex].IsHidden).Select(x => new IntermediateColumn(data, x)).ToArray();
            var list = new List <SuggestedTransform>();
            var includeFeaturesOverride = false;

            foreach (var expert in GetExperts())
            {
                expert.IncludeFeaturesOverride = includeFeaturesOverride;
                SuggestedTransform[] suggestions = expert.Apply(cols).ToArray();
                includeFeaturesOverride |= expert.IncludeFeaturesOverride;

                list.AddRange(suggestions);
            }
            return(list.ToArray());
        }
コード例 #6
0
        /// <summary>
        /// Auto-detect purpose for the data view columns.
        /// </summary>
        public static PurposeInference.Column[] InferPurposes(MLContext context, IDataView data, string label,
                                                              IDictionary <string, ColumnPurpose> columnOverrides = null)
        {
            data = data.Take(MaxRowsToRead);

            var allColumns     = new List <IntermediateColumn>();
            var columnsToInfer = new List <IntermediateColumn>();

            for (var i = 0; i < data.Schema.Count; i++)
            {
                var column = data.Schema[i];
                IntermediateColumn intermediateCol;

                if (column.Name == label)
                {
                    intermediateCol = new IntermediateColumn(data, i, ColumnPurpose.Label);
                }
                else if (columnOverrides != null && columnOverrides.TryGetValue(column.Name, out var columnPurpose))
                {
                    intermediateCol = new IntermediateColumn(data, i, columnPurpose);
                }
                else
                {
                    intermediateCol = new IntermediateColumn(data, i);
                    columnsToInfer.Add(intermediateCol);
                }

                allColumns.Add(intermediateCol);
            }

            foreach (var expert in GetExperts())
            {
                expert.Apply(columnsToInfer.ToArray());
            }

            return(allColumns.Select(c => c.GetColumn()).ToArray());
        }
コード例 #7
0
        public static Type InferPredictorCategoryType(IDataView data, PurposeInference.Column[] columns)
        {
            List <PurposeInference.Column> labels = columns.Where(col => col.Purpose == ColumnPurpose.Label).ToList();

            if (labels.Count == 0)
            {
                return(typeof(SignatureClusteringTrainer));
            }

            if (labels.Count > 1)
            {
                return(typeof(SignatureMultiOutputRegressorTrainer));
            }

            PurposeInference.Column label             = labels.First();
            HashSet <string>        uniqueLabelValues = new HashSet <string>();

            data = data.Take(1000);
            using (var cursor = data.GetRowCursor(index => index == label.ColumnIndex))
            {
                ValueGetter <DvText> getter = DataViewUtils.PopulateGetterArray(cursor, new List <int> {
                    label.ColumnIndex
                })[0];
                while (cursor.MoveNext())
                {
                    var currentLabel = new DvText();
                    getter(ref currentLabel);
                    string currentLabelString = currentLabel.ToString();
                    if (!String.IsNullOrEmpty(currentLabelString) && !uniqueLabelValues.Contains(currentLabelString))
                    {
                        uniqueLabelValues.Add(currentLabelString);
                    }
                }
            }

            if (uniqueLabelValues.Count == 1)
            {
                return(typeof(SignatureAnomalyDetectorTrainer));
            }

            if (uniqueLabelValues.Count == 2)
            {
                return(typeof(SignatureBinaryClassifierTrainer));
            }

            if (uniqueLabelValues.Count > 2)
            {
                if ((label.ItemKind == DataKind.R4) &&
                    uniqueLabelValues.Any(val =>
                {
                    float fVal;
                    return(float.TryParse(val, out fVal) && (fVal > 50 || fVal < 0 || val.Contains('.')));
                }))
                {
                    return(typeof(SignatureRegressorTrainer));
                }

                if (label.ItemKind == DataKind.R4 ||
                    label.ItemKind == DataKind.TX ||
                    data.Schema.GetColumnType(label.ColumnIndex).IsKey)
                {
                    if (columns.Any(col => col.Purpose == ColumnPurpose.Group))
                    {
                        return(typeof(SignatureRankerTrainer));
                    }
                    else
                    {
                        return(typeof(SignatureMultiClassClassifierTrainer));
                    }
                }
            }

            return(null);
        }