Пример #1
0
        private void dataGridView1_CellFormatting(object sender, DataGridViewCellFormattingEventArgs e)
        {
            DataGridViewColumn column = dataGridViewLearningTasks.Columns[e.ColumnIndex];

            if ((column == TaskType || column == WorldType) && e.Value != null)
            {
                // I am not sure about how bad this approach is, but it get things done
                Type typeValue = e.Value as Type;

                DisplayNameAttribute displayNameAtt = typeValue.
                                                      GetCustomAttributes(typeof(DisplayNameAttribute), true).
                                                      FirstOrDefault() as DisplayNameAttribute;
                if (displayNameAtt != null)
                {
                    e.Value = displayNameAtt.DisplayName;
                }
                else
                {
                    e.Value = typeValue.Name;
                }
            }
            else if (column == statusDataGridViewTextBoxColumn)
            {
                TrainingResult       result         = (TrainingResult)e.Value;
                DescriptionAttribute displayNameAtt = result.GetType().
                                                      GetMember(result.ToString())[0].
                                                      GetCustomAttributes(typeof(DescriptionAttribute), true).
                                                      FirstOrDefault() as DescriptionAttribute;
                if (displayNameAtt != null)
                {
                    e.Value = displayNameAtt.Description;
                }
            }
        }
Пример #2
0
    private void train(ref DataPool worldData, int boxerIndex)
    {
        TrainingResult results = exercise.train(ref worldData, boxerIndex);

        worldData.Boxers [boxerIndex].applyTrainingResults(results);
        //results.logTrainingResult ();
    }
Пример #3
0
 public TrainingResult[] TrainForMultipleEpochs(INetwork network, int numberOfEpochs)
 {
     TrainingResult[] results = new TrainingResult[numberOfEpochs];
     for (int i = 0; i < numberOfEpochs; i++)
     {
         results[i] = TrainForOneEpoch(network);
     }
     return(results);
 }
Пример #4
0
    private TrainingResult SnapshotCurrentResult()
    {
        Player[] cts = GetCTs().ToArray();
        Player[] ts  = GetTs().ToArray();

        if (cts.Length != 5 || ts.Length != 5)
        {
            Console.WriteLine("Not 5 players on a team!");
            return(null);
        }

        TrainingResult result = new TrainingResult()
        {
            bombplant_site          = bombsite ?? 'U',
            elapsed_since_bombplant = (float)(secondsPerTick * bombPlantTotalElapsedTicks),
            round_number            = roundNumber,
            map_id          = _demoParser.Map,
            rounds_per_half = roundNumber <= 30 ? 15 : 3,
            round_of_half   = roundNumber <= 30 ?
                              (roundNumber - 1) % 15 + 1
                : (roundNumber - 1) % 3 + 1
        };

        for (int i = 0; i < 5; i++)
        {
            var team = Team.CounterTerrorist;
            var x    = cts[i];

            var j = i + 1;

            result.SetPlayerIsAlive(team, j, x.IsAlive);
            result.SetPlayerEquipmentValue(team, j, x.CurrentEquipmentValue);
            result.SetPlayerEquippedWeapon(team, j, x.ActiveWeapon?.Weapon ?? EquipmentElement.Unknown);
            result.SetPlayerHasHelmet(team, j, x.HasHelmet);
            result.SetPlayerHasKevlar(team, j, x.Armor > 0);
            result.SetPlayerHP(team, j, x.HP);
            result.SetPlayerHasDefuseKit(j, x.HasDefuseKit);
        }

        for (int i = 0; i < 5; i++)
        {
            var team = Team.Terrorist;
            var x    = ts[i];

            var j = i + 1;

            result.SetPlayerIsAlive(team, j, x.IsAlive);
            result.SetPlayerEquipmentValue(team, j, x.CurrentEquipmentValue);
            result.SetPlayerEquippedWeapon(team, j, x.ActiveWeapon?.Weapon ?? EquipmentElement.Unknown);
            result.SetPlayerHasHelmet(team, j, x.HasHelmet);
            result.SetPlayerHasKevlar(team, j, x.Armor > 0);
            result.SetPlayerHP(team, j, x.HP);
        }

        return(result);
    }
Пример #5
0
        public TrainingResult CalculateResult(Training training)
        {
            var trainingResult = new TrainingResult();

            training.Calculate();

            trainingResult.TotalTime          = training.CalculateTotalTime();
            trainingResult.TotalCaloriesCount = training.CalculateCalories();
            return(trainingResult);
        }
Пример #6
0
 private int CompareResults(TrainingResult lhs, TrainingResult rhs)
 {
     if (lhs.score < rhs.score)
     {
         return(-1);
     }
     if (lhs.score > rhs.score)
     {
         return(+1);
     }
     return(0);
 }
Пример #7
0
        public ActionResult TrainingList(StartTrainingModel model)
        {
            var result = new Result()
            {
                Errors = new List <string>()
            };
            var loggedPerson = Person.GetLoggedPerson(User, _db);

            var myTrainings = getMyTrainings(loggedPerson);
            var training    = myTrainings.FirstOrDefault(x => x.TrainingID == model.TrainingID);

            if (training == null || !training.IsActive)
            {
                result.Errors.Add("Nie masz uprawnień by uruchomić to szkolenie");
                return(Json(result));
            }

            var rslt = _db.TrainingResults.FirstOrDefault(x => x.PersonID == loggedPerson.Id &&
                                                          x.TrainingID == model.TrainingID &&
                                                          !x.EndDate.HasValue);

            if (rslt != null)
            {
                result.Errors.Add("To szkolenie jest już aktywne :) Znajdziesz je na liście swoich aktywnych szkoleń");
                return(Json(result));
            }

            var settings = _db.AppSettings.FirstOrDefault(x => x.IsDefault);

            var activeTrainingsCount = _db.TrainingResults.Where(x => x.PersonID == loggedPerson.Id &&
                                                                 !x.EndDate.HasValue)
                                       .Count();

            if (settings != null &&
                settings.MaxActiveTrainings <= activeTrainingsCount)
            {
                result.Errors.Add("Masz uruchomioną maksymalną ilość kursów, jeżeli chcesz uruchomić kolejny zakończ wcześniej któryś z aktywowanych kursów");
                return(Json(result));
            }

            var trainingResult = new TrainingResult();

            trainingResult.PersonID         = loggedPerson.Id;
            trainingResult.StartDate        = DateTime.Now;
            trainingResult.TrainingID       = model.TrainingID;
            loggedPerson.LastActivationDate = DateTime.Now;
            _db.TrainingResults.Add(trainingResult);
            _db.SaveChanges();

            result.Succeeded = true;

            return(Json(result));
        }
Пример #8
0
    public void CalculateResult(NeuralBurst trainee, TrainingResult result)
    {
        var ofs = trainee.transform.position - target.transform.position;
        var lsq = ofs.sqrMagnitude;
        var len = 0.0f;

        if (lsq > 0.001f)
        {
            len = ofs.magnitude;
        }

        result.score = len;
    }
 public ActionResult AddTrainingResult(AddResult AddResult)
 {
     if (ModelState.IsValid)
     {
         using (var db = new Model1())
         {
             AddResult.result.ResultId = TrainingResult.GenerateId();
             AddResult.result.UserId   = int.Parse(Session["UserID"].ToString());
             db.TrainingResult.Add(AddResult.result);
             db.SaveChanges();
         }
     }
     return(RedirectToAction("TrainingPlan", "Client"));
 }
Пример #10
0
        private void cmdTrain_Click(object sender, EventArgs e)
        {
            UpdateStatus("TRAINING", DisplayPart.Status);
            _preprocessor.RemoveStopWords = chkStopWords.Checked;
            _preprocessor.StemWords       = chkStemWords.Checked;
            Task.Run(() =>
            {
                _preprocessor.Load();
                TrainingResult result = _analyzer.Train();

                UpdateStatus(result.GetDescription(), DisplayPart.Statistics);
                UpdateStatus("TRAINED", DisplayPart.Status);
            });
        }
Пример #11
0
    public void applyTrainingResults(TrainingResult results)
    {
        accuracy  += accuracy + results.Accuracy > 999 ? 0 : results.Accuracy;
        endurance += endurance + results.Endurance > 999 ? 0 : results.Endurance;
        health    += health + results.Health > 999 ? 0 : results.Health;
        speed     += speed + results.Speed > 999 ? 0 : results.Speed;
        strength  += strength + results.Strength > 999 ? 0 : results.Strength;

        fatigue += results.Fatigue;

        if (results.Result.Equals(TrainingResult.Outcome.Failure))
        {
            maturity++;
        }
    }
Пример #12
0
 private static bool TryGetTrainingResult(this string line, out TrainingResult trainingResult)
 {
     if (!string.IsNullOrEmpty(line) && line.StartsWith("$^"))
     {
         var regex = new Regex(@"\$\^([\d]+)\$t([\d]+)");
         var match = regex.Match(line);
         if (match.Groups.Count == 3)
         {
             trainingResult = new TrainingResult()
             {
                 Id     = Int32.Parse(match.Groups[1].Value),
                 Series = Int32.Parse(match.Groups[2].Value)
             };
             return(true);
         }
     }
     trainingResult = null;
     return(false);
 }
Пример #13
0
 /// <summary>
 /// Train network
 /// </summary>
 /// <param name="startingIndex">Strating index in digit set</param>
 /// <param name="endingIndex">Ending index in digit set</param>
 /// <param name="learnRate"></param>
 /// <returns>true if all correct, false on errors</returns>
 public bool TrainNetwork(uint startingIndex, uint endingIndex, float learnRate)
 {
     try
     {
         if (perceptron != null)
         {
             if (digitSet.Size != 0)
             {
                 if (endingIndex < digitSet.Size)
                 {
                     float   totalError = 0;
                     Trainer trainer    = new Trainer(perceptron, learnRate);
                     for (uint i = startingIndex; i <= endingIndex; i++)
                     {
                         TrainingResult result = trainer.Train(digitSet[i]);
                         mainForm.ShowIteration(i);
                         totalError += result.MeanSquareError;
                     }
                     totalError /= (endingIndex - startingIndex);
                     mainForm.ShowTotalError(totalError);
                     return(true);
                 }
                 else
                 {
                     throw new Exception("Ending index exceeds digit set size");
                 }
             }
             else
             {
                 throw new Exception("Load digit set");
             }
         }
         else
         {
             throw new Exception("Load or create network");
         }
     }
     catch (Exception exc)
     {
         LastErrorMessage = exc.Message;
         return(false);
     }
 }
 public void SaveResult(TrainingResult result)
 {
     if (result.ResultID == 0)
     {
         db.TrainingResults.Add(result);
     }
     else
     {
         TrainingResult dbEntry = db.TrainingResults.Find(result.ResultID);
         if (dbEntry != null)
         {
             //Override trainnig results
             dbEntry.Repetitions  = result.Repetitions;
             dbEntry.NumberSeries = result.NumberSeries;
             dbEntry.Comments     = result.Comments;
             dbEntry.Weigth       = result.Weigth;
         }
     }
     db.SaveChanges();
 }
Пример #15
0
        public static void GenerateErrorPlot(
            TrainingResult trainingResult,
            string path,
            string title)
        {
            var series = new IList <DataPoint> [1];

            path += ".png";

            title += " - error";

            series[0] = new List <DataPoint>(trainingResult.EpochErrors.Length);
            for (var epoch = 0; epoch < trainingResult.EpochErrors.Length; epoch++)
            {
                var error     = trainingResult.EpochErrors[epoch];
                var dataPoint = new DataPoint(epoch, error);
                series[0].Add(dataPoint);
            }

            Charter.Charter.GeneratePlot(series, path, title);
        }
Пример #16
0
        public static void GenerateEvaluationPlot(
            TrainingResult trainingResult,
            string path,
            string title)
        {
            var series = new IList <DataPoint> [1];

            path += ".png";

            title += " - evaluation";

            series[0] = new List <DataPoint>(trainingResult.Evaluations.Length);
            for (var epoch = 0; epoch < trainingResult.Evaluations.Length; epoch++)
            {
                var evaluation = trainingResult.Evaluations[epoch];
                var dataPoint  = new DataPoint(epoch, evaluation.Percentage);
                series[0].Add(dataPoint);
            }

            Charter.Charter.GeneratePlot(series, path, title, 0, 100, 10);
        }
Пример #17
0
        public void XorTest1()
        {
            const int AlphabetSize = 16;
            const int VectorSize   = 4;

            const int BatchSize     = 3000;
            const int Epochs        = 200;
            const int TestBatchSize = 3000;
            Random    random        = new Random(0);

            string[] classes = Enumerable.Range(0, AlphabetSize).Select(v => v.ToString(CultureInfo.InvariantCulture)).ToArray();
            ClassificationNetwork network = ClassificationNetwork.FromArchitecture("1x1x4~80-80-80-16LSTM", classes);

            float[] vectors = new RandomGeneratorF().Generate(AlphabetSize * VectorSize);

            (Tensor, int[]) createSample(int size)
            {
                Tensor input = new Tensor(null, new[] { size, 1, 1, VectorSize });

                int[] truth = new int[size];

                int v = 0;

                for (int i = 0; i < size; i++)
                {
                    v ^= random.Next(0, AlphabetSize);
                    Vectors.Copy(VectorSize, vectors, v * VectorSize, input.Weights, i * VectorSize);

                    if (i > 0)
                    {
                        truth[i - 1] = v;
                    }
                }

                return(input, truth);
            }

            // train the network
            Trainer <int[]> trainer = new Trainer <int[]>()
            {
                ClipValue = 2.0f
            };

            SGD           sgd  = new SGD();
            ILoss <int[]> loss = new LogLikelihoodLoss();

            for (int epoch = 0; epoch < Epochs; epoch++)
            {
                (Tensor, int[])sample = createSample(BatchSize);

                TrainingResult result = trainer.RunEpoch(
                    network,
                    Enumerable.Repeat(sample, 1),
                    epoch,
                    sgd,
                    loss,
                    CancellationToken.None);
                Console.WriteLine(result.CostLoss);
            }

            // test the network
            (Tensor x, int[] expected) = createSample(TestBatchSize);
            Tensor y = network.Forward(null, x);
            ////y.Reshape(testBatchSize - 1);
            ////expected.Reshape(testBatchSize - 1);
            float error = loss.Loss(y, expected, false);

            Console.WriteLine(y);
            Console.WriteLine(expected);
            ////Console.WriteLine(y.Axes[1]);
            Console.WriteLine(error);

            ////Assert.IsTrue(errorL1 < 0.01, errorL1.ToString(CultureInfo.InvariantCulture));
        }
Пример #18
0
        public void SinTest1()
        {
            const int    batchSize     = 10;
            const double batchStep     = 0.1;
            const int    epochs        = 10000;
            const int    testBatchSize = 5;
            Random       random        = new Random(0);
            Network      network       = Network.FromArchitecture("1x1x1~10-10-1LSTM");

            (Tensor, Tensor) createSample(int size)
            {
                Tensor input    = new Tensor(null, new[] { size, 1, 1, 1 });
                Tensor expected = new Tensor(null, input.Shape);

                double rv = (float)(random.NextDouble() * Math.PI * 2);

                for (int b = 0; b <= size; b++, rv += batchStep)
                {
                    float value = (float)((Math.Sin(rv) / 2.0) + 0.5);
                    if (b < size)
                    {
                        input.Weights[b] = value;
                    }

                    if (b - 1 >= 0)
                    {
                        expected.Weights[b - 1] = value;
                    }
                }

                return(input, expected);
            }

            // train the network
            SquareLoss       loss    = new SquareLoss();
            Trainer <Tensor> trainer = new Trainer <Tensor>();
            SGD sgd = new SGD();

            for (int epoch = 0; epoch < epochs; epoch++)
            {
                (Tensor, Tensor)sample = createSample(batchSize);
                TrainingResult result = trainer.RunEpoch(network, Enumerable.Repeat(sample, 1), epoch, sgd, loss, CancellationToken.None);
                Console.WriteLine(result.CostLoss);
            }

            // test the network
            double errorL1 = 0.0;
            double errorL2 = 0.0;

            for (int test = 0; test < 100; test++)
            {
                (Tensor x, Tensor ye) = createSample(testBatchSize);
                Tensor output = network.Forward(null, x);

                float diff = Math.Abs(ye.Weights[testBatchSize - 1] - output.Weights[testBatchSize - 1]);
                Console.WriteLine(diff);
                errorL1 += Math.Abs(diff);
                errorL2 += Math.Pow(diff, 2.0);
            }

            errorL1 /= 100;
            errorL2  = Math.Sqrt(errorL2) / 100;
            Console.WriteLine(errorL1);
            Console.WriteLine(errorL2);

            Assert.IsTrue(errorL1 < 0.01, errorL1.ToString(CultureInfo.InvariantCulture));
            Assert.IsTrue(errorL2 < 0.001, errorL2.ToString(CultureInfo.InvariantCulture));
        }
Пример #19
0
    public static TrainingResult GetResultData()
    {
        TrainingResult Data = new TrainingResult();

        Data.id           = null;
        Data.trained      = DateTime.Now;
        Data.trainingTime = DateTime.Now;
        Data.score        = calcActiveness();
        Data.activenesses = Hot2gApplication.Instance.GetActivenessesLog();
        Data.start        = Hot2gApplication.Instance.GetStartTime();
        Data.end          = Hot2gApplication.Instance.GetEndTime();

        float TotalSeconds = (float)((Data.end - Data.start).TotalSeconds);

        Data.Rates        = new int[] { 25, 25, 25, 25 };
        Data.Width        = new int[] { 245, 245, 245, 245 };
        Data.TotalSeconds = new int[] { 0, 0, 0, 0 };

        if (Data.activenesses.Length == 0)
        {
            return(Data);
        }

        //int Count = 0;
        int[] Counts = new int[] { 0, 0, 0, 0 };
        for (int i = 0; i < Data.activenesses.Length; i++)
        {
            double it = Data.activenesses[i];
            if (it <= 0.2)
            {
                Counts[0]++;
            }
            else if (0.2 < it && it <= 0.5)
            {
                Counts[1]++;
            }
            else if (0.5 < it && it <= 0.8)
            {
                Counts[2]++;
            }
            else
            {
                Counts[3]++;
            }
        }

        for (int i = 0; i < 4; i++)
        {
            float Rate = (float)Counts[i] / (float)Data.activenesses.Length;

            Data.TotalSeconds[i] = Mathf.RoundToInt(TotalSeconds * Rate);
            Data.Rates[i]        = Mathf.RoundToInt(Rate * 100);
            Data.Width[i]        = Mathf.RoundToInt(980.0f * Rate);
        }

        Data.score = Data.Rates[2] + Data.Rates[3];

        // 幅の計算調整
        int w = Data.Width[0] + Data.Width[1] + Data.Width[2] + Data.Width[3];
        int s = 980 - w;

        Data.Width[0] += s;         // 誤差だから青でごまかす


        return(Data);
    }
Пример #20
0
            private void Learn(int taskIndex, LearningTask task, CancellationToken cancellationToken)
            {
                using (StreamWriter logFile = File.CreateText(task.LogFileName))
                {
                    logFile.AutoFlush = true;

                    try
                    {
                        // report starting time
                        DateTime dateStarted = DateTime.Now;
                        this.WriteLine(logFile, string.Format(CultureInfo.InvariantCulture, "Started: {0}", dateStarted.ToString("G", CultureInfo.InvariantCulture)));

                        ClassificationNetwork net = File.Exists(task.Architecture) ?
                                                    ClassificationNetwork.FromFile(task.Architecture) :
                                                    ClassificationNetwork.FromArchitecture(task.Architecture, task.Classes, task.Classes, task.BlankClass);

                        // learning
                        Learn();
                        net.SaveToFile(task.OutputFileName, NetworkFileFormat.JSON);

                        // report finish time and processing interval
                        DateTime dateFinished = DateTime.Now;
                        this.WriteLine(logFile, string.Empty);
                        this.WriteLine(logFile, string.Format(CultureInfo.InvariantCulture, "Finished: {0:G}", dateFinished));
                        this.WriteLine(logFile, string.Format(CultureInfo.InvariantCulture, "Total time: {0:g}", TimeSpan.FromSeconds((dateFinished - dateStarted).TotalSeconds)));

                        void Learn()
                        {
                            this.WriteLine(logFile, "Learning...");

                            ImageDistortion filter = new ImageDistortion();
                            Stopwatch       timer  = new Stopwatch();

                            this.WriteLine(logFile, "  Epochs: {0}", task.Epochs);

                            this.WriteTrainerParameters(logFile, task.Trainer, task.Algorithm, task.Loss);

                            this.WriteLine(logFile, "Image distortion:");
                            this.WriteLine(logFile, "  Shift: {0}", task.Shift);
                            this.WriteLine(logFile, "  Rotate: {0}", task.Rotate);
                            this.WriteLine(logFile, "  Scale: {0}", task.Scale);
                            this.WriteLine(logFile, "  Crop: {0}", task.Crop);

                            Shape shape = net.InputShape;

                            using (TestImageProvider <string> dataProvider = task.CreateDataProvider(net))
                            {
                                using (TestImageProvider <string> testDataProvider = task.CreateTestDataProvider(net))
                                {
                                    ////int n = 0;
                                    for (int epoch = 0; epoch < task.Epochs; epoch++)
                                    {
                                        // run learning
                                        timer.Restart();

                                        TrainingResult result = task.Trainer.RunEpoch(
                                            epoch,
                                            net,
                                            GenerateLearnSamples(dataProvider, epoch),
                                            task.Algorithm,
                                            task.Loss,
                                            cancellationToken);

                                        timer.Stop();

                                        lock (this.logLocker)
                                        {
                                            string s = string.Format(
                                                CultureInfo.InvariantCulture,
                                                "Net: {0}, Epoch: {1}, Time: {2} ms, {3}",
                                                taskIndex,
                                                epoch,
                                                timer.ElapsedMilliseconds,
                                                result);

                                            this.Write(logFile, s);
                                            ////this.WriteDebugInformation(logFile);
                                            this.WriteLine(logFile, string.Empty);
                                        }

                                        // run testing
                                        string epochOutputFileName = string.Format(CultureInfo.InvariantCulture, task.EpochFileNameTemplate, epoch);

                                        // save network
                                        net.SaveToFile(epochOutputFileName, NetworkFileFormat.JSON);

                                        // run testing
                                        List <ClassificationResult <string> > results = new List <ClassificationResult <string> >();
                                        if (task.Loss is CTCLoss)
                                        {
                                            Context model = Context.FromRegex(@"\d", CultureInfo.InvariantCulture);

                                            foreach ((TestImage image, string[] labels) in GenerateTestSamples(testDataProvider))
                                            {
                                                if (image.Image.IsAllWhite())
                                                {
                                                    results.Add(new ClassificationResult <string>(
                                                                    image.SourceId,
                                                                    "0",
                                                                    string.Concat(labels),
                                                                    1.0f,
                                                                    true));
                                                }
                                                else
                                                {
                                                    Tensor x = ImageExtensions.FromImage(image.Image, null, Shape.BWHC, shape.GetAxis(Axis.X), shape.GetAxis(Axis.Y));
                                                    (string text, float prob) = net.ExecuteSequence(x, model).Answers.FirstOrDefault();

                                                    results.Add(new ClassificationResult <string>(
                                                                    image.SourceId,
                                                                    text,
                                                                    string.Concat(labels),
                                                                    prob,
                                                                    prob >= 0.38f));
                                                }
                                            }
                                        }
                                        else
                                        {
                                            foreach ((TestImage image, string[] labels) in GenerateTestSamples(testDataProvider))
                                            {
                                                if (image.Image.IsAllWhite())
                                                {
                                                    results.Add(new ClassificationResult <string>(
                                                                    image.SourceId,
                                                                    "0",
                                                                    string.Concat(labels),
                                                                    1.0f,
                                                                    true));
                                                }
                                                else
                                                {
                                                    Tensor x = ImageExtensions.FromImage(image.Image, null, Shape.BWHC, shape.GetAxis(Axis.X), shape.GetAxis(Axis.Y));

                                                    foreach (IList <(string answer, float probability)> answer in net.Execute(x).Answers)
                                                    {
                                                        string text = answer.FirstOrDefault().answer;
                                                        float  prob = answer.FirstOrDefault().probability;

                                                        results.Add(new ClassificationResult <string>(
                                                                        image.SourceId,
                                                                        text,
                                                                        string.Concat(labels),
                                                                        prob,
                                                                        prob >= 0.38f));
                                                    }
                                                }
                                            }
                                        }

                                        // write report
                                        ClassificationReport <string> testReport = new ClassificationReport <string>(results);
                                        this.Write(logFile, ClassificationReportWriter <string> .WriteReport(testReport, ClassificationReportMode.Summary));

                                        using (StreamWriter outputFile = File.CreateText(Path.ChangeExtension(epochOutputFileName, ".res")))
                                        {
                                            ClassificationReportWriter <string> .WriteReport(outputFile, testReport, ClassificationReportMode.All);
                                        }
                                    }
                                }

                                IEnumerable <(Tensor x, string[] labels)> GenerateLearnSamples(TestImageProvider <string> provider, int epoch)
                                {
                                    return(GenerateSamples(provider)
                                           .Where(x => !x.image.Image.IsAllWhite())
                                           .SelectMany(x =>
                                    {
                                        if (epoch == 0)
                                        {
                                            ////x.Image.Save("e:\\temp\\" + x.Id + "_" + n.ToString(CultureInfo.InvariantCulture) + "_.bmp");
                                        }

                                        return filter
                                        .Distort(
                                            x.image.Image,
                                            shape.GetAxis(Axis.X),
                                            shape.GetAxis(Axis.Y),
                                            task.Shift,
                                            task.Rotate && x.image.FontStyle != FontStyle.Italic,
                                            task.Scale,
                                            task.Crop)
                                        .Select(bitmap =>
                                        {
                                            if (epoch == 0)
                                            {
                                                ////Interlocked.Increment(ref n);
                                                ////bitmap.Save(@"d:\dnn\temp\" + n.ToString(CultureInfo.InvariantCulture) + ".bmp");
                                                ////bitmap.Save(@"d:\dnn\temp\" + (n).ToString(CultureInfo.InvariantCulture) + "_" + x.SourceId.Id + ".bmp");
                                            }

                                            return (ImageExtensions.FromImage(bitmap, null, Shape.BWHC, shape.GetAxis(Axis.X), shape.GetAxis(Axis.Y)), x.labels);
                                        });
                                    }));
                                }

                                IEnumerable <(TestImage image, string[] labels)> GenerateTestSamples(TestImageProvider <string> provider)
                                {
                                    return(GenerateSamples(provider)
                                           .AsParallel()
                                           .AsOrdered()
                                           .WithCancellation(cancellationToken)
                                           .WithMergeOptions(ParallelMergeOptions.AutoBuffered));
                                }

                                IEnumerable <(TestImage image, string[] labels)> GenerateSamples(TestImageProvider <string> provider)
                                {
                                    return(provider
                                           .Generate(net.AllowedClasses)
                                           .Select(x =>
                                    {
                                        string[] labels = x.Labels;
                                        if (!(task.Loss is CTCLoss))
                                        {
                                            int b = net.OutputShapes.First().GetAxis(Axis.B);
                                            if (labels.Length == 1 && b > 1)
                                            {
                                                labels = Enumerable.Repeat(labels[0], b).ToArray();
                                            }
                                        }

                                        return (x, labels);
                                    }));
                                }
                            }
                        }
                    }
                    finally
                    {
                        logFile.Flush();
                    }
                }
            }
        private TeamResultPrediction Predict(string teamName, int matchesToAnalyzeCount, TrainingResult trainingResult)
        {
            var lastMatchesStats = _soccerTeamLastStatsDataService.GetTeamLastStats(teamName, matchesToAnalyzeCount);

            var statsResult = new StatsResult
            {
                BallPossession = lastMatchesStats.Sum(s => s.BallPossession),
                AttacksOnGoal  = lastMatchesStats.Sum(s => s.AttacksOnGoal),
                ShotsOnGoal    = lastMatchesStats.Sum(s => s.ShotsOnGoal),
                ShotsOutGoal   = lastMatchesStats.Sum(s => s.ShotsOutGoal),
                Corners        = lastMatchesStats.Sum(s => s.Corners),
                Passes         = lastMatchesStats.Sum(s => s.Passes),
                AccuratePasses = lastMatchesStats.Sum(s => s.AccuratePasses),
                Blocks         = lastMatchesStats.Sum(s => s.Blocks),
                Points         = lastMatchesStats.Sum(s => s.ResultPoints)
            };

            var predictionFunction = trainingResult.Model.CreatePredictionEngine <StatsResult, StatsResultPrediction>(trainingResult.MlContext);

            var prediction = predictionFunction.Predict(statsResult);

            return(new TeamResultPrediction {
                TeamName = teamName, Result = prediction.Result
            });
        }
Пример #22
0
        public static void RunSizeExperiment(
            int[][] sizesArray,
            NeuralNetworkOptions options,
            int repetitions,
            string logPath
            )
        {
            //disable early learning end
            options.ErrorThreshold = 0;
            var isVerbose = options.IsVerbose;

            var mainDir = logPath.Split('/')[0];

            if (Directory.Exists(mainDir))
            {
                ClearDirectory(mainDir);
            }
            Directory.CreateDirectory(mainDir);

            for (var i = 0; i < sizesArray.Length; i++)
            {
                var sizes = sizesArray[i];

                var serializedSizes = JsonConvert.SerializeObject(sizes);

                Console.WriteLine($"Running experiment for {serializedSizes}");

                var trainingOptions = new NeuralNetworkOptions(
                    options.LearningRate,
                    options.Momentum,
                    options.ErrorThreshold,
                    sizes,
                    options.TrainingPath,
                    options.ValidationPath,
                    options.TestPath,
                    options.MaxEpochs,
                    options.IsVerbose,
                    options.BatchSize,
                    options.ActivationFunction,
                    options.InitialWeightsRange,
                    true,
                    options.NormalizeInput,
                    options.IsEncoder,
                    options.Lambda, options.TakeBest
                    );

                #region dump used params
                //lel
                var dumpling = JsonConvert.SerializeObject(options, Formatting.Indented);
                File.WriteAllText(logPath + ".log", dumpling);
                #endregion

                var trainingResponses = new TrainingResult[repetitions];

                var runLogPath = logPath + "/" + serializedSizes;
                Directory.CreateDirectory(runLogPath);

                //gather data
                for (var j = 0; j < repetitions; j++)
                {
                    var trainingResponse = MnistTrainer.TrainOnMnist(trainingOptions);
                    trainingResponses[j] = trainingResponse;

                    File.WriteAllText($"{runLogPath}/{serializedSizes}_{j}.json", trainingResponse.NeuralNetwork.ToJson());
                }

                var fileName = logPath + "_" + serializedSizes;

                //log data
                var path = fileName + ".csv";

                //File.Create(path);

                var log = new StringBuilder("sep=|");
                log.AppendLine();
                log.Append("epoch");
                for (var j = 0; j < trainingResponses.Length; j++)
                {
                    log.Append("|evaluation_" + j + "|error_" + j);
                }
                log.AppendLine();
                for (var j = 0; j < trainingResponses[0].Epochs; j++)
                {
                    log.Append(j);
                    for (var n = 0; n < trainingResponses.Length; n++)
                    {
                        var result = trainingResponses[n];
                        log.Append("|" + result.Evaluations[j].Percentage + "|" + result.EpochErrors[j]);
                    }
                    log.AppendLine();
                }
                File.WriteAllText(path, log.ToString());

                #region dump plot
                if (!options.IsEncoder)
                {
                    ExperimentVisualization.GenerateEvaluationPlot(trainingResponses, fileName, fileName);
                }
                ExperimentVisualization.GenerateErrorPlot(trainingResponses, fileName, fileName);
                #endregion
            }
        }