public void TestDOTInVectorWhole() { CreateRandomData(_hostInput1); CreateRandomData(_hostInput2); _gpu.CopyToDevice(_hostInput2, _devPtr2); _gpu.CopyToDevice(_hostInput1, _devPtr1); float gpuRes = _blas.DOT(_devPtr2, _devPtr1);//, ciN, 0, 1, 0, 1); float hostRes = _hostInput1.Zip(_hostInput2, (d1, d2) => d1 * d2).Sum(); Assert.AreEqual(hostRes, gpuRes, 16); }
/// <summary> /// Solves symmetric linear system with conjugate gradient solver. /// A * x = b /// </summary> /// <param name="n">number of rows and columns of matrix A.</param> /// <param name="csrValA">array of nnz elements, where nnz is the number of non-zero elements and can be obtained from csrRowA[m] - csrRowA[0].</param> /// <param name="csrRowA">array of n+1 index elements.</param> /// <param name="csrColA">array of nnz column indices.</param> /// <param name="dx">vector of n elements.</param> /// <param name="db">vector of n elements.</param> /// <param name="dp">vector of n elements. (temporary vector)</param> /// <param name="dAx">vector of n elements. (temporary vector)</param> /// <param name="tolerence">iterate tolerence of conjugate gradient solver.</param> /// <param name="maxIterate">max iterate count of conjugate gradient solver.</param> /// <returns>if A has singulrarity or failure in max iterate count, returns false. return true otherwise.</returns> public SolveResult CG( int n, int nnz, float[] csrValA, int[] csrRowA, int[] csrColA, float[] dx, float[] db, float[] dp, float[] dAx, float tolerence = 0.00001f, int maxIterate = 300) { SolveResult result = new SolveResult(); int k; // Iterate count. float a, b, r0, r1; float zero = 0.0f; float one = 1.0f; if (blas.DOT(db, db) == 0) { SetValue(n, dx, 0); result.IsSuccess = true; return(result); } sparse.CSRMV(n, n, nnz, ref one, csrValA, csrRowA, csrColA, dx, ref zero, dAx); blas.AXPY(-1.0f, dAx, db); r1 = blas.DOT(db, db); k = 1; r0 = 0; while (true) { if (k > 1) { b = r1 / r0; blas.SCAL(b, dp); blas.AXPY(1.0f, db, dp); } else { blas.COPY(db, dp); } sparse.CSRMV(n, n, nnz, ref one, csrValA, csrRowA, csrColA, dp, ref zero, dAx); a = r1 / blas.DOT(dp, dAx); blas.AXPY(a, dp, dx); blas.AXPY(-a, dAx, db); r0 = r1; r1 = blas.DOT(db, db); k++; if (r1 <= tolerence * tolerence) { result.IsSuccess = true; result.LastError = r1; result.IterateCount = k; break; } if (k > maxIterate) { result.IsSuccess = false; result.LastError = r1; result.IterateCount = k; break; } } return(result); }