Source code for geos.hdf5_wrapper.wrapper

import h5py  # type: ignore[import]
import numpy as np
from numpy.core.defchararray import encode, decode
from typing_extensions import Self, Union, Dict, Any, Iterable, Optional, Tuple

# Note: I would like to replace Any here with str, float, int, np.ndarray, etc.
#       However, this heterogeneous pattern causes issues with mypy indexing
hdf5_get_types = Union[ 'hdf5_wrapper', Any ]
nested_dict_type = Dict[ str, Any ]
hdf5_set_types = Union[ 'hdf5_wrapper', nested_dict_type, Any ]


[docs] class hdf5_wrapper(): """A class for reading/writing hdf5 files, which behaves similar to a native dict.""" def __init__( self: Self, fname: str = '', target: Optional[ h5py.File ] = None, mode: str = 'r' ) -> None: """Initialize the hdf5_wrapper class. If the fname is supplied (either by a positional or keyword argument), the wrapper will open a hdf5 database from the filesystem. The recommended options for the mode flag include 'r' for read-only and 'a' for read/write access. If write mode is enabled, and the fname does not point to an existing file, a new database will be created. If the target is supplied, then a new instance of the wrapper will be created using an existing database handle. Args: fname (str): the filename of a new or existing hdf5 database. Defaults to ''. target (hdf5_wrapper): the handle of an existing hdf5 dataset. Defaults to None. mode (str): the read/write behavior of the database. Defaults to 'r'. """ self.mode: str = mode self.target: h5py.File = target if fname: self.target = h5py.File( fname, self.mode ) def __getitem__( self: Self, k: str ) -> hdf5_get_types: """Get a target from the database. If the target is not present in the datastructure and the database is open in read/write mode, the wrapper will create a new group and return an hdf5_wrapper. Otherwise it will throw an error Args: k (str): name of target group or array Returns: hdf5_wrapper/np.ndarray: The returned value """ if ( k not in self.target ): if ( self.mode in [ 'w', 'a' ] ): self.target.create_group( k ) else: raise ValueError( 'Entry does not exist in database: %s' % ( k ) ) tmp = self.target[ k ] if isinstance( tmp, h5py._hl.group.Group ): return hdf5_wrapper( target=tmp, mode=self.mode ) elif isinstance( tmp, h5py._hl.dataset.Dataset ): tmp = np.array( tmp ) # Decode any string types if ( tmp.dtype.kind in [ 'S', 'U', 'O' ] ): tmp = decode( tmp ) # Convert any 0-length arrays to native types if not tmp.shape: tmp = tmp[ () ] return tmp else: return tmp def __setitem__( self: Self, k: str, value: hdf5_set_types ) -> None: """Write an object to the database if write-mode is enabled. Args: k (str): the name of the object value (dict, np.ndarray, float, int, str): the object to be written """ if ( self.mode in [ 'w', 'a' ] ): if isinstance( value, ( dict, hdf5_wrapper ) ): # Recursively add groups and their children if ( k not in self.target ): self.target.create_group( k ) new_group = self[ k ] for kb, x in value.items(): new_group[ kb ] = x else: # Delete the old copy if necessary if ( k in self.target ): del ( self.target[ k ] ) # Add everything else as an ndarray tmp = np.array( value ) if ( tmp.dtype.kind in [ 'S', 'U', 'O' ] ): tmp = encode( tmp ) self.target[ k ] = tmp else: raise ValueError( 'Cannot write to an hdf5 opened in read-only mode! This can be changed by overriding the default mode argument for the wrapper.' )
[docs] def keys( self: Self ) -> Iterable[ str ]: """Get a list of groups and arrays located at the current level. Returns: list: a list of key names pointing to objects at the current level """ if isinstance( self.target, h5py._hl.group.Group ): return list( self.target ) else: raise ValueError( 'Object not a group!' )
[docs] def values( self: Self ) -> Iterable[ hdf5_get_types ]: """Get a list of values located on the current level.""" return [ self[ k ] for k in self.keys() ]
[docs] def items( self: Self ) -> Iterable[ Tuple[ str, hdf5_get_types ] ]: """Get the list if pair (key, value) located at the current level. Returns: Iterable[ Tuple[ str, hdf5_get_types ] ]: list of pair (key, value) """ return zip( self.keys(), self.values() )
def __enter__( self: Self ) -> Self: """Entry point for an iterator.""" return self def __exit__( self: Self ) -> None: """End point for an iterator.""" self.target.close() def __del__( self: Self ) -> None: """Closes the database on wrapper deletion.""" try: if isinstance( self.target, h5py._hl.files.File ): self.target.close() except: #noqa: E722 pass
[docs] def close( self: Self ) -> None: """Closes the database.""" if isinstance( self.target, h5py._hl.files.File ): self.target.close()
[docs] def get_copy( self: Self ) -> nested_dict_type: """Copy the entire database into memory. Returns: dict: a dictionary holding the database contents """ result: Dict[ str, Any ] = {} for k in self.keys(): tmp = self[ k ] if isinstance( tmp, hdf5_wrapper ): result[ k ] = tmp.get_copy() else: result[ k ] = tmp return result
[docs] def copy( self: Self ) -> nested_dict_type: """Copy the entire database into memory. Returns: dict: a dictionary holding the database contents """ return self.get_copy()
[docs] def insert( self: Self, x: Union[ nested_dict_type, 'hdf5_wrapper' ] ) -> None: """Insert the contents of the target object to the current location. Args: x (dict, hdf5_wrapper): the dictionary to insert """ for k, v in x.items(): self[ k ] = v