示例#1
0
        /// <summary>総和</summary>
        public static VariableNode Sum(VariableNode x, int[] axes = null, bool keepdims = false)
        {
            if (axes == null)
            {
                axes = (new int[x.Shape.Ndim]).Select((_, dim) => dim).ToArray();
            }

            foreach (int axis in axes.OrderBy((val) => val))
            {
                if (x.Shape[axis] == 1)
                {
                    continue;
                }

                Function function = new Functions.Aggregation.Sum(axis);

                x = Apply(function, x)[0];
            }

            if (!keepdims)
            {
                List <int> lengths = ((int[])x.Shape).ToList();

                foreach (int axis in axes.OrderByDescending((val) => val))
                {
                    lengths.RemoveAt(axis);
                }

                x = Reshape(x, new Shape(x.Shape.Type, lengths.ToArray()));
            }

            return(x);
        }
示例#2
0
        /// <summary>総和</summary>
        public static Tensor Sum(Tensor x, int[] axes = null, bool keepdims = false)
        {
            if (axes == null)
            {
                axes = (new int[x.Shape.Ndim]).Select((_, dim) => dim).ToArray();
            }

            foreach (int axis in axes.OrderBy((val) => val))
            {
                if (x.Shape[axis] == 1)
                {
                    continue;
                }

                Function function = new Functions.Aggregation.Sum(axis);

                Tensor y = new Tensor(function.OutputShapes(x.Shape)[0]);

                function.Execute(new Tensor[] { x }, new Tensor[] { y });

                x = y;
            }

            if (!keepdims)
            {
                List <int> lengths = ((int[])x.Shape).ToList();

                foreach (int axis in axes.OrderByDescending((val) => val))
                {
                    lengths.RemoveAt(axis);
                }

                x = Reshape(x, new Shape(x.Shape.Type, lengths.ToArray()));
            }

            return(x);
        }