예제 #1
0
        private NDArray GetValueForPredict(PredictorPredictContext ctx, PredictDictionary input)
        {
            using (HeavyProfiler.Log("GetValueForPredict", () => $"Inputs Codifications {ctx.InputCodifications.Count}"))
            {
                if (input.SubQueries.Values.Any(a => a.SubQueryGroups.Comparer != ObjectArrayComparer.Instance))
                {
                    throw new Exception("Unexpected dictionary comparer");
                }

                float[] inputValues = new float[ctx.InputCodifications.Count];
                var     groups      = ctx.InputCodificationsByColumn;

                foreach (var kvp in groups)
                {
                    PredictorColumnBase col = kvp.Key;
                    object?value;
                    if (col is PredictorColumnMain pcm)
                    {
                        value = input.MainQueryValues.GetOrThrow(pcm.PredictorColumn);
                    }
                    else if (col is PredictorColumnSubQuery pcsq)
                    {
                        var sq = input.SubQueries.GetOrThrow(pcsq.SubQuery);

                        var dic = sq.SubQueryGroups.TryGetC(pcsq.Keys);

                        value = dic == null ? null : dic.GetOrThrow(pcsq.PredictorSubQueryColumn);
                    }
                    else
                    {
                        throw new UnexpectedValueException(col);
                    }

                    using (HeavyProfiler.LogNoStackTrace("EncodeValue"))
                    {
                        var enc = Encodings.GetOrThrow(col.Encoding);
                        enc.EncodeValue(value ?? TensorFlowDefault.GetDefaultValue(kvp.Value.FirstEx()), col, kvp.Value, inputValues, 0);
                    }
                }

                using (HeavyProfiler.LogNoStackTrace("CreateBatch"))
                    return(np.array(inputValues).reshape(-1, inputValues.Length));
            }
        }
예제 #2
0
#pragma warning restore CS8618 // Non-nullable field is uninitialized.

        NDArray CreateNDArray(PredictorTrainingContext ctx, List <ResultRow> rows, int codificationCount, Dictionary <PredictorColumnBase, List <PredictorCodification> > codificationByColumn)
        {
            using (HeavyProfiler.Log("CreateValue", () => $"Rows {rows.Count} Codifications {codificationCount}"))
            {
                float[] inputValues = new float[rows.Count * codificationCount];
                for (int i = 0; i < rows.Count; i++)
                {
                    ResultRow mainRow = rows[i];
                    var       mainKey = ctx.MainQuery.GetParentKey(mainRow);

                    int offset = i * codificationCount;

                    foreach (var kvp in codificationByColumn)
                    {
                        PredictorColumnBase col = kvp.Key;
                        object?value;
                        if (col is PredictorColumnMain pcm)
                        {
                            value = mainRow[pcm.PredictorColumnIndex];
                        }
                        else if (col is PredictorColumnSubQuery pcsq)
                        {
                            SubQuery   sq        = ctx.SubQueries.GetOrThrow(pcsq.SubQuery);
                            object?[]? rowValues = sq.GroupedValues.TryGetC(mainKey)?.TryGetC(pcsq.Keys);
                            value = rowValues == null ? null : rowValues[sq.ColumnIndexToValueIndex[pcsq.PredictorColumnIndex]];
                        }
                        else
                        {
                            throw new UnexpectedValueException(col);
                        }

                        using (HeavyProfiler.LogNoStackTrace("EncodeValue"))
                        {
                            ITensorFlowEncoding encoding = Encodings.GetOrThrow(col.Encoding);

                            encoding.EncodeValue(value ?? TensorFlowDefault.GetDefaultValue(kvp.Value.FirstEx()), col, kvp.Value, inputValues, offset);
                        }
                    }
                }

                using (HeavyProfiler.LogNoStackTrace("CreateBatch"))
                    return(np.array(inputValues).reshape((rows.Count, codificationCount)));
            }
        }