Example #1
0
        public void TestLdaTransformEmptyDocumentException()
        {
            var builder = new ArrayDataViewBuilder(Env);
            var data    = new[]
            {
                new[] { (Float)0.0, (Float)0.0, (Float)0.0 },
                new[] { (Float)0.0, (Float)0.0, (Float)0.0 },
                new[] { (Float)0.0, (Float)0.0, (Float)0.0 },
            };

            builder.AddColumn("Zeros", NumberType.Float, data);

            var srcView = builder.GetDataView();
            var col     = new LdaTransform.Column()
            {
                Source = "Zeros"
            };
            var args = new LdaTransform.Arguments()
            {
                Column = new[] { col }
            };

            try
            {
                var lda = new LdaTransform(Env, args, srcView);
            }
            catch (InvalidOperationException ex)
            {
                Assert.Equal(ex.Message, string.Format("The specified documents are all empty in column '{0}'.", col.Source));
                return;
            }

            Assert.True(false, "The LDA transform does not throw expected error on empty documents.");
        }
Example #2
0
        public void TestLDATransform()
        {
            var builder = new ArrayDataViewBuilder(Env);
            var data    = new[]
            {
                new[] { (Float)1.0, (Float)0.0, (Float)0.0 },
                new[] { (Float)0.0, (Float)1.0, (Float)0.0 },
                new[] { (Float)0.0, (Float)0.0, (Float)1.0 },
            };

            builder.AddColumn("F1V", NumberType.Float, data);

            var srcView = builder.GetDataView();

            LdaTransform.Column col = new LdaTransform.Column();
            col.Source   = "F1V";
            col.NumTopic = 20;
            col.NumTopic = 3;
            col.NumSummaryTermPerTopic = 3;
            col.AlphaSum             = 3;
            col.NumThreads           = 1;
            col.ResetRandomGenerator = true;
            LdaTransform.Arguments args = new LdaTransform.Arguments();
            args.Column = new LdaTransform.Column[] { col };

            LdaTransform ldaTransform = new LdaTransform(Env, args, srcView);

            using (var cursor = ldaTransform.GetRowCursor(c => true))
            {
                var             resultGetter    = cursor.GetGetter <VBuffer <Float> >(1);
                VBuffer <Float> resultFirstRow  = new VBuffer <Float>();
                VBuffer <Float> resultSecondRow = new VBuffer <Float>();
                VBuffer <Float> resultThirdRow  = new VBuffer <Float>();

                Assert.True(cursor.MoveNext());
                resultGetter(ref resultFirstRow);
                Assert.True(cursor.MoveNext());
                resultGetter(ref resultSecondRow);
                Assert.True(cursor.MoveNext());
                resultGetter(ref resultThirdRow);
                Assert.False(cursor.MoveNext());

                Assert.True(resultFirstRow.Length == 3);
                Assert.True(resultFirstRow.GetItemOrDefault(0) == 0);
                Assert.True(resultFirstRow.GetItemOrDefault(2) == 0);
                Assert.True(resultFirstRow.GetItemOrDefault(1) == 1.0);
                Assert.True(resultSecondRow.Length == 3);
                Assert.True(resultSecondRow.GetItemOrDefault(0) == 0);
                Assert.True(resultSecondRow.GetItemOrDefault(2) == 0);
                Assert.True(resultSecondRow.GetItemOrDefault(1) == 1.0);
                Assert.True(resultThirdRow.Length == 3);
                Assert.True(resultThirdRow.GetItemOrDefault(0) == 0);
                Assert.True(resultThirdRow.GetItemOrDefault(1) == 0);
                Assert.True(resultThirdRow.GetItemOrDefault(2) == 1.0);
            }
        }
Example #3
0
        public static CommonOutputs.TransformOutput LightLda(IHostEnvironment env, LdaTransform.Arguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(input, nameof(input));

            var h    = EntryPointUtils.CheckArgsAndCreateHost(env, "LightLda", input);
            var view = new LdaTransform(h, input, input.Data);

            return(new CommonOutputs.TransformOutput()
            {
                Model = new TransformModel(h, view, input.Data),
                OutputData = view
            });
        }
Example #4
0
        public void IntrospectiveTraining()
        {
            using (var env = new LocalEnvironment(seed: 1, conc: 1))
            {
                // Pipeline
                var loader = TextLoader.ReadFile(env, MakeSentimentTextLoaderArgs(), new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename)));

                var words = WordBagTransform.Create(env, new WordBagTransform.Arguments()
                {
                    NgramLength = 1,
                    Column      = new[] { new WordBagTransform.Column()
                                          {
                                              Name = "Tokenize", Source = new[] { "SentimentText" }
                                          } }
                }, loader);

                var lda = new LdaTransform(env, new LdaTransform.Arguments()
                {
                    NumTopic      = 10,
                    NumIterations = 3,
                    NumThreads    = 1,
                    Column        = new[] { new LdaTransform.Column {
                                                Source = "Tokenize", Name = "Features"
                                            } }
                }, words);
                var trainData = lda;

                var cachedTrain = new CacheDataView(env, trainData, prefetch: null);
                // Train the first predictor.
                var linearTrainer = new LinearClassificationTrainer(env, new LinearClassificationTrainer.Arguments
                {
                    NumThreads = 1
                });
                var             trainRoles      = new RoleMappedData(cachedTrain, label: "Label", feature: "Features");
                var             linearPredictor = linearTrainer.Train(new Runtime.TrainContext(trainRoles));
                VBuffer <float> weights         = default;
                linearPredictor.GetFeatureWeights(ref weights);

                var topicSummary = lda.GetTopicSummary();
                var treeTrainer  = new FastTreeBinaryClassificationTrainer(env, DefaultColumnNames.Label, DefaultColumnNames.Features, numTrees: 2);
                var ftPredictor  = treeTrainer.Train(new Runtime.TrainContext(trainRoles));
                FastTreeBinaryPredictor treePredictor;
                if (ftPredictor is CalibratedPredictorBase calibrator)
                {
                    treePredictor = (FastTreeBinaryPredictor)calibrator.SubPredictor;
                }
                else
                {
                    treePredictor = (FastTreeBinaryPredictor)ftPredictor;
                }
                var featureNameCollection = FeatureNameCollection.Create(trainRoles.Schema);
                foreach (var tree in treePredictor.GetTrees())
                {
                    var lteChild = tree.LteChild;
                    var gteChild = tree.GtChild;
                    // Get nodes.
                    for (var i = 0; i < tree.NumNodes; i++)
                    {
                        var node              = tree.GetNode(i, false, featureNameCollection);
                        var gainValue         = GetValue <double>(node.KeyValues, "GainValue");
                        var splitGain         = GetValue <double>(node.KeyValues, "SplitGain");
                        var featureName       = GetValue <string>(node.KeyValues, "SplitName");
                        var previousLeafValue = GetValue <double>(node.KeyValues, "PreviousLeafValue");
                        var threshold         = GetValue <string>(node.KeyValues, "Threshold").Split(new[] { ' ' }, 2)[1];
                        var nodeIndex         = i;
                    }
                    // Get leaves.
                    for (var i = 0; i < tree.NumLeaves; i++)
                    {
                        var node      = tree.GetNode(i, true, featureNameCollection);
                        var leafValue = GetValue <double>(node.KeyValues, "LeafValue");
                        var extras    = GetValue <string>(node.KeyValues, "Extras");
                        var nodeIndex = ~i;
                    }
                }
            }
        }