20 #ifndef GEOS_LINEARALGEBRA_INTERFACES_HYPREKERNELS_HPP_ 
   21 #define GEOS_LINEARALGEBRA_INTERFACES_HYPREKERNELS_HPP_ 
   23 #include "codingUtilities/Utilities.hpp" 
   24 #include "codingUtilities/traits.hpp" 
   26 #include "common/GEOS_RAJA_Interface.hpp" 
   30 #include <_hypre_parcsr_mv.h> 
   39 template< 
bool CONST >
 
   42   HYPRE_Int 
const * rowptr;
 
   43   HYPRE_Int 
const * colind;
 
   44   add_const_if_t< HYPRE_Real, CONST > * values;
 
   49   explicit CSRData( add_const_if_t< hypre_CSRMatrix, CONST > * 
const mat )
 
   50     : rowptr( hypre_CSRMatrixI( mat ) ),
 
   51     colind( hypre_CSRMatrixJ( mat ) ),
 
   52     values( hypre_CSRMatrixData( mat ) ),
 
   53     nrow( hypre_CSRMatrixNumRows( mat ) ),
 
   54     ncol( hypre_CSRMatrixNumCols( mat ) ),
 
   55     nnz( hypre_CSRMatrixNumNonzeros( mat ) )
 
   59 HYPRE_BigInt 
const * getOffdColumnMap( hypre_ParCSRMatrix 
const * 
const mat );
 
   61 void scaleMatrixValues( hypre_CSRMatrix * 
const mat,
 
   64 void scaleMatrixRows( hypre_CSRMatrix * 
const mat,
 
   65                       hypre_Vector 
const * 
const vec );
 
   67 void clampMatrixEntries( hypre_CSRMatrix * 
const mat,
 
   70                          bool const skip_diag );
 
   72 real64 computeMaxNorm( hypre_CSRMatrix 
const * 
const mat );
 
   74 real64 computeMaxNorm( hypre_CSRMatrix 
const * 
const mat,
 
   78 template< 
typename F, 
typename R >
 
   79 void rescaleMatrixRows( hypre_ParCSRMatrix * 
const mat,
 
   84   CSRData< false > diag{ hypre_ParCSRMatrixDiag( mat ) };
 
   85   CSRData< false > offd{ hypre_ParCSRMatrixOffd( mat ) };
 
   86   HYPRE_BigInt 
const firstLocalRow = hypre_ParCSRMatrixFirstRowIndex( mat );
 
   88   forAll< execPolicy >( rowIndices.size(), [diag, offd, transform, reduce, rowIndices, firstLocalRow] 
GEOS_HYPRE_DEVICE ( 
localIndex const i )
 
   90     HYPRE_Int 
const localRow = LvArray::integerConversion< HYPRE_Int >( rowIndices[i] - firstLocalRow );
 
   91     GEOS_ASSERT( 0 <= localRow && localRow < diag.nrow );
 
   93     HYPRE_Real scale = 0.0;
 
   94     for( HYPRE_Int k = diag.rowptr[localRow]; k < diag.rowptr[localRow + 1]; ++k )
 
   96       scale = reduce( scale, transform( diag.values[k] ) );
 
  100       for( HYPRE_Int k = offd.rowptr[localRow]; k < offd.rowptr[localRow + 1]; ++k )
 
  102         scale = reduce( scale, transform( offd.values[k] ) );
 
  106     GEOS_ASSERT_MSG( !isZero( scale ), 
"Zero row sum in row " << rowIndices[i] );
 
  108     for( HYPRE_Int k = diag.rowptr[localRow]; k < diag.rowptr[localRow + 1]; ++k )
 
  110       diag.values[k] *= scale;
 
  114       for( HYPRE_Int k = offd.rowptr[localRow]; k < offd.rowptr[localRow + 1]; ++k )
 
  116         offd.values[k] *= scale;
 
  122 template< 
typename F, 
typename R >
 
  123 void computeRowsSums( hypre_ParCSRMatrix 
const * 
const mat,
 
  124                       hypre_ParVector * 
const vec,
 
  128   CSRData< true > 
const diag{ hypre_ParCSRMatrixDiag( mat ) };
 
  129   CSRData< true > 
const offd{ hypre_ParCSRMatrixOffd( mat ) };
 
  130   HYPRE_Real * 
const values = hypre_VectorData( hypre_ParVectorLocalVector( vec ) );
 
  132   forAll< execPolicy >( diag.nrow, [diag, offd, transform, reduce, values] 
GEOS_HYPRE_DEVICE ( HYPRE_Int 
const localRow )
 
  134     HYPRE_Real sum = 0.0;
 
  135     for( HYPRE_Int k = diag.rowptr[localRow]; k < diag.rowptr[localRow + 1]; ++k )
 
  137       sum = reduce( sum, transform( diag.values[k] ) );
 
  141       for( HYPRE_Int k = offd.rowptr[localRow]; k < offd.rowptr[localRow + 1]; ++k )
 
  143         sum = reduce( sum, transform( offd.values[k] ) );
 
  146     values[localRow] = sum;
 
  153 template< 
typename MAP >
 
  155 makeSortedPermutation( HYPRE_Int 
const * 
const indices,
 
  156                        HYPRE_Int 
const size,
 
  157                        HYPRE_Int * 
const perm,
 
  160   for( HYPRE_Int i = 0; i < size; ++i )
 
  166     return map( indices[i] ) < 
map( indices[j] );
 
  168   LvArray::sortedArrayManipulation::makeSorted( perm, perm + size, comp );
 
  173 template< 
typename KERNEL >
 
  174 void addMatrixEntries( hypre_ParCSRMatrix 
const * 
const src,
 
  175                        hypre_ParCSRMatrix * 
const dst,
 
  180   KERNEL::launch( hypre_ParCSRMatrixDiag( src ),
 
  182                   hypre_ParCSRMatrixDiag( dst ),
 
  185   if( hypre_CSRMatrixNumCols( hypre_ParCSRMatrixOffd( dst ) ) > 0 )
 
  187     HYPRE_BigInt 
const * 
const src_colmap = hypre::getOffdColumnMap( src );
 
  188     HYPRE_BigInt 
const * 
const dst_colmap = hypre::getOffdColumnMap( dst );
 
  189     KERNEL::launch( hypre_ParCSRMatrixOffd( src ),
 
  191                     hypre_ParCSRMatrixOffd( dst ),
 
  197 struct AddEntriesRestrictedKernel
 
  199   template< 
typename SRC_COLMAP, 
typename DST_COLMAP >
 
  201   launch( hypre_CSRMatrix 
const * 
const src_mat,
 
  202           SRC_COLMAP 
const src_colmap,
 
  203           hypre_CSRMatrix * 
const dst_mat,
 
  204           DST_COLMAP 
const dst_colmap,
 
  210     CSRData< true > src{ src_mat };
 
  211     CSRData< false > dst{ dst_mat };
 
  214     if( src.ncol == 0 || isZero( scale ) )
 
  220     array1d< HYPRE_Int > 
const src_permutation_arr( src.nnz );
 
  221     array1d< HYPRE_Int > 
const dst_permutation_arr( dst.nnz );
 
  223     arrayView1d< HYPRE_Int > 
const src_permutation = src_permutation_arr.toView();
 
  224     arrayView1d< HYPRE_Int > 
const dst_permutation = dst_permutation_arr.toView();
 
  226     forAll< hypre::execPolicy >( dst.nrow,
 
  235       HYPRE_Int 
const src_offset = src.rowptr[localRow];
 
  236       HYPRE_Int 
const src_length = src.rowptr[localRow + 1] - src_offset;
 
  237       HYPRE_Int 
const * 
const src_indices = src.colind + src_offset;
 
  238       HYPRE_Real 
const * 
const src_values = src.values + src_offset;
 
  239       HYPRE_Int * 
const src_perm = src_permutation.data() + src_offset;
 
  241       HYPRE_Int 
const dst_offset = dst.rowptr[localRow];
 
  242       HYPRE_Int 
const dst_length = dst.rowptr[localRow + 1] - dst_offset;
 
  243       HYPRE_Int 
const * 
const dst_indices = dst.colind + dst_offset;
 
  244       HYPRE_Real * 
const dst_values = dst.values + dst_offset;
 
  245       HYPRE_Int * 
const dst_perm = dst_permutation.data() + dst_offset;
 
  249       internal::makeSortedPermutation( src_indices, src_length, src_perm, src_colmap );
 
  250       internal::makeSortedPermutation( dst_indices, dst_length, dst_perm, dst_colmap );
 
  253       for( HYPRE_Int i = 0, j = 0; i < dst_length && j < src_length; ++i )
 
  255         while( j < src_length && src_colmap( src_indices[src_perm[j]] ) < dst_colmap( dst_indices[dst_perm[i]] ) )
 
  259         if( j < src_length && src_colmap( src_indices[src_perm[j]] ) == dst_colmap( dst_indices[dst_perm[i]] ) )
 
  261           dst_values[dst_perm[i]] += scale * src_values[src_perm[j++]];
 
  268 struct AddEntriesSamePatternKernel
 
  270   template< 
typename SRC_COLMAP, 
typename DST_COLMAP >
 
  272   launch( hypre_CSRMatrix 
const * 
const src_mat,
 
  273           SRC_COLMAP 
const src_colmap,
 
  274           hypre_CSRMatrix * 
const dst_mat,
 
  275           DST_COLMAP 
const dst_colmap,
 
  281     CSRData< true > src{ src_mat };
 
  282     CSRData< false > dst{ dst_mat };
 
  285     if( src.ncol == 0 || isZero( scale ) )
 
  291     forAll< hypre::execPolicy >( dst.nrow,
 
  292                                  [src, src_colmap, dst, dst_colmap, scale] 
GEOS_HYPRE_DEVICE ( HYPRE_Int 
const localRow )
 
  294       HYPRE_Int 
const src_offset = src.rowptr[localRow];
 
  295       HYPRE_Int 
const src_length = src.rowptr[localRow + 1] - src_offset;
 
  296       HYPRE_Int 
const * 
const src_indices = src.colind + src_offset;
 
  297       HYPRE_Real 
const * 
const src_values = src.values + src_offset;
 
  299       HYPRE_Int 
const dst_offset = dst.rowptr[localRow];
 
  300       HYPRE_Int 
const dst_length = dst.rowptr[localRow + 1] - dst_offset;
 
  301       HYPRE_Int 
const * 
const dst_indices = dst.colind + dst_offset;
 
  302       HYPRE_Real * 
const dst_values = dst.values + dst_offset;
 
  306       GEOS_DEBUG_VAR( src_length, dst_length, src_indices, dst_indices, src_colmap, dst_colmap );
 
  309       for( HYPRE_Int i = 0; i < dst_length; ++i )
 
  311         GEOS_ASSERT_EQ( src_colmap( src_indices[i] ), dst_colmap( dst_indices[i] ) );
 
  312         dst_values[i] += scale * src_values[i];
 
#define GEOS_DEBUG_VAR(...)
Mark a debug variable and silence compiler warnings.
 
#define GEOS_HYPRE_DEVICE
Device marker for custom hypre kernels.
 
#define GEOS_ASSERT(EXP)
Assert a condition in debug builds.
 
#define GEOS_ASSERT_MSG(EXP, msg)
Assert a condition in debug builds.
 
#define GEOS_ASSERT_EQ(lhs, rhs)
Assert that two values compare equal in debug builds.
 
Base template for ordered and unordered maps.
 
#define GEOS_LAI_ASSERT_EQ(lhs, rhs)
 
#define GEOS_LAI_ASSERT(expr)
 
ArrayView< T, 1 > arrayView1d
Alias for 1D array view.
 
GEOS_GLOBALINDEX_TYPE globalIndex
Global index type (for indexing objects across MPI partitions).
 
double real64
64-bit floating point type.
 
GEOS_LOCALINDEX_TYPE localIndex
Local index type (for indexing objects within an MPI partition).
 
mapBase< TKEY, TVAL, std::integral_constant< bool, true > > map
Ordered map type.