Пример #1
0
        /// <summary>
        /// Determine the default scorer for a schema bound mapper. This looks for text-valued ScoreColumnKind
        /// metadata on the first column of the mapper. If that text is found and maps to a scorer loadable class,
        /// that component is used. Otherwise, the GenericScorer is used.
        /// </summary>
        /// <param name="mapper">The schema bound mapper to get the default scorer.</param>.
        /// <param name="suffix">An optional suffix to append to the default column names.</param>
        public static TScorerFactory GetScorerComponent(
            ISchemaBoundMapper mapper,
            string suffix = null)
        {
            Contracts.AssertValue(mapper);

            ComponentCatalog.LoadableClassInfo info = null;
            DvText scoreKind = default;

            if (mapper.OutputSchema.ColumnCount > 0 &&
                mapper.OutputSchema.TryGetMetadata(TextType.Instance, MetadataUtils.Kinds.ScoreColumnKind, 0, ref scoreKind) &&
                scoreKind.HasChars)
            {
                var loadName = scoreKind.ToString();
                info = ComponentCatalog.GetLoadableClassInfo <SignatureDataScorer>(loadName);
                if (info == null || !typeof(IDataScorerTransform).IsAssignableFrom(info.Type))
                {
                    info = null;
                }
            }

            Func <IHostEnvironment, IDataView, ISchemaBoundMapper, RoleMappedSchema, IDataScorerTransform> factoryFunc;

            if (info == null)
            {
                factoryFunc = (env, data, innerMapper, trainSchema) =>
                              new GenericScorer(
                    env,
                    new GenericScorer.Arguments()
                {
                    Suffix = suffix
                },
                    data,
                    innerMapper,
                    trainSchema);
            }
            else
            {
                factoryFunc = (env, data, innerMapper, trainSchema) =>
                {
                    object args = info.CreateArguments();
                    if (args is ScorerArgumentsBase scorerArgs)
                    {
                        scorerArgs.Suffix = suffix;
                    }
                    return((IDataScorerTransform)info.CreateInstance(
                               env,
                               args,
                               new object[] { data, innerMapper, trainSchema }));
                };
            }

            return(ComponentFactoryUtils.CreateFromFunction(factoryFunc));
        }
            protected override IEnumerable <SuggestedRecipe> ApplyCore(Type predictorType,
                                                                       TransformInference.SuggestedTransform[] transforms)
            {
                SuggestedRecipe.SuggestedLearner learner = new SuggestedRecipe.SuggestedLearner();
                learner.LoadableClassInfo =
                    ComponentCatalog.GetLoadableClassInfo <SignatureTrainer>(Learners.MultiClassNaiveBayesTrainer.LoadName);
                learner.Settings = "";
                var epInput = new Legacy.Trainers.NaiveBayesClassifier();

                learner.PipelineNode = new TrainerPipelineNode(epInput);
                yield return(new SuggestedRecipe(ToString(), transforms, new[] { learner }));
            }
            protected override IEnumerable <SuggestedRecipe> ApplyCore(Type predictorType,
                                                                       TransformInference.SuggestedTransform[] transforms)
            {
                SuggestedRecipe.SuggestedLearner learner = new SuggestedRecipe.SuggestedLearner();
                if (predictorType == typeof(SignatureMultiClassClassifierTrainer))
                {
                    learner.LoadableClassInfo =
                        ComponentCatalog.GetLoadableClassInfo <SignatureTrainer>(Learners.SdcaMultiClassTrainer.LoadNameValue);
                }
                else
                {
                    learner.LoadableClassInfo =
                        ComponentCatalog.GetLoadableClassInfo <SignatureTrainer>(Learners.LinearClassificationTrainer.LoadNameValue);
                    var epInput = new Legacy.Trainers.StochasticDualCoordinateAscentBinaryClassifier();
                    learner.PipelineNode = new TrainerPipelineNode(epInput);
                }

                learner.Settings = "";
                yield return(new SuggestedRecipe(ToString(), transforms, new[] { learner }));
            }
Пример #4
0
            protected override IEnumerable <SuggestedRecipe> ApplyCore(Type predictorType,
                                                                       TransformInference.SuggestedTransform[] transforms)
            {
                SuggestedRecipe.SuggestedLearner learner = new SuggestedRecipe.SuggestedLearner();
                if (predictorType == typeof(SignatureMultiClassClassifierTrainer))
                {
                    learner.LoadableClassInfo = ComponentCatalog.GetLoadableClassInfo <SignatureTrainer>("OVA");
                    learner.Settings          = "p=FastTreeBinaryClassification";
                }
                else
                {
                    learner.LoadableClassInfo =
                        ComponentCatalog.GetLoadableClassInfo <SignatureTrainer>(FastTreeBinaryClassificationTrainer.LoadNameValue);
                    learner.Settings = "";
                    var epInput = new Trainers.FastTreeBinaryClassifier();
                    learner.PipelineNode = new TrainerPipelineNode(epInput);
                }

                yield return(new SuggestedRecipe(ToString(), transforms, new[] { learner }));
            }
Пример #5
0
        public static ComponentCatalog.LoadableClassInfo CheckTrainer <TSig>(IExceptionContext ectx, SubComponent <ITrainer, TSig> trainer, string dataFile)
        {
            Contracts.CheckValueOrNull(ectx);
            ectx.CheckUserArg(trainer.IsGood(), nameof(TrainCommand.Arguments.Trainer), "A trainer is required.");

            var info = ComponentCatalog.GetLoadableClassInfo <TSig>(trainer.Kind);

            if (info == null)
            {
                throw ectx.ExceptUserArg(nameof(TrainCommand.Arguments.Trainer), "Unknown trainer: '{0}'", trainer.Kind);
            }
            if (!typeof(ITrainer).IsAssignableFrom(info.Type))
            {
                throw ectx.Except("Loadable class '{0}' does not implement 'ITrainer'", info.LoadNames[0]);
            }
            if (string.IsNullOrWhiteSpace(dataFile))
            {
                throw ectx.ExceptUserArg(nameof(TrainCommand.Arguments.DataFile), "Data file must be defined.");
            }
            return(info);
        }
            protected override IEnumerable <SuggestedRecipe> ApplyCore(Type predictorType,
                                                                       TransformInference.SuggestedTransform[] transforms)
            {
                SuggestedRecipe.SuggestedLearner learner = new SuggestedRecipe.SuggestedLearner();
                if (predictorType == typeof(SignatureMultiClassClassifierTrainer))
                {
                    learner.LoadableClassInfo = ComponentCatalog.GetLoadableClassInfo <SignatureTrainer>("OVA");
                    learner.Settings          = "p=AveragedPerceptron{iter=10}";
                }
                else
                {
                    learner.LoadableClassInfo = ComponentCatalog.GetLoadableClassInfo <SignatureTrainer>(Learners.AveragedPerceptronTrainer.LoadNameValue);
                    learner.Settings          = "iter=10";
                    var epInput = new Legacy.Trainers.AveragedPerceptronBinaryClassifier
                    {
                        NumIterations = 10
                    };
                    learner.PipelineNode = new TrainerPipelineNode(epInput);
                }

                yield return
                    (new SuggestedRecipe(ToString(), transforms, new[] { learner }, Int32.MaxValue));
            }
Пример #7
0
        /// <summary>
        /// Determine the default scorer for a schema bound mapper. This looks for text-valued ScoreColumnKind
        /// metadata on the first column of the mapper. If that text is found and maps to a scorer loadable class,
        /// that component is used. Otherwise, the GenericScorer is used.
        /// </summary>
        public static SubComponent <IDataScorerTransform, SignatureDataScorer> GetScorerComponent(ISchemaBoundMapper mapper)
        {
            Contracts.AssertValue(mapper);

            string loadName  = null;
            DvText scoreKind = default;

            if (mapper.OutputSchema.ColumnCount > 0 &&
                mapper.OutputSchema.TryGetMetadata(TextType.Instance, MetadataUtils.Kinds.ScoreColumnKind, 0, ref scoreKind) &&
                scoreKind.HasChars)
            {
                loadName = scoreKind.ToString();
                var info = ComponentCatalog.GetLoadableClassInfo <SignatureDataScorer>(loadName);
                if (info == null || !typeof(IDataScorerTransform).IsAssignableFrom(info.Type))
                {
                    loadName = null;
                }
            }
            if (loadName == null)
            {
                loadName = GenericScorer.LoadName;
            }
            return(new SubComponent <IDataScorerTransform, SignatureDataScorer>(loadName));
        }
Пример #8
0
        public GenerateSweepCandidatesCommand(IHostEnvironment env, Arguments args)
        {
            Contracts.CheckValue(env, nameof(env));
            _host = env.Register("GenerateCandidates");
            _host.CheckValue(args, nameof(args));

            var files = new MultiFileSource(args.DataFile);

            _host.CheckUserArg(files.Count > 0, nameof(args.DataFile), "dataFile is required");
            _dataFile = args.DataFile;

            _rspsOutFolder = Utils.CreateFolderIfNotExists(args.RspOutFolder);
            _host.CheckUserArg(_rspsOutFolder != null, nameof(args.RspOutFolder), "Provide a value rspOutFolder (or 'out', the short name).");

            if (!string.IsNullOrWhiteSpace(args.SchemaDefinitionFile))
            {
                Utils.CheckOptionalUserDirectory(args.SchemaDefinitionFile, nameof(args.SchemaDefinitionFile));
                _schemaDefinitionFile = args.SchemaDefinitionFile;
            }

            if (!string.IsNullOrWhiteSpace(args.Sweeper))
            {
                var info = ComponentCatalog.GetLoadableClassInfo <SignatureSweeper>(args.Sweeper);
                _host.CheckUserArg(info?.SignatureTypes[0] == typeof(SignatureSweeper), nameof(args.Sweeper),
                                   "Please specify a valid sweeper.");
                _sweeper = args.Sweeper;
            }
            else
            {
                _sweeper = "kdo";
            }

            if (!string.IsNullOrWhiteSpace(args.Mode))
            {
                var info = ComponentCatalog.GetLoadableClassInfo <SignatureCommand>(args.Mode);
                _host.CheckUserArg(info?.Type == typeof(TrainCommand) ||
                                   info?.Type == typeof(TrainTestCommand) ||
                                   info?.Type == typeof(CrossValidationCommand), nameof(args.Mode), "Invalid mode.");
                _mode = args.Mode;
            }
            else
            {
                _mode = CrossValidationCommand.LoadName;
            }

            _indented = args.Indent;

            if (!string.IsNullOrWhiteSpace(args.TestFile))
            {
                files = new MultiFileSource(args.TestFile);
                _host.CheckUserArg(files.Count > 0, nameof(args.TestFile), "testFile needs to be a valid file, if provided.");
                _testFile = args.TestFile;
            }
            else
            {
                _host.CheckUserArg(_mode != TrainTestCommand.LoadName, nameof(args.TestFile), "testFile needs to be a valid file, for mode = TrainTest.");
            }

            _outputDataFolder = Utils.CreateFolderIfNotExists(args.OutputDataFolder);
            if (_outputDataFolder == null)
            {
                _outputDataFolder = _rspsOutFolder;
            }
        }
Пример #9
0
        private static Float[] Train(IHost host, ColInfo[] infos, Arguments args, IDataView trainingData)
        {
            Contracts.AssertValue(host, "host");
            host.AssertNonEmpty(infos);

            var       avgDistances  = new Float[infos.Length];
            const int reservoirSize = 5000;

            bool[] activeColumns = new bool[trainingData.Schema.ColumnCount];
            for (int i = 0; i < infos.Length; i++)
            {
                activeColumns[infos[i].Source] = true;
            }

            var reservoirSamplers = new ReservoirSamplerWithReplacement <VBuffer <Float> > [infos.Length];

            using (var cursor = trainingData.GetRowCursor(col => activeColumns[col]))
            {
                var rng = args.Seed.HasValue ? RandomUtils.Create(args.Seed) : host.Rand;
                for (int i = 0; i < infos.Length; i++)
                {
                    if (infos[i].TypeSrc.IsVector)
                    {
                        var get = cursor.GetGetter <VBuffer <Float> >(infos[i].Source);
                        reservoirSamplers[i] = new ReservoirSamplerWithReplacement <VBuffer <Float> >(rng, reservoirSize, get);
                    }
                    else
                    {
                        var   getOne = cursor.GetGetter <Float>(infos[i].Source);
                        Float val    = 0;
                        ValueGetter <VBuffer <Float> > get =
                            (ref VBuffer <Float> dst) =>
                        {
                            getOne(ref val);
                            dst = new VBuffer <float>(1, new[] { val });
                        };
                        reservoirSamplers[i] = new ReservoirSamplerWithReplacement <VBuffer <Float> >(rng, reservoirSize, get);
                    }
                }

                while (cursor.MoveNext())
                {
                    for (int i = 0; i < infos.Length; i++)
                    {
                        reservoirSamplers[i].Sample();
                    }
                }
                for (int i = 0; i < infos.Length; i++)
                {
                    reservoirSamplers[i].Lock();
                }
            }

            for (int iinfo = 0; iinfo < infos.Length; iinfo++)
            {
                var instanceCount = reservoirSamplers[iinfo].NumSampled;

                // If the number of pairs is at most the maximum reservoir size / 2, we go over all the pairs,
                // so we get all the examples. Otherwise, get a sample with replacement.
                VBuffer <Float>[] res;
                int resLength;
                if (instanceCount < reservoirSize && instanceCount * (instanceCount - 1) <= reservoirSize)
                {
                    res       = reservoirSamplers[iinfo].GetCache();
                    resLength = reservoirSamplers[iinfo].Size;
                    Contracts.Assert(resLength == instanceCount);
                }
                else
                {
                    res       = reservoirSamplers[iinfo].GetSample().ToArray();
                    resLength = res.Length;
                }

                // If the dataset contains only one valid Instance, then we can't learn anything anyway, so just return 1.
                if (instanceCount <= 1)
                {
                    avgDistances[iinfo] = 1;
                }
                else
                {
                    Float[] distances;

                    var sub = args.Column[iinfo].MatrixGenerator;
                    if (!sub.IsGood())
                    {
                        sub = args.MatrixGenerator;
                    }
                    var  info     = ComponentCatalog.GetLoadableClassInfo(sub);
                    bool gaussian = info != null && info.Type == typeof(GaussianFourierSampler);

                    // If the number of pairs is at most the maximum reservoir size / 2, go over all the pairs.
                    if (resLength < reservoirSize)
                    {
                        distances = new Float[instanceCount * (instanceCount - 1) / 2];
                        int count = 0;
                        for (int i = 0; i < instanceCount; i++)
                        {
                            for (int j = i + 1; j < instanceCount; j++)
                            {
                                distances[count++] = gaussian ? VectorUtils.L2DistSquared(ref res[i], ref res[j])
                                    : VectorUtils.L1Distance(ref res[i], ref res[j]);
                            }
                        }
                        host.Assert(count == distances.Length);
                    }
                    else
                    {
                        distances = new Float[reservoirSize / 2];
                        for (int i = 0; i < reservoirSize - 1; i += 2)
                        {
                            // For Gaussian kernels, we scale by the L2 distance squared, since the kernel function is exp(-gamma ||x-y||^2).
                            // For Laplacian kernels, we scale by the L1 distance, since the kernel function is exp(-gamma ||x-y||_1).
                            distances[i / 2] = gaussian ? VectorUtils.L2DistSquared(ref res[i], ref res[i + 1]) :
                                               VectorUtils.L1Distance(ref res[i], ref res[i + 1]);
                        }
                    }

                    // If by chance, in the random permutation all the pairs are the same instance we return 1.
                    Float median = MathUtils.GetMedianInPlace(distances, distances.Length);
                    avgDistances[iinfo] = median == 0 ? 1 : median;
                }
            }
            return(avgDistances);
        }