Пример #1
0
        /// <summary>
        /// REVIEW: consider adding an overload that returns <see cref="VBuffer{DvText}"/>
        /// Loads optionally feature names from the repository directory.
        /// Returns false iff no stream was found for feature names, iff result is set to null.
        /// </summary>
        public static bool TryLoadFeatureNames(out FeatureNameCollection featureNames, RepositoryReader rep)
        {
            Contracts.CheckValue(rep, nameof(rep));

            using (var ent = rep.OpenEntryOrNull(ModelFileUtils.DirTrainingInfo, "FeatureNames.bin"))
            {
                if (ent != null)
                {
                    using (var ctx = new ModelLoadContext(rep, ent, ModelFileUtils.DirTrainingInfo))
                    {
                        featureNames = FeatureNameCollection.Create(ctx);
                        return(true);
                    }
                }
            }

            featureNames = null;
            return(false);
        }
Пример #2
0
        public void IntrospectiveTraining()
        {
            using (var env = new LocalEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename)));

                var words = WordBagTransform.Create(env, new WordBagTransform.Arguments()
                {
                    NgramLength = 1,
                    Column      = new[] { new WordBagTransform.Column()
                                          {
                                              Name = "Tokenize", Source = new[] { "SentimentText" }
                                          } }
                }, loader);

                var lda = new LdaTransform(env, new LdaTransform.Arguments()
                {
                    NumTopic      = 10,
                    NumIterations = 3,
                    NumThreads    = 1,
                    Column        = new[] { new LdaTransform.Column {
                                                Source = "Tokenize", Name = "Features"
                                            } }
                }, words);
                var trainData = lda;

                var cachedTrain = new CacheDataView(env, trainData, prefetch: null);
                // Train the first predictor.
                var linearTrainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads = 1
                });
                var             trainRoles      = new RoleMappedData(cachedTrain, label: "Label", feature: "Features");
                var             linearPredictor = linearTrainer.Train(new Runtime.TrainContext(trainRoles));
                VBuffer <float> weights         = default;
                linearPredictor.GetFeatureWeights(ref weights);

                var topicSummary = lda.GetTopicSummary();
                var treeTrainer  = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, numTrees: 2);
                var ftPredictor  = treeTrainer.Train(new Runtime.TrainContext(trainRoles));
                FastTreeBinaryPredictor treePredictor;
                if (ftPredictor is CalibratedPredictorBase calibrator)
                {
                    treePredictor = (FastTreeBinaryPredictor)calibrator.SubPredictor;
                }
                else
                {
                    treePredictor = (FastTreeBinaryPredictor)ftPredictor;
                }
                var featureNameCollection = FeatureNameCollection.Create(trainRoles.Schema);
                foreach (var tree in treePredictor.GetTrees())
                {
                    var lteChild = tree.LteChild;
                    var gteChild = tree.GtChild;
                    // Get nodes.
                    for (var i = 0; i < tree.NumNodes; i++)
                    {
                        var node              = tree.GetNode(i, false, featureNameCollection);
                        var gainValue         = GetValue <double>(node.KeyValues, "GainValue");
                        var splitGain         = GetValue <double>(node.KeyValues, "SplitGain");
                        var featureName       = GetValue <string>(node.KeyValues, "SplitName");
                        var previousLeafValue = GetValue <double>(node.KeyValues, "PreviousLeafValue");
                        var threshold         = GetValue <string>(node.KeyValues, "Threshold").Split(new[] { ' ' }, 2)[1];
                        var nodeIndex         = i;
                    }
                    // Get leaves.
                    for (var i = 0; i < tree.NumLeaves; i++)
                    {
                        var node      = tree.GetNode(i, true, featureNameCollection);
                        var leafValue = GetValue <double>(node.KeyValues, "LeafValue");
                        var extras    = GetValue <string>(node.KeyValues, "Extras");
                        var nodeIndex = ~i;
                    }
                }
            }
        }