Exemplo n.º 1
0
        public override Symbol Build(Symbol data)
        {
            var weightName = UUID.GetID(ID + "_w");
            var biasName   = UUID.GetID(ID + "_b");

            var bias = UseBias ? Symbol.Variable(biasName) : null;

            InitParams.Add(weightName, KernalInitializer);
            if (UseBias)
            {
                InitParams.Add(biasName, BiasInitializer);
            }

            ConstraintParams.Add(weightName, KernalConstraint);
            if (UseBias)
            {
                ConstraintParams.Add(biasName, BiasConstraint);
            }

            RegularizerParams.Add(weightName, KernalRegularizer);
            if (UseBias)
            {
                RegularizerParams.Add(biasName, BiasRegularizer);
            }

            var l = sym.FullyConnected(data, Symbol.Variable(weightName), Dim, bias, !UseBias, true, ID);

            if (Activation != ActivationType.Linear)
            {
                var act = ActivationRegistry.Get(Activation);
                l = act.Build(l);
            }

            return(l);
        }
Exemplo n.º 2
0
        public override Symbol Build(Symbol x)
        {
            var weightName = UUID.GetID(ID + "_w");

            InitParams.Add(weightName, EmbeddingsInitializer);
            ConstraintParams.Add(weightName, EmbeddingsConstraint);
            RegularizerParams.Add(weightName, EmbeddingsRegularizer);
            return(sym.Embedding(x, Symbol.Variable(weightName), InputDim, OutputDim, symbol_name: ID));
        }
Exemplo n.º 3
0
        public override Symbol Build(Symbol x)
        {
            var   biasName   = UUID.GetID(ID + "_b");
            var   weightName = UUID.GetID(ID + "_w");
            var   bias       = UseBias ? Symbol.Variable(biasName) : null;
            Shape pad        = null;

            if (Padding.HasValue)
            {
                pad = new Shape(Padding.Value, Padding.Value, Padding.Value);
            }
            else
            {
                pad = new Shape();
            }

            if (UseBias)
            {
                InitParams.Add(biasName, BiasInitializer);
            }
            InitParams.Add(weightName, KernalInitializer);

            ConstraintParams.Add(weightName, KernalConstraint);
            if (UseBias)
            {
                ConstraintParams.Add(biasName, BiasConstraint);
            }

            RegularizerParams.Add(weightName, KernalRegularizer);
            if (UseBias)
            {
                RegularizerParams.Add(biasName, BiasRegularizer);
            }

            var conv = sym.Deconvolution(x, Symbol.Variable(weightName), new Shape(KernalSize.Item1, KernalSize.Item2, KernalSize.Item3),
                                         Filters, new Shape(Strides.Item1, Strides.Item2, Strides.Item3), new Shape(DialationRate.Item1, DialationRate.Item2, DialationRate.Item3), pad,
                                         new Shape(), new Shape(), bias, !UseBias, 1, 512, null, false, null, ID);

            if (Activation != ActivationType.Linear)
            {
                var act = ActivationRegistry.Get(Activation);
                conv = act.Build(conv);
            }

            return(conv);
        }
Exemplo n.º 4
0
        public override Symbol Build(Symbol x)
        {
            var beta       = UUID.GetID(ID + "_beta");
            var gamma      = UUID.GetID(ID + "_beta");
            var movingMean = UUID.GetID(ID + "_mean");
            var movingVar  = UUID.GetID(ID + "_var");

            InitParams.Add(beta, BetaInitializer);
            InitParams.Add(gamma, GammaInitializer);
            InitParams.Add(movingMean, MovingMeanInitializer);
            InitParams.Add(movingVar, MovingVarianceInitializer);

            ConstraintParams.Add(beta, BetaConstraint);
            ConstraintParams.Add(gamma, GammaConstraint);

            RegularizerParams.Add(beta, BetaRegularizer);
            RegularizerParams.Add(gamma, GammaRegularizer);

            return(sym.BatchNorm(x, Symbol.Variable(gamma), Symbol.Variable(beta), Symbol.Variable(movingMean), Symbol.Variable(movingVar),
                                 Epsilon, Momentum, Center, Scale, false, Axis, false, ID));
        }