GEOSX
KrylovSolver.hpp
1 /*
2  * ------------------------------------------------------------------------------------------------------------
3  * SPDX-License-Identifier: LGPL-2.1-only
4  *
5  * Copyright (c) 2018-2020 Lawrence Livermore National Security LLC
6  * Copyright (c) 2018-2020 The Board of Trustees of the Leland Stanford Junior University
7  * Copyright (c) 2018-2020 TotalEnergies
8  * Copyright (c) 2019- GEOSX Contributors
9  * All rights reserved
10  *
11  * See top level LICENSE, COPYRIGHT, CONTRIBUTORS, NOTICE, and ACKNOWLEDGEMENTS files for details.
12  * ------------------------------------------------------------------------------------------------------------
13  */
14 
15 #ifndef GEOS_LINEARALGEBRA_SOLVERS_KRYLOVSOLVER_HPP_
16 #define GEOS_LINEARALGEBRA_SOLVERS_KRYLOVSOLVER_HPP_
17 
23 
24 namespace geos
25 {
26 
31 template< typename VECTOR >
32 class KrylovSolver : public LinearOperator< VECTOR >
33 {
34 public:
35 
38 
40  using Vector = typename Base::Vector;
41 
49  static std::unique_ptr< KrylovSolver< VECTOR > > create( LinearSolverParameters const & parameters,
50  LinearOperator< VECTOR > const & matrix,
51  LinearOperator< VECTOR > const & precond );
52 
60  LinearOperator< Vector > const & matrix,
61  LinearOperator< Vector > const & precond );
62 
66  virtual ~KrylovSolver() override = default;
67 
73  virtual void solve( Vector const & b, Vector & x ) const = 0;
74 
75 
82  virtual void apply( Vector const & src, Vector & dst ) const override final
83  {
84  solve( src, dst );
85  }
86 
87  virtual globalIndex numGlobalRows() const override final
88  {
89  return m_operator.numGlobalRows();
90  }
91 
92  virtual globalIndex numGlobalCols() const override final
93  {
94  return m_operator.numGlobalCols();
95  }
96 
97  virtual localIndex numLocalRows() const override final
98  {
99  return m_operator.numLocalRows();
100  }
101 
102  virtual localIndex numLocalCols() const override final
103  {
104  return m_operator.numLocalCols();
105  }
106 
107  virtual MPI_Comm comm() const override final
108  {
109  return m_operator.comm();
110  }
111 
116  {
117  return m_params;
118  }
119 
123  LinearSolverResult const & result() const
124  {
125  return m_result;
126  }
127 
133  {
134  return m_residualNorms;
135  }
136 
141  virtual string methodName() const = 0;
142 
143 private:
144 
146 
147  template< typename VEC >
148  struct VectorStorageHelper
149  {
150  using type = VEC;
151 
152  static VEC createFrom( VEC const & src )
153  {
154  VEC v;
155  v.create( src.localSize(), src.comm() );
156  return v;
157  }
158  };
159 
160  template< typename VEC >
161  struct VectorStorageHelper< BlockVectorView< VEC > >
162  {
163  using type = BlockVector< VEC >;
164 
165  static BlockVector< VEC > createFrom( BlockVectorView< VEC > const & src )
166  {
167  BlockVector< VEC > v( src.blockSize() );
168  for( localIndex i = 0; i < src.blockSize(); ++i )
169  {
170  v.block( i ).create( src.block( i ).localSize(), src.block( i ).comm() );
171  }
172  return v;
173  }
174  };
175 
177 
178 protected:
179 
181  using VectorTemp = typename VectorStorageHelper< VECTOR >::type;
182 
190  static VectorTemp createTempVector( Vector const & src )
191  {
192  return VectorStorageHelper< VECTOR >::createFrom( src );
193  }
194 
199  void logProgress() const;
200 
204  void logResult() const;
205 
208 
211 
214 
217 
220 };
221 
222 } //namespace geos
223 
224 #endif //GEOS_LINEARALGEBRA_SOLVERS_KRYLOVSOLVER_HPP_
Concrete representation of a block vector.
Definition: BlockVector.hpp:36
Abstract view of a block vector.
VECTOR const & block(localIndex const blockIndex) const
Get a reference to the vector corresponding to block blockRowIndex.
localIndex blockSize() const
Get block size.
Base class for Krylov solvers.
LinearSolverResult m_result
results of a solve
virtual void solve(Vector const &b, Vector &x) const =0
Solve preconditioned system.
virtual globalIndex numGlobalCols() const override final
Get the number of global columns.
virtual localIndex numLocalRows() const override final
Get the number of local rows.
LinearOperator< Vector > const & m_operator
reference to the operator to be solved
virtual localIndex numLocalCols() const override final
Get the number of local columns.
static VectorTemp createTempVector(Vector const &src)
Helper function to create temporary vectors based on a source vector.
void logResult() const
Output convergence result (called by implementations).
typename Base::Vector Vector
Alias for template parameter.
LinearSolverParameters const & parameters() const
typename VectorStorageHelper< VECTOR >::type VectorTemp
Alias for vector type that can be used for temporaries.
LinearOperator< Vector > const & m_precond
reference to the preconditioning operator
virtual ~KrylovSolver() override=default
Virtual destructor.
virtual void apply(Vector const &src, Vector &dst) const override final
Apply operator to a vector.
void logProgress() const
Output iteration progress (called by implementations).
virtual string methodName() const =0
Get name of the Krylov subspace method.
static std::unique_ptr< KrylovSolver< VECTOR > > create(LinearSolverParameters const &parameters, LinearOperator< VECTOR > const &matrix, LinearOperator< VECTOR > const &precond)
Factory method for instantiating Krylov solver objects.
LinearSolverResult const & result() const
virtual MPI_Comm comm() const override final
Get the MPI communicator the matrix was created with.
arrayView1d< real64 const > history() const
Get convergence history of a linear solve.
virtual globalIndex numGlobalRows() const override final
Get the number of global rows.
KrylovSolver(LinearSolverParameters params, LinearOperator< Vector > const &matrix, LinearOperator< Vector > const &precond)
Constructor.
array1d< real64 > m_residualNorms
Absolute residual norms at each iteration (if available)
LinearSolverParameters m_params
parameters of the solver
Abstract base class for linear operators.
virtual globalIndex numGlobalCols() const =0
Get the number of global columns.
VECTOR Vector
Alias for template parameter.
virtual globalIndex numGlobalRows() const =0
Get the number of global rows.
virtual localIndex numLocalRows() const =0
Get the number of local rows.
virtual localIndex numLocalCols() const =0
Get the number of local columns.
virtual MPI_Comm comm() const =0
Get the MPI communicator the matrix was created with.
ArrayView< T, 1 > arrayView1d
Alias for 1D array view.
Definition: DataTypes.hpp:220
GEOSX_GLOBALINDEX_TYPE globalIndex
Global index type (for indexing objects across MPI partitions).
Definition: DataTypes.hpp:128
GEOSX_LOCALINDEX_TYPE localIndex
Local index type (for indexing objects within an MPI partition).
Definition: DataTypes.hpp:125
Array< T, 1 > array1d
Alias for 1D array.
Definition: DataTypes.hpp:216
Set of parameters for a linear solver or preconditioner.
Results/stats of a linear solve.