예제 #1
0
파일: Matmul.cs 프로젝트: chrinide/bohrium
        /// <summary>
        /// Performs matrix multiplication on the two operands, using the supplied methods.
        /// This is the main entry point for detecting valid inputs, instanciating the output,
        /// and determining if the operation should be lazy evaluated.
        /// </summary>
        /// <typeparam name="T">The type of data to operate on</typeparam>
        /// <typeparam name="CADD">The typed add operator</typeparam>
        /// <typeparam name="CMUL">The typed multiply operator</typeparam>
        /// <param name="addop">The add operator</param>
        /// <param name="mulop">The multiply operator</param>
        /// <param name="in1">The left-hand-side argument</param>
        /// <param name="in2">The right-hand-side argument</param>
        /// <param name="out">An optional output argument, use for in-place operations</param>
        /// <returns>An array with the matrix multiplication result</returns>
        private static NdArray <T> Matmul_Entry <T, CADD, CMUL>(CADD addop, CMUL mulop, NdArray <T> in1, NdArray <T> in2, NdArray <T> @out = null)
            where CADD : struct, IBinaryOp <T>
            where CMUL : struct, IBinaryOp <T>
        {
            if (in1.Shape.Dimensions.LongLength != 2)
            {
                throw new ArgumentException("Input elements must be 2D", "in1");
            }
            if (in2.Shape.Dimensions.LongLength > 2)
            {
                throw new ArgumentException("Input elements must be 2D", "in2");
            }
            if (in1.Shape.Dimensions[1].Length != in2.Shape.Dimensions[0].Length)
            {
                throw new ArgumentException(string.Format("Input elements shape size must match for matrix multiplication"));
            }

            if (in2.Shape.Dimensions.LongLength < 2)
            {
                in2 = in2.Subview(Range.NewAxis, in2.Shape.Dimensions.LongLength);
            }

            long[] newDims = new long[] { in1.Shape.Dimensions[0].Length, in2.Shape.Dimensions[1].Length };
            if (@out == null)
            {
                @out = new NdArray <T>(new Shape(newDims));
            }
            else
            {
                if (@out.Shape.Dimensions.LongLength != 2 || @out.Shape.Dimensions[0].Length != newDims[0] || @out.Shape.Dimensions[1].Length != newDims[1])
                {
                    throw new Exception("The output array for matrix multiplication is not correctly shaped");
                }
            }

            if (@out.DataAccessor is ILazyAccessor <T> )
            {
                ((ILazyAccessor <T>)@out.DataAccessor).AddOperation(new LazyMatmulOperation <T>(addop, mulop), @out, in1, in2);
            }
            else
            {
                ApplyManager.ApplyMatmul <T, CADD, CMUL>(addop, mulop, in1, in2, @out);
            }

            return(@out);
        }