/// <summary> /// REVIEW: consider adding an overload that returns <see cref="VBuffer{DvText}"/> /// Loads optionally feature names from the repository directory. /// Returns false iff no stream was found for feature names, iff result is set to null. /// </summary> public static bool TryLoadFeatureNames(out FeatureNameCollection featureNames, RepositoryReader rep) { Contracts.CheckValue(rep, nameof(rep)); using (var ent = rep.OpenEntryOrNull(ModelFileUtils.DirTrainingInfo, "FeatureNames.bin")) { if (ent != null) { using (var ctx = new ModelLoadContext(rep, ent, ModelFileUtils.DirTrainingInfo)) { featureNames = FeatureNameCollection.Create(ctx); return(true); } } } featureNames = null; return(false); }
public void IntrospectiveTraining() { using (var env = new LocalEnvironment(seed: 1, conc: 1)) { // Pipeline var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); var words = WordBagTransform.Create(env, new WordBagTransform.Arguments() { NgramLength = 1, Column = new[] { new WordBagTransform.Column() { Name = "Tokenize", Source = new[] { "SentimentText" } } } }, loader); var lda = new LdaTransform(env, new LdaTransform.Arguments() { NumTopic = 10, NumIterations = 3, NumThreads = 1, Column = new[] { new LdaTransform.Column { Source = "Tokenize", Name = "Features" } } }, words); var trainData = lda; var cachedTrain = new CacheDataView(env, trainData, prefetch: null); // Train the first predictor. var linearTrainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }); var trainRoles = new RoleMappedData(cachedTrain, label: "Label", feature: "Features"); var linearPredictor = linearTrainer.Train(new Runtime.TrainContext(trainRoles)); VBuffer <float> weights = default; linearPredictor.GetFeatureWeights(ref weights); var topicSummary = lda.GetTopicSummary(); var treeTrainer = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, numTrees: 2); var ftPredictor = treeTrainer.Train(new Runtime.TrainContext(trainRoles)); FastTreeBinaryPredictor treePredictor; if (ftPredictor is CalibratedPredictorBase calibrator) { treePredictor = (FastTreeBinaryPredictor)calibrator.SubPredictor; } else { treePredictor = (FastTreeBinaryPredictor)ftPredictor; } var featureNameCollection = FeatureNameCollection.Create(trainRoles.Schema); foreach (var tree in treePredictor.GetTrees()) { var lteChild = tree.LteChild; var gteChild = tree.GtChild; // Get nodes. for (var i = 0; i < tree.NumNodes; i++) { var node = tree.GetNode(i, false, featureNameCollection); var gainValue = GetValue <double>(node.KeyValues, "GainValue"); var splitGain = GetValue <double>(node.KeyValues, "SplitGain"); var featureName = GetValue <string>(node.KeyValues, "SplitName"); var previousLeafValue = GetValue <double>(node.KeyValues, "PreviousLeafValue"); var threshold = GetValue <string>(node.KeyValues, "Threshold").Split(new[] { ' ' }, 2)[1]; var nodeIndex = i; } // Get leaves. for (var i = 0; i < tree.NumLeaves; i++) { var node = tree.GetNode(i, true, featureNameCollection); var leafValue = GetValue <double>(node.KeyValues, "LeafValue"); var extras = GetValue <string>(node.KeyValues, "Extras"); var nodeIndex = ~i; } } } }