import h5py # type: ignore[import]
from mpi4py import MPI # type: ignore[import]
import numpy as np # type: ignore[import]
import sys
import os
import re
import argparse
import logging
import time
import string
from pathlib import Path
try:
from geos.ats.helpers.permute_array import permuteArray # type: ignore[import]
except ImportError:
# Fallback method to be used if geos-ats isn't found
from permute_array import permuteArray # type: ignore[import]
RTOL_DEFAULT = 0.0
ATOL_DEFAULT = 0.0
EXCLUDE_DEFAULT = [ ".*/commandLine", ".*/schema$", ".*/globalToLocalMap", ".*/timeHistoryOutput.*/restart" ]
logger = logging.getLogger( 'geos-ats' )
[docs]
def write( output, msg ):
"""
Write MSG to both stdout and OUTPUT.
OUTPUT [in/out]: File stream to write to.
MSG [in]: Message to write.
"""
msg = str( msg )
sys.stdout.write( msg )
sys.stdout.flush()
output.write( msg )
def is_lfs_pointer( fname ):
res = False
try:
header = str( open( fname, 'rb' ).read( 16 ) )
if 'Git LFS pointer' in header:
res = True
except Exception:
pass
return res
def load_hdf5( fname, max_wait_time=10, mode='r' ):
file = None
for ii in range( max_wait_time ):
if os.path.isfile( fname ):
try:
file = h5py.File( fname, mode )
logger.debug( f'Opened file: {fname}' )
break
except IOError:
logger.warning( f'Failed to open file: {fname} (attempt {ii+1}/{max_wait_time})' )
if is_lfs_pointer( fname ):
raise Exception( f'Target LFS object is not initialized: {fname}' )
time.sleep( 1 )
return file
def h5PathJoin( p1, p2 ):
if p1 == "/":
return "/" + p2
if p1 == "":
return p2
return p1 + "/" + p2
[docs]
class FileComparison( object ):
"""
Class that compares two hdf5 files.
"""
def __init__( self,
file_path,
baseline_path,
rtol,
atol,
regex_expressions,
output,
warnings_are_errors,
skip_missing,
diff_file=None ):
"""
FILE_PATH [in]: The path of the first file to compare.
BASELINE_PATH [in]: The path of the baseline file to compare against.
RTOL [in]: The relative tolerance used in comparing floating point numbers.
ATOL [in]: The absolute tolerance used in comparing floating point numbers.
REGEX_EXPRESSIONS [in]: A list of compiled regex expressions that match hdf5 groups and datasets to exclude.
OUTPUT [in/out]: The file stream to write output to.
WARNIGNS_ARE_ERRORS [in]: Boolean specifying whether warnings are to be treated as errors.
"""
self.file_path = file_path
self.baseline_path = baseline_path
self.rtol = rtol
self.atol = atol
self.regex_expressions = regex_expressions
self.output = output
self.warnings_are_errors = warnings_are_errors
self.skip_missing = skip_missing
self.diff_file = diff_file
self.different = False
assert ( self.rtol >= 0.0 )
assert ( self.atol >= 0.0 )
def filesDiffer( self ):
# Check to see if the file is on the disk, and wait in case there is any lag in IO
file = load_hdf5( self.file_path )
base_file = load_hdf5( self.baseline_path )
rank = MPI.COMM_WORLD.Get_rank()
self.output.write( f"\nRank {rank} is comparing {self.file_path} with {self.baseline_path} \n" )
# Compare the files
if ( file is not None ) and ( base_file is not None ):
self.file_path = file.filename
self.baseline_path = base_file.filename
self.compareGroups( file, base_file )
else:
if file is None:
self.output.write( f"\nRank {rank} failed to load target file: {self.file_path}\n" )
if base_file is None:
self.output.write( f"\nRank {rank} failed to load baseline file: {self.baseline_path}\n" )
self.different = True
return self.different
def add_links( self, path, message ):
# When comparing the root groups self.diff_file is none.
if self.diff_file is None:
return
base_name = os.path.basename( self.file_path )
diff_group_name = base_name + "/" + path
diff_group = self.diff_file.create_group( diff_group_name )
diff_group.create_dataset( "message", data=message )
diff_group[ "run" ] = h5py.ExternalLink( self.file_path, path )
diff_group[ "baseline" ] = h5py.ExternalLink( self.baseline_path, path )
[docs]
def errorMsg( self, path, message, add_to_diff=False ):
"""
Issue an error which occurred at PATH in the files with the contents of MESSAGE.
Sets self.different to True and rites the error to both stdout and OUTPUT.
PATH [in]: The path in the files at which the error occurred.
MESSAGE [in]: The error message.
"""
self.different = True
msg = '*' * 80 + "\n"
msg += "Error: %s\n" % path
msg += "\t" + "\n\t".join( message.split( "\n" ) )[ :-1 ]
msg += '*' * 80 + "\n"
self.output.write( msg )
if add_to_diff:
self.add_links( path, message )
[docs]
def warningMsg( self, path, message ):
"""
Issue a warning which occurred at PATH in the files with the contents of MESSAGE.
Writes the warning to both stdout and OUTPUT. If WARNINGS_ARE_ERRORS then this
is a wrapper around errorMsg.
PATH [in]: The path in the files at which the warning occurred.
MESSAGE [in]: The warning message.
"""
if self.warnings_are_errors:
return self.errorMsg( path, message )
msg = '*' * 80 + "\n"
msg += "Warning: %s\n" % path
msg += "\t" + "\n\t".join( message.split( "\n" ) )[ :-1 ]
msg += '*' * 80 + "\n"
self.output.write( msg )
[docs]
def isExcluded( self, path ):
"""
Return True iff path matches any of the regex expressions in self.regex_expressions.
PATH [in]: The path to match.
"""
for regex in self.regex_expressions:
if regex.match( path ) is not None:
return True
return False
[docs]
def compareFloatScalars( self, path, val, base_val ):
"""
Compare floating point scalars.
PATH [in]: The path at which the comparison occurs.
VAL [in]: The value to compare.
BASE_VAL [in]: The baseline value to compare against.
"""
dif = abs( val - base_val )
if dif > self.atol and dif > self.rtol * abs( base_val ):
msg = "Scalar values of types %s and %s differ: %s, %s.\n" % ( val.dtype, base_val.dtype, val, base_val )
self.errorMsg( path, msg, True )
[docs]
def compareIntScalars( self, path, val, base_val ):
"""
Compare integer scalars.
PATH [in]: The path at which the comparison occurs.
VAL [in]: The value to compare.
BASE_VAL [in]: The baseline value to compare against.
"""
if val != base_val:
msg = "Scalar values of types %s and %s differ: %s, %s.\n" % ( val.dtype, base_val.dtype, val, base_val )
self.errorMsg( path, msg, True )
[docs]
def compareStringScalars( self, path, val, base_val ):
"""
Compare string scalars.
PATH [in]: The path at which the comparison occurs.
VAL [in]: The value to compare.
BASE_VAL [in]: The baseline value to compare against.
"""
if val != base_val:
msg = "Scalar values of types %s and %s differ: %s, %s.\n" % ( val.dtype, base_val.dtype, val, base_val )
self.errorMsg( path, msg, True )
[docs]
def compareFloatArrays( self, path, arr, base_arr ):
"""
Compares two arrays ARR and BASEARR of floating point values.
Entries x1 and x2 are considered equal iff:
abs(x1 - x2) <= ATOL * ( 1 + max(abs(x2)) )
or
abs(x1 - x2) <= RTOL * abs(x2).
To measure the degree of difference a scaling factor q
is introduced. The goal is now to minimize q such that:
abs(x1 - x2) <= ATOL * ( 1 + max(abs(x2)) ) * q
or
abs(x1 - x2) <= RTOL * abs(x2) * q.
If RTOL * abs(x2) > ATOL * ( 1 + max(abs(x2)) )
q = abs(x1 - x2) / (RTOL * abs(x2))
else
q = abs(x1 - x2) / ( ATOL * ( 1 + max(abs(x2)) ) ).
If the maximum value of q over all the
entries is greater than 1.0 then the arrays are considered
different and an error message is produced.
PATH [in]: The path at which the comparison takes place.
ARR [in]: The hdf5 Dataset to compare.
BASE_ARR [in]: The hdf5 Dataset to compare against.
"""
# If we have zero tolerance then just call the compareIntArrays function.
if self.rtol == 0.0 and self.atol == 0.0:
return self.compareIntArrays( path, arr, base_arr )
# If the shapes are different they can't be compared.
if arr.shape != base_arr.shape:
msg = "Datasets have different shapes and therefore can't be compared: %s, %s.\n" % ( arr.shape,
base_arr.shape )
self.errorMsg( path, msg, True )
return
# First create a copy of the data in the datasets.
arr_cpy = np.copy( arr )
base_arr_cpy = np.copy( base_arr )
# Now compute the difference and store the result in ARR1_CPY
# which is appropriately renamed DIFFERENCE.
difference = np.subtract( arr, base_arr, out=arr_cpy )
np.abs( difference, out=difference )
# Take the absolute value of BASE_ARR_CPY and rename it to ABS_BASE_ARR
abs_base_arr = np.abs( base_arr_cpy, out=base_arr_cpy )
# max_abs_base_arr = np.max( abs_base_arr )
# comm = MPI.COMM_WORLD
# size = comm.Get_size()
# if size > 1:
# max_abs_base_arr = comm.allreduce(max_abs_base_arr, op=MPI.MAX)
# absTol = (1.0 + max_abs_base_arr) * self.atol
absTol = self.atol
# Get the indices of the max absolute and relative error
max_absolute_index = np.unravel_index( np.argmax( difference ), difference.shape )
relative_difference = difference / ( abs_base_arr + 1e-20 )
# If the absolute tolerance is not zero, replace all nan's with zero.
if self.atol != 0:
relative_difference = np.nan_to_num( relative_difference, 0 )
max_relative_index = np.unravel_index( np.argmax( relative_difference ), relative_difference.shape )
if self.rtol != 0.0:
relative_difference /= self.rtol
if self.rtol == 0.0:
difference /= absTol
q = difference
absolute_limited = np.ones( q.shape, dtype=bool )
elif self.atol == 0.0:
q = relative_difference
absolute_limited = np.zeros( q.shape, dtype=bool )
else:
# Multiply ABS_BASE_ARR by RTOL and rename it to RTOL_ABS_BASE
rtol_abs_base = np.multiply( self.rtol, abs_base_arr, out=abs_base_arr )
# Calculate which entries are limited by the relative tolerance.
relative_limited = rtol_abs_base > absTol
# Rename DIFFERENCE to Q where we will store the scaling parameter q.
q = difference
q[ relative_limited ] = relative_difference[ relative_limited ]
# Compute q for the entries which are limited by the absolute tolerance.
absolute_limited = np.logical_not( relative_limited, out=relative_limited )
q[ absolute_limited ] /= absTol
# If the maximum q value is greater than 1.0 than issue an error.
if np.max( q ) > 1.0:
offenders = np.greater( q, 1.0 )
n_offenders = np.sum( offenders )
absolute_offenders = np.logical_and( offenders, absolute_limited, out=offenders )
q_num_absolute = np.sum( absolute_offenders )
if q_num_absolute > 0:
absolute_qs = q * absolute_offenders
q_max_absolute = np.max( absolute_qs )
q_max_absolute_index = np.unravel_index( np.argmax( absolute_qs ), absolute_qs.shape )
q_mean_absolute = np.mean( absolute_qs )
q_std_absolute = np.std( absolute_qs )
offenders = np.greater( q, 1.0, out=offenders )
relative_limited = np.logical_not( absolute_limited, out=absolute_limited )
relative_offenders = np.logical_and( offenders, relative_limited, out=offenders )
q_num_relative = np.sum( relative_offenders )
if q_num_relative > 0:
relative_qs = q * relative_offenders
q_max_relative = np.max( relative_qs )
q_max_relative_index = np.unravel_index( np.argmax( relative_qs ), q.shape )
q_mean_relative = np.mean( relative_qs )
q_std_relative = np.std( relative_qs )
message = "Arrays of types %s and %s have %d values of which %d fail both the relative and absolute tests.\n" % (
arr.dtype, base_arr.dtype, offenders.size, n_offenders )
message += "\tMax absolute difference is at index %s: value = %s, base_value = %s\n" % (
max_absolute_index, arr[ max_absolute_index ], base_arr[ max_absolute_index ] )
message += "\tMax relative difference is at index %s: value = %s, base_value = %s\n" % (
max_relative_index, arr[ max_relative_index ], base_arr[ max_relative_index ] )
message += "Statistics of the q values greater than 1.0 defined by absolute tolerance: N = %d\n" % q_num_absolute
if q_num_absolute > 0:
message += "\tmax = %s, mean = %s, std = %s\n" % ( q_max_absolute, q_mean_absolute, q_std_absolute )
message += "\tmax is at index %s, value = %s, base_value = %s\n" % (
q_max_absolute_index, arr[ q_max_absolute_index ], base_arr[ q_max_absolute_index ] )
message += "Statistics of the q values greater than 1.0 defined by relative tolerance: N = %d\n" % q_num_relative
if q_num_relative > 0:
message += "\tmax = %s, mean = %s, std = %s\n" % ( q_max_relative, q_mean_relative, q_std_relative )
message += "\tmax is at index %s, value = %s, base_value = %s\n" % (
q_max_relative_index, arr[ q_max_relative_index ], base_arr[ q_max_relative_index ] )
self.errorMsg( path, message, True )
[docs]
def compareIntArrays( self, path, arr, base_arr ):
"""
Compare two integer datasets. Exact equality is used as the acceptance criteria.
PATH [in]: The path at which the comparison takes place.
ARR [in]: The hdf5 Dataset to compare.
BASE_ARR [in]: The hdf5 Dataset to compare against.
"""
message = ""
if arr.shape != base_arr.shape:
message = "Datasets have different shapes and therefore can't be compared statistically: %s, %s.\n" % (
arr.shape, base_arr.shape )
else:
# Calculate the absolute difference.
difference = np.subtract( arr, base_arr )
np.abs( difference, out=difference )
offenders = difference != 0.0
n_offenders = np.sum( offenders )
if n_offenders != 0:
max_index = np.unravel_index( np.argmax( difference ), difference.shape )
max_difference = difference[ max_index ]
offenders_mean = np.mean( difference[ offenders ] )
offenders_std = np.std( difference[ offenders ] )
message = "Arrays of types %s and %s have %s values of which %d have differing values.\n" % (
arr.dtype, base_arr.dtype, offenders.size, n_offenders )
message += "Statistics of the differences greater than 0:\n"
message += "\tmax_index = %s, max = %s, mean = %s, std = %s\n" % ( max_index, max_difference,
offenders_mean, offenders_std )
# actually, int8 arrays are almost always char arrays, so we sould add a character comparison.
if arr.dtype == np.int8 and base_arr.dtype == np.int8:
message += self.compareCharArrays( arr, base_arr )
if message != "":
self.errorMsg( path, message, True )
[docs]
def compareCharArrays( self, comp_arr, base_arr ):
"""
Compare the valid characters of two arrays and return a formatted string showing differences.
COMP_ARR [in]: The hdf5 Dataset to compare.
BASE_ARR [in]: The hdf5 Dataset to compare against.
Returns a formatted string highlighting the differing characters.
"""
comp_ndarr = np.array( comp_arr ).flatten()
base_ndarr = np.array( base_arr ).flatten()
# Replace invalid characters by group-separator characters ('\x1D')
valid_chars = set( string.printable )
invalid_char = '\x1D'
comp_str = "".join(
[ chr( x ) if ( x >= 0 and chr( x ) in valid_chars ) else invalid_char for x in comp_ndarr ] )
base_str = "".join(
[ chr( x ) if ( x >= 0 and chr( x ) in valid_chars ) else invalid_char for x in base_ndarr ] )
# replace whitespaces sequences by only one space (preventing indentation / spacing changes detection)
whitespace_pattern = r"[ \t\n\r\v\f]+"
comp_str = re.sub( whitespace_pattern, " ", comp_str )
base_str = re.sub( whitespace_pattern, " ", base_str )
# replace invalid characters sequences by a double space (for clear display)
invalid_char_pattern = r"\x1D+"
comp_str_display = re.sub( invalid_char_pattern, " ", comp_str )
base_str_display = re.sub( invalid_char_pattern, " ", base_str )
message = ""
def limited_display( n, string ):
return string[ :n ] + f"... ({len(string)-n} omitted chars)" if len( string ) > n else string
if len( comp_str ) != len( base_str ):
max_display = 250
message = f"Character arrays have different sizes: {len( comp_str )}, {len( base_str )}.\n"
message += f" {limited_display( max_display, comp_str_display )}\n"
message += f" {limited_display( max_display, base_str_display )}\n"
else:
# We need to trim arrays to the length of the shortest one for the comparisons
min_length = min( len( comp_str_display ), len( base_str_display ) )
comp_str_trim = comp_str_display[ :min_length ]
base_str_trim = base_str_display[ :min_length ]
differing_indices = np.where( np.array( list( comp_str_trim ) ) != np.array( list( base_str_trim ) ) )[ 0 ]
if differing_indices.size != 0:
# check for reordering
arr_set = sorted( set( comp_str.split( invalid_char ) ) )
base_arr_set = sorted( set( base_str.split( invalid_char ) ) )
reordering_detected = arr_set == base_arr_set
max_display = 110 if reordering_detected else 250
message = "Differing valid characters"
message += " (substrings reordering detected):\n" if reordering_detected else ":\n"
message += f" {limited_display( max_display, comp_str_display )}\n"
message += f" {limited_display( max_display, base_str_display )}\n"
message += " " + "".join(
[ "^" if i in differing_indices else " " for i in range( min( max_display, min_length ) ) ] ) + "\n"
return message
[docs]
def compareStringArrays( self, path, arr, base_arr ):
"""
Compare two string datasets. Exact equality is used as the acceptance criteria.
PATH [in]: The path at which the comparison takes place.
ARR [in]: The hdf5 Dataset to compare.
BASE_ARR [in]: The hdf5 Dataset to compare against.
"""
if arr.shape != base_arr.shape or np.any( arr[ : ] != base_arr[ : ] ):
message = "String arrays differ.\n"
message += "String to compare: %s\n" % "".join( arr[ : ] )
message += "Baseline string : %s\n" % "".join( base_arr[ : ] )
self.errorMsg( path, message, True )
[docs]
def compareData( self, path, arr, base_arr ):
"""
Compare the numerical portion of two datasets.
PATH [in]: The path at which the comparison takes place.
ARR [in]: The hdf5 Dataset to compare.
BASE_ARR [in]: The hdf5 Dataset to compare against.
"""
# Get the type of comparison to do.
np_floats = set( [ 'f', 'c' ] )
np_ints = set( [ '?', 'b', 'B', 'i', 'u', 'm', 'M', 'V' ] )
np_numeric = np_floats | np_ints
np_strings = set( [ 'S', 'a', 'U' ] )
int_compare = arr.dtype.kind in np_ints and base_arr.dtype.kind in np_ints
float_compare = not int_compare and ( arr.dtype.kind in np_numeric and base_arr.dtype.kind in np_numeric )
string_compare = arr.dtype.kind in np_strings and base_arr.dtype.kind in np_strings
# If the datasets have different types issue a warning.
if arr.dtype != base_arr.dtype:
msg = "Datasets have different types: %s, %s.\n" % ( arr.dtype, base_arr.dtype )
self.warningMsg( path, msg )
# Handle empty datasets
if arr.shape is None and base_arr.shape is None:
return
if arr.size is None and base_arr.size is None:
return
if arr.size == 0 and base_arr.size == 0:
return
elif arr.size is None and base_arr.size is not None:
self.errorMsg( path, "File to compare has an empty dataset where the baseline's dataset is not empty.\n" )
elif base_arr.size is None and arr.size is not None:
self.warningMsg( path, "Baseline has an empty dataset where the file to compare's dataset is not empty.\n" )
# If either of the datasets is a scalar convert it to an array.
if arr.shape == ():
arr = np.array( [ arr ] )
if base_arr.shape == ():
base_arr = np.array( [ base_arr ] )
# If the datasets only contain one value call the compare scalar functions.
if arr.size == 1 and base_arr.size == 1:
val = arr[ : ].flat[ 0 ]
base_val = base_arr[ : ].flat[ 0 ]
if float_compare:
return self.compareFloatScalars( path, val, base_val )
elif int_compare:
return self.compareIntScalars( path, val, base_val )
elif string_compare:
return self.compareStringScalars( path, val, base_val )
else:
return self.warningMsg( path,
"Unrecognized type combination: %s %s.\n" % ( arr.dtype, base_arr.dtype ) )
# Do the actual comparison.
if float_compare:
return self.compareFloatArrays( path, arr, base_arr )
elif int_compare:
return self.compareIntArrays( path, arr, base_arr )
elif string_compare:
return self.compareStringArrays( path, arr, base_arr )
else:
return self.warningMsg( path, "Unrecognized type combination: %s %s.\n" % ( arr.dtype, base_arr.dtype ) )
[docs]
def compareAttributes( self, path, attrs, base_attrs ):
"""
Compare two sets of attributes.
PATH [in]: The path at which the comparison takes place.
ATTRS [in]: The hdf5 AttributeManager to compare.
BASE_ATTRS [in]: The hdf5 AttributeManager to compare against.
"""
for attrName in set( list( attrs.keys() ) + list( base_attrs.keys() ) ):
if attrName not in attrs:
msg = "Attribute %s is in the baseline file but not the file to compare.\n" % attrName
self.errorMsg( path, msg )
continue
if attrName not in base_attrs:
msg = "Attribute %s is in the file to compare but not the baseline file.\n" % attrName
self.warningMsg( path, msg )
continue
attrsPath = path + ".attrs[" + attrName + "]"
self.compareData( attrsPath, attrs[ attrName ], base_attrs[ attrName ] )
[docs]
def compareDatasets( self, dset, base_dset ):
"""
Compare two datasets.
DSET [in]: The Dataset to compare.
BASE_DSET [in]: The Dataset to compare against.
"""
assert isinstance( dset, h5py.Dataset )
assert isinstance( base_dset, h5py.Dataset )
path = dset.name
self.compareAttributes( path, dset.attrs, base_dset.attrs )
self.compareData( path, dset, base_dset )
def canCompare( self, group, base_group, name ):
name_in_group = name in group
name_in_base_group = name in base_group
if not name_in_group and not name_in_base_group:
return False
elif self.isExcluded( h5PathJoin( group.name, name ) ):
return False
if not name_in_group:
msg = "Group has a child '%s' in the baseline file but not the file to compare.\n" % name
if not self.skip_missing:
self.errorMsg( base_group.name, msg )
return False
if not name_in_base_group:
msg = "Group has a child '%s' in the file to compare but not the baseline file.\n" % name
if not self.skip_missing:
self.errorMsg( group.name, msg )
return False
return True
def compareLvArrays( self, group, base_group, other_children_to_check ):
if self.canCompare( group, base_group, "__dimensions__" ) and self.canCompare(
group, base_group, "__permutation__" ) and self.canCompare( group, base_group, "__values__" ):
other_children_to_check.remove( "__dimensions__" )
other_children_to_check.remove( "__permutation__" )
other_children_to_check.remove( "__values__" )
dimensions = group[ "__dimensions__" ][ : ]
base_dimensions = base_group[ "__dimensions__" ][ : ]
if len( dimensions.shape ) != 1:
msg = "The dimensions of an LvArray must itself be a 1D array not %s\n" % len( dimensions.shape )
self.errorMsg( group.name, msg )
if dimensions.shape != base_dimensions.shape or np.any( dimensions != base_dimensions ):
msg = "Cannot compare LvArrays because they have different dimensions. Dimensions = %s, base dimensions = %s\n" % (
dimensions, base_dimensions )
self.errorMsg( group.name, msg )
return True
permutation = group[ "__permutation__" ][ : ]
base_permutation = base_group[ "__permutation__" ][ : ]
if len( permutation.shape ) != 1:
msg = "The permutation of an LvArray must itself be a 1D array not %s\n" % len( permutation.shape )
self.errorMsg( group.name, msg )
if permutation.shape != dimensions.shape or np.any(
np.sort( permutation ) != np.arange( permutation.size ) ):
msg = "LvArray in the file to compare has an invalid permutation. Dimensions = %s, Permutation = %s\n" % (
dimensions, permutation )
self.errorMsg( group.name, msg )
return True
if base_permutation.shape != base_dimensions.shape or np.any(
np.sort( base_permutation ) != np.arange( base_permutation.size ) ):
msg = "LvArray in the baseline has an invalid permutation. Dimensions = %s, Permutation = %s\n" % (
base_dimensions, base_permutation )
self.errorMsg( group.name, msg )
return True
values = group[ "__values__" ][ : ]
base_values = base_group[ "__values__" ][ : ]
values, errorMsg = permuteArray( values, dimensions, permutation )
if values is None:
msg = "Failed to permute the LvArray: %s\n" % errorMsg
self.errorMsg( group.name, msg )
return True
base_values, errorMsg = permuteArray( base_values, base_dimensions, base_permutation )
if base_values is None:
msg = "Failed to permute the baseline LvArray: %s\n" % errorMsg
self.errorMsg( group.name, msg )
return True
self.compareData( group.name, values, base_values )
return True
return False
[docs]
def compareGroups( self, group, base_group ):
"""
Compare hdf5 groups.
GROUP [in]: The Group to compare.
BASE_GROUP [in]: The Group to compare against.
"""
assert ( isinstance( group, ( h5py.Group, h5py.File ) ) )
assert ( isinstance( base_group, ( h5py.Group, h5py.File ) ) )
path = group.name
# Compare the attributes in the two groups.
self.compareAttributes( path, group.attrs, base_group.attrs )
children_to_check = set( list( group.keys() ) + list( base_group.keys() ) )
self.compareLvArrays( group, base_group, children_to_check )
# Compare the sub groups and datasets.
for name in children_to_check:
if self.canCompare( group, base_group, name ):
item1 = group[ name ]
item2 = base_group[ name ]
if not isinstance( item1, type( item2 ) ):
msg = "Child %s has differing types in the file to compare and the baseline: %s, %s.\n" % (
name, type( item1 ), type( item2 ) )
self.errorMsg( path, msg )
continue
if isinstance( item1, h5py.Group ):
self.compareGroups( item1, item2 )
elif isinstance( item1, h5py.Dataset ):
self.compareDatasets( item1, item2 )
else:
self.warningMsg( path, "Child %s has unknown type: %s.\n" % ( name, type( item1 ) ) )
def findFiles( file_pattern, baseline_pattern, comparison_args ):
# Find the matching files.
file_path = findMaxMatchingFile( file_pattern )
if file_path is None:
raise ValueError( "No files found matching %s." % file_pattern )
baseline_path = findMaxMatchingFile( baseline_pattern )
if baseline_path is None:
raise ValueError( "No files found matching %s." % baseline_pattern )
# Get the output path.
output_base_path = os.path.splitext( file_path )[ 0 ]
output_path = output_base_path + ".restartcheck"
# Open the output file and diff file
files_to_compare = None
with open( output_path, 'w' ) as output_file:
comparison_args[ "output" ] = output_file
writeHeader( file_pattern, file_path, baseline_pattern, baseline_path, comparison_args )
# Check if comparing root files.
if file_path.endswith( ".root" ) and baseline_path.endswith( ".root" ):
p = [ re.compile( "/file_pattern" ), re.compile( "/protocol/version" ) ]
comp = FileComparison( file_path, baseline_path, 0.0, 0.0, p, output_file, True, False )
if comp.filesDiffer():
write( output_file, "The root files are different, cannot compare data files.\n" )
return output_base_path, None
else:
write( output_file, "The root files are similar.\n" )
# Get the number of files and the file patterns.
# We know the number of files are the same from the above comparison.
with h5py.File( file_path, "r" ) as f:
numberOfFiles = f[ "number_of_files" ][ 0 ]
file_data_pattern = "".join( f[ "file_pattern" ][ : ].tobytes().decode( 'ascii' )[ :-1 ] )
with h5py.File( baseline_path, "r" ) as f:
baseline_data_pattern = "".join( f[ "file_pattern" ][ : ].tobytes().decode( 'ascii' )[ :-1 ] )
# Get the paths to the data files.
files_to_compare = []
for i in range( numberOfFiles ):
path_to_data = os.path.join( os.path.dirname( file_path ), file_data_pattern % i )
path_to_baseline_data = os.path.join( os.path.dirname( baseline_path ), baseline_data_pattern % i )
files_to_compare += [ ( path_to_data, path_to_baseline_data ) ]
else:
files_to_compare = [ ( file_path, baseline_path ) ]
return output_base_path, files_to_compare
def gatherOutput( output_file, output_base_path, n_files ):
for i in range( n_files ):
output_path = "%s.%d.restartcheck" % ( output_base_path, i )
with open( output_path, "r" ) as file:
for line in file:
write( output_file, line )
[docs]
def findMaxMatchingFile( file_path ):
"""
Given a path FILE_PATH where the base name of FILE_PATH is treated as a regular expression
find and return the path of the greatest matching file/folder or None if no match is found.
FILE_PATH [in]: The pattern to match.
Examples:
'.*' will return the file/folder with the greatest name in the current directory.
'test/plot_*.hdf5' will return the file with the greatest name in the ./test directory
that begins with 'plot' and ends with '.hdf5'.
"""
file_directory, pattern = os.path.split( file_path )
if file_directory == "":
file_directory = "."
if not os.path.isdir( file_directory ):
return None
pattern = re.compile( pattern )
max_match = ""
for file in os.listdir( file_directory ):
if pattern.match( file ) is not None:
max_match = max( file, max_match )
if max_match == "":
return None
return os.path.join( file_directory, max_match )
[docs]
def main():
"""
Parses the command line arguments and executes the proper comparison. Writes output to
both stdout and a '%s.restartcheck' file where the first part is the path of the file to compare.
Example:
The file to compare is ./a/b/c.hdf5 the output will be a ./a/b/c.restartcheck file.
"""
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
n_ranks = comm.Get_size()
parser = argparse.ArgumentParser()
parser.add_argument( "file_pattern", help="The pattern used to find the file to compare." )
parser.add_argument( "baseline_pattern", help="The pattern used to find the baseline file." )
parser.add_argument( "-r",
"--relative",
type=float,
help="The relative tolerance for floating point differences, default is %s." % RTOL_DEFAULT,
default=RTOL_DEFAULT )
parser.add_argument( "-a",
"--absolute",
type=float,
help="The absolute tolerance for floating point differences, default is %s." % ATOL_DEFAULT,
default=ATOL_DEFAULT )
parser.add_argument( "-e",
"--exclude",
action='append',
help="Regular expressions specifying which groups to skip, default is %s." % EXCLUDE_DEFAULT,
default=EXCLUDE_DEFAULT )
parser.add_argument( "-m",
"--skip-missing",
action="store_true",
help="Ignore values that are missing from either the baseline or target file.",
default=False )
parser.add_argument( "-w",
"--Werror",
action="store_true",
help="Force all warnings to be errors, default is False.",
default=False )
args = parser.parse_args()
# Check the command line arguments
if args.relative < 0.0:
raise ValueError( "Relative tolerance cannot be less than 0.0." )
if args.absolute < 0.0:
raise ValueError( "Absolute tolerance cannot be less than 0.0." )
# Extract the command line arguments.
file_pattern = args.file_pattern
baseline_pattern = args.baseline_pattern
comparison_args = {}
comparison_args[ "rtol" ] = args.relative
comparison_args[ "atol" ] = args.absolute
comparison_args[ "regex_expressions" ] = list( map( re.compile, args.exclude ) )
comparison_args[ "warnings_are_errors" ] = args.Werror
comparison_args[ "skip_missing" ] = args.skip_missing
if rank == 0:
output_base_path, files_to_compare = findFiles( file_pattern, baseline_pattern, comparison_args )
else:
output_base_path, files_to_compare = None, None
files_to_compare = comm.bcast( files_to_compare, root=0 )
output_base_path = comm.bcast( output_base_path, root=0 )
if files_to_compare is None:
return 1
differing_files = []
for i in range( rank, len( files_to_compare ), n_ranks ):
output_path = "%s.%d.restartcheck" % ( output_base_path, i )
diff_path = "%s.%d.diff.hdf5" % ( output_base_path, i )
with open( output_path, 'w' ) as output_file, h5py.File( diff_path, "w" ) as diff_file:
comparison_args[ "output" ] = output_file
comparison_args[ "diff_file" ] = diff_file
file_path, baseline_path = files_to_compare[ i ]
logger.info( f"About to compare {file_path} and {baseline_path}" )
if FileComparison( file_path, baseline_path, **comparison_args ).filesDiffer():
differing_files += [ files_to_compare[ i ] ]
output_file.write( "The files are different.\n" )
else:
output_file.write( "The files are similar.\n" )
differing_files = comm.allgather( differing_files )
all_differing_files = []
for file_list in differing_files:
all_differing_files += file_list
difference_found = len( all_differing_files ) > 0
if rank == 0:
output_path = output_base_path + ".restartcheck"
with open( output_path, 'a' ) as output_file:
gatherOutput( output_file, output_base_path, len( files_to_compare ) )
if difference_found:
write(
output_file, "\nCompared %d pairs of files of which %d are different.\n" %
( len( files_to_compare ), len( all_differing_files ) ) )
for file_path, base_path in all_differing_files:
write( output_file, "\t" + file_path + " and " + base_path + "\n" )
return 1
else:
write( output_file,
"\nThe root files and the %d pairs of files compared are similar.\n" % len( files_to_compare ) )
return difference_found
if __name__ == "__main__" and not sys.flags.interactive:
sys.exit( main() )