Ejemplo n.º 1
0
        public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun)
        {
            Contracts.Assert(mat.Size == dst.Size * src.Size);
            Contracts.Assert(crun >= 0);

            if (Avx.IsSupported)
            {
                if (!tran)
                {
                    Contracts.Assert(crun <= dst.Size);
                    AvxIntrinsics.MatMulX(add, mat, src, dst, crun, src.Size);
                }
                else
                {
                    Contracts.Assert(crun <= src.Size);
                    AvxIntrinsics.MatMulTranX(add, mat, src, dst, dst.Size, crun);
                }
            }
            else if (Sse.IsSupported)
            {
                if (!tran)
                {
                    Contracts.Assert(crun <= dst.Size);
                    SseIntrinsics.MatMulA(add, mat, src, dst, crun, src.Size);
                }
                else
                {
                    Contracts.Assert(crun <= src.Size);
                    SseIntrinsics.MatMulTranA(add, mat, src, dst, dst.Size, crun);
                }
            }
            else
            {
                if (!tran)
                {
                    Contracts.Assert(crun <= dst.Size);
                    for (int i = 0; i < crun; i++)
                    {
                        float dotProduct = 0;
                        for (int j = 0; j < src.Size; j++)
                        {
                            dotProduct += mat[i * src.Size + j] * src[j];
                        }

                        if (add)
                        {
                            dst[i] += dotProduct;
                        }
                        else
                        {
                            dst[i] = dotProduct;
                        }
                    }
                }
                else
                {
                    Contracts.Assert(crun <= src.Size);
                    for (int i = 0; i < dst.Size; i++)
                    {
                        float dotProduct = 0;
                        for (int j = 0; j < crun; j++)
                        {
                            dotProduct += mat[j * src.Size + i] * src[j];
                        }

                        if (add)
                        {
                            dst[i] += dotProduct;
                        }
                        else
                        {
                            dst[i] = dotProduct;
                        }
                    }
                }
            }
        }