Пример #1
0
        static void StartTrainingAsync(PredictorEntity p)
        {
            var cancellationSource = new CancellationTokenSource();

            var ctx = new PredictorTrainingContext(p, cancellationSource.Token);

            var state = new PredictorTrainingState(cancellationSource, ctx);

            if (!Trainings.TryAdd(p.ToLite(), state))
            {
                throw new InvalidOperationException(PredictorMessage._0IsAlreadyBeingTrained.NiceToString(p));
            }

            using (ExecutionContext.SuppressFlow())
            {
                Task.Run(() =>
                {
                    var user = ExecutionMode.Global().Using(_ => p.User !.RetrieveAndRemember());
                    using (UserHolder.UserSession(user))
                    {
                        try
                        {
                            DoTraining(ctx);
                        }
                        finally
                        {
                            Trainings.TryRemove(p.ToLite(), out var _);
                        }
                    }
                });
            }
        }
Пример #2
0
        public static void CreatePredictorCodifications(PredictorTrainingContext ctx)
        {
            var isValueSize   = ((FieldValue)Schema.Current.Field((PredictorCodificationEntity e) => e.IsValue)).Size.Value;
            var groupKey0Size = ((FieldValue)Schema.Current.Field((PredictorCodificationEntity e) => e.SplitKey0)).Size.Value;
            var groupKey1Size = ((FieldValue)Schema.Current.Field((PredictorCodificationEntity e) => e.SplitKey1)).Size.Value;
            var groupKey2Size = ((FieldValue)Schema.Current.Field((PredictorCodificationEntity e) => e.SplitKey2)).Size.Value;

            ctx.ReportProgress($"Saving Codifications");

            ctx.Codifications.Select(pc =>
            {
                string?ToStringValue(QueryToken? token, object?obj, int limit)
                {
                    if (token == null || obj == null)
                    {
                        return(null);
                    }

                    if (obj is Lite <Entity> lite)
                    {
                        return(lite.KeyLong().TryStart(limit));
                    }

                    return(FilterValueConverter.ToString(obj, token.Type)?.TryStart(limit));
                }

                var valueToken = pc.Column.Token;

                var result = new PredictorCodificationEntity
                {
                    Predictor           = ctx.Predictor.ToLite(),
                    Index               = pc.Index,
                    Usage               = pc.Column.Usage,
                    OriginalColumnIndex = pc.Column.PredictorColumnIndex,
                    IsValue             = ToStringValue(valueToken, pc.IsValue, isValueSize),
                    Average             = pc.Average,
                    StdDev              = pc.StdDev,
                    Min = pc.Min,
                    Max = pc.Max,
                };

                if (pc.Column is PredictorColumnSubQuery pcsq)
                {
                    string?GetSplitpKey(int index, int limit)
                    {
                        var token = ctx.SubQueries[pcsq.SubQuery].SplitBy?.ElementAtOrDefault(index)?.Column.Token;
                        var obj   = pcsq.Keys?.ElementAtOrDefault(index);
                        return(ToStringValue(token, obj, limit));
                    }

                    result.SubQueryIndex = ctx.Predictor.SubQueries.IndexOf(pcsq.SubQuery);
                    result.SplitKey0     = GetSplitpKey(0, groupKey0Size);
                    result.SplitKey1     = GetSplitpKey(1, groupKey1Size);
                    result.SplitKey2     = GetSplitpKey(2, groupKey2Size);
                }

                return(result);
            }).BulkInsertQueryIds(a => new { a.Index, a.Usage }, a => a.Predictor == ctx.Predictor.ToLite());
        }
Пример #3
0
        static void DoTraining(PredictorTrainingContext ctx)
        {
            using (HeavyProfiler.Log("DoTraining"))
            {
                try
                {
                    if (ctx.Predictor.ResultSaver != null)
                    {
                        var saver = ResultSavers.GetOrThrow(ctx.Predictor.ResultSaver);
                        saver.AssertValid(ctx.Predictor);
                    }

                    PredictorLogicQuery.RetrieveData(ctx);
                    PredictorCodificationLogic.CreatePredictorCodifications(ctx);

                    var algorithm = Algorithms.GetOrThrow(ctx.Predictor.Algorithm);
                    using (HeavyProfiler.Log("Train"))
                        algorithm.Train(ctx);

                    if (ctx.Predictor.ResultSaver != null)
                    {
                        using (HeavyProfiler.Log("ResultSaver"))
                        {
                            var saver = ResultSavers.GetOrThrow(ctx.Predictor.ResultSaver);
                            saver.SavePredictions(ctx);
                        }
                    }

                    ctx.Predictor.State = PredictorState.Trained;
                    using (OperationLogic.AllowSave <PredictorEntity>())
                        ctx.Predictor.Save();
                }
                catch (OperationCanceledException)
                {
                    var p = ctx.Predictor.ToLite().RetrieveAndForget();
                    CleanTrained(p);
                    p.State = PredictorState.Draft;
                    using (OperationLogic.AllowSave <PredictorEntity>())
                        p.Save();
                }
                catch (Exception ex)
                {
                    ex.Data["entity"] = ctx.Predictor;
                    var e = ex.LogException();
                    var p = ctx.Predictor.ToLite().RetrieveAndForget();
                    p.State             = PredictorState.Error;
                    p.TrainingException = e.ToLite();
                    using (OperationLogic.AllowSave <PredictorEntity>())
                        p.Save();
                }
            }
        }
Пример #4
0
        public static void TrainSync(this PredictorEntity p, bool autoReset = true, Action <string, decimal?>?onReportProgres = null, CancellationToken?cancellationToken = null)
        {
            if (autoReset)
            {
                if (p.State == PredictorState.Trained || p.State == PredictorState.Error)
                {
                    p.Execute(PredictorOperation.Untrain);
                }
                else if (p.State == PredictorState.Training)
                {
                    p.Execute(PredictorOperation.CancelTraining);
                }
            }

            p.User  = UserHolder.Current.ToLite();
            p.State = PredictorState.Training;
            p.Save();

            var ctx = new PredictorTrainingContext(p, cancellationToken ?? new CancellationTokenSource().Token);
            var lastWithProgress = false;

            if (onReportProgres != null)
            {
                ctx.OnReportProgres += onReportProgres;
            }
            else
            {
                ctx.OnReportProgres += (message, progress) =>
                {
                    if (progress == null)
                    {
                        if (lastWithProgress)
                        {
                            Console.WriteLine();
                        }
                        Console.WriteLine(message);
                    }
                    else
                    {
                        SafeConsole.WriteSameLine($"{progress:P} - {message}");
                        lastWithProgress = true;
                    }
                }
            };
            DoTraining(ctx);
        }
Пример #5
0
        public static void RetrieveData(PredictorTrainingContext ctx)
        {
            using (HeavyProfiler.Log("RetrieveData"))
            {
                ctx.ReportProgress($"Executing MainQuery for {ctx.Predictor}");
                QueryRequest mainQueryRequest = GetMainQueryRequest(ctx.Predictor.MainQuery);
                ResultTable  mainResult       = QueryLogic.Queries.ExecuteQuery(mainQueryRequest);

                ctx.MainQuery = new MainQuery
                {
                    QueryRequest = mainQueryRequest,
                    ResultTable  = mainResult,
                };

                if (!mainQueryRequest.GroupResults)
                {
                    ctx.MainQuery.GetParentKey = (ResultRow row) => new object[] { row.Entity };
                }
                else
                {
                    var rcs = mainResult.Columns.Where(a => !(a.Column.Token is AggregateToken)).ToArray();
                    ctx.MainQuery.GetParentKey = (ResultRow row) => row.GetValues(rcs);
                }

                var algorithm = PredictorLogic.Algorithms.GetOrThrow(ctx.Predictor.Algorithm);

                ctx.SubQueries = new Dictionary <PredictorSubQueryEntity, SubQuery>();
                foreach (var sqe in ctx.Predictor.SubQueries)
                {
                    ctx.ReportProgress($"Executing SubQuery {sqe}");
                    QueryRequest queryGroupRequest = ToMultiColumnQuery(ctx.Predictor.MainQuery, sqe);
                    ResultTable  groupResult       = QueryLogic.Queries.ExecuteQuery(queryGroupRequest);

                    var pairs = groupResult.Columns.Zip(sqe.Columns, (rc, sqc) => (rc, sqc)).ToList();

                    var parentKeys = pairs.Extract(a => a.sqc.Usage == PredictorSubQueryColumnUsage.ParentKey).Select(a => a.rc).ToArray();
                    var splitKeys  = pairs.Extract(a => a.sqc.Usage == PredictorSubQueryColumnUsage.SplitBy).Select(a => a.rc).ToArray();
                    var values     = pairs.Select(a => a.rc).ToArray();

                    var groupedValues = groupResult.Rows.AgGroupToDictionary(
                        row => row.GetValues(parentKeys),
                        gr => gr.ToDictionaryEx(
                            row => row.GetValues(splitKeys),
                            row => row.GetValues(values),
                            ObjectArrayComparer.Instance));

                    ctx.SubQueries.Add(sqe, new SubQuery
                    {
                        SubQueryEntity          = sqe,
                        QueryGroupRequest       = queryGroupRequest,
                        ResultTable             = groupResult,
                        GroupedValues           = groupedValues,
                        SplitBy                 = splitKeys,
                        ValueColumns            = values,
                        ColumnIndexToValueIndex = values.Select((r, i) => KeyValuePair.Create(r.Index, i)).ToDictionary()
                    });
                }

                ctx.ReportProgress($"Creating Columns");
                var codifications = new List <PredictorCodification>();

                using (HeavyProfiler.Log("MainQuery"))
                {
                    for (int i = 0; i < mainResult.Columns.Length; i++)
                    {
                        var col = ctx.Predictor.MainQuery.Columns[i];
                        using (HeavyProfiler.Log("Columns", () => col.Token.Token.ToString()))
                        {
                            var mainCol           = new PredictorColumnMain(col, i);
                            var mainCodifications = algorithm.GenerateCodifications(col.Encoding, mainResult.Columns[i], mainCol);
                            codifications.AddRange(mainCodifications);
                        }
                    }
                }

                foreach (var sq in ctx.SubQueries.Values)
                {
                    using (HeavyProfiler.Log("SubQuery", () => sq.ToString() !))
                    {
                        var distinctKeys = sq.GroupedValues.SelectMany(a => a.Value.Keys).Distinct(ObjectArrayComparer.Instance).ToList();

                        distinctKeys.Sort(ObjectArrayComparer.Instance);

                        foreach (var ks in distinctKeys)
                        {
                            using (HeavyProfiler.Log("Keys", () => ks.ToString(k => k?.ToString(), ", ")))
                            {
                                foreach (var vc in sq.ValueColumns)
                                {
                                    var col = sq.SubQueryEntity.Columns[vc.Index];
                                    using (HeavyProfiler.Log("Columns", () => col.Token.Token.ToString()))
                                    {
                                        var subCol = new PredictorColumnSubQuery(col, vc.Index, sq.SubQueryEntity, ks);
                                        var subQueryCodifications = algorithm.GenerateCodifications(col.Encoding, vc, subCol);
                                        codifications.AddRange(subQueryCodifications);
                                    }
                                }
                            }
                        }
                    }
                }

                ctx.SetCodifications(codifications.ToArray());
            }
        }
Пример #6
0
 public PredictorTrainingState(CancellationTokenSource cancellationTokenSource, PredictorTrainingContext context)
 {
     CancellationTokenSource = cancellationTokenSource;
     Context = context;
 }
Пример #7
0
        public void SavePredictions(PredictorTrainingContext ctx)
        {
            using (HeavyProfiler.Log("SavePredictions"))
            {
                var p             = ctx.Predictor.ToLite();
                var outputColumn  = AssertOnlyOutput(ctx.Predictor);
                var isCategorical = outputColumn.Encoding.Is(DefaultColumnEncodings.OneHot);

                var keys = !ctx.Predictor.MainQuery.GroupResults ? null : ctx.Predictor.MainQuery.Columns.Where(c => !(c.Token.Token is AggregateToken)).ToList();
                var key0 = keys?.ElementAtOrDefault(0);
                var key1 = keys?.ElementAtOrDefault(1);
                var key2 = keys?.ElementAtOrDefault(2);

                using (HeavyProfiler.Log("Delete Old Predictions"))
                {
                    ctx.ReportProgress($"Deleting old {typeof(PredictSimpleResultEntity).NicePluralName()}");
                    {
                        var query      = Database.Query <PredictSimpleResultEntity>().Where(a => a.Predictor == p);
                        int chunkSize  = 5000;
                        var totalCount = query.Count();
                        var deleted    = 0;
                        while (totalCount - deleted > 0)
                        {
                            int num = query.OrderBy(a => a.Id).Take(chunkSize).UnsafeDelete();
                            deleted += num;
                            ctx.ReportProgress($"Deleting old {typeof(PredictSimpleResultEntity).NicePluralName()}", deleted / (decimal)totalCount);
                        }
                    }
                }

                using (HeavyProfiler.Log("SavePredictions"))
                {
                    ctx.ReportProgress($"Creating {typeof(PredictSimpleResultEntity).NicePluralName()}");
                    {
                        var dictionary = ctx.ToPredictDictionaries();
                        var toInsert   = new List <PredictSimpleResultEntity>();

                        var pc      = PredictorPredictLogic.CreatePredictContext(ctx.Predictor);
                        int grIndex = 0;
                        foreach (var gr in dictionary.GroupsOf(PredictionBatchSize))
                        {
                            using (HeavyProfiler.LogNoStackTrace("Group"))
                            {
                                ctx.ReportProgress($"Creating {typeof(PredictSimpleResultEntity).NicePluralName()}", (grIndex++ *PredictionBatchSize) / (decimal)dictionary.Count);

                                var inputs = gr.Select(a => a.Value).ToList();

                                var outputs = pc.Algorithm.PredictMultiple(pc, inputs);

                                using (HeavyProfiler.LogNoStackTrace("Create SimpleResults"))
                                {
                                    for (int i = 0; i < inputs.Count; i++)
                                    {
                                        PredictDictionary input  = inputs[i];
                                        PredictDictionary output = outputs[i];

                                        object?inValue  = input.MainQueryValues.GetOrThrow(outputColumn);
                                        object?outValue = output.MainQueryValues.GetOrThrow(outputColumn);

                                        toInsert.Add(new PredictSimpleResultEntity
                                        {
                                            Predictor         = p,
                                            Target            = ctx.Predictor.MainQuery.GroupResults ? null : input.Entity,
                                            Type              = ctx.Validation.Contains(gr[i].Key) ? PredictionSet.Validation : PredictionSet.Training,
                                            Key0              = key0 == null ? null : input.MainQueryValues.GetOrThrow(key0)?.ToString(),
                                            Key1              = key1 == null ? null : input.MainQueryValues.GetOrThrow(key1)?.ToString(),
                                            Key2              = key2 == null ? null : input.MainQueryValues.GetOrThrow(key2)?.ToString(),
                                            OriginalValue     = isCategorical ? null : ReflectionTools.ChangeType <double?>(inValue),
                                            OriginalCategory  = isCategorical ? inValue?.ToString() : null,
                                            PredictedValue    = isCategorical ? null : ReflectionTools.ChangeType <double?>(outValue),
                                            PredictedCategory = isCategorical ? outValue?.ToString() : null,
                                        });
                                    }
                                }
                            }
                        }

                        ctx.Predictor.RegressionTraining       = isCategorical ? null : GetRegressionStats(toInsert.Where(a => a.Type == PredictionSet.Training).ToList());
                        ctx.Predictor.RegressionValidation     = isCategorical ? null : GetRegressionStats(toInsert.Where(a => a.Type == PredictionSet.Validation).ToList());
                        ctx.Predictor.ClassificationTraining   = !isCategorical ? null : GetClassificationStats(toInsert.Where(a => a.Type == PredictionSet.Training).ToList());
                        ctx.Predictor.ClassificationValidation = !isCategorical ? null : GetClassificationStats(toInsert.Where(a => a.Type == PredictionSet.Validation).ToList());

                        using (OperationLogic.AllowSave <PredictorEntity>())
                            ctx.Predictor.Save();

                        if (SaveAllResults)
                        {
                            var groups = toInsert.GroupsOf(PredictionBatchSize).ToList();
                            foreach (var iter in groups.Iterate())
                            {
                                ctx.ReportProgress($"Inserting {typeof(PredictSimpleResultEntity).NicePluralName()}", iter.Position / (decimal)groups.Count);
                                iter.Value.BulkInsert();
                            }
                        }
                    }
                }
            }
        }