private void IncrementVec(MultiCountTableBuilderBase builder, int iCol, ref VBuffer <uint> srcBuffer, uint labelKey) { var n = srcBuffer.Length; var values = srcBuffer.GetValues(); var indices = srcBuffer.GetIndices(); if (srcBuffer.IsDense) { for (int i = 0; i < n; i++) { builder.IncrementSlot(iCol, i, values[i], labelKey); } } else { for (int i = 0; i < indices.Length; i++) { builder.IncrementSlot(iCol, indices[i], values[i], labelKey); } } }
private void TrainTables(IDataView trainingData, List <DataViewSchema.Column> cols, MultiCountTableBuilderBase builder, DataViewSchema.Column labelColumn) { var colCount = _columns.Length; using (var cursor = trainingData.GetRowCursor(cols.Prepend(labelColumn))) { // populate getters var singleGetters = new ValueGetter <uint> [colCount]; var vectorGetters = new ValueGetter <VBuffer <uint> > [colCount]; for (int i = 0; i < colCount; i++) { if (cols[i].Type is VectorDataViewType) { vectorGetters[i] = cursor.GetGetter <VBuffer <uint> >(cols[i]); } else { singleGetters[i] = cursor.GetGetter <uint>(cols[i]); } } var labelGetter = GetLabelGetter(cursor, labelColumn); long labelKey = 0; uint srcSingleValue = 0; var srcBuffer = default(VBuffer <uint>); while (cursor.MoveNext()) { labelGetter(ref labelKey); if (labelKey < 0) // Invalid label, skip the data { continue; } for (int i = 0; i < colCount; i++) { if (cols[i].Type is VectorDataViewType) { vectorGetters[i](ref srcBuffer); _host.Check(srcBuffer.Length == cols[i].Type.GetVectorSize(), "value count mismatch"); IncrementVec(builder, i, ref srcBuffer, (uint)labelKey); } else { singleGetters[i](ref srcSingleValue); builder.IncrementSlot(i, 0, srcSingleValue, (uint)labelKey); } } } } }