GEOS
MpiWrapper.hpp
Go to the documentation of this file.
1 /*
2  * ------------------------------------------------------------------------------------------------------------
3  * SPDX-License-Identifier: LGPL-2.1-only
4  *
5  * Copyright (c) 2016-2024 Lawrence Livermore National Security LLC
6  * Copyright (c) 2018-2024 TotalEnergies
7  * Copyright (c) 2018-2024 The Board of Trustees of the Leland Stanford Junior University
8  * Copyright (c) 2023-2024 Chevron
9  * Copyright (c) 2019- GEOS/GEOSX Contributors
10  * All rights reserved
11  *
12  * See top level LICENSE, COPYRIGHT, CONTRIBUTORS, NOTICE, and ACKNOWLEDGEMENTS files for details.
13  * ------------------------------------------------------------------------------------------------------------
14  */
15 
20 #ifndef GEOS_COMMON_MPIWRAPPER_HPP_
21 #define GEOS_COMMON_MPIWRAPPER_HPP_
22 
23 #include "common/DataTypes.hpp"
24 #include "common/Span.hpp"
25 
26 #if defined(GEOS_USE_MPI)
27  #include <mpi.h>
28 #define MPI_PARAM( x ) x
29 #else
30 #define MPI_PARAM( x )
31 typedef int MPI_Comm;
32 
33 #define MPI_COMM_NULL ((MPI_Comm)0x04000000)
34 #define MPI_COMM_WORLD ((MPI_Comm)0x44000000)
35 #define MPI_COMM_SELF ((MPI_Comm)0x40000000)
36 
37 
38 typedef int MPI_Datatype;
39 #define MPI_CHAR ((MPI_Datatype)0x4c000101)
40 #define MPI_SIGNED_CHAR ((MPI_Datatype)0x4c000118)
41 #define MPI_UNSIGNED_CHAR ((MPI_Datatype)0x4c000102)
42 #define MPI_BYTE ((MPI_Datatype)0x4c00010d)
43 #define MPI_WCHAR ((MPI_Datatype)0x4c00040e)
44 #define MPI_SHORT ((MPI_Datatype)0x4c000203)
45 #define MPI_UNSIGNED_SHORT ((MPI_Datatype)0x4c000204)
46 #define MPI_INT ((MPI_Datatype)0x4c000405)
47 #define MPI_UNSIGNED ((MPI_Datatype)0x4c000406)
48 #define MPI_LONG ((MPI_Datatype)0x4c000807)
49 #define MPI_UNSIGNED_LONG ((MPI_Datatype)0x4c000808)
50 #define MPI_FLOAT ((MPI_Datatype)0x4c00040a)
51 #define MPI_DOUBLE ((MPI_Datatype)0x4c00080b)
52 #define MPI_LONG_DOUBLE ((MPI_Datatype)0x4c00100c)
53 #define MPI_LONG_LONG_INT ((MPI_Datatype)0x4c000809)
54 #define MPI_UNSIGNED_LONG_LONG ((MPI_Datatype)0x4c000819)
55 #define MPI_LONG_LONG MPI_LONG_LONG_INT
56 
57 typedef int MPI_Op;
58 
59 #define MPI_MAX (MPI_Op)(0x58000001)
60 #define MPI_MIN (MPI_Op)(0x58000002)
61 #define MPI_SUM (MPI_Op)(0x58000003)
62 #define MPI_PROD (MPI_Op)(0x58000004)
63 #define MPI_LAND (MPI_Op)(0x58000005)
64 #define MPI_BAND (MPI_Op)(0x58000006)
65 #define MPI_LOR (MPI_Op)(0x58000007)
66 #define MPI_BOR (MPI_Op)(0x58000008)
67 #define MPI_LXOR (MPI_Op)(0x58000009)
68 #define MPI_BXOR (MPI_Op)(0x5800000a)
69 #define MPI_MINLOC (MPI_Op)(0x5800000b)
70 #define MPI_MAXLOC (MPI_Op)(0x5800000c)
71 #define MPI_REPLACE (MPI_Op)(0x5800000d)
72 #define MPI_NO_OP (MPI_Op)(0x5800000e)
73 
74 #define MPI_SUCCESS 0 /* Successful return code */
75 #define MPI_UNDEFINED (-32766)
76 #define MPI_STATUS_IGNORE (MPI_Status *)1
77 #define MPI_STATUSES_IGNORE (MPI_Status *)1
78 #define MPI_REQUEST_NULL ((MPI_Request)0x2c000000)
79 typedef int MPI_Request;
80 
81 typedef int MPI_Info;
82 #define MPI_INFO_NULL (MPI_Info)(0x60000000)
83 
84 struct MPI_Status
85 {
86  int junk;
87 };
88 
89 #endif
90 
91 #if defined(NDEBUG)
92 #define MPI_CHECK_ERROR( error ) ((void) error)
93 #else
94 #define MPI_CHECK_ERROR( error ) GEOS_ERROR_IF_NE( error, MPI_SUCCESS );
95 #endif
96 
97 
98 namespace geos
99 {
100 
102 #ifdef GEOS_USE_MPI
103 extern MPI_Comm MPI_COMM_GEOS;
104 #else
105 extern int MPI_COMM_GEOS;
106 #endif
107 
118 {
119 public:
120 
125  enum class Reduction
126  {
127  Max,
128  Min,
129  Sum,
130  Prod,
131  };
132 
133  MpiWrapper() = delete;
134 
148 
149  static void barrier( MPI_Comm const & MPI_PARAM( comm )=MPI_COMM_GEOS );
150 
151  static int cartCoords( MPI_Comm comm, int rank, int maxdims, int coords[] );
152 
153  static int cartCreate( MPI_Comm comm_old, int ndims, const int dims[], const int periods[],
154  int reorder, MPI_Comm * comm_cart );
155 
156  static int cartRank( MPI_Comm comm, const int coords[] );
157 
158  static void commFree( MPI_Comm & comm );
159 
160  static int commRank( MPI_Comm const & MPI_PARAM( comm )=MPI_COMM_GEOS );
161 
162  static int commSize( MPI_Comm const & MPI_PARAM( comm )=MPI_COMM_GEOS );
163 
164  static bool commCompare( MPI_Comm const & comm1, MPI_Comm const & comm2 );
165 
166  static bool initialized();
167 
168  static int init( int * argc, char * * * argv );
169 
170  static void finalize();
171 
172  static MPI_Comm commDup( MPI_Comm const comm );
173 
174  static MPI_Comm commSplit( MPI_Comm const comm, int color, int key );
175 
176  static int test( MPI_Request * request, int * flag, MPI_Status * status );
177 
178  static int testAny( int count, MPI_Request array_of_requests[], int * idx, int * flags, MPI_Status array_of_statuses[] );
179 
180  static int testSome( int count, MPI_Request array_of_requests[], int * outcount, int array_of_indices[], MPI_Status array_of_statuses[] );
181 
182  static int testAll( int count, MPI_Request array_of_requests[], int * flags, MPI_Status array_of_statuses[] );
183 
190  static int check( MPI_Request * request, int * flag, MPI_Status * status );
191 
204  static int checkAny( int count, MPI_Request array_of_requests[], int * idx, int * flag, MPI_Status array_of_statuses[] );
205 
215  static int checkAll( int count, MPI_Request array_of_requests[], int * flag, MPI_Status array_of_statuses[] );
216 
217  static int wait( MPI_Request * request, MPI_Status * status );
218 
219  static int waitAny( int count, MPI_Request array_of_requests[], int * indx, MPI_Status array_of_statuses[] );
220 
221  static int waitSome( int count, MPI_Request array_of_requests[], int * outcount, int array_of_indices[], MPI_Status array_of_statuses[] );
222 
223  static int waitAll( int count, MPI_Request array_of_requests[], MPI_Status array_of_statuses[] );
224 
225  static double wtime( void );
226 
227 
237  static int activeWaitAny( const int count,
238  MPI_Request array_of_requests[],
239  MPI_Status array_of_statuses[],
240  std::function< MPI_Request ( int ) > func );
241 
251  static int activeWaitSome( const int count,
252  MPI_Request array_of_requests[],
253  MPI_Status array_of_statuses[],
254  std::function< MPI_Request ( int ) > func );
255 
268  static int activeWaitSomeCompletePhase( const int participants,
269  std::vector< std::tuple< MPI_Request *, MPI_Status *, std::function< MPI_Request ( int ) > > > const & phases );
270 
284  static int activeWaitOrderedCompletePhase( const int participants,
285  std::vector< std::tuple< MPI_Request *, MPI_Status *, std::function< MPI_Request ( int ) > > > const & phases );
287 
288 #if !defined(GEOS_USE_MPI)
289  static std::map< int, std::pair< int, void * > > & getTagToPointersMap()
290  {
291  static std::map< int, std::pair< int, void * > > tagToPointers;
292  return tagToPointers;
293  }
294 #endif
295 
300  static int nodeCommSize();
301 
313  template< typename T_SEND, typename T_RECV >
314  static int allgather( T_SEND const * sendbuf,
315  int sendcount,
316  T_RECV * recvbuf,
317  int recvcount,
318  MPI_Comm comm );
319 
332  template< typename T_SEND, typename T_RECV >
333  static int allgatherv( T_SEND const * sendbuf,
334  int sendcount,
335  T_RECV * recvbuf,
336  int * recvcounts,
337  int * displacements,
338  MPI_Comm comm );
339 
346  template< typename T >
347  static void allGather( T const myValue, array1d< T > & allValues, MPI_Comm comm = MPI_COMM_GEOS );
348 
349  template< typename T >
350  static int allGather( arrayView1d< T const > const & sendbuf,
351  array1d< T > & recvbuf,
352  MPI_Comm comm = MPI_COMM_GEOS );
353 
363  template< typename T >
364  static int allReduce( T const * sendbuf, T * recvbuf, int count, MPI_Op op, MPI_Comm comm = MPI_COMM_GEOS );
365 
374  template< typename T >
375  static T allReduce( T const & value, Reduction const op, MPI_Comm comm = MPI_COMM_GEOS );
376 
385  template< typename T >
386  static void allReduce( Span< T const > src, Span< T > dst, Reduction const op, MPI_Comm comm = MPI_COMM_GEOS );
387 
388 
398  template< typename T >
399  static int reduce( T const * sendbuf, T * recvbuf, int count, MPI_Op op, int root, MPI_Comm comm = MPI_COMM_GEOS );
400 
409  template< typename T >
410  static T reduce( T const & value, Reduction const op, int root, MPI_Comm comm = MPI_COMM_GEOS );
411 
420  template< typename T >
421  static void reduce( Span< T const > src, Span< T > dst, Reduction const op, int root, MPI_Comm comm = MPI_COMM_GEOS );
422 
423 
424  template< typename T >
425  static int scan( T const * sendbuf, T * recvbuf, int count, MPI_Op op, MPI_Comm comm );
426 
427  template< typename T >
428  static int exscan( T const * sendbuf, T * recvbuf, int count, MPI_Op op, MPI_Comm comm );
429 
438  template< typename T >
439  static int bcast( T * buffer, int count, int root, MPI_Comm comm );
440 
441 
448  template< typename T >
449  static void broadcast( T & value, int srcRank = 0, MPI_Comm comm = MPI_COMM_GEOS );
450 
463  template< typename TS, typename TR >
464  static int gather( TS const * const sendbuf,
465  int sendcount,
466  TR * const recvbuf,
467  int recvcount,
468  int root,
469  MPI_Comm comm );
470 
485  template< typename TS, typename TR >
486  static int gatherv( TS const * const sendbuf,
487  int sendcount,
488  TR * const recvbuf,
489  const int * recvcounts,
490  const int * displs,
491  int root,
492  MPI_Comm comm );
493 
499  static MPI_Op getMpiOp( Reduction const op );
500 
501  template< typename T >
502  static int recv( array1d< T > & buf,
503  int MPI_PARAM( source ),
504  int tag,
505  MPI_Comm MPI_PARAM( comm ),
506  MPI_Status * MPI_PARAM( request ) );
507 
508  template< typename T >
509  static int iSend( arrayView1d< T > const & buf,
510  int MPI_PARAM( dest ),
511  int tag,
512  MPI_Comm MPI_PARAM( comm ),
513  MPI_Request * MPI_PARAM( request ) );
514 
525  template< typename T >
526  static int iRecv( T * const buf,
527  int count,
528  int source,
529  int tag,
530  MPI_Comm comm,
531  MPI_Request * request );
532 
543  template< typename T >
544  static int iSend( T const * const buf,
545  int count,
546  int dest,
547  int tag,
548  MPI_Comm comm,
549  MPI_Request * request );
550 
558  template< typename U, typename T >
559  static U prefixSum( T const value, MPI_Comm comm = MPI_COMM_GEOS );
560 
566  template< typename T >
567  static T sum( T const & value, MPI_Comm comm = MPI_COMM_GEOS );
568 
575  template< typename T >
576  static void sum( Span< T const > src, Span< T > dst, MPI_Comm comm = MPI_COMM_GEOS );
577 
583  template< typename T >
584  static T min( T const & value, MPI_Comm comm = MPI_COMM_GEOS );
585 
592  template< typename T >
593  static void min( Span< T const > src, Span< T > dst, MPI_Comm comm = MPI_COMM_GEOS );
594 
600  template< typename T >
601  static T max( T const & value, MPI_Comm comm = MPI_COMM_GEOS );
602 
609  template< typename T >
610  static void max( Span< T const > src, Span< T > dst, MPI_Comm comm = MPI_COMM_GEOS );
611 
612 
619  template< typename T > static T maxValLoc( T localValueLocation, MPI_Comm comm = MPI_COMM_GEOS );
620 
621 };
622 
623 namespace internal
624 {
625 
626 template< typename T, typename ENABLE = void >
627 struct MpiTypeImpl {};
628 
629 #define ADD_MPI_TYPE_MAP( T, MPI_T ) \
630  template<> struct MpiTypeImpl< T > { static MPI_Datatype get() { return MPI_T; } }
631 
632 ADD_MPI_TYPE_MAP( float, MPI_FLOAT );
633 ADD_MPI_TYPE_MAP( double, MPI_DOUBLE );
634 
635 ADD_MPI_TYPE_MAP( char, MPI_CHAR );
636 ADD_MPI_TYPE_MAP( signed char, MPI_SIGNED_CHAR );
637 ADD_MPI_TYPE_MAP( unsigned char, MPI_UNSIGNED_CHAR );
638 
639 ADD_MPI_TYPE_MAP( int, MPI_INT );
640 ADD_MPI_TYPE_MAP( long int, MPI_LONG );
641 ADD_MPI_TYPE_MAP( long long int, MPI_LONG_LONG );
642 
643 ADD_MPI_TYPE_MAP( unsigned int, MPI_UNSIGNED );
644 ADD_MPI_TYPE_MAP( unsigned long int, MPI_UNSIGNED_LONG );
645 ADD_MPI_TYPE_MAP( unsigned long long int, MPI_UNSIGNED_LONG_LONG );
646 
647 #undef ADD_MPI_TYPE_MAP
648 
649 template< typename T >
650 struct MpiTypeImpl< T, std::enable_if_t< std::is_enum< T >::value > >
651 {
652  static MPI_Datatype get() { return MpiTypeImpl< std::underlying_type_t< T > >::get(); }
653 };
654 
655 template< typename T >
656 MPI_Datatype getMpiType()
657 {
658  return MpiTypeImpl< T >::get();
659 }
660 
661 }
662 
663 inline MPI_Op MpiWrapper::getMpiOp( Reduction const op )
664 {
665  switch( op )
666  {
667  case Reduction::Sum:
668  {
669  return MPI_SUM;
670  }
671  case Reduction::Min:
672  {
673  return MPI_MIN;
674  }
675  case Reduction::Max:
676  {
677  return MPI_MAX;
678  }
679  case Reduction::Prod:
680  {
681  return MPI_PROD;
682  }
683  default:
684  GEOS_ERROR( "Unsupported reduction operation" );
685  return MPI_NO_OP;
686  }
687 }
688 
689 template< typename T_SEND, typename T_RECV >
690 int MpiWrapper::allgather( T_SEND const * const sendbuf,
691  int sendcount,
692  T_RECV * const recvbuf,
693  int recvcount,
694  MPI_Comm MPI_PARAM( comm ) )
695 {
696 #ifdef GEOS_USE_MPI
697  return MPI_Allgather( sendbuf, sendcount, internal::getMpiType< T_SEND >(),
698  recvbuf, recvcount, internal::getMpiType< T_RECV >(),
699  comm );
700 #else
701  static_assert( std::is_same< T_SEND, T_RECV >::value,
702  "MpiWrapper::allgather() for serial run requires send and receive buffers are of the same type" );
703  GEOS_ERROR_IF_NE_MSG( sendcount, recvcount, "sendcount is not equal to recvcount." );
704  std::copy( sendbuf, sendbuf + sendcount, recvbuf )
705  return 0;
706 #endif
707 }
708 
709 template< typename T_SEND, typename T_RECV >
710 int MpiWrapper::allgatherv( T_SEND const * const sendbuf,
711  int sendcount,
712  T_RECV * const recvbuf,
713  int * recvcounts,
714  int * displacements,
715  MPI_Comm MPI_PARAM( comm ) )
716 {
717 #ifdef GEOS_USE_MPI
718  return MPI_Allgatherv( sendbuf, sendcount, internal::getMpiType< T_SEND >(),
719  recvbuf, recvcounts, displacements, internal::getMpiType< T_RECV >(),
720  comm );
721 #else
722  static_assert( std::is_same< T_SEND, T_RECV >::value,
723  "MpiWrapper::allgatherv() for serial run requires send and receive buffers are of the same type" );
724  GEOS_ERROR_IF_NE_MSG( sendcount, recvcount, "sendcount is not equal to recvcount." );
725  std::copy( sendbuf, sendbuf + sendcount, recvbuf )
726  return 0;
727 #endif
728 }
729 
730 
731 template< typename T >
732 void MpiWrapper::allGather( T const myValue, array1d< T > & allValues, MPI_Comm MPI_PARAM( comm ) )
733 {
734 #ifdef GEOS_USE_MPI
735  int const mpiSize = commSize( comm );
736  allValues.resize( mpiSize );
737 
738  MPI_Datatype const MPI_TYPE = internal::getMpiType< T >();
739 
740  MPI_Allgather( &myValue, 1, MPI_TYPE, allValues.data(), 1, MPI_TYPE, comm );
741 
742 #else
743  allValues.resize( 1 );
744  allValues[0] = myValue;
745 #endif
746 }
747 
748 template< typename T >
749 int MpiWrapper::allGather( arrayView1d< T const > const & sendValues,
750  array1d< T > & allValues,
751  MPI_Comm MPI_PARAM( comm ) )
752 {
753  int const sendSize = LvArray::integerConversion< int >( sendValues.size() );
754 #ifdef GEOS_USE_MPI
755  int const mpiSize = commSize( comm );
756  allValues.resize( mpiSize * sendSize );
757  return MPI_Allgather( sendValues.data(),
758  sendSize,
759  internal::getMpiType< T >(),
760  allValues.data(),
761  sendSize,
762  internal::getMpiType< T >(),
763  comm );
764 
765 #else
766  allValues.resize( sendSize );
767  for( localIndex a=0; a<sendSize; ++a )
768  {
769  allValues[a] = sendValues[a];
770  }
771  return 0;
772 #endif
773 }
774 
775 template< typename T >
776 int MpiWrapper::allReduce( T const * const sendbuf,
777  T * const recvbuf,
778  int const count,
779  MPI_Op const MPI_PARAM( op ),
780  MPI_Comm const MPI_PARAM( comm ) )
781 {
782 #ifdef GEOS_USE_MPI
783  MPI_Datatype const mpiType = internal::getMpiType< T >();
784  return MPI_Allreduce( sendbuf == recvbuf ? MPI_IN_PLACE : sendbuf, recvbuf, count, mpiType, op, comm );
785 #else
786  if( sendbuf != recvbuf )
787  {
788  memcpy( recvbuf, sendbuf, count * sizeof( T ) );
789  }
790  return 0;
791 #endif
792 }
793 
794 template< typename T >
795 int MpiWrapper::reduce( T const * const sendbuf,
796  T * const recvbuf,
797  int const count,
798  MPI_Op const MPI_PARAM( op ),
799  int root,
800  MPI_Comm const MPI_PARAM( comm ) )
801 {
802 #ifdef GEOS_USE_MPI
803  MPI_Datatype const mpiType = internal::getMpiType< T >();
804  return MPI_Reduce( sendbuf == recvbuf ? MPI_IN_PLACE : sendbuf, recvbuf, count, mpiType, op, root, comm );
805 #else
806  if( sendbuf != recvbuf )
807  {
808  memcpy( recvbuf, sendbuf, count * sizeof( T ) );
809  }
810  return 0;
811 #endif
812 }
813 
814 template< typename T >
815 int MpiWrapper::scan( T const * const sendbuf,
816  T * const recvbuf,
817  int count,
818  MPI_Op MPI_PARAM( op ),
819  MPI_Comm MPI_PARAM( comm ) )
820 {
821 #ifdef GEOS_USE_MPI
822  return MPI_Scan( sendbuf, recvbuf, count, internal::getMpiType< T >(), op, comm );
823 #else
824  memcpy( recvbuf, sendbuf, count*sizeof(T) );
825  return 0;
826 #endif
827 }
828 
829 template< typename T >
830 int MpiWrapper::exscan( T const * const MPI_PARAM( sendbuf ),
831  T * const recvbuf,
832  int count,
833  MPI_Op MPI_PARAM( op ),
834  MPI_Comm MPI_PARAM( comm ) )
835 {
836 #ifdef GEOS_USE_MPI
837  return MPI_Exscan( sendbuf, recvbuf, count, internal::getMpiType< T >(), op, comm );
838 #else
839  memset( recvbuf, 0, count*sizeof(T) );
840  return 0;
841 #endif
842 }
843 
844 template< typename T >
845 int MpiWrapper::bcast( T * const MPI_PARAM( buffer ),
846  int MPI_PARAM( count ),
847  int MPI_PARAM( root ),
848  MPI_Comm MPI_PARAM( comm ) )
849 {
850 #ifdef GEOS_USE_MPI
851  return MPI_Bcast( buffer, count, internal::getMpiType< T >(), root, comm );
852 #else
853  return 0;
854 #endif
855 
856 }
857 
858 template< typename T >
859 void MpiWrapper::broadcast( T & MPI_PARAM( value ), int MPI_PARAM( srcRank ), MPI_Comm MPI_PARAM( comm ) )
860 {
861 #ifdef GEOS_USE_MPI
862  MPI_Bcast( &value, 1, internal::getMpiType< T >(), srcRank, comm );
863 #endif
864 }
865 
866 template<>
867 inline
868 void MpiWrapper::broadcast< string >( string & MPI_PARAM( value ),
869  int MPI_PARAM( srcRank ),
870  MPI_Comm MPI_PARAM( comm ) )
871 {
872 #ifdef GEOS_USE_MPI
873  int size = LvArray::integerConversion< int >( value.size() );
874  broadcast( size, srcRank, comm );
875  value.resize( size );
876  MPI_Bcast( const_cast< char * >( value.data() ), size, internal::getMpiType< char >(), srcRank, comm );
877 #endif
878 }
879 
880 template< typename TS, typename TR >
881 int MpiWrapper::gather( TS const * const sendbuf,
882  int sendcount,
883  TR * const recvbuf,
884  int recvcount,
885  int MPI_PARAM( root ),
886  MPI_Comm MPI_PARAM( comm ) )
887 {
888 #ifdef GEOS_USE_MPI
889  return MPI_Gather( sendbuf, sendcount, internal::getMpiType< TS >(),
890  recvbuf, recvcount, internal::getMpiType< TR >(),
891  root, comm );
892 #else
893  static_assert( std::is_same< TS, TR >::value,
894  "MpiWrapper::gather() for serial run requires send and receive buffers are of the same type" );
895  std::size_t const sendBufferSize = sendcount * sizeof(TS);
896  std::size_t const recvBufferSize = recvcount * sizeof(TR);
897  GEOS_ERROR_IF_NE_MSG( sendBufferSize, recvBufferSize, "size of send buffer and receive buffer are not equal" );
898  memcpy( recvbuf, sendbuf, sendBufferSize );
899  return 0;
900 #endif
901 }
902 
903 template< typename TS, typename TR >
904 int MpiWrapper::gatherv( TS const * const sendbuf,
905  int sendcount,
906  TR * const recvbuf,
907  const int * recvcounts,
908  const int * MPI_PARAM( displs ),
909  int MPI_PARAM( root ),
910  MPI_Comm MPI_PARAM( comm ) )
911 {
912 #ifdef GEOS_USE_MPI
913  return MPI_Gatherv( sendbuf, sendcount, internal::getMpiType< TS >(),
914  recvbuf, recvcounts, displs, internal::getMpiType< TR >(),
915  root, comm );
916 #else
917  static_assert( std::is_same< TS, TR >::value,
918  "MpiWrapper::gather() for serial run requires send and receive buffers are of the same type" );
919  std::size_t const sendBufferSize = sendcount * sizeof(TS);
920  std::size_t const recvBufferSize = recvcounts[0] * sizeof(TR);
921  GEOS_ERROR_IF_NE_MSG( sendBufferSize, recvBufferSize, "size of send buffer and receive buffer are not equal" );
922  memcpy( recvbuf, sendbuf, sendBufferSize );
923  return 0;
924 #endif
925 }
926 
927 template< typename T >
928 int MpiWrapper::iRecv( T * const buf,
929  int count,
930  int MPI_PARAM( source ),
931  int tag,
932  MPI_Comm MPI_PARAM( comm ),
933  MPI_Request * MPI_PARAM( request ) )
934 {
935 #ifdef GEOS_USE_MPI
936  GEOS_ERROR_IF( (*request)!=MPI_REQUEST_NULL,
937  "Attempting to use an MPI_Request that is still in use." );
938  return MPI_Irecv( buf, count, internal::getMpiType< T >(), source, tag, comm, request );
939 #else
940  std::map< int, std::pair< int, void * > > & pointerMap = getTagToPointersMap();
941  std::map< int, std::pair< int, void * > >::iterator iPointer = pointerMap.find( tag );
942 
943  if( iPointer==pointerMap.end() )
944  {
945  pointerMap.insert( {tag, {1, buf} } );
946  }
947  else
948  {
949  GEOS_ERROR_IF( iPointer->second.first != 0,
950  "Tag does is assigned, but pointer was not set by iSend." );
951  memcpy( buf, iPointer->second.second, count*sizeof(T) );
952  pointerMap.erase( iPointer );
953  }
954  return 0;
955 #endif
956 }
957 
958 template< typename T >
959 int MpiWrapper::recv( array1d< T > & buf,
960  int MPI_PARAM( source ),
961  int tag,
962  MPI_Comm MPI_PARAM( comm ),
963  MPI_Status * MPI_PARAM( request ) )
964 {
965 #ifdef GEOS_USE_MPI
966  MPI_Status status;
967  int count;
968  MPI_Probe( source, tag, comm, &status );
969  MPI_Get_count( &status, MPI_CHAR, &count );
970 
971  GEOS_ASSERT_EQ( count % sizeof( T ), 0 );
972  buf.resize( count / sizeof( T ) );
973 
974  return MPI_Recv( reinterpret_cast< char * >( buf.data() ),
975  count,
976  MPI_CHAR,
977  source,
978  tag,
979  comm,
980  request );
981 #else
982  GEOS_ERROR( "Not implemented!" );
983  return MPI_SUCCESS;
984 #endif
985 }
986 
987 template< typename T >
988 int MpiWrapper::iSend( arrayView1d< T > const & buf,
989  int MPI_PARAM( dest ),
990  int tag,
991  MPI_Comm MPI_PARAM( comm ),
992  MPI_Request * MPI_PARAM( request ) )
993 {
994 #ifdef GEOS_USE_MPI
995  GEOS_ERROR_IF( (*request)!=MPI_REQUEST_NULL,
996  "Attempting to use an MPI_Request that is still in use." );
997  return MPI_Isend( reinterpret_cast< void const * >( buf.data() ),
998  buf.size() * sizeof( T ),
999  MPI_CHAR,
1000  dest,
1001  tag,
1002  comm,
1003  request );
1004 #else
1005  GEOS_ERROR( "Not implemented." );
1006  return MPI_SUCCESS;
1007 #endif
1008 }
1009 
1010 template< typename T >
1011 int MpiWrapper::iSend( T const * const buf,
1012  int count,
1013  int MPI_PARAM( dest ),
1014  int tag,
1015  MPI_Comm MPI_PARAM( comm ),
1016  MPI_Request * MPI_PARAM( request ) )
1017 {
1018 #ifdef GEOS_USE_MPI
1019  GEOS_ERROR_IF( (*request)!=MPI_REQUEST_NULL,
1020  "Attempting to use an MPI_Request that is still in use." );
1021  return MPI_Isend( buf, count, internal::getMpiType< T >(), dest, tag, comm, request );
1022 #else
1023  std::map< int, std::pair< int, void * > > & pointerMap = getTagToPointersMap();
1024  std::map< int, std::pair< int, void * > >::iterator iPointer = pointerMap.find( tag );
1025 
1026  if( iPointer==pointerMap.end() )
1027  {
1028  pointerMap.insert( {tag, {0, const_cast< T * >(buf)}
1029  } );
1030  }
1031  else
1032  {
1033  GEOS_ERROR_IF( iPointer->second.first != 1,
1034  "Tag does is assigned, but pointer was not set by iRecv." );
1035  memcpy( iPointer->second.second, buf, count*sizeof(T) );
1036  pointerMap.erase( iPointer );
1037  }
1038  return 0;
1039 #endif
1040 }
1041 
1042 template< typename U, typename T >
1043 U MpiWrapper::prefixSum( T const value, MPI_Comm comm )
1044 {
1045  U localResult;
1046 
1047 #ifdef GEOS_USE_MPI
1048  U const convertedValue = value;
1049  int const error = MPI_Exscan( &convertedValue, &localResult, 1, internal::getMpiType< U >(), MPI_SUM, comm );
1050  MPI_CHECK_ERROR( error );
1051 #endif
1052  if( commRank() == 0 )
1053  {
1054  localResult = 0;
1055  }
1056 
1057  return localResult;
1058 }
1059 
1060 
1061 template< typename T >
1062 T MpiWrapper::allReduce( T const & value, Reduction const op, MPI_Comm const comm )
1063 {
1064  T result;
1065  allReduce( &value, &result, 1, getMpiOp( op ), comm );
1066  return result;
1067 }
1068 
1069 template< typename T >
1070 void MpiWrapper::allReduce( Span< T const > const src, Span< T > const dst, Reduction const op, MPI_Comm const comm )
1071 {
1072  GEOS_ASSERT_EQ( src.size(), dst.size() );
1073  allReduce( src.data(), dst.data(), LvArray::integerConversion< int >( src.size() ), getMpiOp( op ), comm );
1074 }
1075 
1076 template< typename T >
1077 T MpiWrapper::sum( T const & value, MPI_Comm comm )
1078 {
1079  return MpiWrapper::allReduce( value, Reduction::Sum, comm );
1080 }
1081 
1082 template< typename T >
1083 void MpiWrapper::sum( Span< T const > src, Span< T > dst, MPI_Comm comm )
1084 {
1085  MpiWrapper::allReduce( src, dst, Reduction::Sum, comm );
1086 }
1087 
1088 template< typename T >
1089 T MpiWrapper::min( T const & value, MPI_Comm comm )
1090 {
1091  return MpiWrapper::allReduce( value, Reduction::Min, comm );
1092 }
1093 
1094 template< typename T >
1095 void MpiWrapper::min( Span< T const > src, Span< T > dst, MPI_Comm comm )
1096 {
1097  MpiWrapper::allReduce( src, dst, Reduction::Min, comm );
1098 }
1099 
1100 template< typename T >
1101 T MpiWrapper::max( T const & value, MPI_Comm comm )
1102 {
1103  return MpiWrapper::allReduce( value, Reduction::Max, comm );
1104 }
1105 
1106 template< typename T >
1107 void MpiWrapper::max( Span< T const > src, Span< T > dst, MPI_Comm comm )
1108 {
1109  MpiWrapper::allReduce( src, dst, Reduction::Max, comm );
1110 }
1111 
1112 
1113 template< typename T >
1114 T MpiWrapper::reduce( T const & value, Reduction const op, int root, MPI_Comm const comm )
1115 {
1116  T result;
1117  reduce( &value, &result, 1, getMpiOp( op ), root, comm );
1118  return result;
1119 }
1120 
1121 template< typename T >
1122 void MpiWrapper::reduce( Span< T const > const src, Span< T > const dst, Reduction const op, int root, MPI_Comm const comm )
1123 {
1124  GEOS_ASSERT_EQ( src.size(), dst.size() );
1125  reduce( src.data(), dst.data(), LvArray::integerConversion< int >( src.size() ), getMpiOp( op ), root, comm );
1126 }
1127 
1128 // Mpi helper function to return struct containing the max value and location across ranks
1129 template< typename T >
1130 T MpiWrapper::maxValLoc( T localValueLocation, MPI_Comm comm )
1131 {
1132  // Ensure T is trivially copyable
1133  static_assert( std::is_trivially_copyable< T >::value, "maxValLoc requires a trivially copyable type" );
1134 
1135  // T to have only 2 data members named value and location
1136  static_assert( (sizeof(T::value)+sizeof(T::location)) == sizeof(T) );
1137 
1138  // Ensure T has value and location members are scalars
1139  static_assert( std::is_scalar_v< decltype(T::value) > || std::is_scalar_v< decltype(T::location) >, "members of struct should be scalar" );
1140  static_assert( !std::is_pointer_v< decltype(T::value) > && !std::is_pointer_v< decltype(T::location) >, "members of struct should not be pointers" );
1141 
1142  // receive "buffer"
1143  int const numProcs = commSize( comm );
1144  std::vector< T > recvValLoc( numProcs );
1145 
1146  MPI_Allgather( &localValueLocation, sizeof(T), MPI_BYTE, recvValLoc.data(), sizeof(T), MPI_BYTE, comm );
1147 
1148  T maxValLoc= *std::max_element( recvValLoc.begin(),
1149  recvValLoc.end(),
1150  []( auto & lhs, auto & rhs ) -> bool {return lhs.value < rhs.value; } );
1151 
1152  return maxValLoc;
1153 }
1154 } /* namespace geos */
1155 
1156 #endif /* GEOS_COMMON_MPIWRAPPER_HPP_ */
#define GEOS_ERROR(msg)
Raise a hard error and terminate the program.
Definition: Logger.hpp:157
#define GEOS_ERROR_IF(EXP, msg)
Conditionally raise a hard error and terminate the program.
Definition: Logger.hpp:142
#define GEOS_ERROR_IF_NE_MSG(lhs, rhs, msg)
Raise a hard error if two values are not equal.
Definition: Logger.hpp:243
#define GEOS_ASSERT_EQ(lhs, rhs)
Assert that two values compare equal in debug builds.
Definition: Logger.hpp:410
Lightweight non-owning wrapper over a contiguous range of elements.
Definition: Span.hpp:42
constexpr T * data() const noexcept
Definition: Span.hpp:131
constexpr size_type size() const noexcept
Definition: Span.hpp:107
ArrayView< T, 1 > arrayView1d
Alias for 1D array view.
Definition: DataTypes.hpp:180
int MPI_COMM_GEOS
Global MPI communicator used by GEOSX.
GEOS_LOCALINDEX_TYPE localIndex
Local index type (for indexing objects within an MPI partition).
Definition: DataTypes.hpp:85
std::size_t size_t
Unsigned size type.
Definition: DataTypes.hpp:79
Array< T, 1 > array1d
Alias for 1D array.
Definition: DataTypes.hpp:176
static int allgatherv(T_SEND const *sendbuf, int sendcount, T_RECV *recvbuf, int *recvcounts, int *displacements, MPI_Comm comm)
Strongly typed wrapper around MPI_Allgatherv.
static int bcast(T *buffer, int count, int root, MPI_Comm comm)
Strongly typed wrapper around MPI_Bcast.
static MPI_Op getMpiOp(Reduction const op)
Returns an MPI_Op associated with our strongly typed Reduction enum.
Definition: MpiWrapper.hpp:663
static int activeWaitSomeCompletePhase(const int participants, std::vector< std::tuple< MPI_Request *, MPI_Status *, std::function< MPI_Request(int) > > > const &phases)
static int checkAll(int count, MPI_Request array_of_requests[], int *flag, MPI_Status array_of_statuses[])
static int activeWaitOrderedCompletePhase(const int participants, std::vector< std::tuple< MPI_Request *, MPI_Status *, std::function< MPI_Request(int) > > > const &phases)
static int gather(TS const *const sendbuf, int sendcount, TR *const recvbuf, int recvcount, int root, MPI_Comm comm)
Strongly typed wrapper around MPI_Gather().
static int gatherv(TS const *const sendbuf, int sendcount, TR *const recvbuf, const int *recvcounts, const int *displs, int root, MPI_Comm comm)
Strongly typed wrapper around MPI_Gatherv.
static int check(MPI_Request *request, int *flag, MPI_Status *status)
static int activeWaitAny(const int count, MPI_Request array_of_requests[], MPI_Status array_of_statuses[], std::function< MPI_Request(int) > func)
static int iSend(T const *const buf, int count, int dest, int tag, MPI_Comm comm, MPI_Request *request)
Strongly typed wrapper around MPI_Isend()
static void allGather(T const myValue, array1d< T > &allValues, MPI_Comm comm=MPI_COMM_GEOS)
Convenience function for MPI_Allgather.
static int allgather(T_SEND const *sendbuf, int sendcount, T_RECV *recvbuf, int recvcount, MPI_Comm comm)
Strongly typed wrapper around MPI_Allgather.
static T max(T const &value, MPI_Comm comm=MPI_COMM_GEOS)
Convenience function for a MPI_Allreduce using a MPI_MAX operation.
static int activeWaitSome(const int count, MPI_Request array_of_requests[], MPI_Status array_of_statuses[], std::function< MPI_Request(int) > func)
static int allReduce(T const *sendbuf, T *recvbuf, int count, MPI_Op op, MPI_Comm comm=MPI_COMM_GEOS)
Strongly typed wrapper around MPI_Allreduce.
static U prefixSum(T const value, MPI_Comm comm=MPI_COMM_GEOS)
Compute exclusive prefix sum and full sum.
static T maxValLoc(T localValueLocation, MPI_Comm comm=MPI_COMM_GEOS)
Convenience function for MPI_Gather using a MPI_MAX operation on struct of value and location.
static int checkAny(int count, MPI_Request array_of_requests[], int *idx, int *flag, MPI_Status array_of_statuses[])
static int iRecv(T *const buf, int count, int source, int tag, MPI_Comm comm, MPI_Request *request)
Strongly typed wrapper around MPI_Irecv()
static T sum(T const &value, MPI_Comm comm=MPI_COMM_GEOS)
Convenience function for a MPI_Allreduce using a MPI_SUM operation.
static T min(T const &value, MPI_Comm comm=MPI_COMM_GEOS)
Convenience function for a MPI_Allreduce using a MPI_MIN operation.
static int nodeCommSize()
Compute the number of ranks allocated on the same node.
static void broadcast(T &value, int srcRank=0, MPI_Comm comm=MPI_COMM_GEOS)
Convenience function for MPI_Broadcast.
static int reduce(T const *sendbuf, T *recvbuf, int count, MPI_Op op, int root, MPI_Comm comm=MPI_COMM_GEOS)
Strongly typed wrapper around MPI_Reduce.