public static void ModelLoad(string path, FunctionStack model) { var modelData = new NpzDictionary(path); foreach (var function in model.Functions) { SetParams(function, modelData); } }
public static void ModelLoad <T>(string path, FunctionStack <T> model) where T : unmanaged, IComparable <T> { var modelData = new NpzDictionary(path); foreach (var function in model.Functions) { SetParams(function, modelData); } }
static void SetParams(Function func, NpzDictionary modelData) { if (func is Linear) { Linear linear = (Linear)func; Array.Copy(Real.ToRealArray(modelData[func.Name + "/W.npy"]), linear.Weight.Data, linear.Weight.Data.Length); if (!linear.NoBias) { Array.Copy(Real.ToRealArray(modelData[func.Name + "/b.npy"]), linear.Bias.Data, linear.Bias.Data.Length); } } else if (func is Convolution2D) { Convolution2D conv2D = (Convolution2D)func; Array.Copy(Real.ToRealArray(modelData[func.Name + "/W.npy"]), conv2D.Weight.Data, conv2D.Weight.Data.Length); if (!conv2D.NoBias) { Array.Copy(Real.ToRealArray(modelData[func.Name + "/b.npy"]), conv2D.Bias.Data, conv2D.Bias.Data.Length); } } else if (func is Deconvolution2D) { Deconvolution2D deconv2D = (Deconvolution2D)func; Array.Copy(Real.ToRealArray(modelData[func.Name + "/W.npy"]), deconv2D.Weight.Data, deconv2D.Weight.Data.Length); if (!deconv2D.NoBias) { Array.Copy(Real.ToRealArray(modelData[func.Name + "/b.npy"]), deconv2D.Bias.Data, deconv2D.Bias.Data.Length); } } else if (func is EmbedID) { EmbedID embed = (EmbedID)func; Array.Copy(Real.ToRealArray(modelData[func.Name + "/W.npy"]), embed.Weight.Data, embed.Weight.Data.Length); } else if (func is BatchNormalization) { BatchNormalization bn = (BatchNormalization)func; Array.Copy(Real.ToRealArray(modelData[func.Name + "/beta.npy"]), bn.Beta.Data, bn.Beta.Data.Length); Array.Copy(Real.ToRealArray(modelData[func.Name + "/gamma.npy"]), bn.Gamma.Data, bn.Gamma.Data.Length); if (bn.IsTrain) { if (modelData.ContainsKey(func.Name + "/avg_mean.npy")) { Array.Copy(Real.ToRealArray(modelData[func.Name + "/avg_mean.npy"]), bn.AvgMean.Data, bn.AvgMean.Data.Length); } if (modelData.ContainsKey(func.Name + "/avg_var.npy")) { Array.Copy(Real.ToRealArray(modelData[func.Name + "/avg_var.npy"]), bn.AvgVar.Data, bn.AvgVar.Data.Length); } } } else if (func is MultiplyScale) { MultiplyScale scale = (MultiplyScale)func; Array.Copy(Real.ToRealArray(modelData[func.Name + "/W.npy"]), scale.Weight.Data, scale.Weight.Data.Length); if (scale.BiasTerm) { Array.Copy(Real.ToRealArray(modelData[func.Name + "/bias/b.npy"]), scale.Bias.Data, scale.Bias.Data.Length); } } }
static void SetParams <T>(Function <T> func, NpzDictionary modelData) where T : unmanaged, IComparable <T> { if (func is Linear <T> ) { Linear <T> linear = (Linear <T>)func; linear.Weight.Data = modelData[func.Name + "/W.npy"].FlattenEx <T>(); if (linear.Bias != null) { linear.Bias.Data = modelData[func.Name + "/b.npy"].FlattenEx <T>(); } } else if (func is Convolution2D <T> ) { Convolution2D <T> conv2D = (Convolution2D <T>)func; conv2D.Weight.Data = modelData[func.Name + "/W.npy"].FlattenEx <T>(); if (conv2D.Bias != null) { conv2D.Bias.Data = modelData[func.Name + "/b.npy"].FlattenEx <T>(); } } else if (func is Deconvolution2D <T> ) { Deconvolution2D <T> deconv2D = (Deconvolution2D <T>)func; deconv2D.Weight.Data = modelData[func.Name + "/W.npy"].FlattenEx <T>(); if (deconv2D.Bias != null) { deconv2D.Bias.Data = modelData[func.Name + "/b.npy"].FlattenEx <T>(); } } else if (func is EmbedID <T> ) { EmbedID <T> embed = (EmbedID <T>)func; embed.Weight.Data = modelData[func.Name + "/W.npy"].FlattenEx <T>(); } else if (func is BatchNormalization <T> ) { BatchNormalization <T> bn = (BatchNormalization <T>)func; bn.Beta.Data = modelData[func.Name + "/beta.npy"].FlattenEx <T>(); bn.Gamma.Data = modelData[func.Name + "/gamma.npy"].FlattenEx <T>(); if (bn.Train) { if (modelData.ContainsKey(func.Name + "/avg_mean.npy")) { bn.AvgMean.Data = modelData[func.Name + "/avg_mean.npy"].FlattenEx <T>(); } if (modelData.ContainsKey(func.Name + "/avg_var.npy")) { bn.AvgVar.Data = modelData[func.Name + "/avg_var.npy"].FlattenEx <T>(); } } } else if (func is MultiplyScale <T> ) { MultiplyScale <T> scale = (MultiplyScale <T>)func; scale.Weight.Data = modelData[func.Name + "/W.npy"].FlattenEx <T>(); if (scale.BiasTerm) { scale.Bias.Data = modelData[func.Name + "/bias/b.npy"].FlattenEx <T>(); } } else if (func is LSTM <T> ) { LSTM <T> lstm = (LSTM <T>)func; lstm.lateral.Weight.Data = modelData[func.Name + "/lateral/W.npy"].FlattenEx <T>(); lstm.upward.Weight.Data = modelData[func.Name + "/upward/W.npy"].FlattenEx <T>(); lstm.upward.Bias.Data = modelData[func.Name + "/upward/b.npy"].FlattenEx <T>(); } }
public static NpzDictionary <T> Load <T>(Stream stream, out NpzDictionary <T> value) where T : class, ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable { return(value = Load <T>(stream)); }
public static NpzDictionary <T> Load <T>(string path, out NpzDictionary <T> value) where T : class, ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable { return(value = Load <T>(new FileStream(path, FileMode.Open))); }