static void StartTrainingAsync(PredictorEntity p) { var cancellationSource = new CancellationTokenSource(); var ctx = new PredictorTrainingContext(p, cancellationSource.Token); var state = new PredictorTrainingState(cancellationSource, ctx); if (!Trainings.TryAdd(p.ToLite(), state)) { throw new InvalidOperationException(PredictorMessage._0IsAlreadyBeingTrained.NiceToString(p)); } using (ExecutionContext.SuppressFlow()) { Task.Run(() => { var user = ExecutionMode.Global().Using(_ => p.User !.RetrieveAndRemember()); using (UserHolder.UserSession(user)) { try { DoTraining(ctx); } finally { Trainings.TryRemove(p.ToLite(), out var _); } } }); } }
public static void CreatePredictorCodifications(PredictorTrainingContext ctx) { var isValueSize = ((FieldValue)Schema.Current.Field((PredictorCodificationEntity e) => e.IsValue)).Size.Value; var groupKey0Size = ((FieldValue)Schema.Current.Field((PredictorCodificationEntity e) => e.SplitKey0)).Size.Value; var groupKey1Size = ((FieldValue)Schema.Current.Field((PredictorCodificationEntity e) => e.SplitKey1)).Size.Value; var groupKey2Size = ((FieldValue)Schema.Current.Field((PredictorCodificationEntity e) => e.SplitKey2)).Size.Value; ctx.ReportProgress($"Saving Codifications"); ctx.Codifications.Select(pc => { string?ToStringValue(QueryToken? token, object?obj, int limit) { if (token == null || obj == null) { return(null); } if (obj is Lite <Entity> lite) { return(lite.KeyLong().TryStart(limit)); } return(FilterValueConverter.ToString(obj, token.Type)?.TryStart(limit)); } var valueToken = pc.Column.Token; var result = new PredictorCodificationEntity { Predictor = ctx.Predictor.ToLite(), Index = pc.Index, Usage = pc.Column.Usage, OriginalColumnIndex = pc.Column.PredictorColumnIndex, IsValue = ToStringValue(valueToken, pc.IsValue, isValueSize), Average = pc.Average, StdDev = pc.StdDev, Min = pc.Min, Max = pc.Max, }; if (pc.Column is PredictorColumnSubQuery pcsq) { string?GetSplitpKey(int index, int limit) { var token = ctx.SubQueries[pcsq.SubQuery].SplitBy?.ElementAtOrDefault(index)?.Column.Token; var obj = pcsq.Keys?.ElementAtOrDefault(index); return(ToStringValue(token, obj, limit)); } result.SubQueryIndex = ctx.Predictor.SubQueries.IndexOf(pcsq.SubQuery); result.SplitKey0 = GetSplitpKey(0, groupKey0Size); result.SplitKey1 = GetSplitpKey(1, groupKey1Size); result.SplitKey2 = GetSplitpKey(2, groupKey2Size); } return(result); }).BulkInsertQueryIds(a => new { a.Index, a.Usage }, a => a.Predictor == ctx.Predictor.ToLite()); }
static void DoTraining(PredictorTrainingContext ctx) { using (HeavyProfiler.Log("DoTraining")) { try { if (ctx.Predictor.ResultSaver != null) { var saver = ResultSavers.GetOrThrow(ctx.Predictor.ResultSaver); saver.AssertValid(ctx.Predictor); } PredictorLogicQuery.RetrieveData(ctx); PredictorCodificationLogic.CreatePredictorCodifications(ctx); var algorithm = Algorithms.GetOrThrow(ctx.Predictor.Algorithm); using (HeavyProfiler.Log("Train")) algorithm.Train(ctx); if (ctx.Predictor.ResultSaver != null) { using (HeavyProfiler.Log("ResultSaver")) { var saver = ResultSavers.GetOrThrow(ctx.Predictor.ResultSaver); saver.SavePredictions(ctx); } } ctx.Predictor.State = PredictorState.Trained; using (OperationLogic.AllowSave <PredictorEntity>()) ctx.Predictor.Save(); } catch (OperationCanceledException) { var p = ctx.Predictor.ToLite().RetrieveAndForget(); CleanTrained(p); p.State = PredictorState.Draft; using (OperationLogic.AllowSave <PredictorEntity>()) p.Save(); } catch (Exception ex) { ex.Data["entity"] = ctx.Predictor; var e = ex.LogException(); var p = ctx.Predictor.ToLite().RetrieveAndForget(); p.State = PredictorState.Error; p.TrainingException = e.ToLite(); using (OperationLogic.AllowSave <PredictorEntity>()) p.Save(); } } }
public static void TrainSync(this PredictorEntity p, bool autoReset = true, Action <string, decimal?>?onReportProgres = null, CancellationToken?cancellationToken = null) { if (autoReset) { if (p.State == PredictorState.Trained || p.State == PredictorState.Error) { p.Execute(PredictorOperation.Untrain); } else if (p.State == PredictorState.Training) { p.Execute(PredictorOperation.CancelTraining); } } p.User = UserHolder.Current.ToLite(); p.State = PredictorState.Training; p.Save(); var ctx = new PredictorTrainingContext(p, cancellationToken ?? new CancellationTokenSource().Token); var lastWithProgress = false; if (onReportProgres != null) { ctx.OnReportProgres += onReportProgres; } else { ctx.OnReportProgres += (message, progress) => { if (progress == null) { if (lastWithProgress) { Console.WriteLine(); } Console.WriteLine(message); } else { SafeConsole.WriteSameLine($"{progress:P} - {message}"); lastWithProgress = true; } } }; DoTraining(ctx); }
public static void RetrieveData(PredictorTrainingContext ctx) { using (HeavyProfiler.Log("RetrieveData")) { ctx.ReportProgress($"Executing MainQuery for {ctx.Predictor}"); QueryRequest mainQueryRequest = GetMainQueryRequest(ctx.Predictor.MainQuery); ResultTable mainResult = QueryLogic.Queries.ExecuteQuery(mainQueryRequest); ctx.MainQuery = new MainQuery { QueryRequest = mainQueryRequest, ResultTable = mainResult, }; if (!mainQueryRequest.GroupResults) { ctx.MainQuery.GetParentKey = (ResultRow row) => new object[] { row.Entity }; } else { var rcs = mainResult.Columns.Where(a => !(a.Column.Token is AggregateToken)).ToArray(); ctx.MainQuery.GetParentKey = (ResultRow row) => row.GetValues(rcs); } var algorithm = PredictorLogic.Algorithms.GetOrThrow(ctx.Predictor.Algorithm); ctx.SubQueries = new Dictionary <PredictorSubQueryEntity, SubQuery>(); foreach (var sqe in ctx.Predictor.SubQueries) { ctx.ReportProgress($"Executing SubQuery {sqe}"); QueryRequest queryGroupRequest = ToMultiColumnQuery(ctx.Predictor.MainQuery, sqe); ResultTable groupResult = QueryLogic.Queries.ExecuteQuery(queryGroupRequest); var pairs = groupResult.Columns.Zip(sqe.Columns, (rc, sqc) => (rc, sqc)).ToList(); var parentKeys = pairs.Extract(a => a.sqc.Usage == PredictorSubQueryColumnUsage.ParentKey).Select(a => a.rc).ToArray(); var splitKeys = pairs.Extract(a => a.sqc.Usage == PredictorSubQueryColumnUsage.SplitBy).Select(a => a.rc).ToArray(); var values = pairs.Select(a => a.rc).ToArray(); var groupedValues = groupResult.Rows.AgGroupToDictionary( row => row.GetValues(parentKeys), gr => gr.ToDictionaryEx( row => row.GetValues(splitKeys), row => row.GetValues(values), ObjectArrayComparer.Instance)); ctx.SubQueries.Add(sqe, new SubQuery { SubQueryEntity = sqe, QueryGroupRequest = queryGroupRequest, ResultTable = groupResult, GroupedValues = groupedValues, SplitBy = splitKeys, ValueColumns = values, ColumnIndexToValueIndex = values.Select((r, i) => KeyValuePair.Create(r.Index, i)).ToDictionary() }); } ctx.ReportProgress($"Creating Columns"); var codifications = new List <PredictorCodification>(); using (HeavyProfiler.Log("MainQuery")) { for (int i = 0; i < mainResult.Columns.Length; i++) { var col = ctx.Predictor.MainQuery.Columns[i]; using (HeavyProfiler.Log("Columns", () => col.Token.Token.ToString())) { var mainCol = new PredictorColumnMain(col, i); var mainCodifications = algorithm.GenerateCodifications(col.Encoding, mainResult.Columns[i], mainCol); codifications.AddRange(mainCodifications); } } } foreach (var sq in ctx.SubQueries.Values) { using (HeavyProfiler.Log("SubQuery", () => sq.ToString() !)) { var distinctKeys = sq.GroupedValues.SelectMany(a => a.Value.Keys).Distinct(ObjectArrayComparer.Instance).ToList(); distinctKeys.Sort(ObjectArrayComparer.Instance); foreach (var ks in distinctKeys) { using (HeavyProfiler.Log("Keys", () => ks.ToString(k => k?.ToString(), ", "))) { foreach (var vc in sq.ValueColumns) { var col = sq.SubQueryEntity.Columns[vc.Index]; using (HeavyProfiler.Log("Columns", () => col.Token.Token.ToString())) { var subCol = new PredictorColumnSubQuery(col, vc.Index, sq.SubQueryEntity, ks); var subQueryCodifications = algorithm.GenerateCodifications(col.Encoding, vc, subCol); codifications.AddRange(subQueryCodifications); } } } } } } ctx.SetCodifications(codifications.ToArray()); } }
public PredictorTrainingState(CancellationTokenSource cancellationTokenSource, PredictorTrainingContext context) { CancellationTokenSource = cancellationTokenSource; Context = context; }
public void SavePredictions(PredictorTrainingContext ctx) { using (HeavyProfiler.Log("SavePredictions")) { var p = ctx.Predictor.ToLite(); var outputColumn = AssertOnlyOutput(ctx.Predictor); var isCategorical = outputColumn.Encoding.Is(DefaultColumnEncodings.OneHot); var keys = !ctx.Predictor.MainQuery.GroupResults ? null : ctx.Predictor.MainQuery.Columns.Where(c => !(c.Token.Token is AggregateToken)).ToList(); var key0 = keys?.ElementAtOrDefault(0); var key1 = keys?.ElementAtOrDefault(1); var key2 = keys?.ElementAtOrDefault(2); using (HeavyProfiler.Log("Delete Old Predictions")) { ctx.ReportProgress($"Deleting old {typeof(PredictSimpleResultEntity).NicePluralName()}"); { var query = Database.Query <PredictSimpleResultEntity>().Where(a => a.Predictor == p); int chunkSize = 5000; var totalCount = query.Count(); var deleted = 0; while (totalCount - deleted > 0) { int num = query.OrderBy(a => a.Id).Take(chunkSize).UnsafeDelete(); deleted += num; ctx.ReportProgress($"Deleting old {typeof(PredictSimpleResultEntity).NicePluralName()}", deleted / (decimal)totalCount); } } } using (HeavyProfiler.Log("SavePredictions")) { ctx.ReportProgress($"Creating {typeof(PredictSimpleResultEntity).NicePluralName()}"); { var dictionary = ctx.ToPredictDictionaries(); var toInsert = new List <PredictSimpleResultEntity>(); var pc = PredictorPredictLogic.CreatePredictContext(ctx.Predictor); int grIndex = 0; foreach (var gr in dictionary.GroupsOf(PredictionBatchSize)) { using (HeavyProfiler.LogNoStackTrace("Group")) { ctx.ReportProgress($"Creating {typeof(PredictSimpleResultEntity).NicePluralName()}", (grIndex++ *PredictionBatchSize) / (decimal)dictionary.Count); var inputs = gr.Select(a => a.Value).ToList(); var outputs = pc.Algorithm.PredictMultiple(pc, inputs); using (HeavyProfiler.LogNoStackTrace("Create SimpleResults")) { for (int i = 0; i < inputs.Count; i++) { PredictDictionary input = inputs[i]; PredictDictionary output = outputs[i]; object?inValue = input.MainQueryValues.GetOrThrow(outputColumn); object?outValue = output.MainQueryValues.GetOrThrow(outputColumn); toInsert.Add(new PredictSimpleResultEntity { Predictor = p, Target = ctx.Predictor.MainQuery.GroupResults ? null : input.Entity, Type = ctx.Validation.Contains(gr[i].Key) ? PredictionSet.Validation : PredictionSet.Training, Key0 = key0 == null ? null : input.MainQueryValues.GetOrThrow(key0)?.ToString(), Key1 = key1 == null ? null : input.MainQueryValues.GetOrThrow(key1)?.ToString(), Key2 = key2 == null ? null : input.MainQueryValues.GetOrThrow(key2)?.ToString(), OriginalValue = isCategorical ? null : ReflectionTools.ChangeType <double?>(inValue), OriginalCategory = isCategorical ? inValue?.ToString() : null, PredictedValue = isCategorical ? null : ReflectionTools.ChangeType <double?>(outValue), PredictedCategory = isCategorical ? outValue?.ToString() : null, }); } } } } ctx.Predictor.RegressionTraining = isCategorical ? null : GetRegressionStats(toInsert.Where(a => a.Type == PredictionSet.Training).ToList()); ctx.Predictor.RegressionValidation = isCategorical ? null : GetRegressionStats(toInsert.Where(a => a.Type == PredictionSet.Validation).ToList()); ctx.Predictor.ClassificationTraining = !isCategorical ? null : GetClassificationStats(toInsert.Where(a => a.Type == PredictionSet.Training).ToList()); ctx.Predictor.ClassificationValidation = !isCategorical ? null : GetClassificationStats(toInsert.Where(a => a.Type == PredictionSet.Validation).ToList()); using (OperationLogic.AllowSave <PredictorEntity>()) ctx.Predictor.Save(); if (SaveAllResults) { var groups = toInsert.GroupsOf(PredictionBatchSize).ToList(); foreach (var iter in groups.Iterate()) { ctx.ReportProgress($"Inserting {typeof(PredictSimpleResultEntity).NicePluralName()}", iter.Position / (decimal)groups.Count); iter.Value.BulkInsert(); } } } } } }