예제 #1
0
        protected override Function BuildNetwork(Variable input, DeviceDescriptor device, string name)
        {
            Variable toAdd = input;

            //create the size by adding 0s if the input size is smaller than the hidden size
            if (input.Shape[0] < HiddenSize)
            {
                toAdd = Layers.AddDummy(input, HiddenSize - input.Shape[0], device);
            }
            else if (input.Shape[0] > HiddenSize)
            {
                Debug.LogError("Can not have a hidden size that is smaller than the input size in resnetnode");
            }

            //first layer
            var c1 = UnityCNTK.Layers.Dense(input, HiddenSize, device, true, name + ".Res1", InitialWeightScale);

            if (normalizationMethod == NormalizationMethod.BatchNormalizatoin)
            {
                c1 = Layers.BatchNormalization(c1, InitialNormalizationBias, InitialNormalizationScale, BNTimeConst, BNSpatial, device, name + ".BN");
            }
            else if (normalizationMethod == NormalizationMethod.LayerNormalization)
            {
                c1 = Layers.LayerNormalization(c1, device, InitialNormalizationBias, InitialNormalizationScale, name + ".LN");
            }
            c1 = CNTKLib.ReLU(c1);
            if (DropoutRate > 0)
            {
                c1 = CNTKLib.Dropout(c1, DropoutRate);
            }

            //second layer
            var c2 = UnityCNTK.Layers.Dense(c1, hiddenSize, device, true, name + ".Res2", InitialWeightScale);

            if (normalizationMethod == NormalizationMethod.BatchNormalizatoin)
            {
                c2 = Layers.BatchNormalization(c2, InitialNormalizationBias, InitialNormalizationScale, BNTimeConst, BNSpatial, device, name + ".BN");
            }
            else if (normalizationMethod == NormalizationMethod.LayerNormalization)
            {
                c2 = Layers.LayerNormalization(c2, device, InitialNormalizationBias, InitialNormalizationScale, name + ".LN");
            }

            //add together
            var p = CNTKLib.Plus(c2, toAdd);

            p = CNTKLib.ReLU(p);
            if (DropoutRate > 0)
            {
                p = CNTKLib.Dropout(p, DropoutRate);
            }

            //add parameters to list
            ParameterNames.Add(ParamTypeToName(ResNodeParamType.Bias_LayerOne));
            ParameterNames.Add(ParamTypeToName(ResNodeParamType.Bias_LayerTwo));
            ParameterNames.Add(ParamTypeToName(ResNodeParamType.Weight_LayerOne));
            ParameterNames.Add(ParamTypeToName(ResNodeParamType.Weight_LayerTwo));

            return(p);
        }