コード例 #1
0
ファイル: TestEmbedID.cs プロジェクト: ziichuan/KelpNet
        public void EmbedIDRandomTest()
        {
            Python.Initialize();
            Chainer.Initialize();

            int inputCount  = Mother.Dice.Next(2, 30);
            int outputCount = Mother.Dice.Next(1, 30);
            int batchCount  = Mother.Dice.Next(1, 5);

            int[,] input = (int[, ])Enumerable.Repeat(0, batchCount * inputCount).ToNdArray(batchCount, inputCount);
            input[0, 0]  = 1;

            Real[,,] dummyGy = Initializer.GetRandomValues <Real[, , ]>(batchCount, inputCount, outputCount);
            Real[,] w        = Initializer.GetRandomValues <Real[, ]>(inputCount, outputCount);

            //Chainer
            NChainer.EmbedID <Real> cEmbedId = new NChainer.EmbedID <Real>(inputCount, outputCount, w);

            Variable <int> cX = new Variable <int>(input);

            Variable <Real> cY = cEmbedId.Forward(cX);

            cY.Grad = dummyGy;

            cY.Backward();


            //KelpNet
            EmbedID <Real> embedId = new EmbedID <Real>(inputCount, outputCount, w);

            NdArray <Real> x = new NdArray <Real>(input, asBatch: true);

            NdArray <Real> y = embedId.Forward(x)[0];

            y.Grad = dummyGy.Flatten();

            y.Backward();


            Real[] cYdata = ((Real[, , ])cY.Data).Flatten();

            Real[] cWgrad = ((Real[, ])cEmbedId.W.Grad).Flatten();

            //許容範囲を算出
            Real delta = 0.00001f;

            //y
            Assert.AreEqual(cYdata.Length, y.Data.Length);
            for (int i = 0; i < y.Data.Length; i++)
            {
                Assert.AreEqual(cYdata[i], y.Data[i], delta);
            }

            //W.grad
            Assert.AreEqual(cWgrad.Length, embedId.Weight.Grad.Length);
            for (int i = 0; i < embedId.Weight.Grad.Length; i++)
            {
                Assert.AreEqual(cWgrad[i], embedId.Weight.Grad[i], delta);
            }
        }
コード例 #2
0
        ////////////////////////////////////////////////////////////////////////////////////////////////////
        /// <summary>   Sets the parameters. </summary>
        ///
        /// <param name="func">         The function. </param>
        /// <param name="modelData">    Information describing the model. </param>
        ////////////////////////////////////////////////////////////////////////////////////////////////////

        static void SetParams(Function func, NpzDictionary modelData)
        {
            if (func is Linear)
            {
                Linear linear = (Linear)func;

                Array.Copy(Real.GetArray(modelData[func.Name + "/W.npy"]), linear.Weight.Data, linear.Weight.Data.Length);

                if (!linear.NoBias)
                {
                    Array.Copy(Real.GetArray(modelData[func.Name + "/b.npy"]), linear.Bias.Data, linear.Bias.Data.Length);
                }
            }
            else if (func is Convolution2D)
            {
                Convolution2D conv2D = (Convolution2D)func;

                Array.Copy(Real.GetArray(modelData[func.Name + "/W.npy"]), conv2D.Weight.Data, conv2D.Weight.Data.Length);

                if (!conv2D.NoBias)
                {
                    Array.Copy(Real.GetArray(modelData[func.Name + "/b.npy"]), conv2D.Bias.Data, conv2D.Bias.Data.Length);
                }
            }
            else if (func is Deconvolution2D)
            {
                Deconvolution2D deconv2D = (Deconvolution2D)func;

                Array.Copy(Real.GetArray(modelData[func.Name + "/W.npy"]), deconv2D.Weight.Data, deconv2D.Weight.Data.Length);

                if (!deconv2D.NoBias)
                {
                    Array.Copy(Real.GetArray(modelData[func.Name + "/b.npy"]), deconv2D.Bias.Data, deconv2D.Bias.Data.Length);
                }
            }
            else if (func is EmbedID)
            {
                EmbedID embed = (EmbedID)func;

                Array.Copy(Real.GetArray(modelData[func.Name + "/W.npy"]), embed.Weight.Data, embed.Weight.Data.Length);
            }
            else if (func is BatchNormalization)
            {
                BatchNormalization bn = (BatchNormalization)func;

                Array.Copy(Real.GetArray(modelData[func.Name + "/beta.npy"]), bn.Beta.Data, bn.Beta.Data.Length);
                Array.Copy(Real.GetArray(modelData[func.Name + "/gamma.npy"]), bn.Gamma.Data, bn.Gamma.Data.Length);

                if (bn.IsTrain)
                {
                    if (modelData.ContainsKey(func.Name + "/avg_mean.npy"))
                    {
                        Array.Copy(Real.GetArray(modelData[func.Name + "/avg_mean.npy"]), bn.AvgMean.Data, bn.AvgMean.Data.Length);
                    }
                    if (modelData.ContainsKey(func.Name + "/avg_var.npy"))
                    {
                        Array.Copy(Real.GetArray(modelData[func.Name + "/avg_var.npy"]), bn.AvgVar.Data, bn.AvgVar.Data.Length);
                    }
                }
            }
            else if (func is MultiplyScale)
            {
                MultiplyScale scale = (MultiplyScale)func;

                Array.Copy(Real.GetArray(modelData[func.Name + "/W.npy"]), scale.Weight.Data, scale.Weight.Data.Length);

                if (scale.BiasTerm)
                {
                    Array.Copy(Real.GetArray(modelData[func.Name + "/bias/b.npy"]), scale.Bias.Data, scale.Bias.Data.Length);
                }
            }
        }
コード例 #3
0
        static void SetParams <T>(Function <T> func, NpzDictionary modelData) where T : unmanaged, IComparable <T>
        {
            if (func is Linear <T> )
            {
                Linear <T> linear = (Linear <T>)func;

                linear.Weight.Data = modelData[func.Name + "/W.npy"].FlattenEx <T>();

                if (linear.Bias != null)
                {
                    linear.Bias.Data = modelData[func.Name + "/b.npy"].FlattenEx <T>();
                }
            }
            else if (func is Convolution2D <T> )
            {
                Convolution2D <T> conv2D = (Convolution2D <T>)func;

                conv2D.Weight.Data = modelData[func.Name + "/W.npy"].FlattenEx <T>();

                if (conv2D.Bias != null)
                {
                    conv2D.Bias.Data = modelData[func.Name + "/b.npy"].FlattenEx <T>();
                }
            }
            else if (func is Deconvolution2D <T> )
            {
                Deconvolution2D <T> deconv2D = (Deconvolution2D <T>)func;

                deconv2D.Weight.Data = modelData[func.Name + "/W.npy"].FlattenEx <T>();

                if (deconv2D.Bias != null)
                {
                    deconv2D.Bias.Data = modelData[func.Name + "/b.npy"].FlattenEx <T>();
                }
            }
            else if (func is EmbedID <T> )
            {
                EmbedID <T> embed = (EmbedID <T>)func;
                embed.Weight.Data = modelData[func.Name + "/W.npy"].FlattenEx <T>();
            }
            else if (func is BatchNormalization <T> )
            {
                BatchNormalization <T> bn = (BatchNormalization <T>)func;

                bn.Beta.Data  = modelData[func.Name + "/beta.npy"].FlattenEx <T>();
                bn.Gamma.Data = modelData[func.Name + "/gamma.npy"].FlattenEx <T>();

                if (bn.Train)
                {
                    if (modelData.ContainsKey(func.Name + "/avg_mean.npy"))
                    {
                        bn.AvgMean.Data = modelData[func.Name + "/avg_mean.npy"].FlattenEx <T>();
                    }
                    if (modelData.ContainsKey(func.Name + "/avg_var.npy"))
                    {
                        bn.AvgVar.Data = modelData[func.Name + "/avg_var.npy"].FlattenEx <T>();
                    }
                }
            }
            else if (func is MultiplyScale <T> )
            {
                MultiplyScale <T> scale = (MultiplyScale <T>)func;

                scale.Weight.Data = modelData[func.Name + "/W.npy"].FlattenEx <T>();

                if (scale.BiasTerm)
                {
                    scale.Bias.Data = modelData[func.Name + "/bias/b.npy"].FlattenEx <T>();
                }
            }
            else if (func is LSTM <T> )
            {
                LSTM <T> lstm = (LSTM <T>)func;

                lstm.lateral.Weight.Data = modelData[func.Name + "/lateral/W.npy"].FlattenEx <T>();
                lstm.upward.Weight.Data  = modelData[func.Name + "/upward/W.npy"].FlattenEx <T>();
                lstm.upward.Bias.Data    = modelData[func.Name + "/upward/b.npy"].FlattenEx <T>();
            }
        }