コード例 #1
0
        /// <summary>
        /// Creates a new identity matrix.
        /// </summary>
        /// <param name="device">The device to create the NdArray on.</param>
        /// <param name="size">The size of the square identity matrix.</param>
        /// <returns>The new NdArray.</returns>
        public static NdArray <T> Identity(IDevice device, int size)
        {
            var newArray = NdArray <T> .Zeros(device, new[] { size, size });

            var diagView = NdArrayOperator <T> .Diag(newArray);

            diagView.FillConst(Primitives.One <T>());

            return(newArray);
        }
コード例 #2
0
        /// <summary>
        /// Calculates the trace along the specified axes.
        /// </summary>
        /// <param name="axis1">The first axis of the diagonal to compute the trace along.</param>
        /// <param name="axis2">The second axis of the diagonal to compute the trace along.</param>
        /// <param name="source">The NdArray containing the source values.</param>
        /// <returns>A new NdArray containing the result of this operation.</returns>
        public static NdArray <T> TraceAxis(int axis1, int axis2, NdArray <T> source)
        {
            var tax = axis1 < axis2 ? axis1 : axis1 - 1;

            return(SumAxis(tax, NdArrayOperator <T> .DiagAxis(axis1, axis2, source)));
        }