internal LDATopic[] Predict(LDAModel model, ReadOnlySpan <int> tokenIndices, int tokenCount, ReadOnlySpan <double> frequency, bool reset) { // Prediction for a single document. // LdaSingleBox.InitializeBeforeTest() is NOT thread-safe. if (!_predictionPreparationDone) { lock (_preparationSyncRoot) { if (!_predictionPreparationDone) { //do some preparation for building tables in native c++ _ldaTrainer.InitializeBeforeTest(); _predictionPreparationDone = true; } } } if (tokenCount == 0) { return(Array.Empty <LDATopic>()); } var retTopics = _ldaTrainer.TestDoc(tokenIndices, frequency, tokenCount, model.NumberOfBurninIterations, reset); var normFactor = 1f / retTopics.Sum(kv => kv.Value); return(retTopics.OrderByDescending(t => t.Value).Select(kv => new LDATopic(kv.Key, kv.Value * normFactor)).ToArray()); }
public void Output(ref VBuffer <Double> src, ref VBuffer <Float> dst, int numBurninIter, bool reset) { // Prediction for a single document. // LdaSingleBox.InitializeBeforeTest() is NOT thread-safe. if (!_predictionPreparationDone) { lock (_preparationSyncRoot) { if (!_predictionPreparationDone) { //do some preparation for building tables in native c++ _ldaTrainer.InitializeBeforeTest(); _predictionPreparationDone = true; } } } int len = InfoEx.NumTopic; var values = dst.Values; var indices = dst.Indices; if (src.Count == 0) { dst = new VBuffer <Float>(len, 0, values, indices); return; } // Make sure all the frequencies are valid and truncate if the sum gets too large. int docSize = 0; int termNum = 0; for (int i = 0; i < src.Count; i++) { int termFreq = GetFrequency(src.Values[i]); if (termFreq < 0) { // REVIEW: Should this log a warning message? And what should it produce? // It currently produces a vbuffer of all NA values. // REVIEW: Need a utility method to do this... if (Utils.Size(values) < len) { values = new Float[len]; } for (int k = 0; k < len; k++) { values[k] = Float.NaN; } dst = new VBuffer <Float>(len, values, indices); return; } if (docSize >= InfoEx.NumMaxDocToken - termFreq) { break; } docSize += termFreq; termNum++; } // REVIEW: Too much memory allocation here on each prediction. List <KeyValuePair <int, float> > retTopics; if (src.IsDense) { retTopics = _ldaTrainer.TestDocDense(src.Values, termNum, numBurninIter, reset); } else { retTopics = _ldaTrainer.TestDoc(src.Indices.Take(src.Count).ToArray(), src.Values.Take(src.Count).ToArray(), termNum, numBurninIter, reset); } int count = retTopics.Count; Contracts.Assert(count <= len); if (Utils.Size(values) < count) { values = new Float[count]; } if (count < len && Utils.Size(indices) < count) { indices = new int[count]; } double normalizer = 0; for (int i = 0; i < count; i++) { int index = retTopics[i].Key; Float value = retTopics[i].Value; Contracts.Assert(value >= 0); Contracts.Assert(0 <= index && index < len); if (count < len) { Contracts.Assert(i == 0 || indices[i - 1] < index); indices[i] = index; } else { Contracts.Assert(index == i); } values[i] = value; normalizer += value; } if (normalizer > 0) { for (int i = 0; i < count; i++) { values[i] = (Float)(values[i] / normalizer); } } dst = new VBuffer <Float>(len, count, values, indices); }