Kernel interface

Finite Element Method Kernel Interface

The finite element method kernel interface (FEMKI) specifies an API for the launching of computational kernels for solving physics discretized using the finite element method. Using this approach, a set of generic element looping pattens and kernel launching functions may be implemented, and reused by various physics solvers that contain kernels conforming to the FEMKI.

There are several main components of the FEMKI:

  1. A collection of element looping functions that provide various looping patterns, and call the launch function.

  2. The kernel interface, which is specified by the finiteElement::KernelBase class. Each physics solver will define a class that contains its kernels functions, most likely deriving, or conforming to the API specified by the KernelBase class. Also part of this class will typically be a nested StackVariables class that defines a collection of stack variables for use in the various kernel interface functions.

  3. A launch function, which launches the kernel, and calls the kernel interface functions conforming to the interface defined by KernelBase. This function is actually a member function of the Kernel class, so it may be overridden by a specific physics kernel, allowing complete customization of the interface, while maintaining the usage of the looping patterns.

A Generic Element Looping Pattern

One example of a looping pattern is the regionBasedKernelApplication function.

The contents of the looping function are displayed here:

/**
 * @brief Performs a loop over specific regions (by type and name) and calls a kernel launch on the subregions
 *   with compile time knowledge of sub-loop bounds such as number of nodes and quadrature points per element.
 * @tparam POLICY The RAJA launch policy to pass to the kernel launch.
 * @tparam CONSTITUTIVE_BASE The common base class for constitutive pass-thru/dispatch which gives the kernel
 *   launch compile time knowledge of the constitutive model. This is achieved through a call to the
 *   ConstitutivePassThru function which should have a specialization for CONSTITUTIVE_BASE implemented in
 *   order to perform the compile time dispatch.
 * @tparam SUBREGION_TYPE The type of subregion to loop over. TODO make this a parameter pack?
 * @tparam KERNEL_FACTORY The type of @p kernelFactory, typically an instantiation of @c KernelFactory, and
 *   must adhere to that interface.
 * @param mesh The MeshLevel object.
 * @param targetRegions The names of the target regions(of type @p SUBREGION_TYPE) to apply the @p KERNEL_TEMPLATE.
 * @param finiteElementName The name of the finite element.
 * @param constitutiveStringName The key to the constitutive model name found on the Region.
 * @param kernelFactory The object used to construct the kernel.
 * @return The maximum contribution to the residual, which may be used to scale the residual.
 *
 * @details Loops over all regions Applies/Launches a kernel specified by the @p KERNEL_TEMPLATE through
 * #::geos::finiteElement::KernelBase::kernelLaunch().
 */
template< typename POLICY,
          typename CONSTITUTIVE_BASE,
          typename SUBREGION_TYPE,
          typename KERNEL_FACTORY >
static
real64 regionBasedKernelApplication( MeshLevel & mesh,
                                     arrayView1d< string const > const & targetRegions,
                                     string const & finiteElementName,
                                     string const & constitutiveStringName,
                                     KERNEL_FACTORY & kernelFactory )
{
  GEOS_MARK_FUNCTION;
  // save the maximum residual contribution for scaling residuals for convergence criteria.
  real64 maxResidualContribution = 0;

  NodeManager & nodeManager = mesh.getNodeManager();
  EdgeManager & edgeManager = mesh.getEdgeManager();
  FaceManager & faceManager = mesh.getFaceManager();
  ElementRegionManager & elementRegionManager = mesh.getElemManager();

  // Loop over all sub-regions in regions of type SUBREGION_TYPE, that are listed in the targetRegions array.
  elementRegionManager.forElementSubRegions< SUBREGION_TYPE >( targetRegions,
                                                               [&constitutiveStringName,
                                                                &maxResidualContribution,
                                                                &nodeManager,
                                                                &edgeManager,
                                                                &faceManager,
                                                                &kernelFactory,
                                                                &finiteElementName]
                                                                 ( localIndex const targetRegionIndex, auto & elementSubRegion )
  {
    localIndex const numElems = elementSubRegion.size();

    // Get the constitutive model...and allocate a null constitutive model if required.

    constitutive::ConstitutiveBase * constitutiveRelation = nullptr;
    constitutive::NullModel * nullConstitutiveModel = nullptr;
    if( elementSubRegion.template hasWrapper< string >( constitutiveStringName ) )
    {
      string const & constitutiveName = elementSubRegion.template getReference< string >( constitutiveStringName );
      constitutiveRelation = &elementSubRegion.template getConstitutiveModel( constitutiveName );
    }
    else
    {
      nullConstitutiveModel = &elementSubRegion.template registerGroup< constitutive::NullModel >( "nullModelGroup" );
      constitutiveRelation = nullConstitutiveModel;
    }

    // Call the constitutive dispatch which converts the type of constitutive model into a compile time constant.
    constitutive::ConstitutivePassThru< CONSTITUTIVE_BASE >::execute( *constitutiveRelation,
                                                                      [&maxResidualContribution,
                                                                       &nodeManager,
                                                                       &edgeManager,
                                                                       &faceManager,
                                                                       targetRegionIndex,
                                                                       &kernelFactory,
                                                                       &elementSubRegion,
                                                                       &finiteElementName,
                                                                       numElems]
                                                                        ( auto & castedConstitutiveRelation )
    {
      FiniteElementBase &
      subRegionFE = elementSubRegion.template getReference< FiniteElementBase >( finiteElementName );

      finiteElement::FiniteElementDispatchHandler< SELECTED_FE_TYPES >::dispatch3D( subRegionFE,
                                                                                    [&maxResidualContribution,
                                                                                     &nodeManager,
                                                                                     &edgeManager,
                                                                                     &faceManager,
                                                                                     targetRegionIndex,
                                                                                     &kernelFactory,
                                                                                     &elementSubRegion,
                                                                                     numElems,
                                                                                     &castedConstitutiveRelation] ( auto const finiteElement )
      {
        auto kernel = kernelFactory.createKernel( nodeManager,
                                                  edgeManager,
                                                  faceManager,
                                                  targetRegionIndex,
                                                  elementSubRegion,
                                                  finiteElement,
                                                  castedConstitutiveRelation );

        using KERNEL_TYPE = decltype( kernel );

        // Call the kernelLaunch function, and store the maximum contribution to the residual.
        maxResidualContribution =
          std::max( maxResidualContribution,
                    KERNEL_TYPE::template kernelLaunch< POLICY, KERNEL_TYPE >( numElems, kernel ) );
      } );
    } );

    // Remove the null constitutive model (not required, but cleaner)
    if( nullConstitutiveModel )
    {
      elementSubRegion.deregisterGroup( "nullModelGroup" );
    }

  } );

  return maxResidualContribution;
}

This pattern may be used with any kernel class that either:

  1. Conforms to the KernelBase interface by defining each of the kernel functions in KernelBase.

  2. Defines its own kernelLaunch function that conforms the the signature of KernelBase::kernelLaunch. This option essentially allows for a custom kernel that does not conform to the interface defined by KernelBase and KernelBase::kernelLaunch.

The KernelBase::kernelLaunch Interface

The kernelLaunch function is a member of the kernel class itself. As mentioned above, a physics implementation may use the existing KernelBase interface, or define its own. The KernelBase::kernelLaunch function defines a launching policy, and an internal looping pattern over the quadrautre points, and calls the functions defined by the KernelBase as shown here:

  template< typename POLICY,
            typename KERNEL_TYPE >
  static
  real64
  kernelLaunch( localIndex const numElems,
                KERNEL_TYPE const & kernelComponent )
  {
    GEOS_MARK_FUNCTION;

    // Define a RAJA reduction variable to get the maximum residual contribution.
    RAJA::ReduceMax< ReducePolicy< POLICY >, real64 > maxResidual( 0 );

    forAll< POLICY >( numElems,
                      [=] GEOS_HOST_DEVICE ( localIndex const k )
    {
      typename KERNEL_TYPE::StackVariables stack;

      kernelComponent.setup( k, stack );
      // #pragma unroll
      for( integer q=0; q<numQuadraturePointsPerElem; ++q )
      {
        kernelComponent.quadraturePointKernel( k, q, stack );
      }
      maxResidual.max( kernelComponent.complete( k, stack ) );
    } );
    return maxResidual.get();
  }

Each of the KernelBase functions called in the KernelBase::kernelLaunch function are intended to provide a certain amount of modularity and flexibility for the physics implementations. The general purpose of each function is described by the function name, but may be further descibed by the function documentation found here.