Exemplo n.º 1
0
        public override Signal Backward(ComputeShader compute, Signal dout)
        {
            dbeta = Refresh(1, dout.Columns, dbeta);
            MatOperations.SumMV(compute, dout, dbeta);

            dgamma = Refresh(1, dout.Columns, dgamma);
            MatOperations.MulMM(compute, dout, xn);
            MatOperations.SumMV(compute, xn, dgamma);

            var dxn = new Signal(dout);

            MatOperations.MulMVM(compute, dout, gamma, dxn);

            var dxc = new Signal(dxn);

            MatOperations.DivMVM(compute, dxn, std, dxc);

            var dxn_x_xc = new Signal(dxn);

            MatOperations.MulMMM(compute, dxn, xc, dxn_x_xc);
            dxn.Dispose();

            var std_x_std = new Signal(std);

            MatOperations.MulMMM(compute, std, std, std_x_std);

            var dxn_x_xc_div_std_x_std = new Signal(dxn_x_xc);

            MatOperations.DivMVM(compute, dxn_x_xc, std_x_std, dxn_x_xc_div_std_x_std);
            dxn_x_xc.Dispose();
            std_x_std.Dispose();

            var dstd = new Signal(std);

            MatOperations.SumMV(compute, dxn_x_xc_div_std_x_std, dstd);
            dxn_x_xc_div_std_x_std.Dispose();

            var dvar = new Signal(dstd);

            DVar(compute, dstd, std, dvar);
            dstd.Dispose();

            DXc(compute, xc, dvar, dxc, 2f / batchSize);
            dvar.Dispose();

            var dmu = new Signal(1, dxc.Columns);

            MatOperations.SumMV(compute, dxc, dmu);

            var dx = new Signal(dout);

            DX(compute, dxc, dmu, dx, 1f / batchSize);
            dxc.Dispose();
            dmu.Dispose();

            return(dx);
        }
Exemplo n.º 2
0
        public override Signal Forward(ComputeShader compute, Signal x, bool train)
        {
            if (train)
            {
                var mu = new Signal(1, x.Columns);
                MatOperations.MeanMV(compute, x, mu);
                // x.Log("x");
                // mu.Log("mu");

                xc = Refresh(x, xc); // xc = x - mu
                MatOperations.SubMVM(compute, x, mu, xc);
                // xc.Log("xc");

                var variance = new Signal(1, xc.Columns);
                MatOperations.VarianceMV(compute, xc, variance);
                // variance.Log("variance");

                std = Refresh(variance, std);
                MatOperations.SqrtMM(compute, variance, std);
                // std.Log("std");

                xn = Refresh(xc, xn); // xn = xc / std
                MatOperations.DivMVM(compute, xc, std, xn);

                batchSize = x.Rows;

                Momentum(compute, mu, runningMean);
                Momentum(compute, variance, runningVar);

                mu.Dispose();
                variance.Dispose();
            }
            else
            {
                xc = Refresh(x, xc); // xc = x - runningMean
                MatOperations.SubMVM(compute, x, runningMean, xc);
                // x.Log("x");
                // runningMean.Log("runningMean");
                // xn.Log("xn");

                xn = Refresh(xc, xn); // xn = xc / sqrt(runningVar + epsilon)
                Xn(compute, xc, runningVar, xn);
            }

            var output = new Signal(xn);

            var kernel = compute.FindKernel("BNForward");

            compute.SetBuffer(kernel, "_X", xn.Buffer);
            compute.SetBuffer(kernel, "_Gamma", gamma.Buffer);
            compute.SetBuffer(kernel, "_Beta", beta.Buffer);
            compute.SetBuffer(kernel, "_Y", output.Buffer);
            Dispatch(compute, kernel, output.Rows, output.Columns);

            return(output);
        }
Exemplo n.º 3
0
        public override Signal Forward(ComputeShader compute, Signal x, bool train)
        {
            mask = Refresh(x, mask);

            var kernel = compute.FindKernel("ReLU");

            compute.SetBuffer(kernel, "_X", x.Buffer);
            compute.SetBuffer(kernel, "_Y", mask.Buffer);
            Dispatch(compute, kernel, mask.Rows, mask.Columns);

            var output = new Signal(mask);

            MatOperations.CopyMM(compute, mask, output);
            return(output);
        }
Exemplo n.º 4
0
        public override Signal Forward(ComputeShader compute, Signal x, bool train)
        {
            this.x = Refresh(x, this.x);
            MatOperations.CopyMM(compute, x, this.x);

            var output = new Signal(x.Rows, weights.Columns);

            // matmul M = input * weights
            MatOperations.Multiply(compute, x, weights, output);

            // matplus M´ = M + biases
            MatOperations.AddVM(compute, biases, output);

            // output.Log();

            return(output);
        }
Exemplo n.º 5
0
        public override Signal Backward(ComputeShader compute, Signal dout)
        {
            var dx = new Signal(dout.Rows, weights.Rows);

            MatOperations.MultiplyMT(compute, dout, weights, dx);
            // dx.Log();

            dW = Refresh(x.Columns, dout.Columns, dW);
            MatOperations.MultiplyTM(compute, x, dout, dW);

            dB = Refresh(1, dout.Columns, dB);
            MatOperations.SumMV(compute, dout, dB);

            // dx.Log();
            // dw.Log();
            // db.Log();

            return(dx);
        }