public static ValueGetter <VBuffer <Single> > GetLabelGetter(ISlotCursor cursor) { var type = cursor.GetSlotType().ItemType; if (type == NumberType.R4) { return(cursor.GetGetter <Single>()); } if (type == NumberType.R8 || type.IsBool) { return(GetVecGetterAs <Single>(NumberType.R4, cursor)); } Contracts.Check(type.IsKey, "Only floating point number, boolean, and key type values can be used as label."); Contracts.Assert(TestGetLabelGetter(type) == null); ulong keyMax = (ulong)type.KeyCount; if (keyMax == 0) { keyMax = ulong.MaxValue; } var getSrc = RowCursorUtils.GetVecGetterAs <ulong>(NumberType.U8, cursor); VBuffer <ulong> src = default(VBuffer <ulong>); return ((ref VBuffer <Single> dst) => { getSrc(ref src); // Unfortunately defaults in one to not translate to defaults of the other, // so this will not be sparsity preserving. Assume a dense output. Single[] vals = dst.Values; Utils.EnsureSize(ref vals, src.Length); foreach (var kv in src.Items(all: true)) { if (0 < kv.Value && kv.Value <= keyMax) { vals[kv.Key] = kv.Value - 1; } else { vals[kv.Key] = Single.NaN; } } dst = new VBuffer <Single>(src.Length, vals, dst.Indices); }); }
// The multi-output regression evaluator prints only the per-label metrics for each fold. protected override void PrintFoldResultsCore(IChannel ch, Dictionary <string, IDataView> metrics) { IDataView fold; if (!metrics.TryGetValue(MetricKinds.OverallMetrics, out fold)) { throw ch.Except("No overall metrics found"); } int isWeightedCol; bool needWeighted = fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.IsWeighted, out isWeightedCol); int stratCol; bool hasStrats = fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out stratCol); int stratVal; bool hasStratVals = fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out stratVal); ch.Assert(hasStrats == hasStratVals); var colCount = fold.Schema.ColumnCount; var vBufferGetters = new ValueGetter <VBuffer <double> > [colCount]; using (var cursor = fold.GetRowCursor(col => true)) { DvBool isWeighted = DvBool.False; ValueGetter <DvBool> isWeightedGetter; if (needWeighted) { isWeightedGetter = cursor.GetGetter <DvBool>(isWeightedCol); } else { isWeightedGetter = (ref DvBool dst) => dst = DvBool.False; } ValueGetter <uint> stratGetter; if (hasStrats) { var type = cursor.Schema.GetColumnType(stratCol); stratGetter = RowCursorUtils.GetGetterAs <uint>(type, cursor, stratCol); } else { stratGetter = (ref uint dst) => dst = 0; } int labelCount = 0; for (int i = 0; i < fold.Schema.ColumnCount; i++) { if (fold.Schema.IsHidden(i) || (needWeighted && i == isWeightedCol) || (hasStrats && (i == stratCol || i == stratVal))) { continue; } var type = fold.Schema.GetColumnType(i); if (type.IsKnownSizeVector && type.ItemType == NumberType.R8) { vBufferGetters[i] = cursor.GetGetter <VBuffer <double> >(i); if (labelCount == 0) { labelCount = type.VectorSize; } else { ch.Check(labelCount == type.VectorSize, "All vector metrics should contain the same number of slots"); } } } var labelNames = new DvText[labelCount]; for (int j = 0; j < labelCount; j++) { labelNames[j] = new DvText(string.Format("Label_{0}", j)); } var sb = new StringBuilder(); sb.AppendLine("Per-label metrics:"); sb.AppendFormat("{0,12} ", " "); for (int i = 0; i < labelCount; i++) { sb.AppendFormat(" {0,20}", labelNames[i]); } sb.AppendLine(); VBuffer <Double> metricVals = default(VBuffer <Double>); bool foundWeighted = !needWeighted; bool foundUnweighted = false; uint strat = 0; while (cursor.MoveNext()) { isWeightedGetter(ref isWeighted); if (foundWeighted && isWeighted.IsTrue || foundUnweighted && isWeighted.IsFalse) { throw ch.Except("Multiple {0} rows found in overall metrics data view", isWeighted.IsTrue ? "weighted" : "unweighted"); } if (isWeighted.IsTrue) { foundWeighted = true; } else { foundUnweighted = true; } stratGetter(ref strat); if (strat > 0) { continue; } for (int i = 0; i < colCount; i++) { if (vBufferGetters[i] != null) { vBufferGetters[i](ref metricVals); ch.Assert(metricVals.Length == labelCount); sb.AppendFormat("{0}{1,12}:", isWeighted.IsTrue ? "Weighted " : "", fold.Schema.GetColumnName(i)); foreach (var metric in metricVals.Items(all: true)) { sb.AppendFormat(" {0,20:G20}", metric.Value); } sb.AppendLine(); } } if (foundUnweighted && foundWeighted) { break; } } ch.Assert(foundUnweighted && foundWeighted); ch.Info(sb.ToString()); } }