Exemple #1
0
    public static Tensor DenseLayer(Tensor input, int outputDim, NeuralNetworkActivation activation, NeuralNetworkInitializer initializer, int seed, string name)
    {
        //if (input.shape.Rank != 1)
        //{
        //    int newDim = input.Shape.Dimensions.Aggregate((d1, d2) => d1 * d2);
        //    input = CNTKLib.Reshape(input, new int[] { newDim });
        //}

        return(tf_with(tf.variable_scope(name), delegate
        {
            Tensor fullyConnected = FullyConnectedLinearLayer(input, outputDim, initializer, seed);
            switch (activation)
            {
            case NeuralNetworkActivation.None: return fullyConnected;

            case NeuralNetworkActivation.ReLU: return tf.nn.relu(fullyConnected, "ReLU");

            case NeuralNetworkActivation.Sigmoid: return tf.nn.sigmoid(fullyConnected, "Sigmoid");

            case NeuralNetworkActivation.Tanh: return tf.nn.tanh(fullyConnected, "Tanh");

            default: throw new InvalidOperationException("Unexpected activation " + activation);
            }
        }));
    }
Exemple #2
0
        public static Function DenseLayer(Variable input, int outputDim, DeviceDescriptor device, NeuralNetworkActivation activation, NeuralNetworkInitializer initializer, int seed, string name)
        {
            if (input.Shape.Rank != 1)
            {
                int newDim = input.Shape.Dimensions.Aggregate((d1, d2) => d1 * d2);
                input = CNTKLib.Reshape(input, new int[] { newDim });
            }

            Function fullyConnected = FullyConnectedLinearLayer(input, outputDim, initializer, seed, device);

            fullyConnected.SetName(name);
            switch (activation)
            {
            case NeuralNetworkActivation.None: return(fullyConnected);

            case NeuralNetworkActivation.ReLU: return(CNTKLib.ReLU(fullyConnected, name + "ReLU"));

            case NeuralNetworkActivation.Sigmoid: return(CNTKLib.Sigmoid(fullyConnected, name + "Sigmoid"));

            case NeuralNetworkActivation.Tanh: return(CNTKLib.Tanh(fullyConnected, name + "Tanh"));

            default: throw new InvalidOperationException("Unexpected activation " + activation);
            }
        }