Source code for geos.processing.tools.ProfileExtractor
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright 2023-2026 TotalEnergies.
# SPDX-FileContributor: Nicolas Pillardou, Paloma Martinez
import logging
import numpy as np
import numpy.typing as npt
from typing_extensions import Self, Union
from enum import Enum
from vtkmodules.vtkCommonDataModel import vtkCellData
from vtkmodules.util.numpy_support import vtk_to_numpy
from geos.utils.Logger import ( Logger, getLogger )
loggerTitle = "Profile Extractor"
[docs]
class ProfileExtractor:
def __init__( self: Self, logger: Union[ Logger, None ] = None ) -> None:
"""Utility class for extracting profiles along fault surfaces.
Args:
logger (Union[Logger, None], optional): A logger to manage the output messages.
Defaults to None, an internal logger is used.
"""
# Logger
self.logger: Logger
if logger is None:
self.logger = getLogger( loggerTitle, True )
else:
self.logger = logging.getLogger( f"{logger.name}" )
self.logger.setLevel( logging.INFO )
self.logger.propagate = False
[docs]
def extractAdaptiveProfile(
self: Self,
centers: npt.NDArray[ np.float64 ],
values: npt.NDArray[ np.float64 ],
xStart: float,
yStart: float,
zStart: float | None = None,
stepSize: float = 20.0,
maxSteps: int = 500,
cellData: vtkCellData | None = None
) -> tuple[ npt.NDArray[ np.float64 ], npt.NDArray[ np.float64 ], npt.NDArray[ np.float64 ],
npt.NDArray[ np.float64 ] ]:
"""Extract a vertical depth profile with automatic fault detection.
The algorithm adaptively follows a vertical sampling strategy guided by
detected fault membership inside the provided cell data. It performs:
1. Finding the closest starting point to the provided (xStart, yStart, zStart).
2. Automatically detecting the target fault using the provided ``cellData``
(e.g., fields like ``FaultMask`` or any other fault-identifying attribute).
3. Filtering the dataset to keep **only cells belonging to that fault**.
4. Splitting the remaining dataset into successive Z-depth slices.
5. For each slice, selecting the nearest cell in the XY plane to build the
final vertical profile.
Args:
centers (np.ndarray): Array of cell centers with shape ``(nCells, 3)``.
values (np.ndarray): Scalar values associated with each cell (shape ``(nCells,)``).
xStart (float or ndarray): Starting X coordinate.
yStart (float or ndarray): Starting Y coordinate.
zStart (float or ndarray | None): Starting Z coordinate. If ``None``, the method uses
the highest point near the provided XY start position.
stepSize (float): Vertical step size used when scanning depth layers.
Default is 20.0.
maxSteps (int): Maximum number of vertical layers to traverse.
Default is 500.
cellData (vtkCellData | None): VTK cell data object containing fields such as
``attribute``, ``FaultMask``, or other identifiers used to detect and filter
the target fault.
Returns:
tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
depth, profile values, X and Y coordinates of the profile path.
"""
# Convert to np arrays
centers = np.asarray( centers )
values = np.asarray( values )
if len( centers ) == 0:
raise ValueError( "No cell provided." )
# ===================================================================
# STEP 1 : Find starting point
# ===================================================================
if zStart is None:
# Look in 2D (XY), take the highest values
self.logger.info( f" Searching near ({xStart:.1f}, {yStart:.1f})" )
dXY = np.sqrt( ( centers[ :, 0 ] - xStart )**2 + ( centers[ :, 1 ] - yStart )**2 )
closestIndices = np.argsort( dXY )[ :20 ]
if len( closestIndices ) == 0:
raise ValueError( "No cells found near start point" )
# Take the highest Z value
closestDepths = centers[ closestIndices, 2 ]
startIdx = closestIndices[ np.argmax( closestDepths ) ]
else:
# Look in 3D
self.logger.info( f" Searching near ({xStart:.1f}, {yStart:.1f}, {zStart:.1f})" )
d3D = np.sqrt( ( centers[ :, 0 ] - xStart )**2 + ( centers[ :, 1 ] - yStart )**2 +
( centers[ :, 2 ] - zStart )**2 )
startIdx = np.argmin( d3D )
startPoint = centers[ startIdx ]
self.logger.info( f" Starting point: ({startPoint[0]:.1f}, {startPoint[1]:.1f}, {startPoint[2]:.1f}) -"
f" Cell index: {startIdx}" )
# ===================================================================
# STEP 2: Auto-detection of target fault
# ===================================================================
faultIds = None
targetFaultId = None
if cellData is not None:
faultFieldNames = [ 'attribute', 'FaultMask', 'faultId', 'region' ]
for fieldName in faultFieldNames:
if cellData.HasArray( fieldName ):
faultIds = vtk_to_numpy( cellData.GetArray( fieldName ) )
if len( faultIds ) != len( centers ):
self.logger.warning( f" Field '{fieldName}' length mismatch, skipping" )
continue
# Get starting ID
targetFaultId = faultIds[ startIdx ]
uniqueIds = np.unique( faultIds )
self.logger.info( f" Found fault field: '{fieldName}'" )
self.logger.info( f" Available fault IDs: {uniqueIds}" )
self.logger.info( f" Target fault ID at start point: {targetFaultId}" )
break
# ===================================================================
# STEP 3: Filter dataset to keep only fault cells
# ===================================================================
if targetFaultId is not None:
maskSameFault = ( faultIds == targetFaultId )
nTotal = len( centers )
nOnFault = np.sum( maskSameFault )
self.logger.info(
f" Filtering to fault ID={targetFaultId}: {nOnFault}/{nTotal} cells ({nOnFault/nTotal*100:.1f}%)"
)
if nOnFault == 0:
raise ValueError( "No cells found on target fault." )
# Replace centers and values by the filtered subset
centers = centers[ maskSameFault ].copy()
values = values[ maskSameFault ].copy()
# Find new starting index in the subset
dToStart = np.sqrt( np.sum( ( centers - startPoint )**2, axis=1 ) )
startIdx = np.argmin( dToStart )
self.logger.info( f" Profile will stay on fault ID={targetFaultId}" )
else:
self.logger.warning( " No fault identification field found" )
# ===================================================================
# STEP 4: Z-slicing of the fault
# ===================================================================
refX = centers[ startIdx, 0 ]
refY = centers[ startIdx, 1 ]
self.logger.info( f" Reference XY: ({refX:.1f}, {refY:.1f})" )
# ===================================================================
# STEP 5: Fault geometry
# ===================================================================
xRange = np.max( centers[ :, 0 ] ) - np.min( centers[ :, 0 ] )
yRange = np.max( centers[ :, 1 ] ) - np.min( centers[ :, 1 ] )
zRange = np.max( centers[ :, 2 ] ) - np.min( centers[ :, 2 ] )
if zRange <= 0:
raise ValueError( f"Invalid zRange: {zRange}" )
lateralExtent = max( xRange, yRange )
xyTolerance = max( lateralExtent * 0.3, 100.0 )
self.logger.info( f" Fault extent: X={xRange:.1f}m, Y={yRange:.1f}m, Z={zRange:.1f}m" )
self.logger.info( f" XY tolerance: {xyTolerance:.1f}m" )
# ===================================================================
# STEP 6: Slice computation
# ===================================================================
zCoordsSorted = np.sort( centers[ :, 2 ] )
zDiffs = np.diff( zCoordsSorted )
zDiffsPositive = zDiffs[ zDiffs > 1e-6 ]
if len( zDiffsPositive ) == 0:
self.logger.warning( " All cells at same Z" )
dXY = np.sqrt( ( centers[ :, 0 ] - refX )**2 + ( centers[ :, 1 ] - refY )**2 )
sortedIndices = np.argsort( dXY )
return ( centers[ sortedIndices, 2 ], values[ sortedIndices ], centers[ sortedIndices,
0 ], centers[ sortedIndices, 1 ] )
medianZSpacing = np.median( zDiffsPositive )
# Check that medianZSpacing is reasonable
if medianZSpacing <= 0 or medianZSpacing > zRange:
medianZSpacing = zRange / 100 # Fallback
sliceThickness = medianZSpacing
zMin = np.min( centers[ :, 2 ] )
zMax = np.max( centers[ :, 2 ] )
nSlices = int( np.ceil( zRange / sliceThickness ) )
nSlices = min( nSlices, 10000 ) # Limit to 10k slices max
if nSlices <= 0:
raise ValueError( f" Invalid nSlices: {nSlices}" )
self.logger.info( f" Median Z spacing: {medianZSpacing:.1f}m" )
self.logger.info( f" Creating {nSlices} slices" )
zSlices = np.linspace( zMax, zMin, nSlices + 1 )
# ===================================================================
# STEP 7: Slice extraction
# ===================================================================
profileIndices = []
for i in range( len( zSlices ) - 1 ):
zTop = zSlices[ i ]
zBottom = zSlices[ i + 1 ]
# Cells in this slice
maskInSlice = ( centers[ :, 2 ] <= zTop ) & ( centers[ :, 2 ] >= zBottom )
indicesInSlice = np.where( maskInSlice )[ 0 ]
if len( indicesInSlice ) == 0:
continue
# XY distance from reference
dXYInSlice = np.sqrt( ( centers[ indicesInSlice, 0 ] - refX )**2 +
( centers[ indicesInSlice, 1 ] - refY )**2 )
validMask = dXYInSlice < xyTolerance
if not np.any( validMask ):
closestInSlice = indicesInSlice[ np.argmin( dXYInSlice ) ]
else:
validIndices = indicesInSlice[ validMask ]
dXYValid = dXYInSlice[ validMask ]
closestInSlice = validIndices[ np.argmin( dXYValid ) ]
profileIndices.append( closestInSlice )
# ===================================================================
# STEP 8: Delete duplicate and sort
# ===================================================================
seen = set()
uniqueIndices = []
for idx in profileIndices:
if idx not in seen:
seen.add( idx )
uniqueIndices.append( idx )
if len( uniqueIndices ) == 0:
raise ValueError( "No points extracted for profile {idx}" )
profileIndicesArr = np.array( uniqueIndices )
# Sort by decreasing depth
sortOrder = np.argsort( -centers[ profileIndicesArr, 2 ] )
profileIndicesArr = profileIndicesArr[ sortOrder ]
# Extract results
depths = centers[ profileIndicesArr, 2 ]
profileValues = values[ profileIndicesArr ]
pathX = centers[ profileIndicesArr, 0 ]
pathY = centers[ profileIndicesArr, 1 ]
# ===================================================================
# STATISTICS
# ===================================================================
depthCoverage = ( depths.max() - depths.min() ) / zRange * 100 if zRange > 0 else 0
xyDisplacement = np.sqrt( ( pathX[ -1 ] - pathX[ 0 ] )**2 + ( pathY[ -1 ] - pathY[ 0 ] )**2 )
self.logger.info( f" Extracted {len(profileIndices)} points" )
self.logger.info( f" Depth range: [{depths.max():.1f}, {depths.min():.1f}]m" )
self.logger.info( f" Coverage: {depthCoverage:.1f}% of fault depth" )
self.logger.info( f" XY displacement: {xyDisplacement:.1f}m" )
return ( depths, profileValues, pathX, pathY )
[docs]
class ProfileExtractorMethod( str, Enum ):
"""String Enum of profile extraction method."""
VERTICAL_TOPO_BASED = "VerticalProfileTopologyBased"
ADAPTATIVE = "AdaptativeProfile"