internal LDATopic[] Predict(LDAModel model, ReadOnlySpan <int> tokenIndices, int tokenCount, ReadOnlySpan <double> frequency, bool reset) { // Prediction for a single document. // LdaSingleBox.InitializeBeforeTest() is NOT thread-safe. if (!_predictionPreparationDone) { lock (_preparationSyncRoot) { if (!_predictionPreparationDone) { //do some preparation for building tables in native c++ _ldaTrainer.InitializeBeforeTest(); _predictionPreparationDone = true; } } } if (tokenCount == 0) { return(Array.Empty <LDATopic>()); } var retTopics = _ldaTrainer.TestDoc(tokenIndices, frequency, tokenCount, model.NumberOfBurninIterations, reset); var normFactor = 1f / retTopics.Sum(kv => kv.Value); return(retTopics.OrderByDescending(t => t.Value).Select(kv => new LDATopic(kv.Key, kv.Value * normFactor)).ToArray()); }
internal void InitializePretrained(LDAModel model) { _ldaTrainer.AllocateModelMemory(model.VocabularyBuckets, model.NumberOfTopics, model.MemBlockSize, model.AliasMemBlockSize); Debug.Assert(model.VocabularyBuckets == model.LDA_Data.Length); for (int termID = 0; termID < model.VocabularyBuckets; termID++) { var kvs = model.LDA_Data[termID]; var topicId = kvs.Select(kv => kv.Key).ToArray(); var topicProb = kvs.Select(kv => kv.Value).ToArray(); var termTopicNum = topicId.Length; _ldaTrainer.SetModel(termID, topicId, topicProb, termTopicNum); } //do the preparation if (!_predictionPreparationDone) { lock (_preparationSyncRoot) { _ldaTrainer.InitializeBeforeTest(); _predictionPreparationDone = true; } } }
internal int FeedTrain(LDAModel model, ReadOnlySpan <int> tokenIndices, int tokenCount, ReadOnlySpan <double> frequency) { if (tokenCount < model.MinimumTokenCountPerDocument) { return(0); } return(_ldaTrainer.LoadDoc(tokenIndices, frequency, tokenCount, model.VocabularyBuckets)); }
internal void ReadModelFromTrainedLDA(LDAModel model) { _ldaTrainer.GetModelStat(out var memBlockSize, out var aliasMemBlockSize); model.MemBlockSize = memBlockSize; model.AliasMemBlockSize = aliasMemBlockSize; Debug.Assert(_ldaTrainer.NumVocab == model.VocabularyBuckets); model.LDA_Data = Enumerable.Range(0, _ldaTrainer.NumVocab) .Select(i => _ldaTrainer.GetModel(i)) .ToArray(); }
internal LdaState(LDAModel model, int numberOfThreads) : this() { _ldaTrainer = new LdaSingleBox( numTopic: model.NumberOfTopics, numVocab: model.VocabularyBuckets, alpha: model.AlphaSum, beta: model.Beta, numIter: model.MaximumNumberOfIterations, likelihoodInterval: model.LikelihoodInterval, numThread: numberOfThreads, mhstep: model.SamplingStepCount, numSummaryTerms: model.NumberOfSummaryTermsPerTopic, denseOutput: false, maxDocToken: model.MaximumTokenCountPerDocument); }