GEOS
FunctionBase.hpp
Go to the documentation of this file.
1 /*
2  * ------------------------------------------------------------------------------------------------------------
3  * SPDX-License-Identifier: LGPL-2.1-only
4  *
5  * Copyright (c) 2016-2024 Lawrence Livermore National Security LLC
6  * Copyright (c) 2018-2024 Total, S.A
7  * Copyright (c) 2018-2024 The Board of Trustees of the Leland Stanford Junior University
8  * Copyright (c) 2023-2024 Chevron
9  * Copyright (c) 2019- GEOS/GEOSX Contributors
10  * All rights reserved
11  *
12  * See top level LICENSE, COPYRIGHT, CONTRIBUTORS, NOTICE, and ACKNOWLEDGEMENTS files for details.
13  * ------------------------------------------------------------------------------------------------------------
14  */
15 
20 #ifndef GEOS_FUNCTIONS_FUNCTIONBASE_HPP_
21 #define GEOS_FUNCTIONS_FUNCTIONBASE_HPP_
22 
23 #include "common/DataTypes.hpp"
24 #include "common/TypeDispatch.hpp"
25 #include "dataRepository/Group.hpp"
26 #include "common/GEOS_RAJA_Interface.hpp"
27 
28 namespace geos
29 {
30 
31 namespace dataRepository
32 {
33 namespace keys
34 {
39 string const inputVarNames( "inputVarNames" );
40 }
41 }
42 
49 {
50 public:
51 
53  static constexpr int MAX_VARS = 4;
54 
56  FunctionBase( const string & name,
57  dataRepository::Group * const parent );
58 
62  virtual ~FunctionBase() override = default;
63 
67  virtual void initializeFunction() = 0;
68 
77 
85  virtual void evaluate( dataRepository::Group const & group,
86  real64 const time,
88  arrayView1d< real64 > const & result ) const = 0;
89 
95  virtual real64 evaluate( real64 const * const input ) const = 0;
96 
99 
105  {
106  static CatalogInterface::CatalogType catalog;
107  return catalog;
108  }
109 
118  real64 const time,
119  SortedArray< localIndex > const & set ) const;
120 
125  void setInputVarNames( string_array inputVarNames ) { m_inputVarNames = std::move( inputVarNames ); }
126 
131  static string const & getOutputDirectory();
132 
137  static void setOutputDirectory( string const & outputDir );
138 
139 protected:
142 
151  template< typename LEAF, typename POLICY = serialPolicy >
152  void evaluateT( dataRepository::Group const & group,
153  real64 const time,
155  arrayView1d< real64 > const & result ) const;
156 
157  virtual void postInputInitialization() override { initializeFunction(); }
158 
159 };
160 
161 template< typename LEAF, typename POLICY >
163  real64 const time,
165  arrayView1d< real64 > const & result ) const
166 {
167  real64 const * inputPtrs[MAX_VARS]{};
168  localIndex varSize[MAX_VARS]{};
169  localIndex varStride[MAX_VARS][2]{};
170 
171  integer const numVars = LvArray::integerConversion< integer >( m_inputVarNames.size() );
172  localIndex totalVarSize = 0;
173  for( integer varIndex = 0; varIndex < numVars; ++varIndex )
174  {
175  string const & varName = m_inputVarNames[varIndex];
176 
177  if( varName == "time" )
178  {
179  inputPtrs[varIndex] = &time;
180  varSize[varIndex] = 1;
181  }
182  else
183  {
184  dataRepository::WrapperBase const & wrapper = group.getWrapperBase( varName );
185  varSize[varIndex] = wrapper.numArrayComp();
186 
188  types::dispatch( Types{}, [&]( auto tupleOfTypes )
189  {
190  using ArrayType = camp::first< decltype( tupleOfTypes ) >;
191  auto const view = dataRepository::Wrapper< ArrayType >::cast( wrapper ).reference().toViewConst();
192  view.move( hostMemorySpace, false );
193  for( int dim = 0; dim < ArrayType::NDIM; ++dim )
194  {
195  varStride[varIndex][dim] = view.strides()[dim];
196  }
197  inputPtrs[varIndex] = view.data();
198  }, wrapper );
199  }
200  totalVarSize += varSize[varIndex];
201  }
202 
203  // Make sure the inputs do not exceed the maximum length
204  GEOS_ERROR_IF_GT_MSG( totalVarSize, MAX_VARS,
205  getDataContext() << ": Function input size exceeded" );
206 
207  // Make sure the result / set size match
208  GEOS_ERROR_IF_NE_MSG( result.size(), set.size(),
209  getDataContext() << ": To apply a function to a set, the size of the result and set must match" );
210 
211  forAll< POLICY >( set.size(), [=]( localIndex const i )
212  {
213  localIndex const index = set[i];
214  real64 input[MAX_VARS]{};
215  int offset = 0;
216  for( integer varIndex = 0; varIndex < numVars; ++varIndex )
217  {
218  for( localIndex compIndex = 0; compIndex < varSize[varIndex]; ++compIndex )
219  {
220  input[offset++] = inputPtrs[varIndex][index * varStride[varIndex][0] + compIndex * varStride[varIndex][1]];
221  }
222  }
223  result[i] = static_cast< LEAF const * >( this )->evaluate( input );
224  } );
225 }
226 } /* namespace geos */
227 
228 #endif /* GEOS_FUNCTIONS_FUNCTIONBASE_HPP_ */
string const inputVarNames("inputVarNames")
The key for inputVarNames.
#define GEOS_ERROR_IF_GT_MSG(lhs, rhs, msg)
Raise a hard error if one value compares greater than the other.
Definition: Logger.hpp:275
#define GEOS_ERROR_IF_NE_MSG(lhs, rhs, msg)
Raise a hard error if two values are not equal.
Definition: Logger.hpp:243
static string const & getOutputDirectory()
Get the output directory for function output.
void evaluateT(dataRepository::Group const &group, real64 const time, SortedArrayView< localIndex const > const &set, arrayView1d< real64 > const &result) const
Method to apply an function with an arbitrary type of output.
virtual ~FunctionBase() override=default
destructor
real64_array evaluateStats(dataRepository::Group const &group, real64 const time, SortedArray< localIndex > const &set) const
This generates statistics by applying a function to an object.
FunctionBase(const string &name, dataRepository::Group *const parent)
Constructor.
static void setOutputDirectory(string const &outputDir)
Set the output directory for function output.
static constexpr int MAX_VARS
Maximum total number of independent variables (including components of multidimensional variables)
static CatalogInterface::CatalogType & getCatalog()
return the catalog entry for the function
string_array m_inputVarNames
names for the input variables
virtual void initializeFunction()=0
Function initialization.
void setInputVarNames(string_array inputVarNames)
Set the input variable names.
virtual void postInputInitialization() override
integer isFunctionOfTime() const
Test to see if the function is a 1D function of time.
virtual real64 evaluate(real64 const *const input) const =0
Method to evaluate a function.
virtual void evaluate(dataRepository::Group const &group, real64 const time, SortedArrayView< localIndex const > const &set, arrayView1d< real64 > const &result) const =0
Method to evaluate a function on a target object.
This class provides the base class/interface for the catalog value objects.
std::unordered_map< std::string, std::unique_ptr< CatalogInterface< BASETYPE, ARGS... > > > CatalogType
This is the type that will be used for the catalog. The catalog is actually instantiated in the BASET...
DataContext const & getDataContext() const
Definition: Group.hpp:1343
WrapperBase const & getWrapperBase(KEY const &key) const
Return a reference to a WrapperBase stored in this group.
Definition: Group.hpp:1121
Base class for all wrappers containing common operations.
Definition: WrapperBase.hpp:56
virtual localIndex numArrayComp() const =0
Return the number of components in a multidimensional array.
static Wrapper & cast(WrapperBase &wrapper)
Downcast base to a typed wrapper.
Definition: Wrapper.hpp:221
T & reference()
Accessor for m_data.
Definition: Wrapper.hpp:581
DimsRange< 1, N > DimsUpTo
Generate a list of types representing array dimensionalities up to (and including) N.
bool dispatch(LIST const combinations, LAMBDA &&lambda, Ts &&... objects)
Dispatch a generic worker function lambda based on runtime type.
internal::Apply< camp::list, LIST > ListofTypeList
Construct a list of list type.
ArrayView< T, 1 > arrayView1d
Alias for 1D array view.
Definition: DataTypes.hpp:180
array1d< string > string_array
A 1-dimensional array of geos::string types.
Definition: DataTypes.hpp:392
std::set< T > set
A set of local indices.
Definition: DataTypes.hpp:263
double real64
64-bit floating point type.
Definition: DataTypes.hpp:99
GEOS_LOCALINDEX_TYPE localIndex
Local index type (for indexing objects within an MPI partition).
Definition: DataTypes.hpp:85
std::int32_t integer
Signed integer type.
Definition: DataTypes.hpp:82
array1d< real64 > real64_array
A 1-dimensional array of geos::real64 types.
Definition: DataTypes.hpp:389
LvArray::SortedArray< T, localIndex, LvArray::ChaiBuffer > SortedArray
A sorted array of local indices.
Definition: DataTypes.hpp:267
LvArray::SortedArrayView< T, localIndex, LvArray::ChaiBuffer > SortedArrayView
A sorted array view of local indices.
Definition: DataTypes.hpp:271