示例#1
0
        public static void SubtractFloat32(Tensor res, Tensor a, Tensor b)
        {
            if (a.Shape.TotalSize > b.Shape.TotalSize)
            {
                long go = res.Shape.TotalSize / b.Shape.TotalSize * b.Shape.TotalSize;
                for (long i = 0; i < go; i += b.Shape.TotalSize)
                {
                    VectorizationFloat.ElementWiseSubtractAVX((float *)a.Base.Array + i, (float *)b.Base.Array, (float *)res.Base.Array + i, b.Shape.TotalSize);
                }

                if (go < res.Shape.TotalSize)
                {
                    VectorizationFloat.ElementWiseSubtractAVX((float *)a.Base.Array + go, (float *)b.Base.Array, (float *)res.Base.Array + go, res.Shape.TotalSize - go);
                }
            }
            else
            {
                long go = res.Shape.TotalSize / a.Shape.TotalSize * a.Shape.TotalSize;
                for (long i = 0; i < go; i += a.Shape.TotalSize)
                {
                    VectorizationFloat.ElementWiseSubtractAVX((float *)a.Base.Array, (float *)b.Base.Array + i, (float *)res.Base.Array + i, a.Shape.TotalSize);
                }

                if (go < res.Shape.TotalSize)
                {
                    VectorizationFloat.ElementWiseSubtractAVX((float *)a.Base.Array, (float *)b.Base.Array + go, (float *)res.Base.Array + go, res.Shape.TotalSize - go);
                }
            }
        }
示例#2
0
        public static void SubtractFloat32_GetGradientB(Tensor gradienta, Tensor s, Tensor a)
        {
            long go = s.Shape.TotalSize / gradienta.Shape.TotalSize * gradienta.Shape.TotalSize;

            for (long i = 0; i < go; i += gradienta.Shape.TotalSize)
            {
                if (i == 0)
                {
                    VectorizationFloat.MakeNegativeAVX((float *)s.Base.Array + i, (float *)gradienta.Base.Array, gradienta.Shape.TotalSize);
                }
                else
                {
                    VectorizationFloat.ElementWiseSubtractAVX((float *)gradienta.Base.Array, (float *)s.Base.Array + i, (float *)gradienta.Base.Array, gradienta.Shape.TotalSize);
                }
            }

            if (go < s.Shape.TotalSize)
            {
                VectorizationFloat.ElementWiseSubtractAVX((float *)gradienta.Base.Array, (float *)s.Base.Array + go, (float *)gradienta.Base.Array, s.Shape.TotalSize - go);
            }
        }