GEOS
HypreKernels.hpp
Go to the documentation of this file.
1 /*
2  * ------------------------------------------------------------------------------------------------------------
3  * SPDX-License-Identifier: LGPL-2.1-only
4  *
5  * Copyright (c) 2016-2024 Lawrence Livermore National Security LLC
6  * Copyright (c) 2018-2024 TotalEnergies
7  * Copyright (c) 2018-2024 The Board of Trustees of the Leland Stanford Junior University
8  * Copyright (c) 2023-2024 Chevron
9  * Copyright (c) 2019- GEOS/GEOSX Contributors
10  * All rights reserved
11  *
12  * See top level LICENSE, COPYRIGHT, CONTRIBUTORS, NOTICE, and ACKNOWLEDGEMENTS files for details.
13  * ------------------------------------------------------------------------------------------------------------
14  */
15 
20 #ifndef GEOS_LINEARALGEBRA_INTERFACES_HYPREKERNELS_HPP_
21 #define GEOS_LINEARALGEBRA_INTERFACES_HYPREKERNELS_HPP_
22 
23 #include "codingUtilities/Utilities.hpp"
24 #include "codingUtilities/traits.hpp"
25 #include "common/DataTypes.hpp"
26 #include "common/GEOS_RAJA_Interface.hpp"
29 
30 #include <_hypre_parcsr_mv.h>
31 
32 namespace geos
33 {
34 namespace hypre
35 {
36 
38 
39 namespace internal
40 {
41 
42 struct ZeroRowSumMessage
43 {
44  long long row;
45 };
46 
47 inline std::ostream & operator<<( std::ostream & os, ZeroRowSumMessage const & msg )
48 {
49  return os << "Zero row sum in row " << msg.row;
50 }
51 
52 GEOS_HOST_DEVICE inline ZeroRowSumMessage zeroRowSumMessage( globalIndex const row )
53 {
54 #if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
55  printf( "***** Zero row sum in row %lld\n", static_cast< long long >( row ) );
56 #endif
57  return { static_cast< long long >( row ) };
58 }
59 
60 } // namespace internal
61 
62 template< bool CONST >
63 struct CSRData
64 {
65  HYPRE_Int const * rowptr;
66  HYPRE_Int const * colind;
67  add_const_if_t< HYPRE_Real, CONST > * values;
68  HYPRE_Int nrow;
69  HYPRE_Int ncol;
70  HYPRE_Int nnz;
71 
72  explicit CSRData( add_const_if_t< hypre_CSRMatrix, CONST > * const mat )
73  : rowptr( hypre_CSRMatrixI( mat ) ),
74  colind( hypre_CSRMatrixJ( mat ) ),
75  values( hypre_CSRMatrixData( mat ) ),
76  nrow( hypre_CSRMatrixNumRows( mat ) ),
77  ncol( hypre_CSRMatrixNumCols( mat ) ),
78  nnz( hypre_CSRMatrixNumNonzeros( mat ) )
79  {}
80 };
81 
82 HYPRE_BigInt const * getOffdColumnMap( hypre_ParCSRMatrix const * const mat );
83 
84 void scaleMatrixValues( hypre_CSRMatrix * const mat,
85  real64 const factor );
86 
87 void scaleMatrixRows( hypre_CSRMatrix * const mat,
88  hypre_Vector const * const vec );
89 
90 void clampMatrixEntries( hypre_CSRMatrix * const mat,
91  real64 const lo,
92  real64 const hi,
93  bool const skip_diag );
94 
95 real64 computeMaxNorm( hypre_CSRMatrix const * const mat );
96 
97 real64 computeMaxNorm( hypre_CSRMatrix const * const mat,
98  arrayView1d< globalIndex const > const & rowIndices,
99  globalIndex const firstLocalRow );
100 
101 template< typename F, typename R >
102 void rescaleMatrixRows( hypre_ParCSRMatrix * const mat,
103  arrayView1d< globalIndex const > const & rowIndices,
104  F transform,
105  R reduce )
106 {
107  CSRData< false > diag{ hypre_ParCSRMatrixDiag( mat ) };
108  CSRData< false > offd{ hypre_ParCSRMatrixOffd( mat ) };
109  HYPRE_BigInt const firstLocalRow = hypre_ParCSRMatrixFirstRowIndex( mat );
110 
111  forAll< execPolicy >( rowIndices.size(), [diag, offd, transform, reduce, rowIndices, firstLocalRow] GEOS_HOST_DEVICE ( localIndex const i )
112  {
113  HYPRE_Int const localRow = LvArray::integerConversion< HYPRE_Int >( rowIndices[i] - firstLocalRow );
114  GEOS_ASSERT( 0 <= localRow && localRow < diag.nrow );
115 
116  HYPRE_Real scale = 0.0;
117  for( HYPRE_Int k = diag.rowptr[localRow]; k < diag.rowptr[localRow + 1]; ++k )
118  {
119  scale = reduce( scale, transform( diag.values[k] ) );
120  }
121  if( offd.ncol > 0 )
122  {
123  for( HYPRE_Int k = offd.rowptr[localRow]; k < offd.rowptr[localRow + 1]; ++k )
124  {
125  scale = reduce( scale, transform( offd.values[k] ) );
126  }
127  }
128 
129  GEOS_ASSERT_MSG( !isZero( scale ), internal::zeroRowSumMessage( rowIndices[i] ) );
130  scale = 1.0 / scale;
131  for( HYPRE_Int k = diag.rowptr[localRow]; k < diag.rowptr[localRow + 1]; ++k )
132  {
133  diag.values[k] *= scale;
134  }
135  if( offd.ncol > 0 )
136  {
137  for( HYPRE_Int k = offd.rowptr[localRow]; k < offd.rowptr[localRow + 1]; ++k )
138  {
139  offd.values[k] *= scale;
140  }
141  }
142  } );
143 }
144 
145 template< typename F, typename R >
146 void computeRowsSums( hypre_ParCSRMatrix const * const mat,
147  hypre_ParVector * const vec,
148  F transform,
149  R reduce )
150 {
151  CSRData< true > const diag{ hypre_ParCSRMatrixDiag( mat ) };
152  CSRData< true > const offd{ hypre_ParCSRMatrixOffd( mat ) };
153  HYPRE_Real * const values = hypre_VectorData( hypre_ParVectorLocalVector( vec ) );
154 
155  forAll< execPolicy >( diag.nrow, [diag, offd, transform, reduce, values] GEOS_HOST_DEVICE ( HYPRE_Int const localRow )
156  {
157  HYPRE_Real sum = 0.0;
158  for( HYPRE_Int k = diag.rowptr[localRow]; k < diag.rowptr[localRow + 1]; ++k )
159  {
160  sum = reduce( sum, transform( diag.values[k] ) );
161  }
162  if( offd.ncol )
163  {
164  for( HYPRE_Int k = offd.rowptr[localRow]; k < offd.rowptr[localRow + 1]; ++k )
165  {
166  sum = reduce( sum, transform( offd.values[k] ) );
167  }
168  }
169  values[localRow] = sum;
170  } );
171 }
172 
173 namespace internal
174 {
175 
176 template< typename MAP >
177 void GEOS_HOST_DEVICE
178 makeSortedPermutation( HYPRE_Int const * const indices,
179  HYPRE_Int const size,
180  HYPRE_Int * const perm,
181  MAP map )
182 {
183  for( HYPRE_Int i = 0; i < size; ++i )
184  {
185  perm[i] = i; // std::iota
186  }
187  auto const comp = [indices, map] GEOS_HOST_DEVICE ( HYPRE_Int i, HYPRE_Int j )
188  {
189  return map( indices[i] ) < map( indices[j] );
190  };
191  LvArray::sortedArrayManipulation::makeSorted( perm, perm + size, comp );
192 }
193 
194 } // namespace internal
195 
196 template< typename KERNEL >
197 void addMatrixEntries( hypre_ParCSRMatrix const * const src,
198  hypre_ParCSRMatrix * const dst,
199  real64 const scale )
200 {
201  GEOS_LAI_ASSERT( src != nullptr );
202  GEOS_LAI_ASSERT( dst != nullptr );
203  KERNEL::launch( hypre_ParCSRMatrixDiag( src ),
204  [] GEOS_HOST_DEVICE ( HYPRE_Int const x ){ return x; },
205  hypre_ParCSRMatrixDiag( dst ),
206  [] GEOS_HOST_DEVICE ( HYPRE_Int const x ){ return x; },
207  scale );
208  if( hypre_CSRMatrixNumCols( hypre_ParCSRMatrixOffd( dst ) ) > 0 )
209  {
210  HYPRE_BigInt const * const src_colmap = hypre::getOffdColumnMap( src );
211  HYPRE_BigInt const * const dst_colmap = hypre::getOffdColumnMap( dst );
212  KERNEL::launch( hypre_ParCSRMatrixOffd( src ),
213  [src_colmap] GEOS_HOST_DEVICE ( HYPRE_Int const i ){ return src_colmap[i]; },
214  hypre_ParCSRMatrixOffd( dst ),
215  [dst_colmap] GEOS_HOST_DEVICE ( HYPRE_Int const i ){ return dst_colmap[i]; },
216  scale );
217  }
218 }
219 
220 struct AddEntriesRestrictedKernel
221 {
222  template< typename SRC_COLMAP, typename DST_COLMAP >
223  static void
224  launch( hypre_CSRMatrix const * const src_mat,
225  SRC_COLMAP const src_colmap,
226  hypre_CSRMatrix * const dst_mat,
227  DST_COLMAP const dst_colmap,
228  real64 const scale )
229  {
230  GEOS_LAI_ASSERT( src_mat != nullptr );
231  GEOS_LAI_ASSERT( dst_mat != nullptr );
232 
233  CSRData< true > src{ src_mat };
234  CSRData< false > dst{ dst_mat };
235  GEOS_LAI_ASSERT_EQ( src.nrow, dst.nrow );
236 
237  if( src.ncol == 0 || isZero( scale ) )
238  {
239  return;
240  }
241 
242  // Allocate contiguous memory to store sorted column permutations of each row
243  array1d< HYPRE_Int > const src_permutation_arr( src.nnz );
244  array1d< HYPRE_Int > const dst_permutation_arr( dst.nnz );
245 
246  arrayView1d< HYPRE_Int > const src_permutation = src_permutation_arr.toView();
247  arrayView1d< HYPRE_Int > const dst_permutation = dst_permutation_arr.toView();
248  // Each thread adds one row of src into dst
249  forAll< hypre::execPolicy >( dst.nrow,
250  [src,
251  src_colmap,
252  dst,
253  dst_colmap,
254  scale,
255  src_permutation,
256  dst_permutation ] GEOS_HOST_DEVICE ( HYPRE_Int const localRow )
257  {
258  HYPRE_Int const src_offset = src.rowptr[localRow];
259  HYPRE_Int const src_length = src.rowptr[localRow + 1] - src_offset;
260  HYPRE_Int const * const src_indices = src.colind + src_offset;
261  HYPRE_Real const * const src_values = src.values + src_offset;
262  HYPRE_Int * const src_perm = src_permutation.data() + src_offset;
263 
264  HYPRE_Int const dst_offset = dst.rowptr[localRow];
265  HYPRE_Int const dst_length = dst.rowptr[localRow + 1] - dst_offset;
266  HYPRE_Int const * const dst_indices = dst.colind + dst_offset;
267  HYPRE_Real * const dst_values = dst.values + dst_offset;
268  HYPRE_Int * const dst_perm = dst_permutation.data() + dst_offset;
269 
270  // Since hypre does not store columns in sorted order, create a sorted "view" of src and dst rows
271  // TODO: it would be nice to cache the permutation arrays somewhere to avoid recomputing
272  internal::makeSortedPermutation( src_indices, src_length, src_perm, src_colmap );
273  internal::makeSortedPermutation( dst_indices, dst_length, dst_perm, dst_colmap );
274 
275  // Add entries looping through them in sorted column order, skipping src entries not in dst
276  for( HYPRE_Int i = 0, j = 0; i < dst_length && j < src_length; ++i )
277  {
278  while( j < src_length && src_colmap( src_indices[src_perm[j]] ) < dst_colmap( dst_indices[dst_perm[i]] ) )
279  {
280  ++j;
281  }
282  if( j < src_length && src_colmap( src_indices[src_perm[j]] ) == dst_colmap( dst_indices[dst_perm[i]] ) )
283  {
284  dst_values[dst_perm[i]] += scale * src_values[src_perm[j++]];
285  }
286  }
287  } );
288  }
289 };
290 
291 struct AddEntriesSamePatternKernel
292 {
293  template< typename SRC_COLMAP, typename DST_COLMAP >
294  static void
295  launch( hypre_CSRMatrix const * const src_mat,
296  SRC_COLMAP const src_colmap,
297  hypre_CSRMatrix * const dst_mat,
298  DST_COLMAP const dst_colmap,
299  real64 const scale )
300  {
301  GEOS_LAI_ASSERT( src_mat != nullptr );
302  GEOS_LAI_ASSERT( dst_mat != nullptr );
303 
304  CSRData< true > src{ src_mat };
305  CSRData< false > dst{ dst_mat };
306  GEOS_LAI_ASSERT_EQ( src.nrow, dst.nrow );
307 
308  if( src.ncol == 0 || isZero( scale ) )
309  {
310  return;
311  }
312 
313  // Each thread adds one row of src into dst
314  forAll< hypre::execPolicy >( dst.nrow,
315  [src, src_colmap, dst, dst_colmap, scale] GEOS_HOST_DEVICE ( HYPRE_Int const localRow )
316  {
317  HYPRE_Int const src_offset = src.rowptr[localRow];
318  HYPRE_Int const src_length = src.rowptr[localRow + 1] - src_offset;
319  HYPRE_Int const * const src_indices = src.colind + src_offset;
320  HYPRE_Real const * const src_values = src.values + src_offset;
321 
322  HYPRE_Int const dst_offset = dst.rowptr[localRow];
323  HYPRE_Int const dst_length = dst.rowptr[localRow + 1] - dst_offset;
324  HYPRE_Int const * const dst_indices = dst.colind + dst_offset;
325  HYPRE_Real * const dst_values = dst.values + dst_offset;
326 
327  GEOS_ASSERT_EQ( src_offset, dst_offset );
328  GEOS_ASSERT_EQ( src_length, dst_length );
329  GEOS_DEBUG_VAR( src_length, dst_length, src_indices, dst_indices, src_colmap, dst_colmap );
330 
331  // NOTE: this assumes that entries are in the exact same order, to avoid creating a sorted view
332  for( HYPRE_Int i = 0; i < dst_length; ++i )
333  {
334  GEOS_ASSERT_EQ( src_colmap( src_indices[i] ), dst_colmap( dst_indices[i] ) );
335  dst_values[i] += scale * src_values[i];
336  }
337  } );
338  }
339 };
340 
342 
343 } // namespace hypre
344 } // namespace geos
345 
346 #endif //GEOS_LINEARALGEBRA_INTERFACES_HYPREKERNELS_HPP_
#define GEOS_HOST_DEVICE
Marks a host-device function.
Definition: GeosxMacros.hpp:49
#define GEOS_DEBUG_VAR(...)
Mark a debug variable and silence compiler warnings.
#define GEOS_ASSERT_MSG(COND,...)
Abort execution if COND is false but only when NDEBUG is not defined..
Definition: Logger.hpp:837
#define GEOS_ASSERT(COND)
Assert a condition in debug builds.
Definition: Logger.hpp:868
#define GEOS_ASSERT_EQ(lhs, rhs)
Assert that two values compare equal in debug builds.
Definition: Logger.hpp:885
#define GEOS_LAI_ASSERT_EQ(lhs, rhs)
Definition: common.hpp:49
#define GEOS_LAI_ASSERT(expr)
Definition: common.hpp:35
GEOS_GLOBALINDEX_TYPE globalIndex
Global index type (for indexing objects across MPI partitions).
Definition: DataTypes.hpp:87
double real64
64-bit floating point type.
Definition: DataTypes.hpp:98
GEOS_LOCALINDEX_TYPE localIndex
Local index type (for indexing objects within an MPI partition).
Definition: DataTypes.hpp:84
std::ostream & operator<<(std::ostream &stream, mapType< K, V, SORTED > const &map)
Stream output operator for map types.
Definition: DataTypes.hpp:326
mapBase< TKEY, TVAL, std::integral_constant< bool, true > > map
Ordered map type.
Definition: DataTypes.hpp:339