20 #ifndef GEOS_LINEARALGEBRA_INTERFACES_HYPREKERNELS_HPP_
21 #define GEOS_LINEARALGEBRA_INTERFACES_HYPREKERNELS_HPP_
23 #include "codingUtilities/Utilities.hpp"
24 #include "codingUtilities/traits.hpp"
26 #include "common/GEOS_RAJA_Interface.hpp"
30 #include <_hypre_parcsr_mv.h>
39 template<
bool CONST >
42 HYPRE_Int
const * rowptr;
43 HYPRE_Int
const * colind;
44 add_const_if_t< HYPRE_Real, CONST > * values;
49 explicit CSRData( add_const_if_t< hypre_CSRMatrix, CONST > *
const mat )
50 : rowptr( hypre_CSRMatrixI( mat ) ),
51 colind( hypre_CSRMatrixJ( mat ) ),
52 values( hypre_CSRMatrixData( mat ) ),
53 nrow( hypre_CSRMatrixNumRows( mat ) ),
54 ncol( hypre_CSRMatrixNumCols( mat ) ),
55 nnz( hypre_CSRMatrixNumNonzeros( mat ) )
59 HYPRE_BigInt
const * getOffdColumnMap( hypre_ParCSRMatrix
const *
const mat );
61 void scaleMatrixValues( hypre_CSRMatrix *
const mat,
64 void scaleMatrixRows( hypre_CSRMatrix *
const mat,
65 hypre_Vector
const *
const vec );
67 void clampMatrixEntries( hypre_CSRMatrix *
const mat,
70 bool const skip_diag );
72 real64 computeMaxNorm( hypre_CSRMatrix
const *
const mat );
74 real64 computeMaxNorm( hypre_CSRMatrix
const *
const mat,
78 template<
typename F,
typename R >
79 void rescaleMatrixRows( hypre_ParCSRMatrix *
const mat,
84 CSRData< false > diag{ hypre_ParCSRMatrixDiag( mat ) };
85 CSRData< false > offd{ hypre_ParCSRMatrixOffd( mat ) };
86 HYPRE_BigInt
const firstLocalRow = hypre_ParCSRMatrixFirstRowIndex( mat );
88 forAll< execPolicy >( rowIndices.size(), [diag, offd, transform, reduce, rowIndices, firstLocalRow]
GEOS_HYPRE_DEVICE (
localIndex const i )
90 HYPRE_Int
const localRow = LvArray::integerConversion< HYPRE_Int >( rowIndices[i] - firstLocalRow );
91 GEOS_ASSERT( 0 <= localRow && localRow < diag.nrow );
93 HYPRE_Real scale = 0.0;
94 for( HYPRE_Int k = diag.rowptr[localRow]; k < diag.rowptr[localRow + 1]; ++k )
96 scale = reduce( scale, transform( diag.values[k] ) );
100 for( HYPRE_Int k = offd.rowptr[localRow]; k < offd.rowptr[localRow + 1]; ++k )
102 scale = reduce( scale, transform( offd.values[k] ) );
106 GEOS_ASSERT_MSG( !isZero( scale ),
"Zero row sum in row " << rowIndices[i] );
108 for( HYPRE_Int k = diag.rowptr[localRow]; k < diag.rowptr[localRow + 1]; ++k )
110 diag.values[k] *= scale;
114 for( HYPRE_Int k = offd.rowptr[localRow]; k < offd.rowptr[localRow + 1]; ++k )
116 offd.values[k] *= scale;
122 template<
typename F,
typename R >
123 void computeRowsSums( hypre_ParCSRMatrix
const *
const mat,
124 hypre_ParVector *
const vec,
128 CSRData< true >
const diag{ hypre_ParCSRMatrixDiag( mat ) };
129 CSRData< true >
const offd{ hypre_ParCSRMatrixOffd( mat ) };
130 HYPRE_Real *
const values = hypre_VectorData( hypre_ParVectorLocalVector( vec ) );
132 forAll< execPolicy >( diag.nrow, [diag, offd, transform, reduce, values]
GEOS_HYPRE_DEVICE ( HYPRE_Int
const localRow )
134 HYPRE_Real sum = 0.0;
135 for( HYPRE_Int k = diag.rowptr[localRow]; k < diag.rowptr[localRow + 1]; ++k )
137 sum = reduce( sum, transform( diag.values[k] ) );
141 for( HYPRE_Int k = offd.rowptr[localRow]; k < offd.rowptr[localRow + 1]; ++k )
143 sum = reduce( sum, transform( offd.values[k] ) );
146 values[localRow] = sum;
153 template<
typename MAP >
155 makeSortedPermutation( HYPRE_Int
const *
const indices,
156 HYPRE_Int
const size,
157 HYPRE_Int *
const perm,
160 for( HYPRE_Int i = 0; i < size; ++i )
166 return map( indices[i] ) <
map( indices[j] );
168 LvArray::sortedArrayManipulation::makeSorted( perm, perm + size, comp );
173 template<
typename KERNEL >
174 void addMatrixEntries( hypre_ParCSRMatrix
const *
const src,
175 hypre_ParCSRMatrix *
const dst,
180 KERNEL::launch( hypre_ParCSRMatrixDiag( src ),
182 hypre_ParCSRMatrixDiag( dst ),
185 if( hypre_CSRMatrixNumCols( hypre_ParCSRMatrixOffd( dst ) ) > 0 )
187 HYPRE_BigInt
const *
const src_colmap = hypre::getOffdColumnMap( src );
188 HYPRE_BigInt
const *
const dst_colmap = hypre::getOffdColumnMap( dst );
189 KERNEL::launch( hypre_ParCSRMatrixOffd( src ),
191 hypre_ParCSRMatrixOffd( dst ),
197 struct AddEntriesRestrictedKernel
199 template<
typename SRC_COLMAP,
typename DST_COLMAP >
201 launch( hypre_CSRMatrix
const *
const src_mat,
202 SRC_COLMAP
const src_colmap,
203 hypre_CSRMatrix *
const dst_mat,
204 DST_COLMAP
const dst_colmap,
210 CSRData< true > src{ src_mat };
211 CSRData< false > dst{ dst_mat };
214 if( src.ncol == 0 || isZero( scale ) )
220 array1d< HYPRE_Int >
const src_permutation_arr( src.nnz );
221 array1d< HYPRE_Int >
const dst_permutation_arr( dst.nnz );
223 arrayView1d< HYPRE_Int >
const src_permutation = src_permutation_arr.toView();
224 arrayView1d< HYPRE_Int >
const dst_permutation = dst_permutation_arr.toView();
226 forAll< hypre::execPolicy >( dst.nrow,
235 HYPRE_Int
const src_offset = src.rowptr[localRow];
236 HYPRE_Int
const src_length = src.rowptr[localRow + 1] - src_offset;
237 HYPRE_Int
const *
const src_indices = src.colind + src_offset;
238 HYPRE_Real
const *
const src_values = src.values + src_offset;
239 HYPRE_Int *
const src_perm = src_permutation.data() + src_offset;
241 HYPRE_Int
const dst_offset = dst.rowptr[localRow];
242 HYPRE_Int
const dst_length = dst.rowptr[localRow + 1] - dst_offset;
243 HYPRE_Int
const *
const dst_indices = dst.colind + dst_offset;
244 HYPRE_Real *
const dst_values = dst.values + dst_offset;
245 HYPRE_Int *
const dst_perm = dst_permutation.data() + dst_offset;
249 internal::makeSortedPermutation( src_indices, src_length, src_perm, src_colmap );
250 internal::makeSortedPermutation( dst_indices, dst_length, dst_perm, dst_colmap );
253 for( HYPRE_Int i = 0, j = 0; i < dst_length && j < src_length; ++i )
255 while( j < src_length && src_colmap( src_indices[src_perm[j]] ) < dst_colmap( dst_indices[dst_perm[i]] ) )
259 if( j < src_length && src_colmap( src_indices[src_perm[j]] ) == dst_colmap( dst_indices[dst_perm[i]] ) )
261 dst_values[dst_perm[i]] += scale * src_values[src_perm[j++]];
268 struct AddEntriesSamePatternKernel
270 template<
typename SRC_COLMAP,
typename DST_COLMAP >
272 launch( hypre_CSRMatrix
const *
const src_mat,
273 SRC_COLMAP
const src_colmap,
274 hypre_CSRMatrix *
const dst_mat,
275 DST_COLMAP
const dst_colmap,
281 CSRData< true > src{ src_mat };
282 CSRData< false > dst{ dst_mat };
285 if( src.ncol == 0 || isZero( scale ) )
291 forAll< hypre::execPolicy >( dst.nrow,
292 [src, src_colmap, dst, dst_colmap, scale]
GEOS_HYPRE_DEVICE ( HYPRE_Int
const localRow )
294 HYPRE_Int
const src_offset = src.rowptr[localRow];
295 HYPRE_Int
const src_length = src.rowptr[localRow + 1] - src_offset;
296 HYPRE_Int
const *
const src_indices = src.colind + src_offset;
297 HYPRE_Real
const *
const src_values = src.values + src_offset;
299 HYPRE_Int
const dst_offset = dst.rowptr[localRow];
300 HYPRE_Int
const dst_length = dst.rowptr[localRow + 1] - dst_offset;
301 HYPRE_Int
const *
const dst_indices = dst.colind + dst_offset;
302 HYPRE_Real *
const dst_values = dst.values + dst_offset;
306 GEOS_DEBUG_VAR( src_length, dst_length, src_indices, dst_indices, src_colmap, dst_colmap );
309 for( HYPRE_Int i = 0; i < dst_length; ++i )
311 GEOS_ASSERT_EQ( src_colmap( src_indices[i] ), dst_colmap( dst_indices[i] ) );
312 dst_values[i] += scale * src_values[i];
#define GEOS_DEBUG_VAR(...)
Mark a debug variable and silence compiler warnings.
#define GEOS_HYPRE_DEVICE
Device marker for custom hypre kernels.
#define GEOS_ASSERT(EXP)
Assert a condition in debug builds.
#define GEOS_ASSERT_MSG(EXP, msg)
Assert a condition in debug builds.
#define GEOS_ASSERT_EQ(lhs, rhs)
Assert that two values compare equal in debug builds.
Base template for ordered and unordered maps.
#define GEOS_LAI_ASSERT_EQ(lhs, rhs)
#define GEOS_LAI_ASSERT(expr)
ArrayView< T, 1 > arrayView1d
Alias for 1D array view.
GEOS_GLOBALINDEX_TYPE globalIndex
Global index type (for indexing objects across MPI partitions).
double real64
64-bit floating point type.
GEOS_LOCALINDEX_TYPE localIndex
Local index type (for indexing objects within an MPI partition).
mapBase< TKEY, TVAL, std::integral_constant< bool, true > > map
Ordered map type.