20 #ifndef GEOS_LINEARALGEBRA_INTERFACES_HYPREKERNELS_HPP_
21 #define GEOS_LINEARALGEBRA_INTERFACES_HYPREKERNELS_HPP_
23 #include "codingUtilities/Utilities.hpp"
24 #include "codingUtilities/traits.hpp"
26 #include "common/GEOS_RAJA_Interface.hpp"
30 #include <_hypre_parcsr_mv.h>
42 struct ZeroRowSumMessage
47 inline std::ostream &
operator<<( std::ostream & os, ZeroRowSumMessage
const & msg )
49 return os <<
"Zero row sum in row " << msg.row;
54 #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
55 printf(
"***** Zero row sum in row %lld\n",
static_cast< long long >( row ) );
57 return {
static_cast< long long >( row ) };
62 template<
bool CONST >
65 HYPRE_Int
const * rowptr;
66 HYPRE_Int
const * colind;
67 add_const_if_t< HYPRE_Real, CONST > * values;
72 explicit CSRData( add_const_if_t< hypre_CSRMatrix, CONST > *
const mat )
73 : rowptr( hypre_CSRMatrixI( mat ) ),
74 colind( hypre_CSRMatrixJ( mat ) ),
75 values( hypre_CSRMatrixData( mat ) ),
76 nrow( hypre_CSRMatrixNumRows( mat ) ),
77 ncol( hypre_CSRMatrixNumCols( mat ) ),
78 nnz( hypre_CSRMatrixNumNonzeros( mat ) )
82 HYPRE_BigInt
const * getOffdColumnMap( hypre_ParCSRMatrix
const *
const mat );
84 void scaleMatrixValues( hypre_CSRMatrix *
const mat,
87 void scaleMatrixRows( hypre_CSRMatrix *
const mat,
88 hypre_Vector
const *
const vec );
90 void clampMatrixEntries( hypre_CSRMatrix *
const mat,
93 bool const skip_diag );
95 real64 computeMaxNorm( hypre_CSRMatrix
const *
const mat );
97 real64 computeMaxNorm( hypre_CSRMatrix
const *
const mat,
98 arrayView1d< globalIndex const >
const & rowIndices,
101 template<
typename F,
typename R >
102 void rescaleMatrixRows( hypre_ParCSRMatrix *
const mat,
103 arrayView1d< globalIndex const >
const & rowIndices,
107 CSRData< false > diag{ hypre_ParCSRMatrixDiag( mat ) };
108 CSRData< false > offd{ hypre_ParCSRMatrixOffd( mat ) };
109 HYPRE_BigInt
const firstLocalRow = hypre_ParCSRMatrixFirstRowIndex( mat );
111 forAll< execPolicy >( rowIndices.size(), [diag, offd, transform, reduce, rowIndices, firstLocalRow]
GEOS_HOST_DEVICE (
localIndex const i )
113 HYPRE_Int
const localRow = LvArray::integerConversion< HYPRE_Int >( rowIndices[i] - firstLocalRow );
114 GEOS_ASSERT( 0 <= localRow && localRow < diag.nrow );
116 HYPRE_Real scale = 0.0;
117 for( HYPRE_Int k = diag.rowptr[localRow]; k < diag.rowptr[localRow + 1]; ++k )
119 scale = reduce( scale, transform( diag.values[k] ) );
123 for( HYPRE_Int k = offd.rowptr[localRow]; k < offd.rowptr[localRow + 1]; ++k )
125 scale = reduce( scale, transform( offd.values[k] ) );
129 GEOS_ASSERT_MSG( !isZero( scale ), internal::zeroRowSumMessage( rowIndices[i] ) );
131 for( HYPRE_Int k = diag.rowptr[localRow]; k < diag.rowptr[localRow + 1]; ++k )
133 diag.values[k] *= scale;
137 for( HYPRE_Int k = offd.rowptr[localRow]; k < offd.rowptr[localRow + 1]; ++k )
139 offd.values[k] *= scale;
145 template<
typename F,
typename R >
146 void computeRowsSums( hypre_ParCSRMatrix
const *
const mat,
147 hypre_ParVector *
const vec,
151 CSRData< true >
const diag{ hypre_ParCSRMatrixDiag( mat ) };
152 CSRData< true >
const offd{ hypre_ParCSRMatrixOffd( mat ) };
153 HYPRE_Real *
const values = hypre_VectorData( hypre_ParVectorLocalVector( vec ) );
155 forAll< execPolicy >( diag.nrow, [diag, offd, transform, reduce, values]
GEOS_HOST_DEVICE ( HYPRE_Int
const localRow )
157 HYPRE_Real sum = 0.0;
158 for( HYPRE_Int k = diag.rowptr[localRow]; k < diag.rowptr[localRow + 1]; ++k )
160 sum = reduce( sum, transform( diag.values[k] ) );
164 for( HYPRE_Int k = offd.rowptr[localRow]; k < offd.rowptr[localRow + 1]; ++k )
166 sum = reduce( sum, transform( offd.values[k] ) );
169 values[localRow] = sum;
176 template<
typename MAP >
178 makeSortedPermutation( HYPRE_Int
const *
const indices,
179 HYPRE_Int
const size,
180 HYPRE_Int *
const perm,
183 for( HYPRE_Int i = 0; i < size; ++i )
189 return map( indices[i] ) <
map( indices[j] );
191 LvArray::sortedArrayManipulation::makeSorted( perm, perm + size, comp );
196 template<
typename KERNEL >
197 void addMatrixEntries( hypre_ParCSRMatrix
const *
const src,
198 hypre_ParCSRMatrix *
const dst,
203 KERNEL::launch( hypre_ParCSRMatrixDiag( src ),
205 hypre_ParCSRMatrixDiag( dst ),
208 if( hypre_CSRMatrixNumCols( hypre_ParCSRMatrixOffd( dst ) ) > 0 )
210 HYPRE_BigInt
const *
const src_colmap = hypre::getOffdColumnMap( src );
211 HYPRE_BigInt
const *
const dst_colmap = hypre::getOffdColumnMap( dst );
212 KERNEL::launch( hypre_ParCSRMatrixOffd( src ),
213 [src_colmap]
GEOS_HOST_DEVICE ( HYPRE_Int
const i ){
return src_colmap[i]; },
214 hypre_ParCSRMatrixOffd( dst ),
215 [dst_colmap]
GEOS_HOST_DEVICE ( HYPRE_Int
const i ){
return dst_colmap[i]; },
220 struct AddEntriesRestrictedKernel
222 template<
typename SRC_COLMAP,
typename DST_COLMAP >
224 launch( hypre_CSRMatrix
const *
const src_mat,
225 SRC_COLMAP
const src_colmap,
226 hypre_CSRMatrix *
const dst_mat,
227 DST_COLMAP
const dst_colmap,
233 CSRData< true > src{ src_mat };
234 CSRData< false > dst{ dst_mat };
237 if( src.ncol == 0 || isZero( scale ) )
243 array1d< HYPRE_Int >
const src_permutation_arr( src.nnz );
244 array1d< HYPRE_Int >
const dst_permutation_arr( dst.nnz );
246 arrayView1d< HYPRE_Int >
const src_permutation = src_permutation_arr.toView();
247 arrayView1d< HYPRE_Int >
const dst_permutation = dst_permutation_arr.toView();
249 forAll< hypre::execPolicy >( dst.nrow,
258 HYPRE_Int
const src_offset = src.rowptr[localRow];
259 HYPRE_Int
const src_length = src.rowptr[localRow + 1] - src_offset;
260 HYPRE_Int
const *
const src_indices = src.colind + src_offset;
261 HYPRE_Real
const *
const src_values = src.values + src_offset;
262 HYPRE_Int *
const src_perm = src_permutation.data() + src_offset;
264 HYPRE_Int
const dst_offset = dst.rowptr[localRow];
265 HYPRE_Int
const dst_length = dst.rowptr[localRow + 1] - dst_offset;
266 HYPRE_Int
const *
const dst_indices = dst.colind + dst_offset;
267 HYPRE_Real *
const dst_values = dst.values + dst_offset;
268 HYPRE_Int *
const dst_perm = dst_permutation.data() + dst_offset;
272 internal::makeSortedPermutation( src_indices, src_length, src_perm, src_colmap );
273 internal::makeSortedPermutation( dst_indices, dst_length, dst_perm, dst_colmap );
276 for( HYPRE_Int i = 0, j = 0; i < dst_length && j < src_length; ++i )
278 while( j < src_length && src_colmap( src_indices[src_perm[j]] ) < dst_colmap( dst_indices[dst_perm[i]] ) )
282 if( j < src_length && src_colmap( src_indices[src_perm[j]] ) == dst_colmap( dst_indices[dst_perm[i]] ) )
284 dst_values[dst_perm[i]] += scale * src_values[src_perm[j++]];
291 struct AddEntriesSamePatternKernel
293 template<
typename SRC_COLMAP,
typename DST_COLMAP >
295 launch( hypre_CSRMatrix
const *
const src_mat,
296 SRC_COLMAP
const src_colmap,
297 hypre_CSRMatrix *
const dst_mat,
298 DST_COLMAP
const dst_colmap,
304 CSRData< true > src{ src_mat };
305 CSRData< false > dst{ dst_mat };
308 if( src.ncol == 0 || isZero( scale ) )
314 forAll< hypre::execPolicy >( dst.nrow,
315 [src, src_colmap, dst, dst_colmap, scale]
GEOS_HOST_DEVICE ( HYPRE_Int
const localRow )
317 HYPRE_Int
const src_offset = src.rowptr[localRow];
318 HYPRE_Int
const src_length = src.rowptr[localRow + 1] - src_offset;
319 HYPRE_Int
const *
const src_indices = src.colind + src_offset;
320 HYPRE_Real
const *
const src_values = src.values + src_offset;
322 HYPRE_Int
const dst_offset = dst.rowptr[localRow];
323 HYPRE_Int
const dst_length = dst.rowptr[localRow + 1] - dst_offset;
324 HYPRE_Int
const *
const dst_indices = dst.colind + dst_offset;
325 HYPRE_Real *
const dst_values = dst.values + dst_offset;
329 GEOS_DEBUG_VAR( src_length, dst_length, src_indices, dst_indices, src_colmap, dst_colmap );
332 for( HYPRE_Int i = 0; i < dst_length; ++i )
334 GEOS_ASSERT_EQ( src_colmap( src_indices[i] ), dst_colmap( dst_indices[i] ) );
335 dst_values[i] += scale * src_values[i];
#define GEOS_HOST_DEVICE
Marks a host-device function.
#define GEOS_DEBUG_VAR(...)
Mark a debug variable and silence compiler warnings.
#define GEOS_ASSERT_MSG(COND,...)
Abort execution if COND is false but only when NDEBUG is not defined..
#define GEOS_ASSERT(COND)
Assert a condition in debug builds.
#define GEOS_ASSERT_EQ(lhs, rhs)
Assert that two values compare equal in debug builds.
#define GEOS_LAI_ASSERT_EQ(lhs, rhs)
#define GEOS_LAI_ASSERT(expr)
GEOS_GLOBALINDEX_TYPE globalIndex
Global index type (for indexing objects across MPI partitions).
double real64
64-bit floating point type.
GEOS_LOCALINDEX_TYPE localIndex
Local index type (for indexing objects within an MPI partition).
std::ostream & operator<<(std::ostream &stream, mapType< K, V, SORTED > const &map)
Stream output operator for map types.
mapBase< TKEY, TVAL, std::integral_constant< bool, true > > map
Ordered map type.