GEOS
HypreKernels.hpp
Go to the documentation of this file.
1 /*
2  * ------------------------------------------------------------------------------------------------------------
3  * SPDX-License-Identifier: LGPL-2.1-only
4  *
5  * Copyright (c) 2016-2024 Lawrence Livermore National Security LLC
6  * Copyright (c) 2018-2024 TotalEnergies
7  * Copyright (c) 2018-2024 The Board of Trustees of the Leland Stanford Junior University
8  * Copyright (c) 2023-2024 Chevron
9  * Copyright (c) 2019- GEOS/GEOSX Contributors
10  * All rights reserved
11  *
12  * See top level LICENSE, COPYRIGHT, CONTRIBUTORS, NOTICE, and ACKNOWLEDGEMENTS files for details.
13  * ------------------------------------------------------------------------------------------------------------
14  */
15 
20 #ifndef GEOS_LINEARALGEBRA_INTERFACES_HYPREKERNELS_HPP_
21 #define GEOS_LINEARALGEBRA_INTERFACES_HYPREKERNELS_HPP_
22 
23 #include "codingUtilities/Utilities.hpp"
24 #include "codingUtilities/traits.hpp"
25 #include "common/DataTypes.hpp"
26 #include "common/GEOS_RAJA_Interface.hpp"
29 
30 #include <_hypre_parcsr_mv.h>
31 
32 namespace geos
33 {
34 namespace hypre
35 {
36 
38 
39 namespace ops
40 {
41 
42 template< typename T >
44 constexpr T identity( T const v )
45 {
46  return v;
47 }
48 
49 template< typename T >
51 constexpr T plus( T const lhs, T const rhs )
52 {
53  return lhs + rhs;
54 }
55 
56 }
57 
58 template< bool CONST >
59 struct CSRData
60 {
61  HYPRE_Int const * rowptr;
62  HYPRE_Int const * colind;
63  add_const_if_t< HYPRE_Real, CONST > * values;
64  HYPRE_Int nrow;
65  HYPRE_Int ncol;
66  HYPRE_Int nnz;
67 
68  explicit CSRData( add_const_if_t< hypre_CSRMatrix, CONST > * const mat )
69  : rowptr( hypre_CSRMatrixI( mat ) ),
70  colind( hypre_CSRMatrixJ( mat ) ),
71  values( hypre_CSRMatrixData( mat ) ),
72  nrow( hypre_CSRMatrixNumRows( mat ) ),
73  ncol( hypre_CSRMatrixNumCols( mat ) ),
74  nnz( hypre_CSRMatrixNumNonzeros( mat ) )
75  {}
76 };
77 
78 HYPRE_BigInt const * getOffdColumnMap( hypre_ParCSRMatrix const * const mat );
79 
80 void scaleMatrixValues( hypre_CSRMatrix * const mat,
81  real64 const factor );
82 
83 void scaleMatrixRows( hypre_CSRMatrix * const mat,
84  hypre_Vector const * const vec );
85 
86 void clampMatrixEntries( hypre_CSRMatrix * const mat,
87  real64 const lo,
88  real64 const hi,
89  bool const skip_diag );
90 
91 real64 computeMaxNorm( hypre_CSRMatrix const * const mat );
92 
93 real64 computeMaxNorm( hypre_CSRMatrix const * const mat,
94  arrayView1d< globalIndex const > const & rowIndices,
95  globalIndex const firstLocalRow );
96 
97 namespace internal
98 {
99 
101 template< typename F, typename R >
102 struct RowReducer
103 {
104  F transform;
105  R reduce;
106 
108  operator()( double acc, double v ) const
109  {
110  return reduce( acc, transform( v ) );
111  }
112 };
113 
114 } // namespace internal
115 
116 template< typename F, typename R >
117 void rescaleMatrixRows( hypre_ParCSRMatrix * const mat,
118  arrayView1d< globalIndex const > const & rowIndices,
119  F transform,
120  R reduce )
121 {
122  CSRData< false > diag{ hypre_ParCSRMatrixDiag( mat ) };
123  CSRData< false > offd{ hypre_ParCSRMatrixOffd( mat ) };
124  HYPRE_BigInt const firstLocalRow = hypre_ParCSRMatrixFirstRowIndex( mat );
125  internal::RowReducer< F, R > reducer{ std::move( transform ), std::move( reduce ) };
126 
127  forAll< execPolicy >( rowIndices.size(), [diag, offd, reducer, rowIndices, firstLocalRow] GEOS_HYPRE_HOST_DEVICE ( localIndex const i )
128  {
129  HYPRE_Int const localRow = LvArray::integerConversion< HYPRE_Int >( rowIndices[i] - firstLocalRow );
130  GEOS_ASSERT( 0 <= localRow && localRow < diag.nrow );
131 
132  HYPRE_Real scale = 0.0;
133  for( HYPRE_Int k = diag.rowptr[localRow]; k < diag.rowptr[localRow + 1]; ++k )
134  {
135  scale = reducer( scale, diag.values[k] );
136  }
137  if( offd.ncol > 0 )
138  {
139  for( HYPRE_Int k = offd.rowptr[localRow]; k < offd.rowptr[localRow + 1]; ++k )
140  {
141  scale = reducer( scale, offd.values[k] );
142  }
143  }
144 
145  GEOS_ASSERT_MSG( !isZero( scale ), "Zero row sum in row " << rowIndices[i] );
146  scale = 1.0 / scale;
147  for( HYPRE_Int k = diag.rowptr[localRow]; k < diag.rowptr[localRow + 1]; ++k )
148  {
149  diag.values[k] *= scale;
150  }
151  if( offd.ncol > 0 )
152  {
153  for( HYPRE_Int k = offd.rowptr[localRow]; k < offd.rowptr[localRow + 1]; ++k )
154  {
155  offd.values[k] *= scale;
156  }
157  }
158  } );
159 }
160 
161 template< typename F, typename R >
162 void computeRowsSums( hypre_ParCSRMatrix const * const mat,
163  hypre_ParVector * const vec,
164  F transform,
165  R reduce )
166 {
167  CSRData< true > const diag{ hypre_ParCSRMatrixDiag( mat ) };
168  CSRData< true > const offd{ hypre_ParCSRMatrixOffd( mat ) };
169  HYPRE_Real * const values = hypre_VectorData( hypre_ParVectorLocalVector( vec ) );
170  internal::RowReducer< F, R > reducer{ std::move( transform ), std::move( reduce ) };
171 
172  forAll< execPolicy >( diag.nrow, [diag, offd, reducer, values] GEOS_HYPRE_HOST_DEVICE ( HYPRE_Int const localRow )
173  {
174  HYPRE_Real sum = 0.0;
175  for( HYPRE_Int k = diag.rowptr[localRow]; k < diag.rowptr[localRow + 1]; ++k )
176  {
177  sum = reducer( sum, diag.values[k] );
178  }
179  if( offd.ncol )
180  {
181  for( HYPRE_Int k = offd.rowptr[localRow]; k < offd.rowptr[localRow + 1]; ++k )
182  {
183  sum = reducer( sum, offd.values[k] );
184  }
185  }
186  values[localRow] = sum;
187  } );
188 }
189 
190 namespace internal
191 {
192 
193 template< typename MAP >
195 makeSortedPermutation( HYPRE_Int const * const indices,
196  HYPRE_Int const size,
197  HYPRE_Int * const perm,
198  MAP map )
199 {
200  for( HYPRE_Int i = 0; i < size; ++i )
201  {
202  perm[i] = i; // std::iota
203  }
204  auto const comp = [indices, map] GEOS_HYPRE_HOST_DEVICE ( HYPRE_Int i, HYPRE_Int j )
205  {
206  return map( indices[i] ) < map( indices[j] );
207  };
208  LvArray::sortedArrayManipulation::makeSorted( perm, perm + size, comp );
209 }
210 
211 } // namespace internal
212 
213 template< typename KERNEL >
214 void addMatrixEntries( hypre_ParCSRMatrix const * const src,
215  hypre_ParCSRMatrix * const dst,
216  real64 const scale )
217 {
218  GEOS_LAI_ASSERT( src != nullptr );
219  GEOS_LAI_ASSERT( dst != nullptr );
220  KERNEL::launch( hypre_ParCSRMatrixDiag( src ),
221  hypre::ops::identity< HYPRE_Int >,
222  hypre_ParCSRMatrixDiag( dst ),
223  hypre::ops::identity< HYPRE_Int >,
224  scale );
225  if( hypre_CSRMatrixNumCols( hypre_ParCSRMatrixOffd( dst ) ) > 0 )
226  {
227  HYPRE_BigInt const * const src_colmap = hypre::getOffdColumnMap( src );
228  HYPRE_BigInt const * const dst_colmap = hypre::getOffdColumnMap( dst );
229  KERNEL::launch( hypre_ParCSRMatrixOffd( src ),
230  [src_colmap] GEOS_HYPRE_DEVICE ( auto i ){ return src_colmap[i]; },
231  hypre_ParCSRMatrixOffd( dst ),
232  [dst_colmap] GEOS_HYPRE_DEVICE ( auto i ){ return dst_colmap[i]; },
233  scale );
234  }
235 }
236 
237 struct AddEntriesRestrictedKernel
238 {
239  template< typename SRC_COLMAP, typename DST_COLMAP >
240  static void
241  launch( hypre_CSRMatrix const * const src_mat,
242  SRC_COLMAP const src_colmap,
243  hypre_CSRMatrix * const dst_mat,
244  DST_COLMAP const dst_colmap,
245  real64 const scale )
246  {
247  GEOS_LAI_ASSERT( src_mat != nullptr );
248  GEOS_LAI_ASSERT( dst_mat != nullptr );
249 
250  CSRData< true > src{ src_mat };
251  CSRData< false > dst{ dst_mat };
252  GEOS_LAI_ASSERT_EQ( src.nrow, dst.nrow );
253 
254  if( src.ncol == 0 || isZero( scale ) )
255  {
256  return;
257  }
258 
259  // Allocate contiguous memory to store sorted column permutations of each row
260  array1d< HYPRE_Int > const src_permutation_arr( src.nnz );
261  array1d< HYPRE_Int > const dst_permutation_arr( dst.nnz );
262 
263  arrayView1d< HYPRE_Int > const src_permutation = src_permutation_arr.toView();
264  arrayView1d< HYPRE_Int > const dst_permutation = dst_permutation_arr.toView();
265  // Each thread adds one row of src into dst
266  forAll< hypre::execPolicy >( dst.nrow,
267  [src,
268  src_colmap,
269  dst,
270  dst_colmap,
271  scale,
272  src_permutation,
273  dst_permutation ] GEOS_HYPRE_DEVICE ( HYPRE_Int const localRow )
274  {
275  HYPRE_Int const src_offset = src.rowptr[localRow];
276  HYPRE_Int const src_length = src.rowptr[localRow + 1] - src_offset;
277  HYPRE_Int const * const src_indices = src.colind + src_offset;
278  HYPRE_Real const * const src_values = src.values + src_offset;
279  HYPRE_Int * const src_perm = src_permutation.data() + src_offset;
280 
281  HYPRE_Int const dst_offset = dst.rowptr[localRow];
282  HYPRE_Int const dst_length = dst.rowptr[localRow + 1] - dst_offset;
283  HYPRE_Int const * const dst_indices = dst.colind + dst_offset;
284  HYPRE_Real * const dst_values = dst.values + dst_offset;
285  HYPRE_Int * const dst_perm = dst_permutation.data() + dst_offset;
286 
287  // Since hypre does not store columns in sorted order, create a sorted "view" of src and dst rows
288  // TODO: it would be nice to cache the permutation arrays somewhere to avoid recomputing
289  internal::makeSortedPermutation( src_indices, src_length, src_perm, src_colmap );
290  internal::makeSortedPermutation( dst_indices, dst_length, dst_perm, dst_colmap );
291 
292  // Add entries looping through them in sorted column order, skipping src entries not in dst
293  for( HYPRE_Int i = 0, j = 0; i < dst_length && j < src_length; ++i )
294  {
295  while( j < src_length && src_colmap( src_indices[src_perm[j]] ) < dst_colmap( dst_indices[dst_perm[i]] ) )
296  {
297  ++j;
298  }
299  if( j < src_length && src_colmap( src_indices[src_perm[j]] ) == dst_colmap( dst_indices[dst_perm[i]] ) )
300  {
301  dst_values[dst_perm[i]] += scale * src_values[src_perm[j++]];
302  }
303  }
304  } );
305  }
306 };
307 
308 struct AddEntriesSamePatternKernel
309 {
310  template< typename SRC_COLMAP, typename DST_COLMAP >
311  static void
312  launch( hypre_CSRMatrix const * const src_mat,
313  SRC_COLMAP const src_colmap,
314  hypre_CSRMatrix * const dst_mat,
315  DST_COLMAP const dst_colmap,
316  real64 const scale )
317  {
318  GEOS_LAI_ASSERT( src_mat != nullptr );
319  GEOS_LAI_ASSERT( dst_mat != nullptr );
320 
321  CSRData< true > src{ src_mat };
322  CSRData< false > dst{ dst_mat };
323  GEOS_LAI_ASSERT_EQ( src.nrow, dst.nrow );
324 
325  if( src.ncol == 0 || isZero( scale ) )
326  {
327  return;
328  }
329 
330  // Each thread adds one row of src into dst
331  forAll< hypre::execPolicy >( dst.nrow,
332  [src, src_colmap, dst, dst_colmap, scale] GEOS_HYPRE_DEVICE ( HYPRE_Int const localRow )
333  {
334  HYPRE_Int const src_offset = src.rowptr[localRow];
335  HYPRE_Int const src_length = src.rowptr[localRow + 1] - src_offset;
336  HYPRE_Int const * const src_indices = src.colind + src_offset;
337  HYPRE_Real const * const src_values = src.values + src_offset;
338 
339  HYPRE_Int const dst_offset = dst.rowptr[localRow];
340  HYPRE_Int const dst_length = dst.rowptr[localRow + 1] - dst_offset;
341  HYPRE_Int const * const dst_indices = dst.colind + dst_offset;
342  HYPRE_Real * const dst_values = dst.values + dst_offset;
343 
344  GEOS_ASSERT_EQ( src_offset, dst_offset );
345  GEOS_ASSERT_EQ( src_length, dst_length );
346  GEOS_DEBUG_VAR( src_length, dst_length, src_indices, dst_indices, src_colmap, dst_colmap );
347 
348  // NOTE: this assumes that entries are in the exact same order, to avoid creating a sorted view
349  for( HYPRE_Int i = 0; i < dst_length; ++i )
350  {
351  GEOS_ASSERT_EQ( src_colmap( src_indices[i] ), dst_colmap( dst_indices[i] ) );
352  dst_values[i] += scale * src_values[i];
353  }
354  } );
355  }
356 };
357 
359 
360 } // namespace hypre
361 } // namespace geos
362 
363 #endif //GEOS_LINEARALGEBRA_INTERFACES_HYPREKERNELS_HPP_
#define GEOS_DEBUG_VAR(...)
Mark a debug variable and silence compiler warnings.
Definition: GeosxMacros.hpp:87
#define GEOS_HYPRE_DEVICE
Host-device marker for custom hypre kernels.
Definition: HypreUtils.hpp:34
#define GEOS_HYPRE_HOST_DEVICE
Host-device marker for custom hypre kernels.
Definition: HypreUtils.hpp:36
#define GEOS_ASSERT(EXP)
Assert a condition in debug builds.
Definition: Logger.hpp:177
#define GEOS_ASSERT_MSG(EXP, msg)
Assert a condition in debug builds.
Definition: Logger.hpp:171
#define GEOS_ASSERT_EQ(lhs, rhs)
Assert that two values compare equal in debug builds.
Definition: Logger.hpp:410
Base template for ordered and unordered maps.
#define GEOS_LAI_ASSERT_EQ(lhs, rhs)
Definition: common.hpp:49
#define GEOS_LAI_ASSERT(expr)
Definition: common.hpp:35
ArrayView< T, 1 > arrayView1d
Alias for 1D array view.
Definition: DataTypes.hpp:179
GEOS_GLOBALINDEX_TYPE globalIndex
Global index type (for indexing objects across MPI partitions).
Definition: DataTypes.hpp:87
double real64
64-bit floating point type.
Definition: DataTypes.hpp:98
GEOS_LOCALINDEX_TYPE localIndex
Local index type (for indexing objects within an MPI partition).
Definition: DataTypes.hpp:84
mapBase< TKEY, TVAL, std::integral_constant< bool, true > > map
Ordered map type.
Definition: DataTypes.hpp:339