/// <summary> /// 加载已有模型 /// </summary> public static Seq2Seq Load(string modelname = "Model.lys2s") { //var tosave = new ModelData(); BinaryFormatter bf = new BinaryFormatter();//用二进制保存和加载,必须是相同的程序集才可以加载 FileStream fs = new FileStream(modelname, FileMode.Open, FileAccess.Read); var mdl = bf.Deserialize(fs) as ModelData; //var mdl = bf.Deserialize(fs) as ModelAttentionData; fs.Close(); fs.Dispose(); var s2s = new Seq2Seq(); //s2s.bd =new WeightMatrix( mdl.bd); //s2s.Whd =new WeightMatrix( mdl.Whd); //s2s.Embedding = new WeightMatrix( mdl.Wil); //s2s.decoder =new LinyeeDecoder( mdl.decoder); //s2s.encoder =new Encoder( mdl.encoder); //s2s.ReversEncoder =new Encoder( mdl.ReversEncoder); s2s.bd = (mdl.bd); s2s.Whd = (mdl.Whd); s2s.Embedding = (mdl.Wil); s2s.decoder = (mdl.decoder); s2s.encoder = (mdl.encoder); s2s.ReversEncoder = (mdl.ReversEncoder); s2s.hidden_size = mdl.hidden_sizes; s2s.word_size = mdl.letter_size; s2s.Depth = mdl.Depth; s2s.clipval = mdl.clipval; s2s.learning_rate = mdl.learning_rate; s2s.max_word = 300; s2s.regc = mdl.regc; s2s.UseDropout = mdl.UseDropout; //s2s.wordToIndex = mdl.wordToIndex; s2s.indexToWord = mdl.indexToWord; Dictionary <string, WordInfoV1_1> w2i = new Dictionary <string, WordInfoV1_1>(); foreach (var kv in s2s.indexToWord) { w2i.Add(kv.Value.w, kv.Value); } s2s.wordToIndex = w2i; s2s.vocab = new List <string>(); s2s.vocab.AddRange(w2i.Keys); s2s.newType = "retrain"; //s2s.Preprocess(); //s2s.Save("Model.lys2s"); return(s2s); }
/// <summary> /// 保存模型 /// </summary> public static void Save(this Seq2Seq s2s, bool hadinback = true, string ModelName = null) { if (inSave && hadinback) { return; } lock (lockSave) { inSave = true; var mname = ModelName; if (string.IsNullOrEmpty(mname)) { mname = s2s.ModelName; } ModelData tosave = new ModelData(); //tosave.wordToIndex = s2s.wordToIndex; tosave.indexToWord = s2s.indexToWord; tosave.bd = s2s.bd; tosave.clipval = s2s.clipval; tosave.decoder = s2s.decoder; tosave.Depth = s2s.Depth; tosave.encoder = s2s.encoder; tosave.hidden_sizes = s2s.hidden_size; tosave.learning_rate = s2s.learning_rate; tosave.letter_size = s2s.word_size; tosave.max_chars_gen = s2s.max_word; tosave.regc = s2s.regc; tosave.ReversEncoder = s2s.ReversEncoder; tosave.UseDropout = s2s.UseDropout; tosave.Whd = s2s.Whd; tosave.Wil = s2s.Embedding; BinaryFormatter bf = new BinaryFormatter(); FileStream fs = new FileStream(mname, FileMode.OpenOrCreate, FileAccess.Write); bf.Serialize(fs, tosave); fs.Close(); fs.Dispose(); inSave = false; } }
internal static void Preprocess(this Seq2Seq s2s) { /// <summary> /// 输入语句 /// </summary> List <List <string> > input = new List <List <string> >(); /// <summary> /// 输出语句 /// </summary> List <List <string> > output = new List <List <string> >(); var HumanTextRaw = File.ReadAllLines("human_text.txt"); var RobotTextRaw = File.ReadAllLines("robot_text.txt"); for (int i = 0; i < HumanTextRaw.Length; i++) { HumanTextRaw[i] = RemoveAccentMark(HumanTextRaw[i]); RobotTextRaw[i] = RemoveAccentMark(RobotTextRaw[i]); } var HumanText = new List <string>(); var RobotText = new List <string>(); for (int i = 0; i < 1000; i++) { HumanText.Add(HumanTextRaw[i]); RobotText.Add(RobotTextRaw[i]); } for (int i = 0; i < HumanText.Count; i++) { input.Add(HumanText[i].ToLower().Trim().Split(' ').ToList()); output.Add(RobotText[i].ToLower().Trim().Split(' ').ToList()); } s2s.OneHotEncoding(input, output); }