public void TestSCALInVectorWhole() { int index = 0; CreateRandomData(_hostInput1); _gpu.CopyToDevice(_hostInput1, _devPtr1); _blas.SCAL(10.0f, _devPtr1); _gpu.CopyFromDevice(_devPtr1, _hostOutput1); foreach (float f in _hostInput1) { Assert.AreEqual(f * 10.0f, _hostOutput1[index++]); } }
/// <summary> /// Solves symmetric linear system with conjugate gradient solver. /// A * x = b /// </summary> /// <param name="n">number of rows and columns of matrix A.</param> /// <param name="csrValA">array of nnz elements, where nnz is the number of non-zero elements and can be obtained from csrRowA[m] - csrRowA[0].</param> /// <param name="csrRowA">array of n+1 index elements.</param> /// <param name="csrColA">array of nnz column indices.</param> /// <param name="dx">vector of n elements.</param> /// <param name="db">vector of n elements.</param> /// <param name="dp">vector of n elements. (temporary vector)</param> /// <param name="dAx">vector of n elements. (temporary vector)</param> /// <param name="tolerence">iterate tolerence of conjugate gradient solver.</param> /// <param name="maxIterate">max iterate count of conjugate gradient solver.</param> /// <returns>if A has singulrarity or failure in max iterate count, returns false. return true otherwise.</returns> public SolveResult CG( int n, int nnz, float[] csrValA, int[] csrRowA, int[] csrColA, float[] dx, float[] db, float[] dp, float[] dAx, float tolerence = 0.00001f, int maxIterate = 300) { SolveResult result = new SolveResult(); int k; // Iterate count. float a, b, r0, r1; float zero = 0.0f; float one = 1.0f; if (blas.DOT(db, db) == 0) { SetValue(n, dx, 0); result.IsSuccess = true; return(result); } sparse.CSRMV(n, n, nnz, ref one, csrValA, csrRowA, csrColA, dx, ref zero, dAx); blas.AXPY(-1.0f, dAx, db); r1 = blas.DOT(db, db); k = 1; r0 = 0; while (true) { if (k > 1) { b = r1 / r0; blas.SCAL(b, dp); blas.AXPY(1.0f, db, dp); } else { blas.COPY(db, dp); } sparse.CSRMV(n, n, nnz, ref one, csrValA, csrRowA, csrColA, dp, ref zero, dAx); a = r1 / blas.DOT(dp, dAx); blas.AXPY(a, dp, dx); blas.AXPY(-a, dAx, db); r0 = r1; r1 = blas.DOT(db, db); k++; if (r1 <= tolerence * tolerence) { result.IsSuccess = true; result.LastError = r1; result.IterateCount = k; break; } if (k > maxIterate) { result.IsSuccess = false; result.LastError = r1; result.IterateCount = k; break; } } return(result); }