private Value GetValueForPredict(PredictorPredictContext ctx, List <PredictDictionary> inputs, DeviceDescriptor device) { using (HeavyProfiler.Log("GetValueForPredict", () => $"Inputs {inputs.Count} Codifications {ctx.InputCodifications.Count}")) { if (inputs.First().SubQueries.Values.Any(a => a.SubQueryGroups.Comparer != ObjectArrayComparer.Instance)) { throw new Exception("Unexpected dictionary comparer"); } float[] inputValues = new float[inputs.Count * ctx.InputCodifications.Count]; var groups = ctx.InputCodificationsByColumn; for (int i = 0; i < inputs.Count; i++) { PredictDictionary input = inputs[i]; int offset = i * ctx.InputCodifications.Count; foreach (var kvp in groups) { PredictorColumnBase col = kvp.Key; object?value; if (col is PredictorColumnMain pcm) { value = input.MainQueryValues.GetOrThrow(pcm.PredictorColumn); } else if (col is PredictorColumnSubQuery pcsq) { var sq = input.SubQueries.GetOrThrow(pcsq.SubQuery); var dic = sq.SubQueryGroups.TryGetC(pcsq.Keys); value = dic == null ? null : dic.GetOrThrow(pcsq.PredictorSubQueryColumn); } else { throw new UnexpectedValueException(col); } using (HeavyProfiler.LogNoStackTrace("EncodeValue")) { var enc = Encodings.GetOrThrow(col.Encoding); enc.EncodeValue(value ?? CNTKDefault.GetDefaultValue(kvp.Value.FirstOrDefault()), col, kvp.Value, inputValues, offset); } } } using (HeavyProfiler.LogNoStackTrace("CreateBatch")) return(Value.CreateBatch <float>(new int[] { ctx.InputCodifications.Count }, inputValues, device)); } }
#pragma warning restore CS8618 // Non-nullable field is uninitialized. Value CreateValue(PredictorTrainingContext ctx, List <ResultRow> rows, int codificationCount, Dictionary <PredictorColumnBase, List <PredictorCodification> > codificationByColumn, DeviceDescriptor device) { using (HeavyProfiler.Log("CreateValue", () => $"Rows {rows.Count} Codifications {codificationCount}")) { float[] inputValues = new float[rows.Count * codificationCount]; for (int i = 0; i < rows.Count; i++) { ResultRow mainRow = rows[i]; var mainKey = ctx.MainQuery.GetParentKey(mainRow); int offset = i * codificationCount; foreach (var kvp in codificationByColumn) { PredictorColumnBase col = kvp.Key; object?value; if (col is PredictorColumnMain pcm) { value = mainRow[pcm.PredictorColumnIndex]; } else if (col is PredictorColumnSubQuery pcsq) { SubQuery sq = ctx.SubQueries.GetOrThrow(pcsq.SubQuery); object?[]? rowValues = sq.GroupedValues.TryGetC(mainKey)?.TryGetC(pcsq.Keys); value = rowValues == null ? null : rowValues[sq.ColumnIndexToValueIndex[pcsq.PredictorColumnIndex]]; } else { throw new UnexpectedValueException(col); } using (HeavyProfiler.LogNoStackTrace("EncodeValue")) { ICNTKEncoding encoding = Encodings.GetOrThrow(col.Encoding); encoding.EncodeValue(value ?? CNTKDefault.GetDefaultValue(kvp.Value.FirstOrDefault()), col, kvp.Value, inputValues, offset); } } } using (HeavyProfiler.LogNoStackTrace("CreateBatch")) return(Value.CreateBatch <float>(new int[] { codificationCount }, inputValues, device)); } }