An array class for containing MPI distributed array.
This example performs a transfrom from time-freq to lag-m space. This involves Fourier transforming each of these two axes of the distributed array:
import numpy as np
from mpi4py import MPI
from caput.mpiarray import MPIArray
nfreq = 32
nprod = 2
ntime = 32
# Initalise array with (nfreq, nprod, ntime) global shape
darr1 = MPIArray((nfreq, nprod, ntime), dtype=np.float64)
# Load in data into parallel array
for lfi, fi in darr1.enumerate(axis=0):
darr1[lfi] = load_freq_data(gfi)
# Perform m-transform (i.e. FFT)
darr2 = MPIArray.wrap(np.fft.fft(darr1, axis=1), axis=0)
# Redistribute to get all frequencies onto each process, this performs the
# global transpose using MPI to make axis=1 the distributed axis, and make
# axis=0 completely local.
darr3 = darr2.redistribute(axis=1)
# Perform the lag transform on the frequency direction.
darr4 = MPIArray.wrap(np.fft.irfft(darr3, axis=0), axis=1)
The MPIArray also supports slicing with the global index using the :attribute:`MPIArray.global_slice` property. This can be used for both fetching and assignment with global indices, supporting the basic slicing notation of numpy.
Its behaviour changes depending on the exact slice it gets:
It’s important to note that it never communicates data between ranks. It only ever operates on data held on the current rank.
Here is an example of this in action:
import numpy as np
from caput import mpiarray, mpiutil
arr = mpiarray.MPIArray((mpiutil.size, 3), dtype=np.float64)
arr[:] = 0.0
for ri in range(mpiutil.size):
if ri == mpiutil.rank:
print ri, arr
mpiutil.barrier()
# Use a global index to assign to the array
arr.global_slice[3] = 17
# Fetch a view of the whole array with a full slice
arr2 = arr.global_slice[:, 2]
# This should be the third column of the array
for ri in range(mpiutil.size):
if ri == mpiutil.rank:
print ri, arr2
mpiutil.barrier()
# Fetch a view of the whole array with a partial slice
arr3 = arr.global_slice[:2, 2]
# The final two ranks should be None
for ri in range(mpiutil.size):
if ri == mpiutil.rank:
print ri, arr3
mpiutil.barrier()