Source code for caput.mpiarray

"""
An array class for containing MPI distributed array.

.. currentmodule:: caput.mpiarray

Classes
=======

.. autosummary::
    :toctree: generated/

    MPIArray

Examples
========

This example performs a transfrom from time-freq to lag-m space. This involves
Fourier transforming each of these two axes of the distributed array::

    import numpy as np
    from mpi4py import MPI

    from caput.mpiarray import MPIArray

    nfreq = 32
    nprod = 2
    ntime = 32

    # Initalise array with (nfreq, nprod, ntime) global shape
    darr1 = MPIArray((nfreq, nprod, ntime), dtype=np.float64)

    # Load in data into parallel array
    for lfi, fi in darr1.enumerate(axis=0):
        darr1[lfi] = load_freq_data(gfi)

    # Perform m-transform (i.e. FFT)
    darr2 = MPIArray.wrap(np.fft.fft(darr1, axis=1), axis=0)

    # Redistribute to get all frequencies onto each process, this performs the
    # global transpose using MPI to make axis=1 the distributed axis, and make
    # axis=0 completely local.
    darr3 = darr2.redistribute(axis=1)

    # Perform the lag transform on the frequency direction.
    darr4 = MPIArray.wrap(np.fft.irfft(darr3, axis=0), axis=1)

Global Slicing
==============

The :class:`MPIArray` also supports slicing with the global index using the
:attribute:`MPIArray.global_slice` property. This can be used for both fetching
and assignment with global indices, supporting the basic slicing notation of
`numpy`.

Its behaviour changes depending on the exact slice it gets:

- A full slice (`:`) along the parallel axis returns an :class:`MPIArray` on
  fetching, and accepts an :class:`MPIArray` on assignment.
- A partial slice (`:`) returns and accepts a numpy array on the rank holding
  the data, and :obj:`None` on other ranks.

It's important to note that it never communicates data between ranks. It only
ever operates on data held on the current rank.

Example
-------

Here is an example of this in action::

    import numpy as np
    from caput import mpiarray, mpiutil

    arr = mpiarray.MPIArray((mpiutil.size, 3), dtype=np.float64)
    arr[:] = 0.0

    for ri in range(mpiutil.size):
        if ri == mpiutil.rank:
            print ri, arr
        mpiutil.barrier()

    # Use a global index to assign to the array
    arr.global_slice[3] = 17

    # Fetch a view of the whole array with a full slice
    arr2 = arr.global_slice[:, 2]

    # This should be the third column of the array
    for ri in range(mpiutil.size):
        if ri == mpiutil.rank:
            print ri, arr2
        mpiutil.barrier()

    # Fetch a view of the whole array with a partial slice
    arr3 = arr.global_slice[:2, 2]

    # The final two ranks should be None
    for ri in range(mpiutil.size):
        if ri == mpiutil.rank:
            print ri, arr3
        mpiutil.barrier()

"""

import os
import time
import numpy as np

from caput import mpiutil


class _global_resolver(object):
    # Private class implementing the global sampling for MPIArray

    def __init__(self, array):

        self.array = array
        self.axis = array.axis
        self.offset = array.local_offset[self.axis]
        self.length = array.global_shape[self.axis]

    def _resolve_slice(self, slobj):
        # Transforms a numpy basic slice on the global arrays into a fully
        # fleshed out slice tuple referencing the positions in the local arrays.
        # If a single integer index is specified for the distributed axis, then
        # either the local index is returned, or None if it doesn't exist on the
        # current rank.

        ndim = self.array.ndim
        local_length = self.array.shape[self.axis]

        # Expand a single integer or slice index
        if isinstance(slobj, int) or isinstance(slobj, slice):
            slobj = (slobj, Ellipsis)

        # Add an ellipsis if length of slice object is too short
        if isinstance(slobj, tuple) and len(slobj) < ndim and Ellipsis not in slobj:
            slobj = slobj + (Ellipsis,)

        # Expand an ellipsis
        slice_list = []
        for sl in slobj:
            if sl is Ellipsis:
                for i in range(ndim - len(slobj) + 1):
                    slice_list.append(slice(None, None))
            else:
                slice_list.append(sl)

        fullslice = True

        # Process the parallel axis. Calculate the correct index for the
        # containing rank, and set None on all others.
        if isinstance(slice_list[self.axis], int):
            index = slice_list[self.axis] - self.offset
            slice_list[self.axis] = None if (index < 0 or index >= local_length) else index
            fullslice = False

        # If it's a slice, then resolve any offsets
        # If any of start or stop is defined then mark that this is not a complete slice
        # Also mark if there is any actual data on this rank
        elif isinstance(slice_list[self.axis], slice):
            s = slice_list[self.axis]
            start = s.start
            stop = s.stop
            step = s.step

            # Check if start is defined, and modify slice
            if start is not None:
                start = start if start >= 0 else start + self.length  # Resolve negative indices
                fullslice = False
                start = start - self.offset
            else:
                start = 0

            # Check if stop is defined and modify slice
            if stop is not None:
                stop = stop if stop >= 0 else stop + self.length  # Resolve negative indices
                fullslice = False
                stop = stop - self.offset
            else:
                stop = local_length

            # If step is defined we don't need to adjust this, but it's no longer a complete slice
            if step is not None:
                fullslice = False

            # If there is no data on this node place None on the parallel axis
            if start >= local_length or stop < 0:
                slice_list[self.axis] = None
            else:
                # Normalise the indices and create slice
                start = max(min(start, local_length), 0)
                stop = max(min(stop, local_length), 0)
                slice_list[self.axis] = slice(start, stop, step)

        return tuple(slice_list), fullslice

    def __getitem__(self, slobj):

        # Resolve the slice object
        slobj, is_fullslice = self._resolve_slice(slobj)

        # If not a full slice, return a numpy array (or None)
        if not is_fullslice:

            # If the parallel axis has a None, that means there is no data on this rank
            if slobj[self.axis] is None:
                return None
            else:
                return self.array[slobj].view(np.ndarray)

        else:

            # Fix up slobj for axes where there is no data
            slobj = [ slice(None, None, None) if sl is None else sl for sl in slobj ]

            # Return an MPIArray view
            arr = self.array[slobj]

            # Figure out which is the distributed axis after the slicing, by
            # removing slice axes which are just ints from the mapping
            dist_axis = [index for index, sl in enumerate(slobj) if not isinstance(sl, int)].index(self.axis)

            return MPIArray.wrap(arr, axis=dist_axis, comm=self.array._comm)

    def __setitem__(self, slobj, value):

        slobj, is_fullslice = self._resolve_slice(slobj)

        if slobj[self.axis] is None:
            return
        self.array[slobj] = value


[docs]class MPIArray(np.ndarray): """A numpy array like object which is distributed across multiple processes. Parameters ---------- global_shape : tuple The global array shape. The returned array will be distributed across the specified index. axis : integer, optional The dimension to distribute the array across. Attributes ---------- global_shape : tuple Global array shape. local_shape : tuple Shape of local section. axis : integer Axis we are distributed over. local_offset : tuple Offset into global array. This is equivalent to the global-index of the [0, 0, ...] element of the local section. local_array : np.ndarray The view of the local numpy array. global_slice : object Return an objects that presents a view of the array with global slicing. Methods ------- wrap redistribute enumerate from_hdf5 to_hdf5 transpose reshape """ @property
[docs] def global_shape(self): return self._global_shape
@property
[docs] def axis(self): return self._axis
@property
[docs] def local_shape(self): return self._local_shape
@property
[docs] def local_offset(self): return self._local_offset
@property
[docs] def local_array(self): return self.view(np.ndarray)
@property
[docs] def comm(self): return self._comm
[docs] def __new__(cls, global_shape, axis=0, comm=None, *args, **kwargs): # if mpiutil.world is None: # raise RuntimeError('There is no mpi4py installation. Aborting.') if comm is None: comm = mpiutil.world # Determine local section of distributed axis local_num, local_start, local_end = mpiutil.split_local(global_shape[axis], comm=comm) # Figure out the local shape and offset lshape = list(global_shape) lshape[axis] = local_num loffset = [0] * len(global_shape) loffset[axis] = local_start # Create array arr = np.ndarray.__new__(cls, lshape, *args, **kwargs) # Set attributes of class arr._global_shape = global_shape arr._axis = axis arr._local_shape = tuple(lshape) arr._local_offset = tuple(loffset) arr._comm = comm return arr
@property
[docs] def global_slice(self): return _global_resolver(self)
@classmethod
[docs] def wrap(cls, array, axis, comm=None): """Turn a set of numpy arrays into a distributed MPIArray object. This is needed for functions such as `np.fft.fft` which always return an `np.ndarray`. Parameters ---------- array : np.ndarray Array to wrap. axis : integer Axis over which the array is distributed. The lengths are checked to try and ensure this is correct. comm : MPI.Comm, optional The communicator over which the array is distributed. If `None` (default), use `MPI.COMM_WORLD`. Returns ------- dist_array : MPIArray An MPIArray view of the input. """ # from mpi4py import MPI if comm is None: comm = mpiutil.world # Get axis length, both locally, and globally axlen = array.shape[axis] totallen = mpiutil.allreduce(axlen, comm=comm) # Figure out what the distributed layout should be local_num, local_start, local_end = mpiutil.split_local(totallen, comm=comm) # Check the local layout is consistent with what we expect, and send # result to all ranks layout_issue = mpiutil.allreduce(axlen != local_num, op=mpiutil.MAX, comm=comm) if layout_issue: raise Exception("Cannot wrap, distributed axis local length is incorrect.") # Set shape and offset lshape = array.shape global_shape = list(lshape) global_shape[axis] = totallen loffset = [0] * len(lshape) loffset[axis] = local_start # Setup attributes of class dist_arr = array.view(cls) dist_arr._global_shape = tuple(global_shape) dist_arr._axis = axis dist_arr._local_shape = tuple(lshape) dist_arr._local_offset = tuple(loffset) dist_arr._comm = comm return dist_arr
[docs] def redistribute(self, axis): """Change the axis that the array is distributed over. Parameters ---------- axis : integer Axis to distribute over. Returns ------- array : MPIArray A new copy of the array distributed over the specified axis. Note that the local section will have changed. """ # Check to see if this is the current distributed axis if self.axis == axis or self.comm is None: return self # Test to see if the datatype is one understood by MPI, this can # probably be fixed up at somepoint by creating a datatype of the right # number of bytes try: mpiutil.typemap(self.dtype) except KeyError: if self.comm is None or self.comm.rank == 0: import warnings warnings.warn('Cannot redistribute array of compound datatypes. Sorry!!') return self # Construct the list of the axes to swap around axlist_f = list(range(len(self.shape))) # Remove the axes we are going to swap around axlist_f.remove(self.axis) axlist_f.remove(axis) # Move the current dist axis to the front, and the new to the end axlist_f.insert(0, self.axis) axlist_f.append(axis) # Perform a local transpose on the array to get the axes in the correct order trans_arr = self.view(np.ndarray).transpose(axlist_f).copy() # Perform the global transpose tmp_gshape = (self.global_shape[self.axis],) + trans_arr.shape[1:] trans_arr = mpiutil.transpose_blocks(trans_arr, tmp_gshape, comm=self.comm) axlist_b = list(range(len(self.shape))) axlist_b.pop(0) last = axlist_b.pop(-1) if self.axis < axis: # This has to awkwardly depend on the order of the axes axlist_b.insert(self.axis, 0) axlist_b.insert(axis, last) else: axlist_b.insert(axis, last) axlist_b.insert(self.axis, 0) # Perform the local transpose to get the axes back in the correct order trans_arr = trans_arr.transpose(axlist_b) # Create a new MPIArray object out of the data dist_arr = MPIArray(self.global_shape, axis=axis, dtype=self.dtype, comm=self.comm) dist_arr[:] = trans_arr return dist_arr
[docs] def enumerate(self, axis): """Helper for enumerating over a given axis. Parameters ---------- axis : integer Which access to enumerate over. Returns ------- iterator : (local_index, global_index) An enumerator which returns the local index into the array *and* the global index it corresponds to. """ start = self.local_offset[axis] end = start + self.local_shape[axis] return enumerate(range(start, end))
@classmethod
[docs] def from_hdf5(cls, f, dataset, comm=None): """Read MPIArray from an HDF5 dataset in parallel. Parameters ---------- f : filename, or `h5py.File` object File to read dataset from. dataset : string Name of dataset to read from. Must exist. comm : MPI.Comm MPI communicator to distribute over. If `None` optional, use `MPI.COMM_WORLD`. Returns ------- array : MPIArray """ import h5py if type(f) == str: fh = h5py.File(f, 'r') to_close = True elif type(f) == h5py.File: fh = f to_close = False else: raise Exception("Did not receive a h5py.File or filename") dset = fh[dataset] gshape = dset.shape dtype = dset.dtype dist_arr = cls(gshape, axis=0, comm=comm, dtype=dtype) start = dist_arr.local_offset[0] end = start + dist_arr.local_shape[0] dist_arr[:] = dset[start:end] if to_close: fh.close() return dist_arr
[docs] def to_hdf5(self, filename, dataset, create=False): """Parallel write into a contiguous HDF5 dataset. Parameters ---------- filename : str File to write dataset into. dataset : string Name of dataset to write into. Should not exist. """ ## Naive non-parallel implementation to start import h5py if self.comm is None or self.comm.rank == 0: with h5py.File(filename, 'a' if create else 'r+') as fh: if dataset in fh: raise Exception("Dataset should not exist.") fh.create_dataset(dataset, self.global_shape, dtype=self.dtype) fh[dataset][:] = np.array(0.0).astype(self.dtype) # wait until all processes see the created file while not os.path.exists(filename): time.sleep(1) # self._comm.Barrier() mpiutil.barrier(comm=self.comm) if self.axis == 0: dist_arr = self else: dist_arr = self.redistribute(axis=0) size = 1 if self.comm is None else self.comm.size for ri in range(size): rank = 0 if self.comm is None else self.comm.rank if ri == rank: with h5py.File(filename, 'r+') as fh: start = dist_arr.local_offset[0] end = start + dist_arr.local_shape[0] fh[dataset][start:end] = dist_arr # dist_arr._comm.Barrier() mpiutil.barrier(comm=self.comm)
[docs] def transpose(self, axes): """Transpose the array axes. Parameters ---------- axes : tuple Tuple of axes permutations. Returns ------- array : MPIArray Transposed MPIArray as a view of the original data. """ tdata = np.ndarray.transpose(self, axes) tdata._global_shape = tuple([self.global_shape[ax] for ax in axes]) tdata._local_shape = tuple([self.local_shape[ax] for ax in axes]) tdata._local_offset = tuple([self.local_offset[ax] for ax in axes]) tdata._axis = list(axes).index(self.axis) tdata._comm = self._comm return tdata
[docs] def reshape(self, *shape): """Reshape the array. Must not attempt to reshape the distributed axis. That axis must be given an input length `None`. Parameters ---------- shape : tuple Tuple of axis lengths. The distributed must be given `None`. Returns ------- array : MPIArray Reshaped MPIArray as a view of the original data. """ if len(shape) == 1 and isinstance(shape[0], (tuple, list)): shape = tuple(shape[0]) # Find which axis is distributed list_shape = list(shape) new_axis = list_shape.index(None) # Fill in the missing value local_shape = list_shape[:] global_shape = list_shape[:] local_offset = [0] * len(list_shape) local_shape[new_axis] = self.local_shape[self.axis] global_shape[new_axis] = self.global_shape[self.axis] local_offset[new_axis] = self.local_offset[self.axis] # Check that the array sizes are compatible if np.prod(local_shape) != np.prod(self.local_shape): raise Exception("Dataset shapes incompatible.") rdata = np.ndarray.reshape(self, local_shape) rdata._axis = new_axis rdata._comm = self._comm rdata._local_shape = tuple(local_shape) rdata._global_shape = tuple(global_shape) rdata._local_offset = tuple(local_offset) return rdata
[docs] def copy(self): """Return a copy of the MPIArray. Returns ------- arr_copy : MPIArray """ return MPIArray.wrap(self.view(np.ndarray).copy(), axis=self.axis, comm=self.comm)