20 #ifndef GEOS_LINEARALGEBRA_INTERFACES_HYPREKERNELS_HPP_
21 #define GEOS_LINEARALGEBRA_INTERFACES_HYPREKERNELS_HPP_
23 #include "codingUtilities/Utilities.hpp"
24 #include "codingUtilities/traits.hpp"
26 #include "common/GEOS_RAJA_Interface.hpp"
30 #include <_hypre_parcsr_mv.h>
42 template<
typename T >
44 constexpr T identity( T
const v )
49 template<
typename T >
51 constexpr T plus( T
const lhs, T
const rhs )
58 template<
bool CONST >
61 HYPRE_Int
const * rowptr;
62 HYPRE_Int
const * colind;
63 add_const_if_t< HYPRE_Real, CONST > * values;
68 explicit CSRData( add_const_if_t< hypre_CSRMatrix, CONST > *
const mat )
69 : rowptr( hypre_CSRMatrixI( mat ) ),
70 colind( hypre_CSRMatrixJ( mat ) ),
71 values( hypre_CSRMatrixData( mat ) ),
72 nrow( hypre_CSRMatrixNumRows( mat ) ),
73 ncol( hypre_CSRMatrixNumCols( mat ) ),
74 nnz( hypre_CSRMatrixNumNonzeros( mat ) )
78 HYPRE_BigInt
const * getOffdColumnMap( hypre_ParCSRMatrix
const *
const mat );
80 void scaleMatrixValues( hypre_CSRMatrix *
const mat,
83 void scaleMatrixRows( hypre_CSRMatrix *
const mat,
84 hypre_Vector
const *
const vec );
86 void clampMatrixEntries( hypre_CSRMatrix *
const mat,
89 bool const skip_diag );
91 real64 computeMaxNorm( hypre_CSRMatrix
const *
const mat );
93 real64 computeMaxNorm( hypre_CSRMatrix
const *
const mat,
101 template<
typename F,
typename R >
108 operator()(
double acc,
double v )
const
110 return reduce( acc, transform( v ) );
116 template<
typename F,
typename R >
117 void rescaleMatrixRows( hypre_ParCSRMatrix *
const mat,
122 CSRData< false > diag{ hypre_ParCSRMatrixDiag( mat ) };
123 CSRData< false > offd{ hypre_ParCSRMatrixOffd( mat ) };
124 HYPRE_BigInt
const firstLocalRow = hypre_ParCSRMatrixFirstRowIndex( mat );
125 internal::RowReducer< F, R > reducer{ std::move( transform ), std::move( reduce ) };
129 HYPRE_Int
const localRow = LvArray::integerConversion< HYPRE_Int >( rowIndices[i] - firstLocalRow );
130 GEOS_ASSERT( 0 <= localRow && localRow < diag.nrow );
132 HYPRE_Real scale = 0.0;
133 for( HYPRE_Int k = diag.rowptr[localRow]; k < diag.rowptr[localRow + 1]; ++k )
135 scale = reducer( scale, diag.values[k] );
139 for( HYPRE_Int k = offd.rowptr[localRow]; k < offd.rowptr[localRow + 1]; ++k )
141 scale = reducer( scale, offd.values[k] );
145 GEOS_ASSERT_MSG( !isZero( scale ),
"Zero row sum in row " << rowIndices[i] );
147 for( HYPRE_Int k = diag.rowptr[localRow]; k < diag.rowptr[localRow + 1]; ++k )
149 diag.values[k] *= scale;
153 for( HYPRE_Int k = offd.rowptr[localRow]; k < offd.rowptr[localRow + 1]; ++k )
155 offd.values[k] *= scale;
161 template<
typename F,
typename R >
162 void computeRowsSums( hypre_ParCSRMatrix
const *
const mat,
163 hypre_ParVector *
const vec,
167 CSRData< true >
const diag{ hypre_ParCSRMatrixDiag( mat ) };
168 CSRData< true >
const offd{ hypre_ParCSRMatrixOffd( mat ) };
169 HYPRE_Real *
const values = hypre_VectorData( hypre_ParVectorLocalVector( vec ) );
170 internal::RowReducer< F, R > reducer{ std::move( transform ), std::move( reduce ) };
172 forAll< execPolicy >( diag.nrow, [diag, offd, reducer, values]
GEOS_HYPRE_HOST_DEVICE ( HYPRE_Int
const localRow )
174 HYPRE_Real sum = 0.0;
175 for( HYPRE_Int k = diag.rowptr[localRow]; k < diag.rowptr[localRow + 1]; ++k )
177 sum = reducer( sum, diag.values[k] );
181 for( HYPRE_Int k = offd.rowptr[localRow]; k < offd.rowptr[localRow + 1]; ++k )
183 sum = reducer( sum, offd.values[k] );
186 values[localRow] = sum;
193 template<
typename MAP >
195 makeSortedPermutation( HYPRE_Int
const *
const indices,
196 HYPRE_Int
const size,
197 HYPRE_Int *
const perm,
200 for( HYPRE_Int i = 0; i < size; ++i )
206 return map( indices[i] ) <
map( indices[j] );
208 LvArray::sortedArrayManipulation::makeSorted( perm, perm + size, comp );
213 template<
typename SRC_COLMAP,
typename DST_COLMAP >
214 void addEntriesRestricted( hypre_CSRMatrix
const *
const src_mat,
215 SRC_COLMAP
const src_colmap,
216 hypre_CSRMatrix *
const dst_mat,
217 DST_COLMAP
const dst_colmap,
223 CSRData< true > src{ src_mat };
224 CSRData< false > dst{ dst_mat };
227 if( src.ncol == 0 || isZero( scale ) )
233 array1d< HYPRE_Int >
const src_permutation_arr( hypre_CSRMatrixNumNonzeros( src_mat ) );
234 array1d< HYPRE_Int >
const dst_permutation_arr( hypre_CSRMatrixNumNonzeros( dst_mat ) );
236 arrayView1d< HYPRE_Int >
const src_permutation = src_permutation_arr.toView();
237 arrayView1d< HYPRE_Int >
const dst_permutation = dst_permutation_arr.toView();
239 forAll< hypre::execPolicy >( dst.nrow,
248 HYPRE_Int
const src_offset = src.rowptr[localRow];
249 HYPRE_Int
const src_length = src.rowptr[localRow + 1] - src_offset;
250 HYPRE_Int
const *
const src_indices = src.colind + src_offset;
251 HYPRE_Real
const *
const src_values = src.values + src_offset;
252 HYPRE_Int *
const src_perm = src_permutation.data() + src_offset;
254 HYPRE_Int
const dst_offset = dst.rowptr[localRow];
255 HYPRE_Int
const dst_length = dst.rowptr[localRow + 1] - dst_offset;
256 HYPRE_Int
const *
const dst_indices = dst.colind + dst_offset;
257 HYPRE_Real *
const dst_values = dst.values + dst_offset;
258 HYPRE_Int *
const dst_perm = dst_permutation.data() + dst_offset;
262 internal::makeSortedPermutation( src_indices, src_length, src_perm, src_colmap );
263 internal::makeSortedPermutation( dst_indices, dst_length, dst_perm, dst_colmap );
266 for( HYPRE_Int i = 0, j = 0; i < dst_length && j < src_length; ++i )
268 while( j < src_length && src_colmap( src_indices[src_perm[j]] ) < dst_colmap( dst_indices[dst_perm[i]] ) )
272 if( j < src_length && src_colmap( src_indices[src_perm[j]] ) == dst_colmap( dst_indices[dst_perm[i]] ) )
274 dst_values[dst_perm[i]] += scale * src_values[src_perm[j++]];
#define GEOS_HYPRE_DEVICE
Host-device marker for custom hypre kernels.
#define GEOS_HYPRE_HOST_DEVICE
Host-device marker for custom hypre kernels.
#define GEOS_ASSERT(EXP)
Assert a condition in debug builds.
#define GEOS_ASSERT_MSG(EXP, msg)
Assert a condition in debug builds.
Base template for ordered and unordered maps.
#define GEOS_LAI_ASSERT_EQ(lhs, rhs)
#define GEOS_LAI_ASSERT(expr)
ArrayView< T, 1 > arrayView1d
Alias for 1D array view.
GEOS_GLOBALINDEX_TYPE globalIndex
Global index type (for indexing objects across MPI partitions).
double real64
64-bit floating point type.
GEOS_LOCALINDEX_TYPE localIndex
Local index type (for indexing objects within an MPI partition).
mapBase< TKEY, TVAL, std::integral_constant< bool, true > > map
Ordered map type.