예제 #1
0
        /// <summary>
        /// Create method corresponding to SignatureDataTransform.
        /// </summary>
        public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(RegistrationName);

            host.CheckValue(args, nameof(args));
            host.CheckValue(input, nameof(input));
            args.Check(host);

            var scores = default(VBuffer <Single>);

            TrainCore(host, input, args, ref scores);

            using (var ch = host.Start("Dropping Slots"))
            {
                int selectedCount;
                var column = CreateDropSlotsColumn(args, ref scores, out selectedCount);

                if (column == null)
                {
                    ch.Info("No features are being dropped.");
                    return(NopTransform.CreateIfNeeded(host, input));
                }

                ch.Info(MessageSensitivity.Schema, "Selected {0} slots out of {1} in column '{2}'", selectedCount, scores.Length, args.FeatureColumn);

                var dsArgs = new DropSlotsTransform.Arguments();
                dsArgs.Column = new[] { column };
                return(new DropSlotsTransform(host, dsArgs, input));
            }
        }
예제 #2
0
        /// <summary>
        /// Create method corresponding to SignatureDataTransform.
        /// </summary>
        public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(RegistrationName);

            host.CheckValue(args, nameof(args));
            host.CheckValue(input, nameof(input));
            host.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column));
            host.CheckUserArg(args.SlotsInOutput > 0, nameof(args.SlotsInOutput));
            host.CheckNonWhiteSpace(args.LabelColumn, nameof(args.LabelColumn));
            host.Check(args.NumBins > 1, "numBins must be greater than 1.");

            using (var ch = host.Start("Selecting Slots"))
            {
                ch.Info("Computing mutual information");
                var sw = new Stopwatch();
                sw.Start();
                var colSet = new HashSet <string>();
                foreach (var col in args.Column)
                {
                    if (!colSet.Add(col))
                    {
                        ch.Warning("Column '{0}' specified multiple time.", col);
                    }
                }
                var colArr   = colSet.ToArray();
                var colSizes = new int[colArr.Length];
                var scores   = MutualInformationFeatureSelectionUtils.TrainCore(host, input, args.LabelColumn, colArr,
                                                                                args.NumBins, colSizes);
                sw.Stop();
                ch.Info("Finished mutual information computation in {0}", sw.Elapsed);

                ch.Info("Selecting features to drop");
                var threshold = ComputeThreshold(scores, args.SlotsInOutput, out int tiedScoresToKeep);

                var columns = CreateDropSlotsColumns(colArr, colArr.Length, scores, threshold, tiedScoresToKeep, out int[] selectedCount);

                if (columns.Count <= 0)
                {
                    ch.Info("No features are being dropped.");
                    return(NopTransform.CreateIfNeeded(host, input));
                }

                for (int i = 0; i < selectedCount.Length; i++)
                {
                    ch.Info("Selected {0} slots out of {1} in column '{2}'", selectedCount[i], colSizes[i], colArr[i]);
                }
                ch.Info("Total number of slots selected: {0}", selectedCount.Sum());

                var dsArgs = new DropSlotsTransform.Arguments();
                dsArgs.Column = columns.ToArray();
                var ds = new DropSlotsTransform(host, dsArgs, input);
                ch.Done();
                return(ds);
            }
        }
예제 #3
0
        protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema)
        {
            // Wrap with a DropSlots transform to pick only the first _numTopClusters slots.
            if (perInst.Schema.TryGetColumnIndex(ClusteringPerInstanceEvaluator.SortedClusters, out int index))
            {
                var type = perInst.Schema.GetColumnType(index);
                if (_numTopClusters < type.VectorSize)
                {
                    var args = new DropSlotsTransform.Arguments
                    {
                        Column = new DropSlotsTransform.Column[]
                        {
                            new DropSlotsTransform.Column()
                            {
                                Name  = ClusteringPerInstanceEvaluator.SortedClusters,
                                Slots = new[] {
                                    new DropSlotsTransform.Range()
                                    {
                                        Min = _numTopClusters
                                    }
                                }
                            }
                        }
                    };
                    perInst = new DropSlotsTransform(Host, args, perInst);
                }
            }

            if (perInst.Schema.TryGetColumnIndex(ClusteringPerInstanceEvaluator.SortedClusterScores, out index))
            {
                var type = perInst.Schema.GetColumnType(index);
                if (_numTopClusters < type.VectorSize)
                {
                    var args = new DropSlotsTransform.Arguments
                    {
                        Column = new DropSlotsTransform.Column[]
                        {
                            new DropSlotsTransform.Column()
                            {
                                Name  = ClusteringPerInstanceEvaluator.SortedClusterScores,
                                Slots = new[] {
                                    new DropSlotsTransform.Range()
                                    {
                                        Min = _numTopClusters
                                    }
                                }
                            }
                        }
                    };
                    perInst = new DropSlotsTransform(Host, args, perInst);
                }
            }
            return(perInst);
        }
        /// <summary>
        /// Create method corresponding to SignatureDataTransform.
        /// </summary>
        public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(RegistrationName);

            host.CheckValue(args, nameof(args));
            host.CheckValue(input, nameof(input));
            host.CheckUserArg(Utils.Size(args.Column) > 0, nameof(args.Column));
            host.CheckUserArg(args.Count > 0, nameof(args.Count));

            int[] colSizes;
            var   scores = CountFeatureSelectionUtils.Train(host, input, args.Column, out colSizes);
            var   size   = args.Column.Length;

            using (var ch = host.Start("Dropping Slots"))
            {
                int[] selectedCount;
                var   columns = CreateDropSlotsColumns(args, size, scores, out selectedCount);

                if (columns.Count <= 0)
                {
                    ch.Info("No features are being dropped.");
                    return(NopTransform.CreateIfNeeded(host, input));
                }

                for (int i = 0; i < selectedCount.Length; i++)
                {
                    ch.Info(MessageSensitivity.Schema, "Selected {0} slots out of {1} in column '{2}'", selectedCount[i], colSizes[i], args.Column[i]);
                }
                ch.Info("Total number of slots selected: {0}", selectedCount.Sum());

                var dsArgs = new DropSlotsTransform.Arguments();
                dsArgs.Column = columns.ToArray();
                ch.Done();
                return(new DropSlotsTransform(host, dsArgs, input));
            }
        }