Beispiel #1
0
        public void TestDOTInVectorWhole()
        {
            CreateRandomData(_hostInput1);
            CreateRandomData(_hostInput2);
            _gpu.CopyToDevice(_hostInput2, _devPtr2);
            _gpu.CopyToDevice(_hostInput1, _devPtr1);
            float gpuRes  = _blas.DOT(_devPtr2, _devPtr1);//, ciN, 0, 1, 0, 1);
            float hostRes = _hostInput1.Zip(_hostInput2, (d1, d2) => d1 * d2).Sum();

            Assert.AreEqual(hostRes, gpuRes, 16);
        }
Beispiel #2
0
        /// <summary>
        /// Solves symmetric linear system with conjugate gradient solver.
        /// A * x = b
        /// </summary>
        /// <param name="n">number of rows and columns of matrix A.</param>
        /// <param name="csrValA">array of nnz elements, where nnz is the number of non-zero elements and can be obtained from csrRowA[m] - csrRowA[0].</param>
        /// <param name="csrRowA">array of n+1 index elements.</param>
        /// <param name="csrColA">array of nnz column indices.</param>
        /// <param name="dx">vector of n elements.</param>
        /// <param name="db">vector of n elements.</param>
        /// <param name="dp">vector of n elements. (temporary vector)</param>
        /// <param name="dAx">vector of n elements. (temporary vector)</param>
        /// <param name="tolerence">iterate tolerence of conjugate gradient solver.</param>
        /// <param name="maxIterate">max iterate count of conjugate gradient solver.</param>
        /// <returns>if A has singulrarity or failure in max iterate count, returns false. return true otherwise.</returns>
        public SolveResult CG(
            int n, int nnz, float[] csrValA, int[] csrRowA, int[] csrColA,
            float[] dx, float[] db, float[] dp, float[] dAx, float tolerence = 0.00001f, int maxIterate = 300)
        {
            SolveResult result = new SolveResult();
            int         k; // Iterate count.
            float       a, b, r0, r1;
            float       zero = 0.0f;
            float       one  = 1.0f;

            if (blas.DOT(db, db) == 0)
            {
                SetValue(n, dx, 0);
                result.IsSuccess = true;

                return(result);
            }

            sparse.CSRMV(n, n, nnz, ref one, csrValA, csrRowA, csrColA, dx, ref zero, dAx);
            blas.AXPY(-1.0f, dAx, db);

            r1 = blas.DOT(db, db);

            k  = 1;
            r0 = 0;

            while (true)
            {
                if (k > 1)
                {
                    b = r1 / r0;
                    blas.SCAL(b, dp);
                    blas.AXPY(1.0f, db, dp);
                }
                else
                {
                    blas.COPY(db, dp);
                }

                sparse.CSRMV(n, n, nnz, ref one, csrValA, csrRowA, csrColA, dp, ref zero, dAx);
                a = r1 / blas.DOT(dp, dAx);
                blas.AXPY(a, dp, dx);
                blas.AXPY(-a, dAx, db);

                r0 = r1;
                r1 = blas.DOT(db, db);

                k++;

                if (r1 <= tolerence * tolerence)
                {
                    result.IsSuccess    = true;
                    result.LastError    = r1;
                    result.IterateCount = k;
                    break;
                }

                if (k > maxIterate)
                {
                    result.IsSuccess    = false;
                    result.LastError    = r1;
                    result.IterateCount = k;
                    break;
                }
            }

            return(result);
        }