Ejemplo n.º 1
0
        protected override void MultiOutputBackward(NdArray[] ys, NdArray x)
        {
            NdArray resultNdArray = ys[0].Clone();

            for (int i = 1; i < ys.Length; i++)
            {
                resultNdArray = NdArray.Concatenate(resultNdArray, ys[i], Axis);
            }

            for (int i = 0; i < x.Grad.Length; i++)
            {
                x.Grad[i] += resultNdArray.Grad[i];
            }
        }
Ejemplo n.º 2
0
        public static void MultiOutputBackward(NdArray <Real>[] ys, NdArray <Real> x, int axis)
        {
            NdArray <Real> resultNdArray = ys[0].Clone();

            for (int i = 1; i < ys.Length; i++)
            {
                resultNdArray = NdArray.Concatenate(resultNdArray, ys[i], axis);
            }

            for (int i = 0; i < x.Grad.Length; i++)
            {
                x.Grad[i] += resultNdArray.Grad[i];
            }
        }
Ejemplo n.º 3
0
        private void BackwardCpu(NdArray[] ys, NdArray x)
        {
            NdArray resultNdArray = ys[0].Clone();

            for (int i = 1; i < ys.Length; i++)
            {
                resultNdArray = NdArray.Concatenate(resultNdArray, ys[i], Axis);
            }

            for (int i = 0; i < x.Grad.Length; i++)
            {
                x.Grad[i] += resultNdArray.Grad[i];
            }
        }
Ejemplo n.º 4
0
        private NdArray ForwardCpu(params NdArray[] xs)
        {
            int[] sections   = new int[xs.Length - 1];
            int   sizeOffset = xs[0].Shape[Axis];

            NdArray resultNdArray = xs[0].Clone();

            for (int i = 1; i < xs.Length; i++)
            {
                //BackwardのSplitで使用しないため最後のshapeを保存しないロジックになっている
                sections[i - 1] = sizeOffset;
                sizeOffset     += xs[i].Shape[Axis];

                resultNdArray = NdArray.Concatenate(resultNdArray, xs[i], Axis);
            }

            resultNdArray.ParentFunc = this;

            _prevInputSections.Add(sections);

            return(resultNdArray);
        }