GEOS
HypreKernels.hpp
Go to the documentation of this file.
1 /*
2  * ------------------------------------------------------------------------------------------------------------
3  * SPDX-License-Identifier: LGPL-2.1-only
4  *
5  * Copyright (c) 2016-2024 Lawrence Livermore National Security LLC
6  * Copyright (c) 2018-2024 TotalEnergies
7  * Copyright (c) 2018-2024 The Board of Trustees of the Leland Stanford Junior University
8  * Copyright (c) 2023-2024 Chevron
9  * Copyright (c) 2019- GEOS/GEOSX Contributors
10  * All rights reserved
11  *
12  * See top level LICENSE, COPYRIGHT, CONTRIBUTORS, NOTICE, and ACKNOWLEDGEMENTS files for details.
13  * ------------------------------------------------------------------------------------------------------------
14  */
15 
20 #ifndef GEOS_LINEARALGEBRA_INTERFACES_HYPREKERNELS_HPP_
21 #define GEOS_LINEARALGEBRA_INTERFACES_HYPREKERNELS_HPP_
22 
23 #include "codingUtilities/Utilities.hpp"
24 #include "codingUtilities/traits.hpp"
25 #include "common/DataTypes.hpp"
26 #include "common/GEOS_RAJA_Interface.hpp"
29 
30 #include <_hypre_parcsr_mv.h>
31 
32 namespace geos
33 {
34 namespace hypre
35 {
36 
38 
39 template< bool CONST >
40 struct CSRData
41 {
42  HYPRE_Int const * rowptr;
43  HYPRE_Int const * colind;
44  add_const_if_t< HYPRE_Real, CONST > * values;
45  HYPRE_Int nrow;
46  HYPRE_Int ncol;
47  HYPRE_Int nnz;
48 
49  explicit CSRData( add_const_if_t< hypre_CSRMatrix, CONST > * const mat )
50  : rowptr( hypre_CSRMatrixI( mat ) ),
51  colind( hypre_CSRMatrixJ( mat ) ),
52  values( hypre_CSRMatrixData( mat ) ),
53  nrow( hypre_CSRMatrixNumRows( mat ) ),
54  ncol( hypre_CSRMatrixNumCols( mat ) ),
55  nnz( hypre_CSRMatrixNumNonzeros( mat ) )
56  {}
57 };
58 
59 HYPRE_BigInt const * getOffdColumnMap( hypre_ParCSRMatrix const * const mat );
60 
61 void scaleMatrixValues( hypre_CSRMatrix * const mat,
62  real64 const factor );
63 
64 void scaleMatrixRows( hypre_CSRMatrix * const mat,
65  hypre_Vector const * const vec );
66 
67 void clampMatrixEntries( hypre_CSRMatrix * const mat,
68  real64 const lo,
69  real64 const hi,
70  bool const skip_diag );
71 
72 real64 computeMaxNorm( hypre_CSRMatrix const * const mat );
73 
74 real64 computeMaxNorm( hypre_CSRMatrix const * const mat,
75  arrayView1d< globalIndex const > const & rowIndices,
76  globalIndex const firstLocalRow );
77 
78 template< typename F, typename R >
79 void rescaleMatrixRows( hypre_ParCSRMatrix * const mat,
80  arrayView1d< globalIndex const > const & rowIndices,
81  F transform,
82  R reduce )
83 {
84  CSRData< false > diag{ hypre_ParCSRMatrixDiag( mat ) };
85  CSRData< false > offd{ hypre_ParCSRMatrixOffd( mat ) };
86  HYPRE_BigInt const firstLocalRow = hypre_ParCSRMatrixFirstRowIndex( mat );
87 
88  forAll< execPolicy >( rowIndices.size(), [diag, offd, transform, reduce, rowIndices, firstLocalRow] GEOS_HYPRE_DEVICE ( localIndex const i )
89  {
90  HYPRE_Int const localRow = LvArray::integerConversion< HYPRE_Int >( rowIndices[i] - firstLocalRow );
91  GEOS_ASSERT( 0 <= localRow && localRow < diag.nrow );
92 
93  HYPRE_Real scale = 0.0;
94  for( HYPRE_Int k = diag.rowptr[localRow]; k < diag.rowptr[localRow + 1]; ++k )
95  {
96  scale = reduce( scale, transform( diag.values[k] ) );
97  }
98  if( offd.ncol > 0 )
99  {
100  for( HYPRE_Int k = offd.rowptr[localRow]; k < offd.rowptr[localRow + 1]; ++k )
101  {
102  scale = reduce( scale, transform( offd.values[k] ) );
103  }
104  }
105 
106  GEOS_ASSERT_MSG( !isZero( scale ), "Zero row sum in row " << rowIndices[i] );
107  scale = 1.0 / scale;
108  for( HYPRE_Int k = diag.rowptr[localRow]; k < diag.rowptr[localRow + 1]; ++k )
109  {
110  diag.values[k] *= scale;
111  }
112  if( offd.ncol > 0 )
113  {
114  for( HYPRE_Int k = offd.rowptr[localRow]; k < offd.rowptr[localRow + 1]; ++k )
115  {
116  offd.values[k] *= scale;
117  }
118  }
119  } );
120 }
121 
122 template< typename F, typename R >
123 void computeRowsSums( hypre_ParCSRMatrix const * const mat,
124  hypre_ParVector * const vec,
125  F transform,
126  R reduce )
127 {
128  CSRData< true > const diag{ hypre_ParCSRMatrixDiag( mat ) };
129  CSRData< true > const offd{ hypre_ParCSRMatrixOffd( mat ) };
130  HYPRE_Real * const values = hypre_VectorData( hypre_ParVectorLocalVector( vec ) );
131 
132  forAll< execPolicy >( diag.nrow, [diag, offd, transform, reduce, values] GEOS_HYPRE_DEVICE ( HYPRE_Int const localRow )
133  {
134  HYPRE_Real sum = 0.0;
135  for( HYPRE_Int k = diag.rowptr[localRow]; k < diag.rowptr[localRow + 1]; ++k )
136  {
137  sum = reduce( sum, transform( diag.values[k] ) );
138  }
139  if( offd.ncol )
140  {
141  for( HYPRE_Int k = offd.rowptr[localRow]; k < offd.rowptr[localRow + 1]; ++k )
142  {
143  sum = reduce( sum, transform( offd.values[k] ) );
144  }
145  }
146  values[localRow] = sum;
147  } );
148 }
149 
150 namespace internal
151 {
152 
153 template< typename MAP >
155 makeSortedPermutation( HYPRE_Int const * const indices,
156  HYPRE_Int const size,
157  HYPRE_Int * const perm,
158  MAP map )
159 {
160  for( HYPRE_Int i = 0; i < size; ++i )
161  {
162  perm[i] = i; // std::iota
163  }
164  auto const comp = [indices, map] GEOS_HYPRE_DEVICE ( HYPRE_Int i, HYPRE_Int j )
165  {
166  return map( indices[i] ) < map( indices[j] );
167  };
168  LvArray::sortedArrayManipulation::makeSorted( perm, perm + size, comp );
169 }
170 
171 } // namespace internal
172 
173 template< typename KERNEL >
174 void addMatrixEntries( hypre_ParCSRMatrix const * const src,
175  hypre_ParCSRMatrix * const dst,
176  real64 const scale )
177 {
178  GEOS_LAI_ASSERT( src != nullptr );
179  GEOS_LAI_ASSERT( dst != nullptr );
180  KERNEL::launch( hypre_ParCSRMatrixDiag( src ),
181  [] GEOS_HYPRE_DEVICE ( auto x ){ return x; },
182  hypre_ParCSRMatrixDiag( dst ),
183  [] GEOS_HYPRE_DEVICE ( auto x ){ return x; },
184  scale );
185  if( hypre_CSRMatrixNumCols( hypre_ParCSRMatrixOffd( dst ) ) > 0 )
186  {
187  HYPRE_BigInt const * const src_colmap = hypre::getOffdColumnMap( src );
188  HYPRE_BigInt const * const dst_colmap = hypre::getOffdColumnMap( dst );
189  KERNEL::launch( hypre_ParCSRMatrixOffd( src ),
190  [src_colmap] GEOS_HYPRE_DEVICE ( auto i ){ return src_colmap[i]; },
191  hypre_ParCSRMatrixOffd( dst ),
192  [dst_colmap] GEOS_HYPRE_DEVICE ( auto i ){ return dst_colmap[i]; },
193  scale );
194  }
195 }
196 
197 struct AddEntriesRestrictedKernel
198 {
199  template< typename SRC_COLMAP, typename DST_COLMAP >
200  static void
201  launch( hypre_CSRMatrix const * const src_mat,
202  SRC_COLMAP const src_colmap,
203  hypre_CSRMatrix * const dst_mat,
204  DST_COLMAP const dst_colmap,
205  real64 const scale )
206  {
207  GEOS_LAI_ASSERT( src_mat != nullptr );
208  GEOS_LAI_ASSERT( dst_mat != nullptr );
209 
210  CSRData< true > src{ src_mat };
211  CSRData< false > dst{ dst_mat };
212  GEOS_LAI_ASSERT_EQ( src.nrow, dst.nrow );
213 
214  if( src.ncol == 0 || isZero( scale ) )
215  {
216  return;
217  }
218 
219  // Allocate contiguous memory to store sorted column permutations of each row
220  array1d< HYPRE_Int > const src_permutation_arr( src.nnz );
221  array1d< HYPRE_Int > const dst_permutation_arr( dst.nnz );
222 
223  arrayView1d< HYPRE_Int > const src_permutation = src_permutation_arr.toView();
224  arrayView1d< HYPRE_Int > const dst_permutation = dst_permutation_arr.toView();
225  // Each thread adds one row of src into dst
226  forAll< hypre::execPolicy >( dst.nrow,
227  [src,
228  src_colmap,
229  dst,
230  dst_colmap,
231  scale,
232  src_permutation,
233  dst_permutation ] GEOS_HYPRE_DEVICE ( HYPRE_Int const localRow )
234  {
235  HYPRE_Int const src_offset = src.rowptr[localRow];
236  HYPRE_Int const src_length = src.rowptr[localRow + 1] - src_offset;
237  HYPRE_Int const * const src_indices = src.colind + src_offset;
238  HYPRE_Real const * const src_values = src.values + src_offset;
239  HYPRE_Int * const src_perm = src_permutation.data() + src_offset;
240 
241  HYPRE_Int const dst_offset = dst.rowptr[localRow];
242  HYPRE_Int const dst_length = dst.rowptr[localRow + 1] - dst_offset;
243  HYPRE_Int const * const dst_indices = dst.colind + dst_offset;
244  HYPRE_Real * const dst_values = dst.values + dst_offset;
245  HYPRE_Int * const dst_perm = dst_permutation.data() + dst_offset;
246 
247  // Since hypre does not store columns in sorted order, create a sorted "view" of src and dst rows
248  // TODO: it would be nice to cache the permutation arrays somewhere to avoid recomputing
249  internal::makeSortedPermutation( src_indices, src_length, src_perm, src_colmap );
250  internal::makeSortedPermutation( dst_indices, dst_length, dst_perm, dst_colmap );
251 
252  // Add entries looping through them in sorted column order, skipping src entries not in dst
253  for( HYPRE_Int i = 0, j = 0; i < dst_length && j < src_length; ++i )
254  {
255  while( j < src_length && src_colmap( src_indices[src_perm[j]] ) < dst_colmap( dst_indices[dst_perm[i]] ) )
256  {
257  ++j;
258  }
259  if( j < src_length && src_colmap( src_indices[src_perm[j]] ) == dst_colmap( dst_indices[dst_perm[i]] ) )
260  {
261  dst_values[dst_perm[i]] += scale * src_values[src_perm[j++]];
262  }
263  }
264  } );
265  }
266 };
267 
268 struct AddEntriesSamePatternKernel
269 {
270  template< typename SRC_COLMAP, typename DST_COLMAP >
271  static void
272  launch( hypre_CSRMatrix const * const src_mat,
273  SRC_COLMAP const src_colmap,
274  hypre_CSRMatrix * const dst_mat,
275  DST_COLMAP const dst_colmap,
276  real64 const scale )
277  {
278  GEOS_LAI_ASSERT( src_mat != nullptr );
279  GEOS_LAI_ASSERT( dst_mat != nullptr );
280 
281  CSRData< true > src{ src_mat };
282  CSRData< false > dst{ dst_mat };
283  GEOS_LAI_ASSERT_EQ( src.nrow, dst.nrow );
284 
285  if( src.ncol == 0 || isZero( scale ) )
286  {
287  return;
288  }
289 
290  // Each thread adds one row of src into dst
291  forAll< hypre::execPolicy >( dst.nrow,
292  [src, src_colmap, dst, dst_colmap, scale] GEOS_HYPRE_DEVICE ( HYPRE_Int const localRow )
293  {
294  HYPRE_Int const src_offset = src.rowptr[localRow];
295  HYPRE_Int const src_length = src.rowptr[localRow + 1] - src_offset;
296  HYPRE_Int const * const src_indices = src.colind + src_offset;
297  HYPRE_Real const * const src_values = src.values + src_offset;
298 
299  HYPRE_Int const dst_offset = dst.rowptr[localRow];
300  HYPRE_Int const dst_length = dst.rowptr[localRow + 1] - dst_offset;
301  HYPRE_Int const * const dst_indices = dst.colind + dst_offset;
302  HYPRE_Real * const dst_values = dst.values + dst_offset;
303 
304  GEOS_ASSERT_EQ( src_offset, dst_offset );
305  GEOS_ASSERT_EQ( src_length, dst_length );
306  GEOS_DEBUG_VAR( src_length, dst_length, src_indices, dst_indices, src_colmap, dst_colmap );
307 
308  // NOTE: this assumes that entries are in the exact same order, to avoid creating a sorted view
309  for( HYPRE_Int i = 0; i < dst_length; ++i )
310  {
311  GEOS_ASSERT_EQ( src_colmap( src_indices[i] ), dst_colmap( dst_indices[i] ) );
312  dst_values[i] += scale * src_values[i];
313  }
314  } );
315  }
316 };
317 
319 
320 } // namespace hypre
321 } // namespace geos
322 
323 #endif //GEOS_LINEARALGEBRA_INTERFACES_HYPREKERNELS_HPP_
#define GEOS_DEBUG_VAR(...)
Mark a debug variable and silence compiler warnings.
Definition: GeosxMacros.hpp:87
#define GEOS_HYPRE_DEVICE
Device marker for custom hypre kernels.
Definition: HypreUtils.hpp:34
#define GEOS_ASSERT(EXP)
Assert a condition in debug builds.
Definition: Logger.hpp:177
#define GEOS_ASSERT_MSG(EXP, msg)
Assert a condition in debug builds.
Definition: Logger.hpp:171
#define GEOS_ASSERT_EQ(lhs, rhs)
Assert that two values compare equal in debug builds.
Definition: Logger.hpp:410
Base template for ordered and unordered maps.
#define GEOS_LAI_ASSERT_EQ(lhs, rhs)
Definition: common.hpp:49
#define GEOS_LAI_ASSERT(expr)
Definition: common.hpp:35
ArrayView< T, 1 > arrayView1d
Alias for 1D array view.
Definition: DataTypes.hpp:179
GEOS_GLOBALINDEX_TYPE globalIndex
Global index type (for indexing objects across MPI partitions).
Definition: DataTypes.hpp:87
double real64
64-bit floating point type.
Definition: DataTypes.hpp:98
GEOS_LOCALINDEX_TYPE localIndex
Local index type (for indexing objects within an MPI partition).
Definition: DataTypes.hpp:84
mapBase< TKEY, TVAL, std::integral_constant< bool, true > > map
Ordered map type.
Definition: DataTypes.hpp:339