Esempio n. 1
0
 Func <Variable, Function> CreateModel(int numOutputDimension, int numLstmLayer, int numHiddenDimension)
 {
     return((input) =>
     {
         Function model = input;
         for (int i = 0; i < numLstmLayer; i++)
         {
             model = Stabilize.Build(model, device);
             model = LSTM.Build(model, numHiddenDimension, device);
         }
         model = Dense.Build(model, numOutputDimension, device);
         return model;
     });
 }
Esempio n. 2
0
        /// <summary>
        /// This function creates an LSTM block that implements one step of recurrence.
        /// It accepts the previous state and outputs its new state as a two-valued tuple (output, cell state)
        /// </summary>
        /// <typeparam name="TElementType">The data type of the values. May be set to float or double</typeparam>
        /// <param name="input">The input to the LSTM</param>
        /// <param name="prevOutput">The output of the previous step of the LSTM</param>
        /// <param name="prevCellState">The cell state of the previous step of the LSTM</param>
        /// <param name="enableSelfStabilization">If True, then all state-related projection will contain a Stabilizer()</param>
        /// <param name="device">Device used for the computation of this cell</param>
        /// <returns>A function (prev_h, prev_c, input) -> (h, c) that implements one step of a recurrent LSTM layer</returns>
        public static Tuple <Function, Function> LSTMCell <TElementType>(Variable input, Variable prevOutput,
                                                                         Variable prevCellState, bool enableSelfStabilization, DeviceDescriptor device)
        {
            int lstmOutputDimension = prevOutput.Shape[0];
            int lstmCellDimension   = prevCellState.Shape[0];

            bool     isFloatType = typeof(TElementType) == typeof(float);
            DataType dataType    = isFloatType ? DataType.Float : DataType.Double;

            //  This is done according to the example in CNTK. The problem with this is that because it is called multiple times,
            //  the number of parameter tensor increases. Fix may be needed to make it identical to the one on Python.
            Func <int, Parameter> createBiasParameters;

            if (isFloatType)
            {
                createBiasParameters = (dimension) => new Parameter(new [] { dimension }, 0.01f, device, "Bias");
            }
            else
            {
                createBiasParameters = (dimension) => new Parameter(new[] { dimension }, 0.01, device, "Bias");
            }

            uint seed = 1;
            Func <int, Parameter> createWeightParameters = (outputDimension) => new Parameter(new [] { outputDimension, NDShape.InferredDimension },
                                                                                              dataType, CNTKLib.GlorotUniformInitializer(1.0, 1, 0, seed++), device, "Weight");

            Function stabilizedPrevOutput;

            if (enableSelfStabilization)
            {
                stabilizedPrevOutput = Stabilize.Build <TElementType>(prevOutput, device, "StabilizedPrevOutput");
            }
            else
            {
                stabilizedPrevOutput = prevOutput;
            }

            Func <Variable> InputLinearCombinationPlusBias = () =>
                                                             CNTKLib.Plus(createBiasParameters(lstmCellDimension),
                                                                          (createWeightParameters(lstmCellDimension) * input), "LinearCombinationPlusBias");

            Func <Variable, Variable> PrevOutputLinearCombination = (previousOutput) =>
                                                                    CNTKLib.Times(createWeightParameters(lstmCellDimension), previousOutput);

            //  Forget Gate
            Function ft =
                CNTKLib.Sigmoid(
                    InputLinearCombinationPlusBias() + PrevOutputLinearCombination(stabilizedPrevOutput),
                    "ForgetGate");

            //  Input Gate
            Function it =
                CNTKLib.Sigmoid(
                    InputLinearCombinationPlusBias() + PrevOutputLinearCombination(stabilizedPrevOutput),
                    "InputGate");
            Function ctt =
                CNTKLib.Tanh(
                    InputLinearCombinationPlusBias() + PrevOutputLinearCombination(stabilizedPrevOutput),
                    "CandidateValue");

            //  New Cell State
            Function ct =
                CNTKLib.Plus(CNTKLib.ElementTimes(ft, prevCellState), CNTKLib.ElementTimes(it, ctt));

            //  Output Gate
            Function ot =
                CNTKLib.Sigmoid(
                    InputLinearCombinationPlusBias() + PrevOutputLinearCombination(stabilizedPrevOutput),
                    "OutputGate");
            Function ht =
                CNTKLib.ElementTimes(ot, CNTKLib.Tanh(ct), "Output");
            Function stabilizedHt;

            if (enableSelfStabilization)
            {
                stabilizedHt = Stabilize.Build <TElementType>(ht, device, "OutputStabilized");
            }
            else
            {
                stabilizedHt = ht;
            }

            //  Prepare output
            Function c = ct;
            Function h = (lstmOutputDimension != lstmCellDimension)
                ? CNTKLib.Times(createWeightParameters(lstmOutputDimension), stabilizedHt, "Output")
                : stabilizedHt;

            return(new Tuple <Function, Function>(h, c));
        }