GEOS
KrylovSolver.hpp
1 /*
2  * ------------------------------------------------------------------------------------------------------------
3  * SPDX-License-Identifier: LGPL-2.1-only
4  *
5  * Copyright (c) 2016-2024 Lawrence Livermore National Security LLC
6  * Copyright (c) 2018-2024 TotalEnergies
7  * Copyright (c) 2018-2024 The Board of Trustees of the Leland Stanford Junior University
8  * Copyright (c) 2023-2024 Chevron
9  * Copyright (c) 2019- GEOS/GEOSX Contributors
10  * All rights reserved
11  *
12  * See top level LICENSE, COPYRIGHT, CONTRIBUTORS, NOTICE, and ACKNOWLEDGEMENTS files for details.
13  * ------------------------------------------------------------------------------------------------------------
14  */
15 
16 #ifndef GEOS_LINEARALGEBRA_SOLVERS_KRYLOVSOLVER_HPP_
17 #define GEOS_LINEARALGEBRA_SOLVERS_KRYLOVSOLVER_HPP_
18 
24 
25 namespace geos
26 {
27 
32 template< typename VECTOR >
33 class KrylovSolver : public LinearOperator< VECTOR >
34 {
35 public:
36 
39 
41  using Vector = typename Base::Vector;
42 
50  static std::unique_ptr< KrylovSolver< VECTOR > > create( LinearSolverParameters const & parameters,
51  LinearOperator< VECTOR > const & matrix,
52  LinearOperator< VECTOR > const & precond );
53 
61  LinearOperator< Vector > const & matrix,
62  LinearOperator< Vector > const & precond );
63 
67  virtual ~KrylovSolver() override = default;
68 
74  virtual void solve( Vector const & b, Vector & x ) const = 0;
75 
76 
83  virtual void apply( Vector const & src, Vector & dst ) const override final
84  {
85  solve( src, dst );
86  }
87 
88  virtual globalIndex numGlobalRows() const override final
89  {
90  return m_operator.numGlobalRows();
91  }
92 
93  virtual globalIndex numGlobalCols() const override final
94  {
95  return m_operator.numGlobalCols();
96  }
97 
98  virtual localIndex numLocalRows() const override final
99  {
100  return m_operator.numLocalRows();
101  }
102 
103  virtual localIndex numLocalCols() const override final
104  {
105  return m_operator.numLocalCols();
106  }
107 
108  virtual MPI_Comm comm() const override final
109  {
110  return m_operator.comm();
111  }
112 
117  {
118  return m_params;
119  }
120 
124  LinearSolverResult const & result() const
125  {
126  return m_result;
127  }
128 
134  {
135  return m_residualNorms;
136  }
137 
142  virtual string methodName() const = 0;
143 
144 private:
145 
147 
148  template< typename VEC >
149  struct VectorStorageHelper
150  {
151  using type = VEC;
152 
153  static VEC createFrom( VEC const & src )
154  {
155  VEC v;
156  v.create( src.localSize(), src.comm() );
157  return v;
158  }
159  };
160 
161  template< typename VEC >
162  struct VectorStorageHelper< BlockVectorView< VEC > >
163  {
164  using type = BlockVector< VEC >;
165 
166  static BlockVector< VEC > createFrom( BlockVectorView< VEC > const & src )
167  {
168  BlockVector< VEC > v( src.blockSize() );
169  for( localIndex i = 0; i < src.blockSize(); ++i )
170  {
171  v.block( i ).create( src.block( i ).localSize(), src.block( i ).comm() );
172  }
173  return v;
174  }
175  };
176 
178 
179 protected:
180 
182  using VectorTemp = typename VectorStorageHelper< VECTOR >::type;
183 
191  static VectorTemp createTempVector( Vector const & src )
192  {
193  return VectorStorageHelper< VECTOR >::createFrom( src );
194  }
195 
200  void logProgress() const;
201 
205  void logResult() const;
206 
209 
212 
215 
218 
221 };
222 
223 } //namespace geos
224 
225 #endif //GEOS_LINEARALGEBRA_SOLVERS_KRYLOVSOLVER_HPP_
Concrete representation of a block vector.
Definition: BlockVector.hpp:37
Abstract view of a block vector.
VECTOR const & block(localIndex const blockIndex) const
Get a reference to the vector corresponding to block blockRowIndex.
localIndex blockSize() const
Get block size.
Base class for Krylov solvers.
LinearSolverResult m_result
results of a solve
virtual void solve(Vector const &b, Vector &x) const =0
Solve preconditioned system.
virtual globalIndex numGlobalCols() const override final
Get the number of global columns.
virtual localIndex numLocalRows() const override final
Get the number of local rows.
LinearOperator< Vector > const & m_operator
reference to the operator to be solved
virtual localIndex numLocalCols() const override final
Get the number of local columns.
static VectorTemp createTempVector(Vector const &src)
Helper function to create temporary vectors based on a source vector.
void logResult() const
Output convergence result (called by implementations).
typename Base::Vector Vector
Alias for template parameter.
LinearSolverParameters const & parameters() const
typename VectorStorageHelper< VECTOR >::type VectorTemp
Alias for vector type that can be used for temporaries.
LinearOperator< Vector > const & m_precond
reference to the preconditioning operator
virtual ~KrylovSolver() override=default
Virtual destructor.
virtual void apply(Vector const &src, Vector &dst) const override final
Apply operator to a vector.
void logProgress() const
Output iteration progress (called by implementations).
virtual string methodName() const =0
Get name of the Krylov subspace method.
static std::unique_ptr< KrylovSolver< VECTOR > > create(LinearSolverParameters const &parameters, LinearOperator< VECTOR > const &matrix, LinearOperator< VECTOR > const &precond)
Factory method for instantiating Krylov solver objects.
LinearSolverResult const & result() const
virtual MPI_Comm comm() const override final
Get the MPI communicator the matrix was created with.
arrayView1d< real64 const > history() const
Get convergence history of a linear solve.
virtual globalIndex numGlobalRows() const override final
Get the number of global rows.
KrylovSolver(LinearSolverParameters params, LinearOperator< Vector > const &matrix, LinearOperator< Vector > const &precond)
Constructor.
array1d< real64 > m_residualNorms
Absolute residual norms at each iteration (if available)
LinearSolverParameters m_params
parameters of the solver
Abstract base class for linear operators.
virtual globalIndex numGlobalCols() const =0
Get the number of global columns.
VECTOR Vector
Alias for template parameter.
virtual globalIndex numGlobalRows() const =0
Get the number of global rows.
virtual localIndex numLocalRows() const =0
Get the number of local rows.
virtual localIndex numLocalCols() const =0
Get the number of local columns.
virtual MPI_Comm comm() const =0
Get the MPI communicator the matrix was created with.
ArrayView< T, 1 > arrayView1d
Alias for 1D array view.
Definition: DataTypes.hpp:180
GEOS_GLOBALINDEX_TYPE globalIndex
Global index type (for indexing objects across MPI partitions).
Definition: DataTypes.hpp:88
GEOS_LOCALINDEX_TYPE localIndex
Local index type (for indexing objects within an MPI partition).
Definition: DataTypes.hpp:85
Array< T, 1 > array1d
Alias for 1D array.
Definition: DataTypes.hpp:176
Set of parameters for a linear solver or preconditioner.
Results/stats of a linear solve.