Skip to content

lucidrains/jax2torch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

18 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

jax2torch

Use Jax functions in Pytorch with DLPack, as outlined in a gist by @mattjj. The repository was made for the purposes of making this differentiable alignment work interoperable with Pytorch projects.

Install

$ pip install jax2torch

Memory management

By default, Jax pre-allocates 90% of VRAM, which leaves Pytorch with very little left over. To prevent this behavior, set the XLA_PYTHON_CLIENT_PREALLOCATE environmental variable to false before running any Jax code:

import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

Usage

Open In Colab Quick test

import jax
import torch
from jax2torch import jax2torch
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# Jax function

@jax.jit
def jax_pow(x, y = 2):
  return x ** y

# convert to Torch function

torch_pow = jax2torch(jax_pow)

# run it on Torch data!

x = torch.tensor([1., 2., 3.])
y = torch_pow(x, y = 3)
print(y)  # tensor([1., 8., 27.])

# And differentiate!

x = torch.tensor([2., 3.], requires_grad = True)
y = torch.sum(torch_pow(x, y = 3))
y.backward()
print(x.grad) # tensor([12., 27.])

About

Use Jax functions in Pytorch

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy