Example #1
0
        public AttentionSeq2Seq(string modelFilePath, int batchSize, ArchTypeEnums archType, int[] deviceIds)
        {
            CheckParameters(batchSize, archType, deviceIds);

            if (archType == ArchTypeEnums.GPU_CUDA)
            {
                TensorAllocator.InitDevices(deviceIds);
                SetDefaultDeviceIds(deviceIds.Length);
            }

            m_archType  = archType;
            m_deviceIds = deviceIds;

            Load(modelFilePath);
            InitWeightsFactory();

            SetBatchSize(batchSize);
        }
Example #2
0
        public void UpdateWeights(List <IWeightTensor> model, int batchSize, float step_size, float regc, int iter)
        {
            Dictionary <int, List <IWeightTensor> > id2Models = new Dictionary <int, List <IWeightTensor> >();
            HashSet <string> setWeightsName = new HashSet <string>();

            foreach (var item in model)
            {
                if (setWeightsName.Contains(item.Name))
                {
                    throw new ArgumentException($"Found duplicated weights name '{item.Name}'");
                }
                setWeightsName.Add(item.Name);

                if (id2Models.ContainsKey(item.DeviceId) == false)
                {
                    id2Models.Add(item.DeviceId, new List <IWeightTensor>());
                }
                id2Models[item.DeviceId].Add(item);

                if (m_cacheName2V.ContainsKey(item.Name) == false)
                {
                    var allocator = TensorAllocator.Allocator(item.DeviceId);
                    m_cacheName2V[item.Name] = new Tensor(allocator, DType.Float32, item.Sizes);
                    Ops.Fill(m_cacheName2V[item.Name], 0.0f);

                    m_cacheName2M[item.Name] = new Tensor(allocator, DType.Float32, item.Sizes);
                    Ops.Fill(m_cacheName2M[item.Name], 0.0f);

                    Logger.WriteLine($"Added weight '{item.Name}' to optimizer.");
                }
            }

            Parallel.ForEach(id2Models, kv =>
            {
                foreach (var item in kv.Value)
                {
                    var m = item as WeightTensor;
                    UpdateWeightsTensor(m, batchSize, step_size, m_clipval, regc, iter);
                }
            });
        }
        private void TrainEp(int ep, float learningRate)
        {
            int      processedLine = 0;
            DateTime startDateTime = DateTime.Now;

            double         costInTotal               = 0.0;
            long           srcWordCnts               = 0;
            long           tgtWordCnts               = 0;
            double         avgCostPerWordInTotal     = 0.0;
            double         lastAvgCostPerWordInTotal = 100000.0;
            List <SntPair> sntPairs = new List <SntPair>();

            TensorAllocator.FreeMemoryAllDevices();

            Logger.WriteLine($"Base learning rate is '{learningRate}' at epoch '{ep}'");

            //Clean caches of parameter optmization
            Logger.WriteLine($"Cleaning cache of weights optmiazation.'");
            CleanWeightCache();

            Logger.WriteLine($"Start to process training corpus.");
            foreach (var sntPair in TrainCorpus)
            {
                sntPairs.Add(sntPair);

                if (sntPairs.Count == TrainCorpus.BatchSize)
                {
                    List <IWeightMatrix>  encoded = new List <IWeightMatrix>();
                    List <List <string> > srcSnts = new List <List <string> >();
                    List <List <string> > tgtSnts = new List <List <string> >();

                    var slen = 0;
                    var tlen = 0;
                    for (int j = 0; j < TrainCorpus.BatchSize; j++)
                    {
                        List <string> srcSnt = new List <string>();

                        //Add BOS and EOS tags to source sentences
                        srcSnt.Add(m_START);
                        srcSnt.AddRange(sntPairs[j].SrcSnt);
                        srcSnt.Add(m_END);

                        srcSnts.Add(srcSnt);
                        tgtSnts.Add(sntPairs[j].TgtSnt.ToList());

                        slen += srcSnt.Count;
                        tlen += sntPairs[j].TgtSnt.Length;
                    }
                    srcWordCnts += slen;
                    tgtWordCnts += tlen;

                    Reset();

                    //Copy weights from weights kept in default device to all other devices
                    SyncWeights();

                    float cost = 0.0f;
                    Parallel.For(0, m_deviceIds.Length, i =>
                    {
                        IComputeGraph computeGraph = CreateComputGraph(i);

                        //Bi-directional encoding input source sentences
                        IWeightMatrix encodedWeightMatrix = Encode(computeGraph, srcSnts.GetRange(i * m_batchSize, m_batchSize), m_biEncoder[i], m_srcEmbedding[i]);

                        //Generate output decoder sentences
                        List <List <string> > predictSentence;
                        float lcost = Decode(tgtSnts.GetRange(i * m_batchSize, m_batchSize), computeGraph, encodedWeightMatrix, m_decoder[i], m_decoderFFLayer[i],
                                             m_tgtEmbedding[i], out predictSentence);

                        lock (locker)
                        {
                            cost += lcost;
                        }
                        //Calculate gradients
                        computeGraph.Backward();
                    });

                    //Sum up gradients in all devices, and kept it in default device for parameters optmization
                    SyncGradient();


                    if (float.IsInfinity(cost) == false && float.IsNaN(cost) == false)
                    {
                        processedLine += TrainCorpus.BatchSize;
                        double costPerWord = (cost / tlen);
                        costInTotal              += cost;
                        avgCostPerWordInTotal     = costInTotal / tgtWordCnts;
                        lastAvgCostPerWordInTotal = avgCostPerWordInTotal;
                    }
                    else
                    {
                        Logger.WriteLine($"Invalid cost value.");
                    }

                    //Optmize parameters
                    float avgAllLR = UpdateParameters(learningRate, TrainCorpus.BatchSize);

                    //Clear gradient over all devices
                    ClearGradient();

                    if (IterationDone != null && processedLine % (100 * TrainCorpus.BatchSize) == 0)
                    {
                        IterationDone(this, new CostEventArg()
                        {
                            AvgLearningRate           = avgAllLR,
                            CostPerWord               = cost / tlen,
                            avgCostInTotal            = avgCostPerWordInTotal,
                            Epoch                     = ep,
                            ProcessedSentencesInTotal = processedLine,
                            ProcessedWordsInTotal     = srcWordCnts * 2 + tgtWordCnts,
                            StartDateTime             = startDateTime
                        });
                    }


                    //Save model for each 10000 steps
                    if (processedLine % (TrainCorpus.BatchSize * 1000) == 0)
                    {
                        Save();
                        TensorAllocator.FreeMemoryAllDevices();
                    }

                    sntPairs.Clear();
                }
            }

            Logger.WriteLine($"Epoch '{ep}' took '{DateTime.Now - startDateTime}' time to finish.");

            Save();
        }