public void Train(IEnumerable <IDocument> documents, int threads, IEnumerable <string> stopwords = null) { var stopWords = new HashSet <uint>((stopwords ?? StopWords.Snowball.For(Language)).Select(s => Hash(s.AsSpan()))); if (Data.NumberOfTopics <= 1) { throw new Exception($"Invalid number of topics ({nameof(Data)}.{nameof(Data.NumberOfTopics)}), must be > 1"); } var state = new LdaState(Data, threads); var(count, corpusSize) = InitializeVocabulary(documents, stopWords); if (count == 0 || corpusSize == 0) { throw new Exception("Empty corpus, nothing to train LDA model"); } var vocabulary = new ConcurrentDictionary <int, string>(); state.AllocateDataMemory(count, corpusSize); foreach (var doc in documents) { GetTokensAndFrequencies(doc, vocabulary, stopWords, out var tokenCount, out var tokenIndices, out var tokenFrequencies); if (tokenCount >= Data.MinimumTokenCountPerDocument) { var docIndex = state.FeedTrain(Data, tokenIndices, tokenCount, tokenFrequencies); } ArrayPool <int> .Shared.Return(tokenIndices); ArrayPool <double> .Shared.Return(tokenFrequencies); } state.CompleteTrain(); state.ReadModelFromTrainedLDA(Data); Data.Vocabulary = vocabulary; Data.StopWords = stopWords; State = state; }
private void Train(IChannel ch, IDataView trainingData, LdaState[] states) { Host.AssertValue(ch); ch.AssertValue(trainingData); ch.AssertValue(states); ch.Assert(states.Length == Infos.Length); bool[] activeColumns = new bool[trainingData.Schema.ColumnCount]; int[] numVocabs = new int[Infos.Length]; for (int i = 0; i < Infos.Length; i++) { activeColumns[Infos[i].Source] = true; numVocabs[i] = 0; } //the current lda needs the memory allocation before feedin data, so needs two sweeping of the data, //one for the pre-calc memory, one for feedin data really //another solution can be prepare these two value externally and put them in the beginning of the input file. long[] corpusSize = new long[Infos.Length]; int[] numDocArray = new int[Infos.Length]; using (var cursor = trainingData.GetRowCursor(col => activeColumns[col])) { var getters = new ValueGetter <VBuffer <Double> > [Utils.Size(Infos)]; for (int i = 0; i < Infos.Length; i++) { corpusSize[i] = 0; numDocArray[i] = 0; getters[i] = RowCursorUtils.GetVecGetterAs <Double>(NumberType.R8, cursor, Infos[i].Source); } VBuffer <Double> src = default(VBuffer <Double>); long rowCount = 0; while (cursor.MoveNext()) { ++rowCount; for (int i = 0; i < Infos.Length; i++) { int docSize = 0; getters[i](ref src); // compute term, doc instance#. for (int termID = 0; termID < src.Count; termID++) { int termFreq = GetFrequency(src.Values[termID]); if (termFreq < 0) { // Ignore this row. docSize = 0; break; } if (docSize >= _exes[i].NumMaxDocToken - termFreq) { break; //control the document length } //if legal then add the term docSize += termFreq; } // Ignore empty doc if (docSize == 0) { continue; } numDocArray[i]++; corpusSize[i] += docSize * 2 + 1; // in the beggining of each doc, there is a cursor variable // increase numVocab if needed. if (numVocabs[i] < src.Length) { numVocabs[i] = src.Length; } } } for (int i = 0; i < Infos.Length; ++i) { if (numDocArray[i] != rowCount) { ch.Assert(numDocArray[i] < rowCount); ch.Warning($"Column '{Infos[i].Name}' has skipped {rowCount - numDocArray[i]} of {rowCount} rows either empty or with negative, non-finite, or fractional values."); } } } // Initialize all LDA states for (int i = 0; i < Infos.Length; i++) { var state = new LdaState(Host, _exes[i], numVocabs[i]); if (numDocArray[i] == 0 || corpusSize[i] == 0) { throw ch.Except("The specified documents are all empty in column '{0}'.", Infos[i].Name); } state.AllocateDataMemory(numDocArray[i], corpusSize[i]); states[i] = state; } using (var cursor = trainingData.GetRowCursor(col => activeColumns[col])) { int[] docSizeCheck = new int[Infos.Length]; // This could be optimized so that if multiple trainers consume the same column, it is // fed into the train method once. var getters = new ValueGetter <VBuffer <Double> > [Utils.Size(Infos)]; for (int i = 0; i < Infos.Length; i++) { docSizeCheck[i] = 0; getters[i] = RowCursorUtils.GetVecGetterAs <Double>(NumberType.R8, cursor, Infos[i].Source); } VBuffer <Double> src = default(VBuffer <Double>); while (cursor.MoveNext()) { for (int i = 0; i < Infos.Length; i++) { getters[i](ref src); docSizeCheck[i] += states[i].FeedTrain(Host, ref src); } } for (int i = 0; i < Infos.Length; i++) { Host.Assert(corpusSize[i] == docSizeCheck[i]); states[i].CompleteTrain(); } } }