Exemple #1
0
        public static void ModelLoad(string path, FunctionStack model)
        {
            var modelData = new NpzDictionary(path);

            foreach (var function in model.Functions)
            {
                SetParams(function, modelData);
            }
        }
Exemple #2
0
        public static void ModelLoad <T>(string path, FunctionStack <T> model) where T : unmanaged, IComparable <T>
        {
            var modelData = new NpzDictionary(path);

            foreach (var function in model.Functions)
            {
                SetParams(function, modelData);
            }
        }
Exemple #3
0
        static void SetParams(Function func, NpzDictionary modelData)
        {
            if (func is Linear)
            {
                Linear linear = (Linear)func;

                Array.Copy(Real.ToRealArray(modelData[func.Name + "/W.npy"]), linear.Weight.Data, linear.Weight.Data.Length);

                if (!linear.NoBias)
                {
                    Array.Copy(Real.ToRealArray(modelData[func.Name + "/b.npy"]), linear.Bias.Data, linear.Bias.Data.Length);
                }
            }
            else if (func is Convolution2D)
            {
                Convolution2D conv2D = (Convolution2D)func;

                Array.Copy(Real.ToRealArray(modelData[func.Name + "/W.npy"]), conv2D.Weight.Data, conv2D.Weight.Data.Length);

                if (!conv2D.NoBias)
                {
                    Array.Copy(Real.ToRealArray(modelData[func.Name + "/b.npy"]), conv2D.Bias.Data, conv2D.Bias.Data.Length);
                }
            }
            else if (func is Deconvolution2D)
            {
                Deconvolution2D deconv2D = (Deconvolution2D)func;

                Array.Copy(Real.ToRealArray(modelData[func.Name + "/W.npy"]), deconv2D.Weight.Data, deconv2D.Weight.Data.Length);

                if (!deconv2D.NoBias)
                {
                    Array.Copy(Real.ToRealArray(modelData[func.Name + "/b.npy"]), deconv2D.Bias.Data, deconv2D.Bias.Data.Length);
                }
            }
            else if (func is EmbedID)
            {
                EmbedID embed = (EmbedID)func;

                Array.Copy(Real.ToRealArray(modelData[func.Name + "/W.npy"]), embed.Weight.Data, embed.Weight.Data.Length);
            }
            else if (func is BatchNormalization)
            {
                BatchNormalization bn = (BatchNormalization)func;

                Array.Copy(Real.ToRealArray(modelData[func.Name + "/beta.npy"]), bn.Beta.Data, bn.Beta.Data.Length);
                Array.Copy(Real.ToRealArray(modelData[func.Name + "/gamma.npy"]), bn.Gamma.Data, bn.Gamma.Data.Length);

                if (bn.IsTrain)
                {
                    if (modelData.ContainsKey(func.Name + "/avg_mean.npy"))
                    {
                        Array.Copy(Real.ToRealArray(modelData[func.Name + "/avg_mean.npy"]), bn.AvgMean.Data, bn.AvgMean.Data.Length);
                    }
                    if (modelData.ContainsKey(func.Name + "/avg_var.npy"))
                    {
                        Array.Copy(Real.ToRealArray(modelData[func.Name + "/avg_var.npy"]), bn.AvgVar.Data, bn.AvgVar.Data.Length);
                    }
                }
            }
            else if (func is MultiplyScale)
            {
                MultiplyScale scale = (MultiplyScale)func;

                Array.Copy(Real.ToRealArray(modelData[func.Name + "/W.npy"]), scale.Weight.Data, scale.Weight.Data.Length);

                if (scale.BiasTerm)
                {
                    Array.Copy(Real.ToRealArray(modelData[func.Name + "/bias/b.npy"]), scale.Bias.Data, scale.Bias.Data.Length);
                }
            }
        }
Exemple #4
0
        static void SetParams <T>(Function <T> func, NpzDictionary modelData) where T : unmanaged, IComparable <T>
        {
            if (func is Linear <T> )
            {
                Linear <T> linear = (Linear <T>)func;

                linear.Weight.Data = modelData[func.Name + "/W.npy"].FlattenEx <T>();

                if (linear.Bias != null)
                {
                    linear.Bias.Data = modelData[func.Name + "/b.npy"].FlattenEx <T>();
                }
            }
            else if (func is Convolution2D <T> )
            {
                Convolution2D <T> conv2D = (Convolution2D <T>)func;

                conv2D.Weight.Data = modelData[func.Name + "/W.npy"].FlattenEx <T>();

                if (conv2D.Bias != null)
                {
                    conv2D.Bias.Data = modelData[func.Name + "/b.npy"].FlattenEx <T>();
                }
            }
            else if (func is Deconvolution2D <T> )
            {
                Deconvolution2D <T> deconv2D = (Deconvolution2D <T>)func;

                deconv2D.Weight.Data = modelData[func.Name + "/W.npy"].FlattenEx <T>();

                if (deconv2D.Bias != null)
                {
                    deconv2D.Bias.Data = modelData[func.Name + "/b.npy"].FlattenEx <T>();
                }
            }
            else if (func is EmbedID <T> )
            {
                EmbedID <T> embed = (EmbedID <T>)func;
                embed.Weight.Data = modelData[func.Name + "/W.npy"].FlattenEx <T>();
            }
            else if (func is BatchNormalization <T> )
            {
                BatchNormalization <T> bn = (BatchNormalization <T>)func;

                bn.Beta.Data  = modelData[func.Name + "/beta.npy"].FlattenEx <T>();
                bn.Gamma.Data = modelData[func.Name + "/gamma.npy"].FlattenEx <T>();

                if (bn.Train)
                {
                    if (modelData.ContainsKey(func.Name + "/avg_mean.npy"))
                    {
                        bn.AvgMean.Data = modelData[func.Name + "/avg_mean.npy"].FlattenEx <T>();
                    }
                    if (modelData.ContainsKey(func.Name + "/avg_var.npy"))
                    {
                        bn.AvgVar.Data = modelData[func.Name + "/avg_var.npy"].FlattenEx <T>();
                    }
                }
            }
            else if (func is MultiplyScale <T> )
            {
                MultiplyScale <T> scale = (MultiplyScale <T>)func;

                scale.Weight.Data = modelData[func.Name + "/W.npy"].FlattenEx <T>();

                if (scale.BiasTerm)
                {
                    scale.Bias.Data = modelData[func.Name + "/bias/b.npy"].FlattenEx <T>();
                }
            }
            else if (func is LSTM <T> )
            {
                LSTM <T> lstm = (LSTM <T>)func;

                lstm.lateral.Weight.Data = modelData[func.Name + "/lateral/W.npy"].FlattenEx <T>();
                lstm.upward.Weight.Data  = modelData[func.Name + "/upward/W.npy"].FlattenEx <T>();
                lstm.upward.Bias.Data    = modelData[func.Name + "/upward/b.npy"].FlattenEx <T>();
            }
        }
 public static NpzDictionary <T> Load <T>(Stream stream, out NpzDictionary <T> value) where T : class, ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable
 {
     return(value = Load <T>(stream));
 }
 public static NpzDictionary <T> Load <T>(string path, out NpzDictionary <T> value) where T : class, ICloneable, IList, ICollection, IEnumerable, IStructuralComparable, IStructuralEquatable
 {
     return(value = Load <T>(new FileStream(path, FileMode.Open)));
 }