//Errors with CNTK: https://github.com/Microsoft/CNTK/issues/2614 public void Train(PredictorTrainingContext ctx) { InitialSetup(); var p = ctx.Predictor; var nn = (NeuralNetworkSettingsEntity)p.AlgorithmSettings; DeviceDescriptor device = GetDevice(nn); Variable inputVariable = Variable.InputVariable(new[] { ctx.InputCodifications.Count }, DataType.Float, "input"); Variable outputVariable = Variable.InputVariable(new[] { ctx.OutputCodifications.Count }, DataType.Float, "output"); Variable currentVar = inputVariable; nn.HiddenLayers.ForEach((layer, i) => { currentVar = NetworkBuilder.DenseLayer(currentVar, layer.Size, device, layer.Activation, layer.Initializer, p.Settings.Seed ?? 0, "hidden" + i); }); Function calculatedOutputs = NetworkBuilder.DenseLayer(currentVar, ctx.OutputCodifications.Count, device, nn.OutputActivation, nn.OutputInitializer, p.Settings.Seed ?? 0, "output"); Function loss = NetworkBuilder.GetEvalFunction(nn.LossFunction, calculatedOutputs, outputVariable); Function evalError = NetworkBuilder.GetEvalFunction(nn.EvalErrorFunction, calculatedOutputs, outputVariable); // prepare for training Learner learner = NetworkBuilder.GetInitializer(calculatedOutputs.Parameters(), nn); Trainer trainer = Trainer.CreateTrainer(calculatedOutputs, loss, evalError, new List <Learner>() { learner }); Random rand = p.Settings.Seed == null ? new Random() : new Random(p.Settings.Seed.Value); var(training, validation) = ctx.SplitTrainValidation(rand); var minibachtSize = nn.MinibatchSize; var numMinibatches = nn.NumMinibatches; Stopwatch sw = Stopwatch.StartNew(); List <FinalCandidate> candidate = new List <FinalCandidate>(); for (int i = 0; i < numMinibatches; i++) { using (HeavyProfiler.Log("MiniBatch", () => i.ToString())) { ctx.ReportProgress("Training Minibatches", (i + 1) / (decimal)numMinibatches); { var trainMinibatch = 0.To(minibachtSize).Select(_ => rand.NextElement(training)).ToList(); using (Value inputValue = CreateValue(ctx, trainMinibatch, ctx.InputCodifications.Count, ctx.InputCodificationsByColumn, device)) using (Value outputValue = CreateValue(ctx, trainMinibatch, ctx.OutputCodifications.Count, ctx.OutputCodificationsByColumn, device)) { using (HeavyProfiler.Log("TrainMinibatch", () => i.ToString())) trainer.TrainMinibatch(new Dictionary <Variable, Value>() { { inputVariable, inputValue }, { outputVariable, outputValue }, }, false, device); } } var ep = new EpochProgress { Ellapsed = sw.ElapsedMilliseconds, Epoch = i, TrainingExamples = (int)trainer.TotalNumberOfSamplesSeen(), LossTraining = trainer.PreviousMinibatchLossAverage(), EvaluationTraining = trainer.PreviousMinibatchEvaluationAverage(), LossValidation = null, EvaluationValidation = null, }; ctx.Progresses.Enqueue(ep); if (ctx.StopTraining) { p = ctx.Predictor = ctx.Predictor.ToLite().RetrieveAndRemember(); } var isLast = numMinibatches - nn.BestResultFromLast <= i; if (isLast || (i % nn.SaveProgressEvery) == 0 || ctx.StopTraining) { if (isLast || (i % nn.SaveValidationProgressEvery) == 0 || ctx.StopTraining) { using (HeavyProfiler.LogNoStackTrace("Validation")) { var validateMinibatch = 0.To(minibachtSize).Select(_ => rand.NextElement(validation)).ToList(); using (Value inputValValue = CreateValue(ctx, validateMinibatch, ctx.InputCodifications.Count, ctx.InputCodificationsByColumn, device)) using (Value outputValValue = CreateValue(ctx, validateMinibatch, ctx.OutputCodifications.Count, ctx.OutputCodificationsByColumn, device)) { var inputs = new Dictionary <Variable, Value>() { { inputVariable, inputValValue }, { outputVariable, outputValValue }, }; ep.LossValidation = loss.EvaluateAvg(inputs, device); ep.EvaluationValidation = evalError.EvaluateAvg(inputs, device); } } } var progress = ep.SaveEntity(ctx.Predictor); if (isLast || ctx.StopTraining) { using (HeavyProfiler.LogNoStackTrace("FinalCandidate")) { candidate.Add(new FinalCandidate { Model = calculatedOutputs.Save(), ResultTraining = new PredictorMetricsEmbedded { Evaluation = progress.EvaluationTraining, Loss = progress.LossTraining }, ResultValidation = new PredictorMetricsEmbedded { Evaluation = progress.EvaluationValidation, Loss = progress.LossValidation }, }); } } } if (ctx.StopTraining) { break; } } } var best = candidate.WithMin(a => a.ResultValidation.Loss !.Value); p.ResultTraining = best.ResultTraining; p.ResultValidation = best.ResultValidation; var fp = new Entities.Files.FilePathEmbedded(PredictorFileType.PredictorFile, "Model.cntk", best.Model); p.Files.Add(fp); using (OperationLogic.AllowSave <PredictorEntity>()) p.Save(); }
public static void Register() { new Execute(DashboardOperation.Save) { CanBeNew = true, CanBeModified = true, Execute = (cp, _) => { } }.Register(); new Delete(DashboardOperation.Delete) { Delete = (cp, _) => { var parts = cp.Parts.Select(a => a.Content).ToList(); cp.Delete(); Database.DeleteList(parts); } }.Register(); new ConstructFrom <DashboardEntity>(DashboardOperation.Clone) { Construct = (cp, _) => cp.Clone() }.Register(); new Execute(DashboardOperation.RegenerateCachedQueries) { CanExecute = c => c.CacheQueryConfiguration == null?ValidationMessage._0IsNotSet.NiceToString(ReflectionTools.GetPropertyInfo(() => c.CacheQueryConfiguration)) : null, AvoidImplicitSave = true, Execute = (db, _) => { var cq = db.CacheQueryConfiguration !; var oldCachedQueries = db.CachedQueries().ToList(); oldCachedQueries.ForEach(a => a.File.DeleteFileOnCommit()); db.CachedQueries().UnsafeDelete(); var definitions = DashboardLogic.GetCachedQueryDefinitions(db).ToList(); var combined = DashboardLogic.CombineCachedQueryDefinitions(definitions); foreach (var c in combined) { var qr = c.QueryRequest; if (qr.Pagination is Pagination.All) { qr = qr.Clone(); qr.Pagination = new Pagination.Firsts(cq.MaxRows + 1); } var now = Clock.Now; Stopwatch sw = Stopwatch.StartNew(); var rt = Connector.CommandTimeoutScope(cq.TimeoutForQueries).Using(_ => QueryLogic.Queries.ExecuteQuery(qr)); var queryDuration = sw.ElapsedMilliseconds; if (c.QueryRequest.Pagination is Pagination.All) { if (rt.Rows.Length == cq.MaxRows) { throw new ApplicationException($"The query for {c.UserAssets.CommaAnd(a => a.KeyLong())} has returned more than {cq.MaxRows} rows: " + JsonSerializer.Serialize(QueryRequestTS.FromQueryRequest(c.QueryRequest), EntityJsonContext.FullJsonSerializerOptions)); } else { rt = new ResultTable(rt.AllColumns(), null, new Pagination.All()); } } sw.Restart(); var json = new CachedQueryJS { CreationDate = now, QueryRequest = QueryRequestTS.FromQueryRequest(c.QueryRequest), ResultTable = rt, }; var bytes = JsonSerializer.SerializeToUtf8Bytes(json, EntityJsonContext.FullJsonSerializerOptions); var file = new Entities.Files.FilePathEmbedded(CachedQueryFileType.CachedQuery, "CachedQuery.json", bytes).SaveFile(); var uploadDuration = sw.ElapsedMilliseconds; new CachedQueryEntity { CreationDate = now, UserAssets = c.UserAssets.ToMList(), NumColumns = qr.Columns.Count + (qr.GroupResults ? 0 : 1), NumRows = rt.Rows.Length, QueryDuration = queryDuration, UploadDuration = uploadDuration, File = file, Dashboard = db.ToLite(), }.Save(); } } }.SetMinimumTypeAllowed(TypeAllowedBasic.Read).Register();