Exemplo n.º 1
0
 public void TestAXPYInVectorWhole()
 {
     CreateRandomData(_hostInput1);
     CreateRandomData(_hostInput2);
     _gpu.CopyToDevice(_hostInput2, _devPtr2);
     _gpu.CopyToDevice(_hostInput1, _devPtr1);
     _blas.AXPY(10.0f, _devPtr1, _devPtr2);
     _gpu.CopyFromDevice(_devPtr2, _hostOutput1);
     for (int i = 0; i < ciN; i++)
     {
         Assert.AreEqual(10.0f * _hostInput1[i] + _hostInput2[i], _hostOutput1[i]);
     }
 }
Exemplo n.º 2
0
        /// <summary>
        /// Solves symmetric linear system with conjugate gradient solver.
        /// A * x = b
        /// </summary>
        /// <param name="n">number of rows and columns of matrix A.</param>
        /// <param name="csrValA">array of nnz elements, where nnz is the number of non-zero elements and can be obtained from csrRowA[m] - csrRowA[0].</param>
        /// <param name="csrRowA">array of n+1 index elements.</param>
        /// <param name="csrColA">array of nnz column indices.</param>
        /// <param name="dx">vector of n elements.</param>
        /// <param name="db">vector of n elements.</param>
        /// <param name="dp">vector of n elements. (temporary vector)</param>
        /// <param name="dAx">vector of n elements. (temporary vector)</param>
        /// <param name="tolerence">iterate tolerence of conjugate gradient solver.</param>
        /// <param name="maxIterate">max iterate count of conjugate gradient solver.</param>
        /// <returns>if A has singulrarity or failure in max iterate count, returns false. return true otherwise.</returns>
        public SolveResult CG(
            int n, int nnz, float[] csrValA, int[] csrRowA, int[] csrColA,
            float[] dx, float[] db, float[] dp, float[] dAx, float tolerence = 0.00001f, int maxIterate = 300)
        {
            SolveResult result = new SolveResult();
            int         k; // Iterate count.
            float       a, b, r0, r1;
            float       zero = 0.0f;
            float       one  = 1.0f;

            if (blas.DOT(db, db) == 0)
            {
                SetValue(n, dx, 0);
                result.IsSuccess = true;

                return(result);
            }

            sparse.CSRMV(n, n, nnz, ref one, csrValA, csrRowA, csrColA, dx, ref zero, dAx);
            blas.AXPY(-1.0f, dAx, db);

            r1 = blas.DOT(db, db);

            k  = 1;
            r0 = 0;

            while (true)
            {
                if (k > 1)
                {
                    b = r1 / r0;
                    blas.SCAL(b, dp);
                    blas.AXPY(1.0f, db, dp);
                }
                else
                {
                    blas.COPY(db, dp);
                }

                sparse.CSRMV(n, n, nnz, ref one, csrValA, csrRowA, csrColA, dp, ref zero, dAx);
                a = r1 / blas.DOT(dp, dAx);
                blas.AXPY(a, dp, dx);
                blas.AXPY(-a, dAx, db);

                r0 = r1;
                r1 = blas.DOT(db, db);

                k++;

                if (r1 <= tolerence * tolerence)
                {
                    result.IsSuccess    = true;
                    result.LastError    = r1;
                    result.IterateCount = k;
                    break;
                }

                if (k > maxIterate)
                {
                    result.IsSuccess    = false;
                    result.LastError    = r1;
                    result.IterateCount = k;
                    break;
                }
            }

            return(result);
        }