Open
Description
The Keops library (https://www.kernel-operations.io/) has numpy / pytorch bindings. However, unless I am doing something wrong, I don't think torch2jax (as it stands) will support this due to the usage of LazyTensors? I was wondering if there is a way around this?
import torch
import jax.numpy as jnp
import jax.random as jr
from jax.config import config
from pykeops.torch import LazyTensor
from torch2jax import j2t, t2j
config.update("jax_enable_x64", True)
M, N, D = 1000, 2000, 3
x = torch.randn(M, D, requires_grad=True).cuda()
y = torch.randn(N, D).cuda()
def test_func(x,y):
x_i = LazyTensor(x.view(M, 1, D))
y_j = LazyTensor(y.view(1, N, D))
# We can now perform large-scale computations, without memory overflows:
D_ij = ((x_i - y_j)**2).sum(dim=2)
K_ij = (- D_ij).exp()
a_i = K_ij.sum(dim=1)
return a_i
a_i = test_func(x,y)
key = jr.PRNGKey(0)
x_jax = jr.normal(key, (M, D))
y_jax = jr.normal(key, (N, D))
jax_test_func = t2j(test_func)
a_i = jax_test_func(x_jax, y_jax)
Metadata
Assignees
Labels
No labels