public void TestLdaTransformEmptyDocumentException() { var builder = new ArrayDataViewBuilder(Env); var data = new[] { new[] { (Float)0.0, (Float)0.0, (Float)0.0 }, new[] { (Float)0.0, (Float)0.0, (Float)0.0 }, new[] { (Float)0.0, (Float)0.0, (Float)0.0 }, }; builder.AddColumn("Zeros", NumberType.Float, data); var srcView = builder.GetDataView(); var col = new LdaTransform.Column() { Source = "Zeros" }; var args = new LdaTransform.Arguments() { Column = new[] { col } }; try { var lda = new LdaTransform(Env, args, srcView); } catch (InvalidOperationException ex) { Assert.Equal(ex.Message, string.Format("The specified documents are all empty in column '{0}'.", col.Source)); return; } Assert.True(false, "The LDA transform does not throw expected error on empty documents."); }
public void TestLDATransform() { var builder = new ArrayDataViewBuilder(Env); var data = new[] { new[] { (Float)1.0, (Float)0.0, (Float)0.0 }, new[] { (Float)0.0, (Float)1.0, (Float)0.0 }, new[] { (Float)0.0, (Float)0.0, (Float)1.0 }, }; builder.AddColumn("F1V", NumberType.Float, data); var srcView = builder.GetDataView(); LdaTransform.Column col = new LdaTransform.Column(); col.Source = "F1V"; col.NumTopic = 20; col.NumTopic = 3; col.NumSummaryTermPerTopic = 3; col.AlphaSum = 3; col.NumThreads = 1; col.ResetRandomGenerator = true; LdaTransform.Arguments args = new LdaTransform.Arguments(); args.Column = new LdaTransform.Column[] { col }; LdaTransform ldaTransform = new LdaTransform(Env, args, srcView); using (var cursor = ldaTransform.GetRowCursor(c => true)) { var resultGetter = cursor.GetGetter <VBuffer <Float> >(1); VBuffer <Float> resultFirstRow = new VBuffer <Float>(); VBuffer <Float> resultSecondRow = new VBuffer <Float>(); VBuffer <Float> resultThirdRow = new VBuffer <Float>(); Assert.True(cursor.MoveNext()); resultGetter(ref resultFirstRow); Assert.True(cursor.MoveNext()); resultGetter(ref resultSecondRow); Assert.True(cursor.MoveNext()); resultGetter(ref resultThirdRow); Assert.False(cursor.MoveNext()); Assert.True(resultFirstRow.Length == 3); Assert.True(resultFirstRow.GetItemOrDefault(0) == 0); Assert.True(resultFirstRow.GetItemOrDefault(2) == 0); Assert.True(resultFirstRow.GetItemOrDefault(1) == 1.0); Assert.True(resultSecondRow.Length == 3); Assert.True(resultSecondRow.GetItemOrDefault(0) == 0); Assert.True(resultSecondRow.GetItemOrDefault(2) == 0); Assert.True(resultSecondRow.GetItemOrDefault(1) == 1.0); Assert.True(resultThirdRow.Length == 3); Assert.True(resultThirdRow.GetItemOrDefault(0) == 0); Assert.True(resultThirdRow.GetItemOrDefault(1) == 0); Assert.True(resultThirdRow.GetItemOrDefault(2) == 1.0); } }
public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LdaTransform.Arguments input) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(input, nameof(input)); var h = EntryPointUtils.CheckArgsAndCreateHost(env, "LightLda", input); var view = new LdaTransform(h, input, input.Data); return(new CommonOutputs.TransformOutput() { Model = new TransformModel(h, view, input.Data), OutputData = view }); }
public void IntrospectiveTraining() { using (var env = new LocalEnvironment(seed: 1, conc: 1)) { // Pipeline var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename))); var words = WordBagTransform.Create(env, new WordBagTransform.Arguments() { NgramLength = 1, Column = new[] { new WordBagTransform.Column() { Name = "Tokenize", Source = new[] { "SentimentText" } } } }, loader); var lda = new LdaTransform(env, new LdaTransform.Arguments() { NumTopic = 10, NumIterations = 3, NumThreads = 1, Column = new[] { new LdaTransform.Column { Source = "Tokenize", Name = "Features" } } }, words); var trainData = lda; var cachedTrain = new CacheDataView(env, trainData, prefetch: null); // Train the first predictor. var linearTrainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }); var trainRoles = new RoleMappedData(cachedTrain, label: "Label", feature: "Features"); var linearPredictor = linearTrainer.Train(new Runtime.TrainContext(trainRoles)); VBuffer <float> weights = default; linearPredictor.GetFeatureWeights(ref weights); var topicSummary = lda.GetTopicSummary(); var treeTrainer = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, numTrees: 2); var ftPredictor = treeTrainer.Train(new Runtime.TrainContext(trainRoles)); FastTreeBinaryPredictor treePredictor; if (ftPredictor is CalibratedPredictorBase calibrator) { treePredictor = (FastTreeBinaryPredictor)calibrator.SubPredictor; } else { treePredictor = (FastTreeBinaryPredictor)ftPredictor; } var featureNameCollection = FeatureNameCollection.Create(trainRoles.Schema); foreach (var tree in treePredictor.GetTrees()) { var lteChild = tree.LteChild; var gteChild = tree.GtChild; // Get nodes. for (var i = 0; i < tree.NumNodes; i++) { var node = tree.GetNode(i, false, featureNameCollection); var gainValue = GetValue <double>(node.KeyValues, "GainValue"); var splitGain = GetValue <double>(node.KeyValues, "SplitGain"); var featureName = GetValue <string>(node.KeyValues, "SplitName"); var previousLeafValue = GetValue <double>(node.KeyValues, "PreviousLeafValue"); var threshold = GetValue <string>(node.KeyValues, "Threshold").Split(new[] { ' ' }, 2)[1]; var nodeIndex = i; } // Get leaves. for (var i = 0; i < tree.NumLeaves; i++) { var node = tree.GetNode(i, true, featureNameCollection); var leafValue = GetValue <double>(node.KeyValues, "LeafValue"); var extras = GetValue <string>(node.KeyValues, "Extras"); var nodeIndex = ~i; } } } }