internal void Fit(IDataView input) { double convergence = 1, lHoodOld = 0; bool resetVar = true; var corpus = ExtractCorpus(input); var llnaModel = LLNAModel.Create(_nbTopics, corpus, _parameters); _model = llnaModel; var sufficientStatistic = new LLNASufficientStatistics(llnaModel); var corpusLambda = Matrix <double> .Build.Dense(corpus.NbDocuments, llnaModel.K, 0); var corpusNu = Matrix <double> .Build.Dense(corpus.NbDocuments, llnaModel.K, 0); var corpusPhiSum = Matrix <double> .Build.Dense(corpus.NbDocuments, llnaModel.K, 0); int iteration = 0; // Expectation - maximization algorithm (EM) algorithm. do { var result = Expectation(corpus, llnaModel, sufficientStatistic, corpusLambda, corpusNu, corpusPhiSum, resetVar); convergence = (lHoodOld - result.LHood) / lHoodOld; if (convergence < 0) { resetVar = false; if (_parameters.VarMaxIter > 0) { _parameters.VarMaxIter += 10; } else { _parameters.VarConvergence /= 10; } } else { llnaModel.Maximize(sufficientStatistic); lHoodOld = result.LHood; resetVar = true; iteration++; } sufficientStatistic.Reset(); }while ((iteration < _parameters.EmMaxIter) && ((convergence > _parameters.EmConvergence) || (convergence < 0))); }
private ExpectationResult Expectation( Corpus corpus, LLNAModel llnaModel, LLNASufficientStatistics llnaSufficientStatistics, Matrix <double> corpusLambda, Matrix <double> corpusNu, Matrix <double> corpusPhiSum, bool resetVar) { double avgNIter = 0, avgConverged = 0; double total = 0; for (int i = 0; i < corpus.NbDocuments; i++) { Debug.WriteLine($"Document {i}"); var doc = corpus.GetDocument(i); var varInference = new VariationalInferenceParameter(doc.NbTerms, llnaModel.K); if (resetVar) { varInference.Init(llnaModel); } else { varInference.Lambda = corpusLambda.Row(i); varInference.Nu = corpusNu.Row(i); varInference.OptimizeZeta(llnaModel); varInference.OptimizePhi(llnaModel, doc); varInference.NIter = 0; } var lHood = Inference(varInference, doc, llnaModel); llnaSufficientStatistics.Update(varInference, doc); total += lHood; avgNIter += varInference.NIter; if (varInference.Converged) { avgConverged++; } corpusLambda.SetRow(i, varInference.Lambda); corpusNu.SetRow(i, varInference.Nu); var phiSum = varInference.Phi.ColumnSum(); corpusPhiSum.SetRow(i, phiSum); } avgNIter = avgNIter / corpus.NbDocuments; avgConverged = avgConverged / corpus.NbDocuments; return(new ExpectationResult(avgNIter, avgConverged, total)); }