コード例 #1
0
        public void Save()
        {
            ModelAttentionMetaData tosave = new ModelAttentionMetaData();

            tosave.Clipval           = m_clipvalue;
            tosave.EncoderLayerDepth = m_encoderLayerDepth;
            tosave.DecoderLayerDepth = m_decoderLayerDepth;
            tosave.HiddenDim         = m_hiddenDim;
            tosave.LearningRate      = m_startLearningRate;
            tosave.EmbeddingDim      = m_embeddingDim;
            tosave.MultiHeadNum      = m_multiHeadNum;
            tosave.EncoderType       = m_encoderType;
            tosave.Regc           = m_regc;
            tosave.DropoutRatio   = m_dropoutRatio;
            tosave.SrcWordToIndex = m_srcWordToIndex;
            tosave.SrcIndexToWord = m_srcIndexToWord;

            tosave.TgtWordToIndex = m_tgtWordToIndex;
            tosave.TgtIndexToWord = m_tgtIndexToWord;

            try
            {
                if (File.Exists(m_modelFilePath))
                {
                    File.Copy(m_modelFilePath, $"{m_modelFilePath}.bak", true);
                }

                BinaryFormatter bf = new BinaryFormatter();
                FileStream      fs = new FileStream(m_modelFilePath, FileMode.Create, FileAccess.Write);
                bf.Serialize(fs, tosave);

                m_encoder[m_encoderDefaultDeviceId].Save(fs);
                m_decoder[m_decoderDefaultDeviceId].Save(fs);

                m_srcEmbedding[m_srcEmbeddingDefaultDeviceId].Save(fs);
                m_tgtEmbedding[m_tgtEmbeddingDefaultDeviceId].Save(fs);

                m_decoderFFLayer[m_DecoderFFLayerDefaultDeviceId].Save(fs);

                fs.Close();
                fs.Dispose();
            }
            catch (Exception err)
            {
                Logger.WriteLine($"Failed to save model to file. Exception = '{err.Message}'");
            }
        }
コード例 #2
0
        public void Save()
        {
            ModelAttentionMetaData tosave = new ModelAttentionMetaData();

            tosave.clipval       = this.m_clipvalue;
            tosave.Depth         = this.Depth;
            tosave.hidden_sizes  = this.HiddenSize;
            tosave.learning_rate = m_startLearningRate;
            tosave.letter_size   = this.WordVectorSize;
            tosave.max_chars_gen = this.m_maxWord;
            tosave.regc          = this.m_regc;
            tosave.DropoutRatio  = m_dropoutRatio;
            tosave.s_wordToIndex = m_srcWordToIndex;
            tosave.s_indexToWord = m_srcIndexToWord;

            tosave.t_wordToIndex = m_tgtWordToIndex;
            tosave.t_indexToWord = m_tgtIndexToWord;

            try
            {
                if (File.Exists(m_modelFilePath))
                {
                    File.Copy(m_modelFilePath, $"{m_modelFilePath}.bak", true);
                }

                BinaryFormatter bf = new BinaryFormatter();
                FileStream      fs = new FileStream(m_modelFilePath, FileMode.Create, FileAccess.Write);
                bf.Serialize(fs, tosave);

                m_biEncoder[m_biEncoderDefaultDeviceId].Save(fs);
                m_decoder[m_decoderDefaultDeviceId].Save(fs);

                m_srcEmbedding[m_srcEmbeddingDefaultDeviceId].Save(fs);
                m_tgtEmbedding[m_tgtEmbeddingDefaultDeviceId].Save(fs);

                m_decoderFFLayer[m_DecoderFFLayerDeviceId].Save(fs);

                fs.Close();
                fs.Dispose();
            }
            catch (Exception err)
            {
                Logger.WriteLine($"Failed to save model to file. Exception = '{err.Message}'");
            }
        }
コード例 #3
0
        public AttentionSeq2Seq(string modelFilePath, int batchSize, ArchTypeEnums archType, int[] deviceIds)
        {
            m_batchSize     = batchSize;
            m_deviceIds     = deviceIds;
            m_modelFilePath = modelFilePath;

            TensorAllocator.InitDevices(archType, deviceIds);
            SetDefaultDeviceIds(deviceIds.Length);

            Logger.WriteLine($"Loading model from '{modelFilePath}'...");

            ModelAttentionMetaData modelMetaData = new ModelAttentionMetaData();
            BinaryFormatter        bf            = new BinaryFormatter();
            FileStream             fs            = new FileStream(m_modelFilePath, FileMode.Open, FileAccess.Read);

            modelMetaData = bf.Deserialize(fs) as ModelAttentionMetaData;

            m_clipvalue         = modelMetaData.Clipval;
            m_encoderLayerDepth = modelMetaData.EncoderLayerDepth;
            m_decoderLayerDepth = modelMetaData.DecoderLayerDepth;
            m_hiddenDim         = modelMetaData.HiddenDim;
            m_startLearningRate = modelMetaData.LearningRate;
            m_embeddingDim      = modelMetaData.EmbeddingDim;
            m_multiHeadNum      = modelMetaData.MultiHeadNum;
            m_encoderType       = modelMetaData.EncoderType;
            m_regc           = modelMetaData.Regc;
            m_dropoutRatio   = modelMetaData.DropoutRatio;
            m_srcWordToIndex = modelMetaData.SrcWordToIndex;
            m_srcIndexToWord = modelMetaData.SrcIndexToWord;
            m_tgtWordToIndex = modelMetaData.TgtWordToIndex;
            m_tgtIndexToWord = modelMetaData.TgtIndexToWord;

            CreateEncoderDecoderEmbeddings();

            m_encoder[m_encoderDefaultDeviceId].Load(fs);
            m_decoder[m_decoderDefaultDeviceId].Load(fs);

            m_srcEmbedding[m_srcEmbeddingDefaultDeviceId].Load(fs);
            m_tgtEmbedding[m_tgtEmbeddingDefaultDeviceId].Load(fs);

            m_decoderFFLayer[m_DecoderFFLayerDefaultDeviceId].Load(fs);

            fs.Close();
            fs.Dispose();
        }
コード例 #4
0
        public void Load(string modelFilePath)
        {
            Logger.WriteLine($"Loading model from '{modelFilePath}'...");
            m_modelFilePath = modelFilePath;

            ModelAttentionMetaData tosave = new ModelAttentionMetaData();
            BinaryFormatter        bf     = new BinaryFormatter();
            FileStream             fs     = new FileStream(m_modelFilePath, FileMode.Open, FileAccess.Read);

            tosave = bf.Deserialize(fs) as ModelAttentionMetaData;

            m_clipvalue         = tosave.clipval;
            Depth               = tosave.Depth;
            HiddenSize          = tosave.hidden_sizes;
            m_startLearningRate = tosave.learning_rate;
            WordVectorSize      = tosave.letter_size;
            m_maxWord           = 100;
            m_regc              = tosave.regc;
            m_dropoutRatio      = tosave.DropoutRatio;
            m_srcWordToIndex    = tosave.s_wordToIndex;
            m_srcIndexToWord    = tosave.s_indexToWord;
            m_tgtWordToIndex    = tosave.t_wordToIndex;
            m_tgtIndexToWord    = tosave.t_indexToWord;

            InitWeights();

            m_biEncoder[m_biEncoderDefaultDeviceId].Load(fs);
            m_decoder[m_decoderDefaultDeviceId].Load(fs);

            m_srcEmbedding[m_srcEmbeddingDefaultDeviceId].Load(fs);
            m_tgtEmbedding[m_tgtEmbeddingDefaultDeviceId].Load(fs);

            m_decoderFFLayer[m_DecoderFFLayerDeviceId].Load(fs);

            fs.Close();
            fs.Dispose();
        }