コード例 #1
0
ファイル: PTB.cs プロジェクト: vishalbelsare/AleaTK
            public double RunEpoch(int[] data, double learningRate = 1.0, bool verbose = false)
            {
                var cfg        = Config;
                var isTraining = IsTraining;
                var epochSize  = (data.Length / cfg.BatchSize - 1) / cfg.NumSteps;
                var time       = Stopwatch.StartNew();
                var costs      = 0.0;
                var iters      = 0;
                var step       = 0;
                var firstBatch = true;

                foreach (var batch in Data.Iterator(data, cfg.NumSteps, cfg.BatchSize))
                {
                    Optimizer.AssignTensor(Inputs, batch.Inputs.AsTensor());
                    Optimizer.AssignTensor(Targets, batch.Targets.AsTensor());

                    if (firstBatch)
                    {
                        ResetStates();
                        firstBatch = false;
                    }
                    else
                    {
                        CopyStates();
                    }

                    Optimizer.Forward();

                    if (isTraining)
                    {
                        Optimizer.Backward();
                        Optimizer.Optimize(learningRate);
                    }

                    var loss = Optimizer.GetTensor(Loss.Loss).ToScalar();
                    var cost = loss / cfg.BatchSize;
                    costs += cost;
                    iters += cfg.NumSteps;

                    if (Profiling || (verbose && (step % (epochSize / 10) == 10)))
                    {
                        var perplexity = Math.Exp(costs / iters);
                        var wps        = (iters * cfg.BatchSize) / (time.Elapsed.TotalMilliseconds / 1000.0);

                        Console.WriteLine($"{step:D4}: {step * 1.0 / epochSize:F3} perplexity: {perplexity:F3} speed:{wps:F0} wps cost: {cost:F3}");
                    }

                    if (Profiling && step > 5)
                    {
                        break;
                    }

                    step++;
                }
                return(Math.Exp(costs / iters));
            }