private void CreateCheckPoint(ParallelCorpus validCorpus, List <IMetric> metrics, IModelMetaData modelMetaData, Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> ForwardOnSingleDevice, double avgCostPerWordInTotal)
        {
            if (validCorpus != null)
            {
                // The valid corpus is provided, so evaluate the model.
                if (RunValid(validCorpus, ForwardOnSingleDevice, metrics) == true)
                {
                    SaveModel(modelMetaData);
                }
            }
            else if (m_avgCostPerWordInTotalInLastEpoch > avgCostPerWordInTotal)
            {
                // We don't have valid corpus, so if we could have lower cost, save the model
                SaveModel(modelMetaData);
            }

            TensorAllocator.FreeMemoryAllDevices();
        }
Example #2
0
        private Tensor BuildRandomTensor(int rows, int columns, double prob)
        {
            float[] weights = new float[rows * columns];
            for (int i = 0; i < weights.Length; i++)
            {
                double r = rnd.NextDouble();
                if (r < prob)
                {
                    weights[i] = 1.0f;
                }
            }

            Tensor noise = new Tensor(TensorAllocator.Allocator(deviceId), DType.Float32, rows, columns);

            noise.SetElementsAsFloat(weights);

            return(noise);
        }
Example #3
0
        public static (Tile tile, Tensor <int> input, Tensor <int> res) TransposeTile()
        {
            var input = new Symbol("input", 0);

            input.Shape = new Shape(ElementKind.Int32, new[] { 16, 16 });
            var res = new Symbol("res", 1);

            res.Shape = new Shape(ElementKind.Int32, new[] { 16, 16 });

            var i = new Index("i");

            i.Ranges = new[] { new Range(0, 16) };

            var j = new Index("j");

            j.Ranges = new[] { new Range(0, 16) };

            var statement = new Statement(
                StatementKind.Assign,
                // left
                new Element(res,
                            new[] { new IndexExpression(j), new IndexExpression(i) }),
                // right
                new ElementExpression(
                    new Element(input, new[] { new IndexExpression(i), new IndexExpression(j), })));

            var tile = new Tile("transpose", new[] { statement });

            var allocator = new TensorAllocator();
            var tinput    = (Tensor <int>)allocator.Create(input.Shape, "input");
            var s         = tinput.Buffer.Span;

            for (int m = 0; m < 16; m++)
            {
                for (int n = 0; n < 16; n++)
                {
                    s[m * 16 + n] = n - m;
                }
            }

            var tres = (Tensor <int>)allocator.Create(input.Shape, "res");

            return(tile, tinput, tres);
        }
Example #4
0
    public void TestAddTensorTensor()
    {
        TensorAllocator.InitDevices(ProcessorTypeEnums.CPU, new int[] { 0 });

        var graph = new ComputeGraphTensor(new WeightTensorFactory(), 0, true);

        var tensorA = new WeightTensor(new long[2] {
            2, 2
        }, 1, 0, name: "tensorA", isTrainable: true);
        var tensorB = new WeightTensor(new long[2] {
            2, 2
        }, 2, 0, name: "tensorB", isTrainable: true);

        var tensorSum = graph.Add(tensorA, tensorB);

        float v = tensorSum.GetWeightAt(new long[] { 1, 1 });

        Assert.IsTrue(v == 3.0f);
    }
Example #5
0
 public void AddMulGradient(Tensor w, Tensor g, bool inPlace = false)
 {
     if (m_TGradient == null)
     {
         m_allocator = TensorAllocator.Allocator(DeviceId);
         m_TGradient = new Tensor(m_allocator, DType.Float32, w.Sizes);
         Ops.Mul(m_TGradient, w, g);
     }
     else
     {
         if (inPlace)
         {
             Ops.Mul(m_TGradient, w, g);
         }
         else
         {
             Ops.AddMul(m_TGradient, m_TGradient, w, g);
         }
     }
 }
Example #6
0
    public void TestAddSubGradients()
    {
        int batchSize = 5;
        int vocabSize = 20;

        TensorAllocator.InitDevices(ProcessorTypeEnums.CPU, new int[] { 0 });

        var graph = new ComputeGraphTensor(new WeightTensorFactory(), 0, true);

        var tensorA = new WeightTensor(new long[2] {
            batchSize, vocabSize
        }, 1, 0, name: "tensorA", isTrainable: true);
        var tensorB = new WeightTensor(new long[2] {
            batchSize, vocabSize
        }, 1, 0, name: "tensorB", isTrainable: true);
        var tensorIdx = BuildRandomLabelTensor(batchSize, vocabSize, "tensorIdx");

        var tensorANeg    = graph.Mul(tensorA, -1.0f);
        var tensorANegSum = graph.Add(tensorANeg, 100.0f);
        var tensorSub     = graph.Sub(100.0f, tensorB);

        float v1 = tensorANegSum.GetWeightAt(new long[] { 1, 1 });
        float v2 = tensorSub.GetWeightAt(new long[] { 1, 1 });

        Assert.IsTrue(v1 == v2);

        var softmax1 = graph.Softmax(tensorANegSum);
        var softmax2 = graph.Softmax(tensorSub);

        graph.CrossEntropyLoss(softmax1, tensorIdx);
        graph.CrossEntropyLoss(softmax2, tensorIdx);

        graph.Backward();

        float gA = tensorA.GetGradientAt(new long[] { 1, 1 });
        float gB = tensorB.GetGradientAt(new long[] { 1, 1, });

        Assert.IsTrue(gA == gB);
    }
Example #7
0
    public void TestCrossEntropyLoss()
    {
        int batchSize = 5;
        int vocabSize = 20;

        TensorAllocator.InitDevices(ProcessorTypeEnums.CPU, new int[] { 0 });
        var graph = new ComputeGraphTensor(new WeightTensorFactory(), 0, true);

        var tensorA = BuildRandomTensor(shape: new long[2] {
            batchSize, vocabSize
        }, name: "tensorA", isTrainable: true);
        var tensorIdx = BuildRandomLabelTensor(batchSize, vocabSize, "tensorIdx");

        var probs = graph.Softmax(tensorA);

        float[] softmaxWeights = probs.ToWeightArray();
        graph.CrossEntropyLoss(probs, tensorIdx);

        graph.Backward();

        //Check if graidents are correct
        for (int i = 0; i < batchSize; i++)
        {
            for (int j = 0; j < vocabSize; j++)
            {
                float softmaxWeight = softmaxWeights[i * vocabSize + j];
                float tensorAGrad   = tensorA.GetGradientAt(new long[] { i, j });

                if (tensorIdx.GetWeightAt(new long[] { i, 0 }) != j)
                {
                    Assert.IsTrue(Math.Round(tensorAGrad, 5) == Math.Round(softmaxWeight, 5));
                }
                else
                {
                    Assert.IsTrue(Math.Round(tensorAGrad, 5) == Math.Round(softmaxWeight - 1.0f, 5));
                }
            }
        }
    }
Example #8
0
        public WeightTensor(int rows, int columns, float c, int deviceId)
        {
            DeviceId = deviceId;
            Rows     = rows;
            Columns  = columns;

            var n = rows * columns;

            var allocator = TensorAllocator.Allocator(deviceId);

            TGradient = new Tensor(allocator, DType.Float32, Rows, Columns);
            Ops.Fill(TGradient, 0.0f);

            TCash = new Tensor(allocator, DType.Float32, Rows, Columns);
            Ops.Fill(TCash, 0.0f);

            TLrW = new Tensor(allocator, DType.Float32, Rows, Columns);
            Ops.Fill(TLrW, 0.0f);

            TWeight = new Tensor(allocator, DType.Float32, Rows, Columns);
            Ops.Fill(TWeight, c);
        }
Example #9
0
        public WeightTensor(long[] sizes, int deviceId, string name = "", bool isTrainable = false, bool normal = false)
        {
            Name        = name;
            DeviceId    = deviceId;
            IsTrainable = isTrainable;
            m_allocator = TensorAllocator.Allocator(DeviceId);
            Sizes       = sizes;

            if (normal)
            {
                var     n      = Rows * Columns;
                float[] weight = new float[n];

                var scale = (float)Math.Sqrt(2.0 / Rows);
                for (int i = 0; i < n; i++)
                {
                    weight[i] = RandomGenerator.NormalRandom(0.0f, scale);
                }

                TWeight = Tensor.FromArray(m_allocator, weight).View(Sizes);
            }
        }
        public WeightTensor(int rows, int columns, int deviceId, bool keepCache = true, bool normal = false)
        {
            DeviceId = deviceId;
            Rows     = rows;
            Columns  = columns;
            var n = rows * columns;

            float[] weight = new float[n];


            var scale = (float)Math.Sqrt(1.0 / (rows * columns));

            if (normal)
            {
                scale = 0.08f;
            }
            for (int i = 0; i < n; i++)
            {
                weight[i] = RandomGenerator.NormalRandom(0.0f, scale);
            }

            var allocator = TensorAllocator.Allocator(deviceId);

            TGradient = new Tensor(allocator, DType.Float32, Rows, Columns);
            Ops.Fill(TGradient, 0.0f);

            if (keepCache)
            {
                TCache = new Tensor(allocator, DType.Float32, Rows, Columns);
                Ops.Fill(TCache, 0.0f);

                TLrW = new Tensor(allocator, DType.Float32, Rows, Columns);
                Ops.Fill(TLrW, 0.0f);
            }

            TWeight = Tensor.FromArray(allocator, weight).View(Rows, Columns);
        }
Example #11
0
        public static (Tile tile, Tensor <int> input, Tensor <int> res) SimpleMaxTile()
        {
            var input = new Symbol("input", 0);

            input.Shape = new Shape(ElementKind.Int32, new[] { 16 });
            var res = new Symbol("res", 1);

            res.Shape = new Shape(ElementKind.Int32, new[] { 1 });

            var i = new Index("i");

            i.Ranges = new[] { new Range(0, 16) };

            var statement = new Statement(
                StatementKind.Max,
                // left
                new Element(res,
                            new[] { new IndexExpression(0) }),
                // right
                new ElementExpression(
                    new Element(input, new[] { new IndexExpression(i) })));

            var tile = new Tile("max", new[] { statement });

            var allocator = new TensorAllocator();
            var tinput    = (Tensor <int>)allocator.Create(input.Shape, "input");
            var s         = tinput.Buffer.Span;

            for (int n = 0; n < 16; n++)
            {
                s[n] = n + 2;
            }

            var tres = (Tensor <int>)allocator.Create(res.Shape, "res");

            return(tile, tinput, tres);
        }
 public BaseSeq2SeqFramework(int[] deviceIds, ProcessorTypeEnums processorType, string modelFilePath, float memoryUsageRatio = 0.9f, string[] compilerOptions = null)
 {
     m_deviceIds = deviceIds;
     m_modelFilePath = modelFilePath;
     TensorAllocator.InitDevices(processorType, m_deviceIds, memoryUsageRatio, compilerOptions);
 }
Example #13
0
        public IWeightMatrix Softmax(IWeightMatrix w, bool bp = true)
        {
            WeightTensor m   = w as WeightTensor;
            var          res = weightTensorFactory.CreateWeightTensor(m.Rows, m.Columns, deviceId, new Tensor(TensorAllocator.Allocator(deviceId), DType.Float32, m.Rows, m.Columns), bp);

            Ops.Softmax(res.TWeight, m.TWeight);

            if (this.needs_backprop && bp)
            {
                Action backward = () =>
                {
                    Ops.SoftmaxGrad(m.TGradient, res.TGradient, res.TWeight);
                    res.Dispose();
                };
                this.backprop.Add(backward);
            }

            return(res);
        }
 public BaseSeq2SeqFramework(int[] deviceIds, ProcessorTypeEnums processorType, string modelFilePath)
 {
     m_deviceIds     = deviceIds;
     m_modelFilePath = modelFilePath;
     TensorAllocator.InitDevices(processorType, m_deviceIds);
 }
        internal void TrainOneEpoch(int ep, ParallelCorpus trainCorpus, ParallelCorpus validCorpus, ILearningRate learningRate, AdamOptimizer solver, List <IMetric> metrics, IModelMetaData modelMetaData,
                                    Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> ForwardOnSingleDevice)
        {
            int      processedLineInTotal   = 0;
            DateTime startDateTime          = DateTime.Now;
            DateTime lastCheckPointDateTime = DateTime.Now;
            double   costInTotal            = 0.0;
            long     srcWordCntsInTotal     = 0;
            long     tgtWordCntsInTotal     = 0;
            double   avgCostPerWordInTotal  = 0.0;

            TensorAllocator.FreeMemoryAllDevices();

            Logger.WriteLine($"Start to process training corpus.");
            List <SntPairBatch> sntPairBatchs = new List <SntPairBatch>();

            foreach (SntPairBatch sntPairBatch in trainCorpus)
            {
                sntPairBatchs.Add(sntPairBatch);
                if (sntPairBatchs.Count == m_deviceIds.Length)
                {
                    // Copy weights from weights kept in default device to all other devices
                    CopyWeightsFromDefaultDeviceToAllOtherDevices();

                    int  batchSplitFactor   = 1;
                    bool runNetwordSuccssed = false;

                    while (runNetwordSuccssed == false)
                    {
                        try
                        {
                            (float cost, int sWordCnt, int tWordCnt, int processedLine) = RunNetwork(ForwardOnSingleDevice, sntPairBatchs, batchSplitFactor);
                            processedLineInTotal += processedLine;
                            srcWordCntsInTotal   += sWordCnt;
                            tgtWordCntsInTotal   += tWordCnt;

                            //Sum up gradients in all devices, and kept it in default device for parameters optmization
                            SumGradientsToTensorsInDefaultDevice();

                            //Optmize parameters
                            float lr = learningRate.GetCurrentLearningRate();
                            List <IWeightTensor> models = GetParametersFromDefaultDevice();
                            solver.UpdateWeights(models, processedLine, lr, m_regc, m_weightsUpdateCount + 1);


                            costInTotal          += cost;
                            avgCostPerWordInTotal = costInTotal / tgtWordCntsInTotal;
                            m_weightsUpdateCount++;
                            if (IterationDone != null && m_weightsUpdateCount % 100 == 0)
                            {
                                IterationDone(this, new CostEventArg()
                                {
                                    LearningRate              = lr,
                                    CostPerWord               = cost / tWordCnt,
                                    AvgCostInTotal            = avgCostPerWordInTotal,
                                    Epoch                     = ep,
                                    Update                    = m_weightsUpdateCount,
                                    ProcessedSentencesInTotal = processedLineInTotal,
                                    ProcessedWordsInTotal     = srcWordCntsInTotal + tgtWordCntsInTotal,
                                    StartDateTime             = startDateTime
                                });
                            }

                            runNetwordSuccssed = true;
                        }
                        catch (Exception err)
                        {
                            batchSplitFactor *= 2;
                            Logger.WriteLine($"Increase batch split factor to {batchSplitFactor}, and retry it.");

                            if (batchSplitFactor >= sntPairBatchs[0].BatchSize)
                            {
                                Logger.WriteLine($"Batch split factor is larger than batch size, give it up.");
                                throw err;
                            }
                        }
                    }

                    // Evaluate model every hour and save it if we could get a better one.
                    TimeSpan ts = DateTime.Now - lastCheckPointDateTime;
                    if (ts.TotalHours > 1.0)
                    {
                        CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal);
                        lastCheckPointDateTime = DateTime.Now;
                    }

                    sntPairBatchs.Clear();
                }
            }

            Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"Epoch '{ep}' took '{DateTime.Now - startDateTime}' time to finish. AvgCost = {avgCostPerWordInTotal.ToString("F6")}, AvgCostInLastEpoch = {m_avgCostPerWordInTotalInLastEpoch.ToString("F6")}");

            CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal);
            m_avgCostPerWordInTotalInLastEpoch = avgCostPerWordInTotal;
        }
Example #16
0
        public static (Tile, Tensor <float> a, Tensor <float> x, Tensor <float> b, Tensor <float> res) LinearTile()
        {
            // linear 2D layer: res[i] = sum(a[i,j] * x[j] + b[i])

            var a = new Symbol("A", 0);

            a.Shape = new Shape(ElementKind.Float32, new[] { 16, 32 });

            var x = new Symbol("x", 1);

            x.Shape = new Shape(ElementKind.Float32, new[] { 32 });

            var b = new Symbol("b", 2);

            b.Shape = new Shape(ElementKind.Float32, new[] { 16 });

            var res = new Symbol("res", 3);

            res.Shape = new Shape(ElementKind.Float32, new[] { 16 });

            var i = new Index("i");

            i.Ranges = new[] { new Range(0, 16) };

            var j = new Index("j");

            j.Ranges = new[] { new Range(0, 32) };

            var statement = new Statement(StatementKind.AddSum,
                                          // left
                                          new Element(res, new[] { new IndexExpression(i) }),
                                          // right
                                          new ElementExpression(BinaryExpressionKind.Add,
                                                                new ElementExpression(BinaryExpressionKind.Multiply,
                                                                                      new ElementExpression(new Element(a, new[] { new IndexExpression(i), new IndexExpression(j) })),
                                                                                      new ElementExpression(new Element(x, new[] { new IndexExpression(j) }))),
                                                                new ElementExpression(new Element(b, new[] { new IndexExpression(i) }))));

            var tile = new Tile("linear", new[] { statement });

            var allocator = new TensorAllocator();
            var ta        = (Tensor <float>)allocator.Create(a.Shape, "a");
            var sa        = ta.Buffer.Span;

            for (int m = 0; m < 16; m++)
            {
                for (int n = 0; n < 32; n++)
                {
                    sa[m * 32 + n] = (n + m) % 2;
                }
            }

            var tx = (Tensor <float>)allocator.Create(x.Shape, "x");
            var sx = tx.Buffer.Span;

            for (int n = 0; n < 32; n++)
            {
                sx[n] = n % 3;
            }

            var tb = (Tensor <float>)allocator.Create(b.Shape, "b");
            var sb = tb.Buffer.Span;

            for (int m = 0; m < 16; m++)
            {
                sb[m] = m % 5;
            }

            var tres = (Tensor <float>)allocator.Create(res.Shape, "res");

            return(tile, ta, tx, tb, tres);
        }
Example #17
0
        public IWeightMatrix SoftmaxM(IWeightMatrix w, bool bp = true)
        {
            WeightTensor m   = w as WeightTensor;
            var          res = weightTensorFactory.CreateWeightTensor(m.Rows, m.Columns, deviceId, new Tensor(TensorAllocator.Allocator(deviceId), DType.Float32, m.Rows, m.Columns), bp);

            Tensor tTmp = new Tensor(TensorAllocator.Allocator(deviceId), DType.Float32, m.Rows, m.Columns);

            var maxval  = Ops.Max(null, m.TWeight, 1);
            var maxvalM = maxval.Expand(m.Rows, m.Columns);

            Ops.ExpSub2(tTmp, m.TWeight, maxvalM);

            var sumV = Ops.Sum(null, tTmp, 1);
            var sumM = sumV.Expand(m.Rows, m.Columns);

            Ops.Div(res.TWeight, tTmp, sumM);

            maxval.Dispose();
            maxvalM.Dispose();
            sumV.Dispose();
            sumM.Dispose();

            if (this.needs_backprop && bp)
            {
                Action backward = () =>
                {
                    Ops.Mul(tTmp, res.TGradient, res.TWeight);
                    Ops.Add(m.TGradient, m.TGradient, tTmp);

                    var ss  = Ops.Sum(null, tTmp, 1);
                    var ssN = Ops.Neg(null, ss);

                    var ssM = ssN.Expand(m.Rows, m.Columns);
                    Ops.AddMul(m.TGradient, m.TGradient, res.TWeight, ssM);


                    tTmp.Dispose();
                    ss.Dispose();
                    ssM.Dispose();
                    ssN.Dispose();
                };
                this.backprop.Add(backward);
            }
            else
            {
                tTmp.Dispose();
            }

            return(res);
        }
        internal void TrainOneEpoch(int ep, ParallelCorpus trainCorpus, ParallelCorpus validCorpus, ILearningRate learningRate, AdamOptimizer solver, List <IMetric> metrics, IModelMetaData modelMetaData,
                                    Func <IComputeGraph, List <List <string> >, List <List <string> >, int, bool, float> ForwardOnSingleDevice)
        {
            int      processedLineInTotal   = 0;
            DateTime startDateTime          = DateTime.Now;
            DateTime lastCheckPointDateTime = DateTime.Now;
            double   costInTotal            = 0.0;
            long     srcWordCnts            = 0;
            long     tgtWordCnts            = 0;
            double   avgCostPerWordInTotal  = 0.0;

            TensorAllocator.FreeMemoryAllDevices();

            Logger.WriteLine($"Start to process training corpus.");
            List <SntPairBatch> sntPairBatchs = new List <SntPairBatch>();

            foreach (SntPairBatch sntPairBatch in trainCorpus)
            {
                sntPairBatchs.Add(sntPairBatch);
                if (sntPairBatchs.Count == m_deviceIds.Length)
                {
                    float cost          = 0.0f;
                    int   tlen          = 0;
                    int   processedLine = 0;

                    // Copy weights from weights kept in default device to all other devices
                    CopyWeightsFromDefaultDeviceToAllOtherDevices();

                    // Run forward and backward on all available processors
                    Parallel.For(0, m_deviceIds.Length, i =>
                    {
                        SntPairBatch sntPairBatch_i = sntPairBatchs[i];
                        // Construct sentences for encoding and decoding
                        List <List <string> > srcTkns = new List <List <string> >();
                        List <List <string> > tgtTkns = new List <List <string> >();
                        int sLenInBatch = 0;
                        int tLenInBatch = 0;
                        for (int j = 0; j < sntPairBatch_i.BatchSize; j++)
                        {
                            srcTkns.Add(sntPairBatch_i.SntPairs[j].SrcSnt.ToList());
                            sLenInBatch += sntPairBatch_i.SntPairs[j].SrcSnt.Length;

                            tgtTkns.Add(sntPairBatch_i.SntPairs[j].TgtSnt.ToList());
                            tLenInBatch += sntPairBatch_i.SntPairs[j].TgtSnt.Length;
                        }

                        float lcost = 0.0f;
                        // Create a new computing graph instance
                        using (IComputeGraph computeGraph_i = CreateComputGraph(i))
                        {
                            // Run forward part
                            lcost = ForwardOnSingleDevice(computeGraph_i, srcTkns, tgtTkns, i, true);
                            // Run backward part and compute gradients
                            computeGraph_i.Backward();
                        }

                        lock (locker)
                        {
                            cost                 += lcost;
                            srcWordCnts          += sLenInBatch;
                            tgtWordCnts          += tLenInBatch;
                            tlen                 += tLenInBatch;
                            processedLineInTotal += sntPairBatch_i.BatchSize;
                            processedLine        += sntPairBatch_i.BatchSize;
                        }
                    });

                    //Sum up gradients in all devices, and kept it in default device for parameters optmization
                    SumGradientsToTensorsInDefaultDevice();

                    //Optmize parameters
                    float lr = learningRate.GetCurrentLearningRate();
                    List <IWeightTensor> models = GetParametersFromDefaultDevice();
                    solver.UpdateWeights(models, processedLine, lr, m_regc, m_weightsUpdateCount + 1);

                    //Clear gradient over all devices
                    ZeroGradientOnAllDevices();

                    costInTotal          += cost;
                    avgCostPerWordInTotal = costInTotal / tgtWordCnts;
                    m_weightsUpdateCount++;
                    if (IterationDone != null && m_weightsUpdateCount % 100 == 0)
                    {
                        IterationDone(this, new CostEventArg()
                        {
                            LearningRate              = lr,
                            CostPerWord               = cost / tlen,
                            AvgCostInTotal            = avgCostPerWordInTotal,
                            Epoch                     = ep,
                            Update                    = m_weightsUpdateCount,
                            ProcessedSentencesInTotal = processedLineInTotal,
                            ProcessedWordsInTotal     = srcWordCnts + tgtWordCnts,
                            StartDateTime             = startDateTime
                        });
                    }

                    // Evaluate model every hour and save it if we could get a better one.
                    TimeSpan ts = DateTime.Now - lastCheckPointDateTime;
                    if (ts.TotalHours > 1.0)
                    {
                        CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal);
                        lastCheckPointDateTime = DateTime.Now;
                    }

                    sntPairBatchs.Clear();
                }
            }

            Logger.WriteLine(Logger.Level.info, ConsoleColor.Green, $"Epoch '{ep}' took '{DateTime.Now - startDateTime}' time to finish. AvgCost = {avgCostPerWordInTotal.ToString("F6")}, AvgCostInLastEpoch = {m_avgCostPerWordInTotalInLastEpoch.ToString("F6")}");

            CreateCheckPoint(validCorpus, metrics, modelMetaData, ForwardOnSingleDevice, avgCostPerWordInTotal);
            m_avgCostPerWordInTotalInLastEpoch = avgCostPerWordInTotal;
        }