Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Investigate] Custom MPI_Reduce operations #99

Open
PhilipVinc opened this issue Jul 19, 2021 · 0 comments
Open

[Investigate] Custom MPI_Reduce operations #99

PhilipVinc opened this issue Jul 19, 2021 · 0 comments

Comments

@PhilipVinc
Copy link
Member

I was pleasantly surprised by seeing that mpi4jax automatically supports custom MPI_Reduce operations. See for example the following:

import numpy as np
from mpi4py import MPI
import mpi4jax
import jax
import jax.numpy as jnp
from functools import partial

rank = MPI.COMM_WORLD.rank

# create numpy arrays to reduce
src = (np.arange(8) + rank*8).reshape(4,2)
src[0] = rank
src = jnp.array(src)
dst = np.zeros_like(src)

MPI.COMM_WORLD.barrier()
print("starting")
MPI.COMM_WORLD.barrier()

def myadd(xmem, ymem, dt):
    x = np.frombuffer(xmem, dtype=src.dtype)
    y = np.frombuffer(ymem, dtype=src.dtype)

    z = x + y

    print("Rank %d reducing %s (%s) and %s (%s), yielding %s" % (rank, x, type(x), y, type(y), z))

    y[:] = z

op = MPI.Op.Create(myadd, commute=True)

#MPI.COMM_WORLD.Reduce(src, dst, op)

jax.jit(partial(mpi4jax.reduce, op=op, root=0))(src)

if MPI.COMM_WORLD.rank == 0:
    print("ANSWER: %s" % dst)

MPI.COMM_WORLD.barrier()
print("-------------------------------")
MPI.COMM_WORLD.barrier()


jax.jit(partial(mpi4jax.allreduce, op=op))(src)

if MPI.COMM_WORLD.rank == 0:
    print("ANSWER: %s" % dst)

However this works because mpi4py calls back into the python runtime, which will slow down the execution.
Ideally I'd like to use numba CFFI to define custom operations without calling back into python...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy