GEOS
TableFunction.hpp
Go to the documentation of this file.
1 /*
2  * ------------------------------------------------------------------------------------------------------------
3  * SPDX-License-Identifier: LGPL-2.1-only
4  *
5  * Copyright (c) 2016-2024 Lawrence Livermore National Security LLC
6  * Copyright (c) 2018-2024 TotalEnergies
7  * Copyright (c) 2018-2024 The Board of Trustees of the Leland Stanford Junior University
8  * Copyright (c) 2023-2024 Chevron
9  * Copyright (c) 2019- GEOS/GEOSX Contributors
10  * All rights reserved
11  *
12  * See top level LICENSE, COPYRIGHT, CONTRIBUTORS, NOTICE, and ACKNOWLEDGEMENTS files for details.
13  * ------------------------------------------------------------------------------------------------------------
14  */
15 
20 #ifndef GEOS_FUNCTIONS_TABLEFUNCTION_HPP_
21 #define GEOS_FUNCTIONS_TABLEFUNCTION_HPP_
22 
23 #include "FunctionBase.hpp"
24 
26 #include "LvArray/src/tensorOps.hpp"
28 #include "common/Units.hpp"
29 
30 namespace geos
31 {
32 
39 {
40 public:
41 
44  {
45  Linear,
46  Nearest,
47  Upper,
48  Lower
49  };
50 
53  {
55  bool writeCSV;
57  bool writeInLog;
58  };
59 
61  static constexpr integer maxDimensions = 4;
62 
69  {
70 public:
71 
75 
76  KernelWrapper() = default;
77  KernelWrapper( KernelWrapper const & ) = default;
78  KernelWrapper( KernelWrapper && ) = default;
79  KernelWrapper & operator=( KernelWrapper const & ) = default;
80 
84  {
85  m_coordinates = std::move( other.m_coordinates );
86  m_values = std::move( other.m_values );
87  m_interpolationMethod = other.m_interpolationMethod;
88  return *this;
89  }
90 
92 
99  template< typename IN_ARRAY >
101  real64 compute( IN_ARRAY const & input ) const;
102 
109  template< typename IN_ARRAY, typename OUT_ARRAY >
111  real64 compute( IN_ARRAY const & input, OUT_ARRAY && derivatives ) const;
112 
120  void move( LvArray::MemorySpace const space, bool const touch )
121  {
122  m_coordinates.move( space, touch );
123  m_values.move( space, touch );
124  }
125 
126 private:
127 
128  friend class TableFunction; // Allow only parent class to construct the wrapper
129 
141  KernelWrapper( InterpolationType interpolationMethod,
142  ArrayOfArraysView< real64 const > const & coordinates,
143  arrayView1d< real64 const > const & values );
144 
150  template< typename IN_ARRAY >
152  real64
153  interpolateLinear( IN_ARRAY const & input ) const;
154 
161  template< typename IN_ARRAY, typename OUT_ARRAY >
163  real64
164  interpolateLinear( IN_ARRAY const & input, OUT_ARRAY && derivatives ) const;
165 
171  template< typename IN_ARRAY >
173  real64
174  interpolateRound( IN_ARRAY const & input ) const;
175 
183  template< typename IN_ARRAY >
185  real64
186  getCoord( IN_ARRAY const & input, localIndex dim, InterpolationType interpolationMethod ) const;
187 
194  template< typename IN_ARRAY, typename OUT_ARRAY >
196  real64
197  interpolateRound( IN_ARRAY const & input, OUT_ARRAY && derivatives ) const;
198 
200  TableFunction::InterpolationType m_interpolationMethod = InterpolationType::Linear;
201 
203  ArrayOfArraysView< real64 const > m_coordinates;
204 
207  };
208 
214  TableFunction( const string & name,
215  dataRepository::Group * const parent );
216 
221  static string catalogName() { return "TableFunction"; }
222 
226  virtual void initializeFunction() override;
227 
232 
233  void initializePostSubGroups() override;
234 
242  virtual void evaluate( dataRepository::Group const & group,
243  real64 const time,
245  arrayView1d< real64 > const & result ) const override final
246  {
247  FunctionBase::evaluateT< TableFunction, parallelHostPolicy >( group, time, set, result );
248  }
249 
255  virtual real64 evaluate( real64 const * const input ) const override final;
256 
264  real64 getCoord( real64 const * const input, localIndex dim, InterpolationType interpolationMethod ) const;
265 
273  void checkCoord( real64 coord, localIndex dim ) const;
274 
278  integer numDimensions() const { return LvArray::integerConversion< integer >( m_coordinates.size() ); }
279 
284  ArrayOfArraysView< real64 const > getCoordinates() const { return m_coordinates.toViewConst(); }
285 
289  ArrayOfArraysView< real64 > getCoordinates() { return m_coordinates.toView(); }
290 
295  arrayView1d< real64 const > getValues() const { return m_values.toViewConst(); }
296 
300  array1d< real64 > & getValues() { return m_values; }
301 
306  InterpolationType getInterpolationMethod() const { return m_interpolationMethod; }
307 
312  units::Unit getDimUnit( localIndex const dim ) const
313  {
314  return size_t(dim) < m_dimUnits.size() ? m_dimUnits[dim] : units::Unknown;
315  }
316 
322 
328  void setTableCoordinates( array1d< real64_array > const & coordinates,
329  stdVector< units::Unit > const & dimUnits = {} );
330 
335  void setDimUnits( stdVector< units::Unit > const & dimUnits )
336  {
337  m_dimUnits = dimUnits;
338  }
339 
345  void setTableValues( real64_array values, units::Unit unit = units::Unknown );
346 
352  {
353  m_valueUnit = unit;
354  }
355 
359  units::Unit getValueUnit() const { return m_valueUnit; }
360 
365  string getTableDescription() const;
366 
373  string getCoordsDescription( integer dimId, bool shortUnitsToVariables ) const;
374 
379  string getValuesDescription() const;
380 
385  void outputTableData( OutputOptions const outputOpts ) const;
386 
392 
395  {
397  static constexpr char const * coordinatesString() { return "coordinates"; }
399  static constexpr char const * valuesString() { return "values"; }
401  static constexpr char const * interpolationString() { return "interpolation"; }
403  static constexpr char const * coordinateFilesString() { return "coordinateFiles"; }
405  static constexpr char const * voxelFileString() { return "voxelFile"; }
407  static constexpr char const * writeCSVFlagString() { return "writeCSV"; }
408  };
409 
410 private:
411 
418  void readFile( string const & filename, array1d< real64 > & target );
419 
420 
422  array1d< real64 > m_tableCoordinates1D;
423 
425  path_array m_coordinateFiles;
426 
428  Path m_voxelFile;
429 
431  InterpolationType m_interpolationMethod;
432 
434  ArrayOfArrays< real64 > m_coordinates;
435 
437  array1d< real64 > m_values;
438 
440  stdVector< units::Unit > m_dimUnits;
441 
443  units::Unit m_valueUnit;
444 
446  KernelWrapper m_kernelWrapper;
447 
449  integer m_writeCSV;
450 };
452 
453 template< typename IN_ARRAY >
456 real64
457 TableFunction::KernelWrapper::compute( IN_ARRAY const & input ) const
458 {
459  if( m_interpolationMethod == TableFunction::InterpolationType::Linear )
460  {
461  return interpolateLinear( input );
462  }
463  else // Nearest, Upper, Lower interpolation methods
464  {
465  return interpolateRound( input );
466  }
467 }
468 
469 template< typename IN_ARRAY >
472 real64
473 TableFunction::KernelWrapper::interpolateLinear( IN_ARRAY const & input ) const
474 {
475  integer const numDimensions = LvArray::integerConversion< integer >( m_coordinates.size() );
476  localIndex bounds[maxDimensions][2]{};
477  real64 weights[maxDimensions][2]{};
478 
479  // Determine position, weights
480  for( localIndex dim = 0; dim < numDimensions; ++dim )
481  {
482  arraySlice1d< real64 const > const coords = m_coordinates[dim];
483  if( input[dim] <= coords[0] )
484  {
485  // Coordinate is to the left of this axis
486  bounds[dim][0] = 0;
487  bounds[dim][1] = 0;
488  weights[dim][0] = 0.0;
489  weights[dim][1] = 1.0;
490  }
491  else if( input[dim] >= coords[coords.size() - 1] )
492  {
493  // Coordinate is to the right of this axis
494  bounds[dim][0] = coords.size() - 1;
495  bounds[dim][1] = bounds[dim][0];
496  weights[dim][0] = 1.0;
497  weights[dim][1] = 0.0;
498  }
499  else
500  {
501  // Find the coordinate index
503  // Sergey's note: find uses a binary search... If we assume coordinates are
504  // evenly spaced, we can speed things up considerably
505  // Mel's note: As we cannot be sure coords are evenly spaced,
506  // - Either we insert coords to get even spacing ( /!\ memory consumption ),
507  // - Or we can use an interpolation search with an hint array which would be linearly interpolated ( benchmark ).
508  auto const lower = LvArray::sortedArrayManipulation::find( coords.begin(), coords.size(), input[dim] );
509  bounds[dim][1] = LvArray::integerConversion< localIndex >( lower );
510  bounds[dim][0] = bounds[dim][1] - 1;
511 
512  real64 const dx = coords[bounds[dim][1]] - coords[bounds[dim][0]];
513  weights[dim][0] = 1.0 - ( input[dim] - coords[bounds[dim][0]]) / dx;
514  weights[dim][1] = 1.0 - weights[dim][0];
515  }
516  }
517 
518  // Calculate the result
519  real64 value = 0.0;
520  integer const numCorners = 1 << numDimensions;
521  for( integer point = 0; point < numCorners; ++point )
522  {
523  // Find array index
524  localIndex tableIndex = 0;
525  localIndex stride = 1;
526  for( integer dim = 0; dim < numDimensions; ++dim )
527  {
528  integer const corner = (point >> dim) & 1;
529  tableIndex += bounds[dim][corner] * stride;
530  stride *= m_coordinates.sizeOfArray( dim );
531  }
532 
533  // Determine weighted value
534  real64 cornerValue = m_values[tableIndex];
535  for( integer dim = 0; dim < numDimensions; ++dim )
536  {
537  integer const corner = (point >> dim) & 1;
538  cornerValue *= weights[dim][corner];
539  }
540  value += cornerValue;
541  }
542  return value;
543 }
544 
545 template< typename IN_ARRAY >
548 real64
549 TableFunction::KernelWrapper::interpolateRound( IN_ARRAY const & input ) const
550 {
551  integer const numDimensions = LvArray::integerConversion< integer >( m_coordinates.size() );
552 
553  // Determine the index to the nearest table entry
554  localIndex tableIndex = 0;
555  localIndex stride = 1;
556  for( integer dim = 0; dim < numDimensions; ++dim )
557  {
558  arraySlice1d< real64 const > const coords = m_coordinates[dim];
559  // Determine the index along each table axis
560  localIndex subIndex;
561  if( input[dim] <= coords[0] )
562  {
563  // Coordinate is to the left of the table axis
564  subIndex = 0;
565  }
566  else if( input[dim] >= coords[coords.size() - 1] )
567  {
568  // Coordinate is to the right of the table axis
569  subIndex = coords.size() - 1;
570  }
571  else
572  {
573  // Coordinate is within the table axis
574  // Note: find() will return the index of the upper table vertex
575  auto const lower = LvArray::sortedArrayManipulation::find( coords.begin(), coords.size(), input[dim] );
576  subIndex = LvArray::integerConversion< localIndex >( lower );
577 
578  // Interpolation types:
579  // - Nearest returns the value of the closest table vertex
580  // - Upper returns the value of the next table vertex
581  // - Lower returns the value of the previous table vertex
582  if( m_interpolationMethod == TableFunction::InterpolationType::Nearest )
583  {
584  if( ( input[dim] - coords[subIndex - 1]) <= ( coords[subIndex] - input[dim]) )
585  {
586  --subIndex;
587  }
588  }
589  else if( m_interpolationMethod == TableFunction::InterpolationType::Lower )
590  {
591  if( subIndex > 0 )
592  {
593  --subIndex;
594  }
595  }
596  }
597 
598  // Increment the global table index
599  tableIndex += subIndex * stride;
600  stride *= coords.size();
601  }
602 
603  // Retrieve the nearest value
604  return m_values[tableIndex];
605 }
606 
607 template< typename IN_ARRAY >
610 real64
611 TableFunction::KernelWrapper::getCoord( IN_ARRAY const & input, localIndex const dim, InterpolationType interpolationMethod ) const
612 {
613  // Determine the index to the nearest table entry
614  localIndex subIndex;
615  arraySlice1d< real64 const > const coords = m_coordinates[dim];
616  // Determine the index along each table axis
617  if( input[dim] <= coords[0] )
618  {
619  // Coordinate is to the left of the table axis
620  subIndex = 0;
621  }
622  else if( input[dim] >= coords[coords.size() - 1] )
623  {
624  // Coordinate is to the right of the table axis
625  subIndex = coords.size() - 1;
626  }
627  else
628  {
629  // Coordinate is within the table axis
630  // Note: find() will return the index of the upper table vertex
631  auto const lower = LvArray::sortedArrayManipulation::find( coords.begin(), coords.size(), input[dim] );
632  subIndex = LvArray::integerConversion< localIndex >( lower );
633 
634  // Interpolation types:
635  // - Nearest returns the value of the closest table vertex
636  // - Upper returns the value of the next table vertex
637  // - Lower returns the value of the previous table vertex
638  if( interpolationMethod == TableFunction::InterpolationType::Nearest )
639  {
640  if( ( input[dim] - coords[subIndex - 1]) <= ( coords[subIndex] - input[dim]) )
641  {
642  --subIndex;
643  }
644  }
645  else if( interpolationMethod == TableFunction::InterpolationType::Lower )
646  {
647  if( subIndex > 0 )
648  {
649  --subIndex;
650  }
651  }
652  }
653 
654  // Retrieve the nearest coordinate
655  return coords[subIndex];
656 }
657 
658 template< typename IN_ARRAY, typename OUT_ARRAY >
661 real64
662 TableFunction::KernelWrapper::compute( IN_ARRAY const & input, OUT_ARRAY && derivatives ) const
663 {
664  // Linear interpolation
665  if( m_interpolationMethod == TableFunction::InterpolationType::Linear )
666  {
667  return interpolateLinear( input, derivatives );
668  }
669  // Nearest, Upper, Lower interpolation methods
670  else
671  {
672  return interpolateRound( input, derivatives );
673  }
674 }
675 
676 template< typename IN_ARRAY, typename OUT_ARRAY >
679 real64
680 TableFunction::KernelWrapper::interpolateLinear( IN_ARRAY const & input, OUT_ARRAY && derivatives ) const
681 {
682  integer const numDimensions = LvArray::integerConversion< integer >( m_coordinates.size() );
683 
684  localIndex bounds[maxDimensions][2]{};
685  real64 weights[maxDimensions][2]{};
686  real64 dWeights_dInput[maxDimensions][2]{};
687 
688  // Determine position, weights
689  for( integer dim = 0; dim < numDimensions; ++dim )
690  {
691  arraySlice1d< real64 const > const coords = m_coordinates[dim];
692  if( input[dim] <= coords[0] )
693  {
694  // Coordinate is to the left of this axis
695  bounds[dim][0] = 0;
696  bounds[dim][1] = 0;
697  weights[dim][0] = 0;
698  weights[dim][1] = 1;
699  dWeights_dInput[dim][0] = 0;
700  dWeights_dInput[dim][1] = 0;
701  }
702  else if( input[dim] >= coords[coords.size() - 1] )
703  {
704  // Coordinate is to the right of this axis
705  bounds[dim][0] = coords.size() - 1;
706  bounds[dim][1] = bounds[dim][0];
707  weights[dim][0] = 1;
708  weights[dim][1] = 0;
709  dWeights_dInput[dim][0] = 0;
710  dWeights_dInput[dim][1] = 0;
711  }
712  else
713  {
714  // Find the coordinate index
716  // Note: find uses a binary search... If we assume coordinates are
717  // evenly spaced, we can speed things up considerably
718  auto lower = LvArray::sortedArrayManipulation::find( coords.begin(), coords.size(), input[dim] );
719  bounds[dim][1] = LvArray::integerConversion< localIndex >( lower );
720  bounds[dim][0] = bounds[dim][1] - 1;
721 
722  real64 const dx = coords[bounds[dim][1]] - coords[bounds[dim][0]];
723  weights[dim][0] = 1.0 - ( input[dim] - coords[bounds[dim][0]]) / dx;
724  weights[dim][1] = 1.0 - weights[dim][0];
725  dWeights_dInput[dim][0] = -1.0 / dx;
726  dWeights_dInput[dim][1] = -dWeights_dInput[dim][0];
727  }
728  }
729 
730  // Calculate the result
731  real64 value = 0.0;
732  for( integer dim = 0; dim < numDimensions; ++dim )
733  {
734  derivatives[dim] = 0.0;
735  }
736 
737  integer const numCorners = 1 << numDimensions;
738  for( integer point = 0; point < numCorners; ++point )
739  {
740  // Find array index
741  localIndex tableIndex = 0;
742  localIndex stride = 1;
743  for( integer dim = 0; dim < numDimensions; ++dim )
744  {
745  integer const corner = (point >> dim) & 1;
746  tableIndex += bounds[dim][corner] * stride;
747  stride *= m_coordinates.sizeOfArray( dim );
748  }
749 
750  // Determine weighted value
751  real64 cornerValue = m_values[tableIndex];
752  real64 dCornerValue_dInput[maxDimensions]{};
753  for( integer dim = 0; dim < numDimensions; ++dim )
754  {
755  dCornerValue_dInput[dim] = cornerValue;
756  }
757 
758  for( integer dim = 0; dim < numDimensions; ++dim )
759  {
760  integer const corner = (point >> dim) & 1;
761  cornerValue *= weights[dim][corner];
762  for( integer kk = 0; kk < numDimensions; ++kk )
763  {
764  dCornerValue_dInput[kk] *= ( dim == kk ) ? dWeights_dInput[dim][corner] : weights[dim][corner];
765  }
766  }
767 
768  for( integer dim = 0; dim < numDimensions; ++dim )
769  {
770  derivatives[dim] += dCornerValue_dInput[dim];
771  }
772  value += cornerValue;
773  }
774  return value;
775 }
776 
777 template< typename IN_ARRAY, typename OUT_ARRAY >
780 real64
781 TableFunction::KernelWrapper::interpolateRound( IN_ARRAY const & input, OUT_ARRAY && derivatives ) const
782 {
783  GEOS_UNUSED_VAR( input, derivatives );
784  GEOS_ERROR( "Rounding interpolation with derivatives not implemented" );
785  return 0.0;
786 }
787 
789 
792  "linear",
793  "nearest",
794  "upper",
795  "lower" );
796 
802 template<>
803 string TableTextFormatter::toString< TableFunction >( TableFunction const & tableData ) const;
804 
810 template<>
811 string TableCSVFormatter::toString< TableFunction >( TableFunction const & tableData ) const;
812 
813 } /* namespace geos */
814 
815 #endif /* GEOS_FUNCTIONS_TABLEFUNCTION_HPP_ */
#define GEOS_HOST_DEVICE
Marks a host-device function.
Definition: GeosxMacros.hpp:49
#define GEOS_UNUSED_VAR(...)
Mark an unused variable and silence compiler warnings.
Definition: GeosxMacros.hpp:84
#define GEOS_FORCE_INLINE
Marks a function or lambda for inlining.
Definition: GeosxMacros.hpp:51
#define GEOS_ERROR(msg)
Raise a hard error and terminate the program.
Definition: Logger.hpp:157
Enumerates the Units that are in use in GEOS and regroups useful conversion and formatting functions.
Unit
Enumerator of available unit types for given physical scales. Units are in SI by default.
Definition: Units.hpp:59
Class describing a file Path.
Definition: Path.hpp:35
GEOS_HOST_DEVICE real64 compute(IN_ARRAY const &input, OUT_ARRAY &&derivatives) const
Interpolate in the table with derivatives.
void move(LvArray::MemorySpace const space, bool const touch)
Move the KernelWrapper to the given execution space, optionally touching it.
GEOS_HOST_DEVICE real64 compute(IN_ARRAY const &input) const
Interpolate in the table.
string getCoordsDescription(integer dimId, bool shortUnitsToVariables) const
InterpolationType
Enumerator of available interpolation types.
void setTableValues(real64_array values, units::Unit unit=units::Unknown)
Set the table values.
array1d< real64 > & getValues()
Get the table values.
arrayView1d< real64 const > getValues() const
Get the table values.
void checkCoord(real64 coord, localIndex dim) const
Check if the given coordinate is in the bounds of the table coordinates in the specified dimension,...
virtual void evaluate(dataRepository::Group const &group, real64 const time, SortedArrayView< localIndex const > const &set, arrayView1d< real64 > const &result) const override final
Method to evaluate a function on a target object.
ArrayOfArraysView< real64 > getCoordinates()
Get the table axes definitions.
real64 getCoord(real64 const *const input, localIndex dim, InterpolationType interpolationMethod) const
Method to get coordinates.
void setTableCoordinates(array1d< real64_array > const &coordinates, stdVector< units::Unit > const &dimUnits={})
Set the table coordinates.
void setValueUnits(units::Unit unit)
Set the table value units.
string getTableDescription() const
void reInitializeFunction()
Build the maps used to evaluate the table function.
KernelWrapper createKernelWrapper() const
Create an instance of the kernel wrapper.
void setDimUnits(stdVector< units::Unit > const &dimUnits)
Set the units of each dimension.
void initializePostSubGroups() override
Called by Initialize() after to initializing sub-Groups.
TableFunction(const string &name, dataRepository::Group *const parent)
The constructor.
units::Unit getValueUnit() const
void setInterpolationMethod(InterpolationType const method)
Set the interpolation method.
static constexpr integer maxDimensions
maximum dimensions for the coordinates in the table
virtual real64 evaluate(real64 const *const input) const override final
Method to evaluate a function.
static string catalogName()
The catalog name interface.
virtual void initializeFunction() override
Initialize the table function.
ArrayOfArraysView< real64 const > getCoordinates() const
Get the table axes definitions.
integer numDimensions() const
units::Unit getDimUnit(localIndex const dim) const
InterpolationType getInterpolationMethod() const
Get the interpolation method.
string getValuesDescription() const
void outputTableData(OutputOptions const outputOpts) const
Print the table(s) in the log and/or CSV files when requested by the user.
Group & operator=(Group const &)=delete
Deleted copy assignment operator.
ArrayView< T, 1 > arrayView1d
Alias for 1D array view.
Definition: DataTypes.hpp:179
LvArray::ArrayOfArraysView< T, INDEX_TYPE const, CONST_SIZES, LvArray::ChaiBuffer > ArrayOfArraysView
View of array of variable-sized arrays. See LvArray::ArrayOfArraysView for details.
Definition: DataTypes.hpp:285
array1d< Path > path_array
A 1-dimensional array of geos::Path types.
Definition: DataTypes.hpp:364
std::set< T > set
A set of local indices.
Definition: DataTypes.hpp:262
double real64
64-bit floating point type.
Definition: DataTypes.hpp:98
GEOS_LOCALINDEX_TYPE localIndex
Local index type (for indexing objects within an MPI partition).
Definition: DataTypes.hpp:84
array1d< real64 > real64_array
A 1-dimensional array of geos::real64 types.
Definition: DataTypes.hpp:357
std::size_t size_t
Unsigned size type.
Definition: DataTypes.hpp:78
LvArray::SortedArrayView< T, localIndex, LvArray::ChaiBuffer > SortedArrayView
A sorted array view of local indices.
Definition: DataTypes.hpp:270
int integer
Signed integer type.
Definition: DataTypes.hpp:81
Array< T, 1 > array1d
Alias for 1D array.
Definition: DataTypes.hpp:175
LvArray::ArrayOfArrays< T, INDEX_TYPE, LvArray::ChaiBuffer > ArrayOfArrays
Array of variable-sized arrays. See LvArray::ArrayOfArrays for details.
Definition: DataTypes.hpp:281
ENUM_STRINGS(LinearSolverParameters::SolverType, "direct", "cg", "gmres", "fgmres", "bicgstab", "richardson", "preconditioner")
Declare strings associated with enumeration values.
internal::StdVectorWrapper< T, Allocator, USE_STD_CONTAINER_BOUNDS_CHECKING > stdVector
Struct containing output options.
bool writeInLog
Request table output in log.
bool writeCSV
Request table output in CSV file.
Struct containing lookup keys for data repository wrappers.
static constexpr char const * coordinatesString()
static constexpr char const * interpolationString()
static constexpr char const * valuesString()
static constexpr char const * voxelFileString()
static constexpr char const * coordinateFilesString()
static constexpr char const * writeCSVFlagString()