Beispiel #1
0
    static void StartTrainingAsync(PredictorEntity p)
    {
        var cancellationSource = new CancellationTokenSource();

        var ctx = new PredictorTrainingContext(p, cancellationSource.Token);

        var state = new PredictorTrainingState(cancellationSource, ctx);

        if (!Trainings.TryAdd(p.ToLite(), state))
        {
            throw new InvalidOperationException(PredictorMessage._0IsAlreadyBeingTrained.NiceToString(p));
        }

        using (ExecutionContext.SuppressFlow())
        {
            Task.Run(() =>
            {
                var user = ExecutionMode.Global().Using(_ => p.User !.RetrieveAndRemember());
                using (UserHolder.UserSession(user))
                {
                    try
                    {
                        DoTraining(ctx);
                    }
                    finally
                    {
                        Trainings.TryRemove(p.ToLite(), out var _);
                    }
                }
            });
        }
    }
Beispiel #2
0
    static void DoTraining(PredictorTrainingContext ctx)
    {
        using (HeavyProfiler.Log("DoTraining"))
        {
            try
            {
                if (ctx.Predictor.ResultSaver != null)
                {
                    var saver = ResultSavers.GetOrThrow(ctx.Predictor.ResultSaver);
                    saver.AssertValid(ctx.Predictor);
                }

                PredictorLogicQuery.RetrieveData(ctx);
                PredictorCodificationLogic.CreatePredictorCodifications(ctx);

                var algorithm = Algorithms.GetOrThrow(ctx.Predictor.Algorithm);
                using (HeavyProfiler.Log("Train"))
                    algorithm.Train(ctx);

                if (ctx.Predictor.ResultSaver != null)
                {
                    using (HeavyProfiler.Log("ResultSaver"))
                    {
                        var saver = ResultSavers.GetOrThrow(ctx.Predictor.ResultSaver);
                        saver.SavePredictions(ctx);
                    }
                }

                ctx.Predictor.State = PredictorState.Trained;
                using (OperationLogic.AllowSave <PredictorEntity>())
                    ctx.Predictor.Save();
            }
            catch (OperationCanceledException)
            {
                var p = ctx.Predictor.ToLite().Retrieve();
                CleanTrained(p);
                p.State = PredictorState.Draft;
                using (OperationLogic.AllowSave <PredictorEntity>())
                    p.Save();
            }
            catch (Exception ex)
            {
                ex.Data["entity"] = ctx.Predictor;
                var e = ex.LogException();
                var p = ctx.Predictor.ToLite().Retrieve();
                p.State             = PredictorState.Error;
                p.TrainingException = e.ToLite();
                using (OperationLogic.AllowSave <PredictorEntity>())
                    p.Save();
            }
        }
    }
Beispiel #3
0
    public static void TrainSync(this PredictorEntity p, bool autoReset = true, Action <string, decimal?>?onReportProgres = null, CancellationToken?cancellationToken = null)
    {
        if (autoReset)
        {
            if (p.State == PredictorState.Trained || p.State == PredictorState.Error)
            {
                p.Execute(PredictorOperation.Untrain);
            }
            else if (p.State == PredictorState.Training)
            {
                p.Execute(PredictorOperation.CancelTraining);
            }
        }

        p.User  = UserHolder.Current.ToLite();
        p.State = PredictorState.Training;
        p.Save();

        var ctx = new PredictorTrainingContext(p, cancellationToken ?? new CancellationTokenSource().Token);
        var lastWithProgress = false;

        if (onReportProgres != null)
        {
            ctx.OnReportProgres += onReportProgres;
        }
        else
        {
            ctx.OnReportProgres += (message, progress) =>
            {
                if (progress == null)
                {
                    if (lastWithProgress)
                    {
                        Console.WriteLine();
                    }
                    SafeConsole.WriteLineColor(ConsoleColor.White, message);
                }
                else
                {
                    SafeConsole.WriteLineColor(ConsoleColor.White, $"{progress:P} - {message}");
                    lastWithProgress = true;
                }
            }
        };
        DoTraining(ctx);
    }
Beispiel #4
0
#pragma warning restore CS8618 // Non-nullable field is uninitialized.

        NDArray CreateNDArray(PredictorTrainingContext ctx, List <ResultRow> rows, int codificationCount, Dictionary <PredictorColumnBase, List <PredictorCodification> > codificationByColumn)
        {
            using (HeavyProfiler.Log("CreateValue", () => $"Rows {rows.Count} Codifications {codificationCount}"))
            {
                float[] inputValues = new float[rows.Count * codificationCount];
                for (int i = 0; i < rows.Count; i++)
                {
                    ResultRow mainRow = rows[i];
                    var       mainKey = ctx.MainQuery.GetParentKey(mainRow);

                    int offset = i * codificationCount;

                    foreach (var kvp in codificationByColumn)
                    {
                        PredictorColumnBase col = kvp.Key;
                        object?value;
                        if (col is PredictorColumnMain pcm)
                        {
                            value = mainRow[pcm.PredictorColumnIndex];
                        }
                        else if (col is PredictorColumnSubQuery pcsq)
                        {
                            SubQuery   sq        = ctx.SubQueries.GetOrThrow(pcsq.SubQuery);
                            object?[]? rowValues = sq.GroupedValues.TryGetC(mainKey)?.TryGetC(pcsq.Keys);
                            value = rowValues == null ? null : rowValues[sq.ColumnIndexToValueIndex[pcsq.PredictorColumnIndex]];
                        }
                        else
                        {
                            throw new UnexpectedValueException(col);
                        }

                        using (HeavyProfiler.LogNoStackTrace("EncodeValue"))
                        {
                            ITensorFlowEncoding encoding = Encodings.GetOrThrow(col.Encoding);

                            encoding.EncodeValue(value ?? TensorFlowDefault.GetDefaultValue(kvp.Value.FirstEx()), col, kvp.Value, inputValues, offset);
                        }
                    }
                }

                using (HeavyProfiler.LogNoStackTrace("CreateBatch"))
                    return(np.array(inputValues).reshape((rows.Count, codificationCount)));
            }
        }
Beispiel #5
0
        //Errors with CNTK: https://github.com/Microsoft/CNTK/issues/2614
        public void Train(PredictorTrainingContext ctx)
        {
            InitialSetup();

            tf.compat.v1.disable_eager_execution();
            var p = ctx.Predictor;

            var nn = (NeuralNetworkSettingsEntity)p.AlgorithmSettings;

            Tensor inputPlaceholder  = tf.placeholder(tf.float32, new[] { -1, ctx.InputCodifications.Count }, "inputPlaceholder");
            Tensor outputPlaceholder = tf.placeholder(tf.float32, new[] { -1, ctx.OutputCodifications.Count }, "outputPlaceholder");

            Tensor currentTensor = inputPlaceholder;

            nn.HiddenLayers.ForEach((layer, i) =>
            {
                currentTensor = NetworkBuilder.DenseLayer(currentTensor, layer.Size, layer.Activation, layer.Initializer, p.Settings.Seed ?? 0, "hidden" + i);
            });
            Tensor output           = NetworkBuilder.DenseLayer(currentTensor, ctx.OutputCodifications.Count, nn.OutputActivation, nn.OutputInitializer, p.Settings.Seed ?? 0, "output");
            Tensor calculatedOutput = tf.identity(output, "calculatedOutput");

            Tensor loss     = NetworkBuilder.GetEvalFunction(nn.LossFunction, outputPlaceholder, calculatedOutput);
            Tensor accuracy = NetworkBuilder.GetEvalFunction(nn.EvalErrorFunction, outputPlaceholder, calculatedOutput);

            // prepare for training
            Optimizer optimizer = NetworkBuilder.GetOptimizer(nn);

            Operation trainOperation = optimizer.minimize(loss);

            Random rand = p.Settings.Seed == null ?
                          new Random() :
                          new Random(p.Settings.Seed.Value);

            var(training, validation) = ctx.SplitTrainValidation(rand);

            var minibachtSize  = nn.MinibatchSize;
            var numMinibatches = nn.NumMinibatches;


            Stopwatch             sw        = Stopwatch.StartNew();
            List <FinalCandidate> candidate = new List <FinalCandidate>();

            var config = new ConfigProto
            {
                IntraOpParallelismThreads = 1,
                InterOpParallelismThreads = 1,
                LogDevicePlacement        = true
            };

            ctx.ReportProgress($"Deleting Files");
            var dir = PredictorDirectory(ctx.Predictor);

            if (Directory.Exists(dir))
            {
                Directory.Delete(dir, true);
            }

            Directory.CreateDirectory(dir);

            ctx.ReportProgress($"Starting training...");

            var saver = tf.train.Saver();

            using (var sess = tf.Session(config))
            {
                sess.run(tf.global_variables_initializer());

                for (int i = 0; i < numMinibatches; i++)
                {
                    using (HeavyProfiler.Log("MiniBatch", () => i.ToString()))
                    {
                        var trainMinibatch = 0.To(minibachtSize).Select(_ => rand.NextElement(training)).ToList();

                        var inputValue  = CreateNDArray(ctx, trainMinibatch, ctx.InputCodifications.Count, ctx.InputCodificationsByColumn);
                        var outputValue = CreateNDArray(ctx, trainMinibatch, ctx.OutputCodifications.Count, ctx.OutputCodificationsByColumn);

                        using (HeavyProfiler.Log("TrainMinibatch", () => i.ToString()))
                        {
                            sess.run(trainOperation,
                                     (inputPlaceholder, inputValue),
                                     (outputPlaceholder, outputValue));
                        }

                        if (ctx.StopTraining)
                        {
                            p = ctx.Predictor = ctx.Predictor.ToLite().RetrieveAndRemember();
                        }

                        var isLast = numMinibatches - nn.BestResultFromLast <= i;
                        if (isLast || (i % nn.SaveProgressEvery) == 0 || ctx.StopTraining)
                        {
                            float loss_val;
                            float accuracy_val;

                            using (HeavyProfiler.Log("EvalTraining", () => i.ToString()))
                            {
                                (loss_val, accuracy_val) = sess.run((loss, accuracy),
                                                                    (inputPlaceholder, inputValue),
                                                                    (outputPlaceholder, outputValue));
                            }

                            var ep = new EpochProgress
                            {
                                Ellapsed           = sw.ElapsedMilliseconds,
                                Epoch              = i,
                                TrainingExamples   = i * minibachtSize,
                                LossTraining       = loss_val,
                                AccuracyTraining   = accuracy_val,
                                LossValidation     = null,
                                AccuracyValidation = null,
                            };

                            ctx.ReportProgress($"Training Minibatches Loss:{loss_val} / Accuracy:{accuracy_val}", (i + 1) / (decimal)numMinibatches);

                            ctx.Progresses.Enqueue(ep);

                            if (isLast || (i % nn.SaveValidationProgressEvery) == 0 || ctx.StopTraining)
                            {
                                using (HeavyProfiler.LogNoStackTrace("EvalValidation"))
                                {
                                    var validateMinibatch = 0.To(minibachtSize).Select(_ => rand.NextElement(validation)).ToList();

                                    var inputValValue  = CreateNDArray(ctx, validateMinibatch, ctx.InputCodifications.Count, ctx.InputCodificationsByColumn);
                                    var outputValValue = CreateNDArray(ctx, validateMinibatch, ctx.OutputCodifications.Count, ctx.OutputCodificationsByColumn);

                                    (loss_val, accuracy_val) = sess.run((loss, accuracy),
                                                                        (inputPlaceholder, inputValValue),
                                                                        (outputPlaceholder, outputValValue));


                                    ep.LossValidation     = loss_val;
                                    ep.AccuracyValidation = accuracy_val;
                                }
                            }

                            var progress = ep.SaveEntity(ctx.Predictor);

                            if (isLast || ctx.StopTraining)
                            {
                                Directory.CreateDirectory(TrainingModelDirectory(ctx.Predictor, i));
                                var save = saver.save(sess, Path.Combine(TrainingModelDirectory(ctx.Predictor, i), ModelFileName));

                                using (HeavyProfiler.LogNoStackTrace("FinalCandidate"))
                                {
                                    candidate.Add(new FinalCandidate
                                    {
                                        ModelIndex     = i,
                                        ResultTraining = new PredictorMetricsEmbedded {
                                            Accuracy = progress.AccuracyTraining, Loss = progress.LossTraining
                                        },
                                        ResultValidation = new PredictorMetricsEmbedded {
                                            Accuracy = progress.AccuracyValidation, Loss = progress.LossValidation
                                        },
                                    });
                                }
                            }
                        }

                        if (ctx.StopTraining)
                        {
                            break;
                        }
                    }
                }
            }

            var best = candidate.WithMin(a => a.ResultValidation.Loss !.Value);

            p.ResultTraining   = best.ResultTraining;
            p.ResultValidation = best.ResultValidation;

            var files = Directory.GetFiles(TrainingModelDirectory(ctx.Predictor, best.ModelIndex));

            p.Files.AddRange(files.Select(p => new Entities.Files.FilePathEmbedded(PredictorFileType.PredictorFile, p)));

            using (OperationLogic.AllowSave <PredictorEntity>())
                p.Save();
        }
Beispiel #6
0
    public static void CreatePredictorCodifications(PredictorTrainingContext ctx)
    {
        var isValueSize   = ((FieldValue)Schema.Current.Field((PredictorCodificationEntity e) => e.IsValue)).Size !.Value;
        var groupKey0Size = ((FieldValue)Schema.Current.Field((PredictorCodificationEntity e) => e.SplitKey0)).Size !.Value;
        var groupKey1Size = ((FieldValue)Schema.Current.Field((PredictorCodificationEntity e) => e.SplitKey1)).Size !.Value;
        var groupKey2Size = ((FieldValue)Schema.Current.Field((PredictorCodificationEntity e) => e.SplitKey2)).Size !.Value;

        ctx.ReportProgress($"Saving Codifications");
#pragma warning disable CS8619 // Nullability of reference types in value doesn't match target type. CSBUG

        ctx.Codifications.Select(pc =>
        {
            string?ToStringValue(QueryToken? token, object?obj, int limit)
            {
                if (token == null || obj == null)
                {
                    return(null);
                }

                if (obj is Lite <Entity> lite)
                {
                    return(lite.KeyLong().TryStart(limit));
                }

                return(FilterValueConverter.ToString(obj, token.Type)?.TryStart(limit));
            }

            var valueToken = pc.Column.Token;

            var result = new PredictorCodificationEntity
            {
                Predictor           = ctx.Predictor.ToLite(),
                Index               = pc.Index,
                Usage               = pc.Column.Usage,
                OriginalColumnIndex = pc.Column.PredictorColumnIndex,
                IsValue             = ToStringValue(valueToken, pc.IsValue, isValueSize),
                Average             = pc.Average,
                StdDev              = pc.StdDev,
                Min = pc.Min,
                Max = pc.Max,
            };

            if (pc.Column is PredictorColumnSubQuery pcsq)
            {
                string?GetSplitpKey(int index, int limit)
                {
                    var token = ctx.SubQueries[pcsq.SubQuery].SplitBy?.ElementAtOrDefault(index)?.Column.Token;
                    var obj   = pcsq.Keys?.ElementAtOrDefault(index);
                    return(ToStringValue(token, obj, limit));
                }

                result.SubQueryIndex = ctx.Predictor.SubQueries.IndexOf(pcsq.SubQuery);
                result.SplitKey0     = GetSplitpKey(0, groupKey0Size);
                result.SplitKey1     = GetSplitpKey(1, groupKey1Size);
                result.SplitKey2     = GetSplitpKey(2, groupKey2Size);
            }

            return(result);
        }).BulkInsertQueryIds(a => new { a.Index, a.Usage }, a => a.Predictor.Is(ctx.Predictor.ToLite()));
#pragma warning restore CS8619 // Nullability of reference types in value doesn't match target type.
    }
Beispiel #7
0
    public void SavePredictions(PredictorTrainingContext ctx)
    {
        using (HeavyProfiler.Log("SavePredictions"))
        {
            var p             = ctx.Predictor.ToLite();
            var outputColumn  = AssertOnlyOutput(ctx.Predictor);
            var isCategorical = outputColumn.Encoding.Is(DefaultColumnEncodings.OneHot);

            var keys = !ctx.Predictor.MainQuery.GroupResults ? null : ctx.Predictor.MainQuery.Columns.Where(c => !(c.Token.Token is AggregateToken)).ToList();
            var key0 = keys?.ElementAtOrDefault(0);
            var key1 = keys?.ElementAtOrDefault(1);
            var key2 = keys?.ElementAtOrDefault(2);

            using (HeavyProfiler.Log("Delete Old Predictions"))
            {
                ctx.ReportProgress($"Deleting old {typeof(PredictSimpleResultEntity).NicePluralName()}");
                {
                    var query      = Database.Query <PredictSimpleResultEntity>().Where(a => a.Predictor.Is(p));
                    int chunkSize  = 5000;
                    var totalCount = query.Count();
                    var deleted    = 0;
                    while (totalCount - deleted > 0)
                    {
                        int num = query.OrderBy(a => a.Id).Take(chunkSize).UnsafeDelete();
                        deleted += num;
                        ctx.ReportProgress($"Deleting old {typeof(PredictSimpleResultEntity).NicePluralName()}", deleted / (decimal)totalCount);
                    }
                }
            }

            using (HeavyProfiler.Log("SavePredictions"))
            {
                ctx.ReportProgress($"Creating {typeof(PredictSimpleResultEntity).NicePluralName()}");
                {
                    var dictionary = ctx.ToPredictDictionaries();
                    var toInsert   = new List <PredictSimpleResultEntity>();

                    var pc      = PredictorPredictLogic.CreatePredictContext(ctx.Predictor);
                    int grIndex = 0;
                    foreach (var gr in dictionary.Chunk(PredictionBatchSize))
                    {
                        using (HeavyProfiler.LogNoStackTrace("Group"))
                        {
                            ctx.ReportProgress($"Creating {typeof(PredictSimpleResultEntity).NicePluralName()}", (grIndex++ *PredictionBatchSize) / (decimal)dictionary.Count);

                            var inputs = gr.Select(a => a.Value).ToList();

                            var outputs = pc.Algorithm.PredictMultiple(pc, inputs);

                            using (HeavyProfiler.LogNoStackTrace("Create SimpleResults"))
                            {
                                for (int i = 0; i < inputs.Count; i++)
                                {
                                    PredictDictionary input  = inputs[i];
                                    PredictDictionary output = outputs[i];

                                    object?inValue  = input.MainQueryValues.GetOrThrow(outputColumn);
                                    object?outValue = output.MainQueryValues.GetOrThrow(outputColumn);

                                    toInsert.Add(new PredictSimpleResultEntity
                                    {
                                        Predictor         = p,
                                        Target            = ctx.Predictor.MainQuery.GroupResults ? null : input.Entity,
                                        Type              = ctx.Validation.Contains(gr[i].Key) ? PredictionSet.Validation : PredictionSet.Training,
                                        Key0              = key0 == null ? null : input.MainQueryValues.GetOrThrow(key0)?.ToString(),
                                        Key1              = key1 == null ? null : input.MainQueryValues.GetOrThrow(key1)?.ToString(),
                                        Key2              = key2 == null ? null : input.MainQueryValues.GetOrThrow(key2)?.ToString(),
                                        OriginalValue     = isCategorical ? null : ReflectionTools.ChangeType <double?>(inValue),
                                        OriginalCategory  = isCategorical ? inValue?.ToString() : null,
                                        PredictedValue    = isCategorical ? null : ReflectionTools.ChangeType <double?>(outValue),
                                        PredictedCategory = isCategorical ? outValue?.ToString() : null,
                                    });
                                }
                            }
                        }
                    }

                    ctx.Predictor.RegressionTraining       = isCategorical ? null : GetRegressionStats(toInsert.Where(a => a.Type == PredictionSet.Training).ToList());
                    ctx.Predictor.RegressionValidation     = isCategorical ? null : GetRegressionStats(toInsert.Where(a => a.Type == PredictionSet.Validation).ToList());
                    ctx.Predictor.ClassificationTraining   = !isCategorical ? null : GetClassificationStats(toInsert.Where(a => a.Type == PredictionSet.Training).ToList());
                    ctx.Predictor.ClassificationValidation = !isCategorical ? null : GetClassificationStats(toInsert.Where(a => a.Type == PredictionSet.Validation).ToList());

                    using (OperationLogic.AllowSave <PredictorEntity>())
                        ctx.Predictor.Save();

                    if (SaveAllResults)
                    {
                        var groups = toInsert.Chunk(PredictionBatchSize).ToList();
                        foreach (var iter in groups.Iterate())
                        {
                            ctx.ReportProgress($"Inserting {typeof(PredictSimpleResultEntity).NicePluralName()}", iter.Position / (decimal)groups.Count);
                            iter.Value.BulkInsert();
                        }
                    }
                }
            }
        }
    }
Beispiel #8
0
 public PredictorTrainingState(CancellationTokenSource cancellationTokenSource, PredictorTrainingContext context)
 {
     CancellationTokenSource = cancellationTokenSource;
     Context = context;
 }
        //Errors with CNTK: https://github.com/Microsoft/CNTK/issues/2614
        public void Train(PredictorTrainingContext ctx)
        {
            InitialSetup();
            var p = ctx.Predictor;

            var nn = (NeuralNetworkSettingsEntity)p.AlgorithmSettings;

            DeviceDescriptor device         = GetDevice(nn);
            Variable         inputVariable  = Variable.InputVariable(new[] { ctx.InputCodifications.Count }, DataType.Float, "input");
            Variable         outputVariable = Variable.InputVariable(new[] { ctx.OutputCodifications.Count }, DataType.Float, "output");

            Variable currentVar = inputVariable;

            nn.HiddenLayers.ForEach((layer, i) =>
            {
                currentVar = NetworkBuilder.DenseLayer(currentVar, layer.Size, device, layer.Activation, layer.Initializer, p.Settings.Seed ?? 0, "hidden" + i);
            });
            Function calculatedOutputs = NetworkBuilder.DenseLayer(currentVar, ctx.OutputCodifications.Count, device, nn.OutputActivation, nn.OutputInitializer, p.Settings.Seed ?? 0, "output");

            Function loss      = NetworkBuilder.GetEvalFunction(nn.LossFunction, calculatedOutputs, outputVariable);
            Function evalError = NetworkBuilder.GetEvalFunction(nn.EvalErrorFunction, calculatedOutputs, outputVariable);

            // prepare for training
            Learner learner = NetworkBuilder.GetInitializer(calculatedOutputs.Parameters(), nn);

            Trainer trainer = Trainer.CreateTrainer(calculatedOutputs, loss, evalError, new List <Learner>()
            {
                learner
            });

            Random rand = p.Settings.Seed == null ?
                          new Random() :
                          new Random(p.Settings.Seed.Value);

            var(training, validation) = ctx.SplitTrainValidation(rand);

            var minibachtSize  = nn.MinibatchSize;
            var numMinibatches = nn.NumMinibatches;

            Stopwatch             sw        = Stopwatch.StartNew();
            List <FinalCandidate> candidate = new List <FinalCandidate>();

            for (int i = 0; i < numMinibatches; i++)
            {
                using (HeavyProfiler.Log("MiniBatch", () => i.ToString()))
                {
                    ctx.ReportProgress("Training Minibatches", (i + 1) / (decimal)numMinibatches);

                    {
                        var trainMinibatch = 0.To(minibachtSize).Select(_ => rand.NextElement(training)).ToList();
                        using (Value inputValue = CreateValue(ctx, trainMinibatch, ctx.InputCodifications.Count, ctx.InputCodificationsByColumn, device))
                            using (Value outputValue = CreateValue(ctx, trainMinibatch, ctx.OutputCodifications.Count, ctx.OutputCodificationsByColumn, device))
                            {
                                using (HeavyProfiler.Log("TrainMinibatch", () => i.ToString()))
                                    trainer.TrainMinibatch(new Dictionary <Variable, Value>()
                                    {
                                        { inputVariable, inputValue },
                                        { outputVariable, outputValue },
                                    }, false, device);
                            }
                    }

                    var ep = new EpochProgress
                    {
                        Ellapsed             = sw.ElapsedMilliseconds,
                        Epoch                = i,
                        TrainingExamples     = (int)trainer.TotalNumberOfSamplesSeen(),
                        LossTraining         = trainer.PreviousMinibatchLossAverage(),
                        EvaluationTraining   = trainer.PreviousMinibatchEvaluationAverage(),
                        LossValidation       = null,
                        EvaluationValidation = null,
                    };

                    ctx.Progresses.Enqueue(ep);

                    if (ctx.StopTraining)
                    {
                        p = ctx.Predictor = ctx.Predictor.ToLite().RetrieveAndRemember();
                    }

                    var isLast = numMinibatches - nn.BestResultFromLast <= i;
                    if (isLast || (i % nn.SaveProgressEvery) == 0 || ctx.StopTraining)
                    {
                        if (isLast || (i % nn.SaveValidationProgressEvery) == 0 || ctx.StopTraining)
                        {
                            using (HeavyProfiler.LogNoStackTrace("Validation"))
                            {
                                var validateMinibatch = 0.To(minibachtSize).Select(_ => rand.NextElement(validation)).ToList();

                                using (Value inputValValue = CreateValue(ctx, validateMinibatch, ctx.InputCodifications.Count, ctx.InputCodificationsByColumn, device))
                                    using (Value outputValValue = CreateValue(ctx, validateMinibatch, ctx.OutputCodifications.Count, ctx.OutputCodificationsByColumn, device))
                                    {
                                        var inputs = new Dictionary <Variable, Value>()
                                        {
                                            { inputVariable, inputValValue },
                                            { outputVariable, outputValValue },
                                        };

                                        ep.LossValidation       = loss.EvaluateAvg(inputs, device);
                                        ep.EvaluationValidation = evalError.EvaluateAvg(inputs, device);
                                    }
                            }
                        }

                        var progress = ep.SaveEntity(ctx.Predictor);

                        if (isLast || ctx.StopTraining)
                        {
                            using (HeavyProfiler.LogNoStackTrace("FinalCandidate"))
                            {
                                candidate.Add(new FinalCandidate
                                {
                                    Model = calculatedOutputs.Save(),

                                    ResultTraining = new PredictorMetricsEmbedded {
                                        Evaluation = progress.EvaluationTraining, Loss = progress.LossTraining
                                    },
                                    ResultValidation = new PredictorMetricsEmbedded {
                                        Evaluation = progress.EvaluationValidation, Loss = progress.LossValidation
                                    },
                                });
                            }
                        }
                    }

                    if (ctx.StopTraining)
                    {
                        break;
                    }
                }
            }

            var best = candidate.WithMin(a => a.ResultValidation.Loss !.Value);

            p.ResultTraining   = best.ResultTraining;
            p.ResultValidation = best.ResultValidation;

            var fp = new Entities.Files.FilePathEmbedded(PredictorFileType.PredictorFile, "Model.cntk", best.Model);

            p.Files.Add(fp);

            using (OperationLogic.AllowSave <PredictorEntity>())
                p.Save();
        }
Beispiel #10
0
    public static void RetrieveData(PredictorTrainingContext ctx)
    {
        using (HeavyProfiler.Log("RetrieveData"))
        {
            ctx.ReportProgress($"Executing MainQuery for {ctx.Predictor}");
            QueryRequest mainQueryRequest = GetMainQueryRequest(ctx.Predictor.MainQuery);
            ResultTable  mainResult       = QueryLogic.Queries.ExecuteQuery(mainQueryRequest);

            ctx.MainQuery = new MainQuery
            {
                QueryRequest = mainQueryRequest,
                ResultTable  = mainResult,
            };

            if (!mainQueryRequest.GroupResults)
            {
                ctx.MainQuery.GetParentKey = (ResultRow row) => new object[] { row.Entity };
            }
            else
            {
                var rcs = mainResult.Columns.Where(a => !(a.Column.Token is AggregateToken)).ToArray();
                ctx.MainQuery.GetParentKey = (ResultRow row) => row.GetValues(rcs);
            }

            var algorithm = PredictorLogic.Algorithms.GetOrThrow(ctx.Predictor.Algorithm);

            ctx.SubQueries = new Dictionary <PredictorSubQueryEntity, SubQuery>();
            foreach (var sqe in ctx.Predictor.SubQueries)
            {
                ctx.ReportProgress($"Executing SubQuery {sqe}");
                QueryRequest queryGroupRequest = ToMultiColumnQuery(ctx.Predictor.MainQuery, sqe);
                ResultTable  groupResult       = QueryLogic.Queries.ExecuteQuery(queryGroupRequest);

                var pairs = groupResult.Columns.Zip(sqe.Columns, (rc, sqc) => (rc, sqc)).ToList();

                var parentKeys = pairs.Extract(a => a.sqc.Usage == PredictorSubQueryColumnUsage.ParentKey).Select(a => a.rc).ToArray();
                var splitKeys  = pairs.Extract(a => a.sqc.Usage == PredictorSubQueryColumnUsage.SplitBy).Select(a => a.rc).ToArray();
                var values     = pairs.Select(a => a.rc).ToArray();

                var groupedValues = groupResult.Rows.AgGroupToDictionary(
                    row => row.GetValues(parentKeys),
                    gr => gr.ToDictionaryEx(
                        row => row.GetValues(splitKeys),
                        row => row.GetValues(values),
                        ObjectArrayComparer.Instance));

                ctx.SubQueries.Add(sqe, new SubQuery
                {
                    SubQueryEntity          = sqe,
                    QueryGroupRequest       = queryGroupRequest,
                    ResultTable             = groupResult,
                    GroupedValues           = groupedValues,
                    SplitBy                 = splitKeys,
                    ValueColumns            = values,
                    ColumnIndexToValueIndex = values.Select((r, i) => KeyValuePair.Create(r.Index, i)).ToDictionary()
                });
            }

            ctx.ReportProgress($"Creating Columns");
            var codifications = new List <PredictorCodification>();

            using (HeavyProfiler.Log("MainQuery"))
            {
                for (int i = 0; i < mainResult.Columns.Length; i++)
                {
                    var col = ctx.Predictor.MainQuery.Columns[i];
                    using (HeavyProfiler.Log("Columns", () => col.Token.Token.ToString()))
                    {
                        var mainCol           = new PredictorColumnMain(col, i);
                        var mainCodifications = algorithm.GenerateCodifications(col.Encoding, mainResult.Columns[i], mainCol);
                        codifications.AddRange(mainCodifications);
                    }
                }
            }

            foreach (var sq in ctx.SubQueries.Values)
            {
                using (HeavyProfiler.Log("SubQuery", () => sq.ToString() !))
                {
                    var distinctKeys = sq.GroupedValues.SelectMany(a => a.Value.Keys).Distinct(ObjectArrayComparer.Instance).ToList();

                    distinctKeys.Sort(ObjectArrayComparer.Instance);

                    foreach (var ks in distinctKeys)
                    {
                        using (HeavyProfiler.Log("Keys", () => ks.ToString(k => k?.ToString(), ", ")))
                        {
                            foreach (var vc in sq.ValueColumns)
                            {
                                var col = sq.SubQueryEntity.Columns[vc.Index];
                                using (HeavyProfiler.Log("Columns", () => col.Token.Token.ToString()))
                                {
                                    var subCol = new PredictorColumnSubQuery(col, vc.Index, sq.SubQueryEntity, ks);
                                    var subQueryCodifications = algorithm.GenerateCodifications(col.Encoding, vc, subCol);
                                    codifications.AddRange(subQueryCodifications);
                                }
                            }
                        }
                    }
                }
            }

            ctx.SetCodifications(codifications.ToArray());
        }
    }