GEOS
TableFunction.hpp
Go to the documentation of this file.
1 /*
2  * ------------------------------------------------------------------------------------------------------------
3  * SPDX-License-Identifier: LGPL-2.1-only
4  *
5  * Copyright (c) 2016-2024 Lawrence Livermore National Security LLC
6  * Copyright (c) 2018-2024 TotalEnergies
7  * Copyright (c) 2018-2024 The Board of Trustees of the Leland Stanford Junior University
8  * Copyright (c) 2023-2024 Chevron
9  * Copyright (c) 2019- GEOS/GEOSX Contributors
10  * All rights reserved
11  *
12  * See top level LICENSE, COPYRIGHT, CONTRIBUTORS, NOTICE, and ACKNOWLEDGEMENTS files for details.
13  * ------------------------------------------------------------------------------------------------------------
14  */
15 
20 #ifndef GEOS_FUNCTIONS_TABLEFUNCTION_HPP_
21 #define GEOS_FUNCTIONS_TABLEFUNCTION_HPP_
22 
23 #include "FunctionBase.hpp"
24 
26 #include "LvArray/src/tensorOps.hpp"
28 #include "common/Units.hpp"
29 
30 namespace geos
31 {
32 
39 {
40 public:
41 
44  {
45  Linear,
46  Nearest,
47  Upper,
48  Lower
49  };
50 
53  {
55  bool writeCSV;
57  bool writeInLog;
58  };
59 
61  static constexpr integer maxDimensions = 4;
62 
69  {
70 public:
71 
75 
76  KernelWrapper() = default;
77  KernelWrapper( KernelWrapper const & ) = default;
78  KernelWrapper( KernelWrapper && ) = default;
79  KernelWrapper & operator=( KernelWrapper const & ) = default;
80 
84  {
85  m_coordinates = std::move( other.m_coordinates );
86  m_values = std::move( other.m_values );
87  m_interpolationMethod = other.m_interpolationMethod;
88  return *this;
89  }
90 
92 
99  template< typename IN_ARRAY >
101  real64 compute( IN_ARRAY const & input ) const;
102 
109  template< typename IN_ARRAY, typename OUT_ARRAY >
111  real64 compute( IN_ARRAY const & input, OUT_ARRAY && derivatives ) const;
112 
120  void move( LvArray::MemorySpace const space, bool const touch )
121  {
122  m_coordinates.move( space, touch );
123  m_values.move( space, touch );
124  }
125 
126 private:
127 
128  friend class TableFunction; // Allow only parent class to construct the wrapper
129 
141  KernelWrapper( InterpolationType interpolationMethod,
142  ArrayOfArraysView< real64 const > const & coordinates,
143  arrayView1d< real64 const > const & values );
144 
150  template< typename IN_ARRAY >
152  real64
153  interpolateLinear( IN_ARRAY const & input ) const;
154 
161  template< typename IN_ARRAY, typename OUT_ARRAY >
163  real64
164  interpolateLinear( IN_ARRAY const & input, OUT_ARRAY && derivatives ) const;
165 
171  template< typename IN_ARRAY >
173  real64
174  interpolateRound( IN_ARRAY const & input ) const;
175 
183  template< typename IN_ARRAY >
185  real64
186  getCoord( IN_ARRAY const & input, localIndex dim, InterpolationType interpolationMethod ) const;
187 
194  template< typename IN_ARRAY, typename OUT_ARRAY >
196  real64
197  interpolateRound( IN_ARRAY const & input, OUT_ARRAY && derivatives ) const;
198 
200  TableFunction::InterpolationType m_interpolationMethod = InterpolationType::Linear;
201 
203  ArrayOfArraysView< real64 const > m_coordinates;
204 
207  };
208 
214  TableFunction( const string & name,
215  dataRepository::Group * const parent );
216 
221  static string catalogName() { return "TableFunction"; }
222 
226  virtual void initializeFunction() override;
227 
232 
233  void initializePostSubGroups() override;
234 
242  virtual void evaluate( dataRepository::Group const & group,
243  real64 const time,
245  arrayView1d< real64 > const & result ) const override final
246  {
247  FunctionBase::evaluateT< TableFunction, parallelHostPolicy >( group, time, set, result );
248  }
249 
255  virtual real64 evaluate( real64 const * const input ) const override final;
256 
264  real64 getCoord( real64 const * const input, localIndex dim, InterpolationType interpolationMethod ) const;
265 
273  void checkCoord( real64 coord, localIndex dim ) const;
274 
278  integer numDimensions() const { return LvArray::integerConversion< integer >( m_coordinates.size() ); }
279 
284  ArrayOfArraysView< real64 const > getCoordinates() const { return m_coordinates.toViewConst(); }
285 
289  ArrayOfArraysView< real64 > getCoordinates() { return m_coordinates.toView(); }
290 
295  arrayView1d< real64 const > getValues() const { return m_values.toViewConst(); }
296 
300  array1d< real64 > & getValues() { return m_values; }
301 
306  InterpolationType getInterpolationMethod() const { return m_interpolationMethod; }
307 
312  units::Unit getDimUnit( localIndex const dim ) const
313  {
314  return size_t(dim) < m_dimUnits.size() ? m_dimUnits[dim] : units::Unknown;
315  }
316 
322 
328  void setTableCoordinates( array1d< real64_array > const & coordinates,
329  std::vector< units::Unit > const & dimUnits = {} );
330 
335  void setDimUnits( std::vector< units::Unit > const & dimUnits )
336  {
337  m_dimUnits = dimUnits;
338  }
339 
345  void setTableValues( real64_array values, units::Unit unit = units::Unknown );
346 
352  {
353  m_valueUnit = unit;
354  }
355 
359  units::Unit getValueUnit() const { return m_valueUnit; }
360 
365  string getTableDescription() const;
366 
373  string getCoordsDescription( integer dimId, bool shortUnitsToVariables ) const;
374 
379  string getValuesDescription() const;
380 
385  void outputTableData( OutputOptions const outputOpts ) const;
386 
392 
395  {
397  static constexpr char const * coordinatesString() { return "coordinates"; }
399  static constexpr char const * valuesString() { return "values"; }
401  static constexpr char const * interpolationString() { return "interpolation"; }
403  static constexpr char const * coordinateFilesString() { return "coordinateFiles"; }
405  static constexpr char const * voxelFileString() { return "voxelFile"; }
407  static constexpr char const * writeCSVFlagString() { return "writeCSV"; }
408  };
409 
410 private:
411 
418  void readFile( string const & filename, array1d< real64 > & target );
419 
420 
422  array1d< real64 > m_tableCoordinates1D;
423 
425  path_array m_coordinateFiles;
426 
428  Path m_voxelFile;
429 
431  InterpolationType m_interpolationMethod;
432 
434  ArrayOfArrays< real64 > m_coordinates;
435 
437  array1d< real64 > m_values;
438 
440  std::vector< units::Unit > m_dimUnits;
441 
443  units::Unit m_valueUnit;
444 
446  KernelWrapper m_kernelWrapper;
447 
449  integer m_writeCSV;
450 };
452 template< typename IN_ARRAY >
455 real64
456 TableFunction::KernelWrapper::compute( IN_ARRAY const & input ) const
457 {
458  if( m_interpolationMethod == TableFunction::InterpolationType::Linear )
459  {
460  return interpolateLinear( input );
461  }
462  else // Nearest, Upper, Lower interpolation methods
463  {
464  return interpolateRound( input );
465  }
466 }
467 
468 template< typename IN_ARRAY >
471 real64
472 TableFunction::KernelWrapper::interpolateLinear( IN_ARRAY const & input ) const
473 {
474  integer const numDimensions = LvArray::integerConversion< integer >( m_coordinates.size() );
475  localIndex bounds[maxDimensions][2]{};
476  real64 weights[maxDimensions][2]{};
477 
478  // Determine position, weights
479  for( localIndex dim = 0; dim < numDimensions; ++dim )
480  {
481  arraySlice1d< real64 const > const coords = m_coordinates[dim];
482  if( input[dim] <= coords[0] )
483  {
484  // Coordinate is to the left of this axis
485  bounds[dim][0] = 0;
486  bounds[dim][1] = 0;
487  weights[dim][0] = 0.0;
488  weights[dim][1] = 1.0;
489  }
490  else if( input[dim] >= coords[coords.size() - 1] )
491  {
492  // Coordinate is to the right of this axis
493  bounds[dim][0] = coords.size() - 1;
494  bounds[dim][1] = bounds[dim][0];
495  weights[dim][0] = 1.0;
496  weights[dim][1] = 0.0;
497  }
498  else
499  {
500  // Find the coordinate index
502  // Sergey's note: find uses a binary search... If we assume coordinates are
503  // evenly spaced, we can speed things up considerably
504  // Mel's note: As we cannot be sure coords are evenly spaced,
505  // - Either we insert coords to get even spacing ( /!\ memory consumption ),
506  // - Or we can use an interpolation search with an hint array which would be linearly interpolated ( benchmark ).
507  auto const lower = LvArray::sortedArrayManipulation::find( coords.begin(), coords.size(), input[dim] );
508  bounds[dim][1] = LvArray::integerConversion< localIndex >( lower );
509  bounds[dim][0] = bounds[dim][1] - 1;
510 
511  real64 const dx = coords[bounds[dim][1]] - coords[bounds[dim][0]];
512  weights[dim][0] = 1.0 - ( input[dim] - coords[bounds[dim][0]]) / dx;
513  weights[dim][1] = 1.0 - weights[dim][0];
514  }
515  }
516 
517  // Calculate the result
518  real64 value = 0.0;
519  integer const numCorners = 1 << numDimensions;
520  for( integer point = 0; point < numCorners; ++point )
521  {
522  // Find array index
523  localIndex tableIndex = 0;
524  localIndex stride = 1;
525  for( integer dim = 0; dim < numDimensions; ++dim )
526  {
527  integer const corner = (point >> dim) & 1;
528  tableIndex += bounds[dim][corner] * stride;
529  stride *= m_coordinates.sizeOfArray( dim );
530  }
531 
532  // Determine weighted value
533  real64 cornerValue = m_values[tableIndex];
534  for( integer dim = 0; dim < numDimensions; ++dim )
535  {
536  integer const corner = (point >> dim) & 1;
537  cornerValue *= weights[dim][corner];
538  }
539  value += cornerValue;
540  }
541  return value;
542 }
543 
544 template< typename IN_ARRAY >
547 real64
548 TableFunction::KernelWrapper::interpolateRound( IN_ARRAY const & input ) const
549 {
550  integer const numDimensions = LvArray::integerConversion< integer >( m_coordinates.size() );
551 
552  // Determine the index to the nearest table entry
553  localIndex tableIndex = 0;
554  localIndex stride = 1;
555  for( integer dim = 0; dim < numDimensions; ++dim )
556  {
557  arraySlice1d< real64 const > const coords = m_coordinates[dim];
558  // Determine the index along each table axis
559  localIndex subIndex;
560  if( input[dim] <= coords[0] )
561  {
562  // Coordinate is to the left of the table axis
563  subIndex = 0;
564  }
565  else if( input[dim] >= coords[coords.size() - 1] )
566  {
567  // Coordinate is to the right of the table axis
568  subIndex = coords.size() - 1;
569  }
570  else
571  {
572  // Coordinate is within the table axis
573  // Note: find() will return the index of the upper table vertex
574  auto const lower = LvArray::sortedArrayManipulation::find( coords.begin(), coords.size(), input[dim] );
575  subIndex = LvArray::integerConversion< localIndex >( lower );
576 
577  // Interpolation types:
578  // - Nearest returns the value of the closest table vertex
579  // - Upper returns the value of the next table vertex
580  // - Lower returns the value of the previous table vertex
581  if( m_interpolationMethod == TableFunction::InterpolationType::Nearest )
582  {
583  if( ( input[dim] - coords[subIndex - 1]) <= ( coords[subIndex] - input[dim]) )
584  {
585  --subIndex;
586  }
587  }
588  else if( m_interpolationMethod == TableFunction::InterpolationType::Lower )
589  {
590  if( subIndex > 0 )
591  {
592  --subIndex;
593  }
594  }
595  }
596 
597  // Increment the global table index
598  tableIndex += subIndex * stride;
599  stride *= coords.size();
600  }
601 
602  // Retrieve the nearest value
603  return m_values[tableIndex];
604 }
605 
606 template< typename IN_ARRAY >
609 real64
610 TableFunction::KernelWrapper::getCoord( IN_ARRAY const & input, localIndex const dim, InterpolationType interpolationMethod ) const
611 {
612  // Determine the index to the nearest table entry
613  localIndex subIndex;
614  arraySlice1d< real64 const > const coords = m_coordinates[dim];
615  // Determine the index along each table axis
616  if( input[dim] <= coords[0] )
617  {
618  // Coordinate is to the left of the table axis
619  subIndex = 0;
620  }
621  else if( input[dim] >= coords[coords.size() - 1] )
622  {
623  // Coordinate is to the right of the table axis
624  subIndex = coords.size() - 1;
625  }
626  else
627  {
628  // Coordinate is within the table axis
629  // Note: find() will return the index of the upper table vertex
630  auto const lower = LvArray::sortedArrayManipulation::find( coords.begin(), coords.size(), input[dim] );
631  subIndex = LvArray::integerConversion< localIndex >( lower );
632 
633  // Interpolation types:
634  // - Nearest returns the value of the closest table vertex
635  // - Upper returns the value of the next table vertex
636  // - Lower returns the value of the previous table vertex
637  if( interpolationMethod == TableFunction::InterpolationType::Nearest )
638  {
639  if( ( input[dim] - coords[subIndex - 1]) <= ( coords[subIndex] - input[dim]) )
640  {
641  --subIndex;
642  }
643  }
644  else if( interpolationMethod == TableFunction::InterpolationType::Lower )
645  {
646  if( subIndex > 0 )
647  {
648  --subIndex;
649  }
650  }
651  }
652 
653  // Retrieve the nearest coordinate
654  return coords[subIndex];
655 }
656 
657 template< typename IN_ARRAY, typename OUT_ARRAY >
660 real64
661 TableFunction::KernelWrapper::compute( IN_ARRAY const & input, OUT_ARRAY && derivatives ) const
662 {
663  // Linear interpolation
664  if( m_interpolationMethod == TableFunction::InterpolationType::Linear )
665  {
666  return interpolateLinear( input, derivatives );
667  }
668  // Nearest, Upper, Lower interpolation methods
669  else
670  {
671  return interpolateRound( input, derivatives );
672  }
673 }
674 
675 template< typename IN_ARRAY, typename OUT_ARRAY >
678 real64
679 TableFunction::KernelWrapper::interpolateLinear( IN_ARRAY const & input, OUT_ARRAY && derivatives ) const
680 {
681  integer const numDimensions = LvArray::integerConversion< integer >( m_coordinates.size() );
682 
683  localIndex bounds[maxDimensions][2]{};
684  real64 weights[maxDimensions][2]{};
685  real64 dWeights_dInput[maxDimensions][2]{};
686 
687  // Determine position, weights
688  for( integer dim = 0; dim < numDimensions; ++dim )
689  {
690  arraySlice1d< real64 const > const coords = m_coordinates[dim];
691  if( input[dim] <= coords[0] )
692  {
693  // Coordinate is to the left of this axis
694  bounds[dim][0] = 0;
695  bounds[dim][1] = 0;
696  weights[dim][0] = 0;
697  weights[dim][1] = 1;
698  dWeights_dInput[dim][0] = 0;
699  dWeights_dInput[dim][1] = 0;
700  }
701  else if( input[dim] >= coords[coords.size() - 1] )
702  {
703  // Coordinate is to the right of this axis
704  bounds[dim][0] = coords.size() - 1;
705  bounds[dim][1] = bounds[dim][0];
706  weights[dim][0] = 1;
707  weights[dim][1] = 0;
708  dWeights_dInput[dim][0] = 0;
709  dWeights_dInput[dim][1] = 0;
710  }
711  else
712  {
713  // Find the coordinate index
715  // Note: find uses a binary search... If we assume coordinates are
716  // evenly spaced, we can speed things up considerably
717  auto lower = LvArray::sortedArrayManipulation::find( coords.begin(), coords.size(), input[dim] );
718  bounds[dim][1] = LvArray::integerConversion< localIndex >( lower );
719  bounds[dim][0] = bounds[dim][1] - 1;
720 
721  real64 const dx = coords[bounds[dim][1]] - coords[bounds[dim][0]];
722  weights[dim][0] = 1.0 - ( input[dim] - coords[bounds[dim][0]]) / dx;
723  weights[dim][1] = 1.0 - weights[dim][0];
724  dWeights_dInput[dim][0] = -1.0 / dx;
725  dWeights_dInput[dim][1] = -dWeights_dInput[dim][0];
726  }
727  }
728 
729  // Calculate the result
730  real64 value = 0.0;
731  for( integer dim = 0; dim < numDimensions; ++dim )
732  {
733  derivatives[dim] = 0.0;
734  }
735 
736  integer const numCorners = 1 << numDimensions;
737  for( integer point = 0; point < numCorners; ++point )
738  {
739  // Find array index
740  localIndex tableIndex = 0;
741  localIndex stride = 1;
742  for( integer dim = 0; dim < numDimensions; ++dim )
743  {
744  integer const corner = (point >> dim) & 1;
745  tableIndex += bounds[dim][corner] * stride;
746  stride *= m_coordinates.sizeOfArray( dim );
747  }
748 
749  // Determine weighted value
750  real64 cornerValue = m_values[tableIndex];
751  real64 dCornerValue_dInput[maxDimensions]{};
752  for( integer dim = 0; dim < numDimensions; ++dim )
753  {
754  dCornerValue_dInput[dim] = cornerValue;
755  }
756 
757  for( integer dim = 0; dim < numDimensions; ++dim )
758  {
759  integer const corner = (point >> dim) & 1;
760  cornerValue *= weights[dim][corner];
761  for( integer kk = 0; kk < numDimensions; ++kk )
762  {
763  dCornerValue_dInput[kk] *= ( dim == kk ) ? dWeights_dInput[dim][corner] : weights[dim][corner];
764  }
765  }
766 
767  for( integer dim = 0; dim < numDimensions; ++dim )
768  {
769  derivatives[dim] += dCornerValue_dInput[dim];
770  }
771  value += cornerValue;
772  }
773  return value;
774 }
775 
776 template< typename IN_ARRAY, typename OUT_ARRAY >
779 real64
780 TableFunction::KernelWrapper::interpolateRound( IN_ARRAY const & input, OUT_ARRAY && derivatives ) const
781 {
782  GEOS_UNUSED_VAR( input, derivatives );
783  GEOS_ERROR( "Rounding interpolation with derivatives not implemented" );
784  return 0.0;
785 }
786 
788 
791  "linear",
792  "nearest",
793  "upper",
794  "lower" );
795 
801 template<>
802 string TableTextFormatter::toString< TableFunction >( TableFunction const & tableData ) const;
803 
809 template<>
810 string TableCSVFormatter::toString< TableFunction >( TableFunction const & tableData ) const;
811 
812 } /* namespace geos */
813 
814 #endif /* GEOS_FUNCTIONS_TABLEFUNCTION_HPP_ */
#define GEOS_HOST_DEVICE
Marks a host-device function.
Definition: GeosxMacros.hpp:49
#define GEOS_UNUSED_VAR(...)
Mark an unused variable and silence compiler warnings.
Definition: GeosxMacros.hpp:84
#define GEOS_FORCE_INLINE
Marks a function or lambda for inlining.
Definition: GeosxMacros.hpp:51
#define GEOS_ERROR(msg)
Raise a hard error and terminate the program.
Definition: Logger.hpp:157
Enumerates the Units that are in use in GEOS and regroups useful conversion and formatting functions.
Unit
Enumerator of available unit types for given physical scales. Units are in SI by default.
Definition: Units.hpp:59
Class describing a file Path.
Definition: Path.hpp:33
GEOS_HOST_DEVICE real64 compute(IN_ARRAY const &input, OUT_ARRAY &&derivatives) const
Interpolate in the table with derivatives.
void move(LvArray::MemorySpace const space, bool const touch)
Move the KernelWrapper to the given execution space, optionally touching it.
GEOS_HOST_DEVICE real64 compute(IN_ARRAY const &input) const
Interpolate in the table.
string getCoordsDescription(integer dimId, bool shortUnitsToVariables) const
InterpolationType
Enumerator of available interpolation types.
void setTableValues(real64_array values, units::Unit unit=units::Unknown)
Set the table values.
array1d< real64 > & getValues()
Get the table values.
arrayView1d< real64 const > getValues() const
Get the table values.
void checkCoord(real64 coord, localIndex dim) const
Check if the given coordinate is in the bounds of the table coordinates in the specified dimension,...
virtual void evaluate(dataRepository::Group const &group, real64 const time, SortedArrayView< localIndex const > const &set, arrayView1d< real64 > const &result) const override final
Method to evaluate a function on a target object.
ArrayOfArraysView< real64 > getCoordinates()
Get the table axes definitions.
real64 getCoord(real64 const *const input, localIndex dim, InterpolationType interpolationMethod) const
Method to get coordinates.
void setTableCoordinates(array1d< real64_array > const &coordinates, std::vector< units::Unit > const &dimUnits={})
Set the table coordinates.
void setValueUnits(units::Unit unit)
Set the table value units.
void setDimUnits(std::vector< units::Unit > const &dimUnits)
Set the units of each dimension.
string getTableDescription() const
void reInitializeFunction()
Build the maps used to evaluate the table function.
KernelWrapper createKernelWrapper() const
Create an instance of the kernel wrapper.
void initializePostSubGroups() override
Called by Initialize() after to initializing sub-Groups.
TableFunction(const string &name, dataRepository::Group *const parent)
The constructor.
units::Unit getValueUnit() const
void setInterpolationMethod(InterpolationType const method)
Set the interpolation method.
static constexpr integer maxDimensions
maximum dimensions for the coordinates in the table
virtual real64 evaluate(real64 const *const input) const override final
Method to evaluate a function.
static string catalogName()
The catalog name interface.
virtual void initializeFunction() override
Initialize the table function.
ArrayOfArraysView< real64 const > getCoordinates() const
Get the table axes definitions.
integer numDimensions() const
units::Unit getDimUnit(localIndex const dim) const
InterpolationType getInterpolationMethod() const
Get the interpolation method.
string getValuesDescription() const
void outputTableData(OutputOptions const outputOpts) const
Print the table(s) in the log and/or CSV files when requested by the user.
Group & operator=(Group const &)=delete
Deleted copy assignment operator.
ArrayView< T, 1 > arrayView1d
Alias for 1D array view.
Definition: DataTypes.hpp:180
ENUM_STRINGS(LinearSolverParameters::SolverType, "direct", "cg", "gmres", "fgmres", "bicgstab", "preconditioner")
Declare strings associated with enumeration values.
LvArray::ArrayOfArraysView< T, INDEX_TYPE const, CONST_SIZES, LvArray::ChaiBuffer > ArrayOfArraysView
View of array of variable-sized arrays. See LvArray::ArrayOfArraysView for details.
Definition: DataTypes.hpp:286
array1d< Path > path_array
A 1-dimensional array of geos::Path types.
Definition: DataTypes.hpp:396
std::set< T > set
A set of local indices.
Definition: DataTypes.hpp:263
double real64
64-bit floating point type.
Definition: DataTypes.hpp:99
GEOS_LOCALINDEX_TYPE localIndex
Local index type (for indexing objects within an MPI partition).
Definition: DataTypes.hpp:85
std::int32_t integer
Signed integer type.
Definition: DataTypes.hpp:82
array1d< real64 > real64_array
A 1-dimensional array of geos::real64 types.
Definition: DataTypes.hpp:389
std::size_t size_t
Unsigned size type.
Definition: DataTypes.hpp:79
LvArray::SortedArrayView< T, localIndex, LvArray::ChaiBuffer > SortedArrayView
A sorted array view of local indices.
Definition: DataTypes.hpp:271
Array< T, 1 > array1d
Alias for 1D array.
Definition: DataTypes.hpp:176
LvArray::ArrayOfArrays< T, INDEX_TYPE, LvArray::ChaiBuffer > ArrayOfArrays
Array of variable-sized arrays. See LvArray::ArrayOfArrays for details.
Definition: DataTypes.hpp:282
Struct containing output options.
bool writeInLog
Request table output in log.
bool writeCSV
Request table output in CSV file.
Struct containing lookup keys for data repository wrappers.
static constexpr char const * coordinatesString()
static constexpr char const * interpolationString()
static constexpr char const * valuesString()
static constexpr char const * voxelFileString()
static constexpr char const * coordinateFilesString()
static constexpr char const * writeCSVFlagString()