예제 #1
0
        protected void InitMaskedFunc(StreamingContext sc)
        {
            if (IsParallel)
            {
                switch (this)
                {
                case MaskedLinear <float> linearF:
                    linearF.SingleInputForward   = x => LinearF.SingleInputForward(x, linearF.Mask, linearF.Weight, linearF.Bias, linearF.ForwardKernel, linearF);
                    linearF.SingleOutputBackward = (y, x) => LinearF.SingleOutputBackward(y, x, linearF.Mask, linearF.Weight, linearF.Bias, linearF.BackwardgWKernel, linearF.BackwardgXKernel, linearF.Activation);
                    break;

                case MaskedLinear <double> linearD:
                    linearD.SingleInputForward   = x => LinearD.SingleInputForward(x, linearD.Mask, linearD.Weight, linearD.Bias, linearD.ForwardKernel, linearD);
                    linearD.SingleOutputBackward = (y, x) => LinearD.SingleOutputBackward(y, x, linearD.Mask, linearD.Weight, linearD.Bias, linearD.BackwardgWKernel, linearD.BackwardgXKernel, linearD.Activation);
                    break;
                }
            }
            else
            {
                switch (this)
                {
                case MaskedLinear <float> linearF:
                    linearF.SingleInputForward   = x => CPU.LinearF.SingleInputForward(x, linearF.Mask, linearF.Weight, linearF.Bias, linearF.Activation, linearF);
                    linearF.SingleOutputBackward = (y, x) => CPU.LinearF.SingleOutputBackward(y, x, linearF.Mask, linearF.Weight, linearF.Bias, linearF.Activation);
                    break;

                case MaskedLinear <double> linearD:
                    linearD.SingleInputForward   = x => CPU.LinearD.SingleInputForward(x, linearD.Mask, linearD.Weight, linearD.Bias, linearD.Activation, linearD);
                    linearD.SingleOutputBackward = (y, x) => CPU.LinearD.SingleOutputBackward(y, x, linearD.Mask, linearD.Weight, linearD.Bias, linearD.Activation);
                    break;
                }
            }
        }
예제 #2
0
        protected virtual void InitFunc(StreamingContext sc)
        {
            switch (this)
            {
            case Linear <float> linearF:
                linearF.SingleInputForward   = x => LinearF.SingleInputForward(x, linearF.Weight, linearF.Bias, linearF.Activation, linearF);
                linearF.SingleOutputBackward = (y, x) => LinearF.SingleOutputBackward(y, x, linearF.Weight, linearF.Bias, linearF.Activation);
                break;

            case Linear <double> linearD:
                linearD.SingleInputForward   = x => LinearD.SingleInputForward(x, linearD.Weight, linearD.Bias, linearD.Activation, linearD);
                linearD.SingleOutputBackward = (y, x) => LinearD.SingleOutputBackward(y, x, linearD.Weight, linearD.Bias, linearD.Activation);
                break;
            }
        }