Skip to content

Interoperability with Keops? #4

Open
@adam-hartshorne

Description

The Keops library (https://www.kernel-operations.io/) has numpy / pytorch bindings. However, unless I am doing something wrong, I don't think torch2jax (as it stands) will support this due to the usage of LazyTensors? I was wondering if there is a way around this?

import torch 
import jax.numpy as jnp
import jax.random as jr
from jax.config import config

from pykeops.torch import LazyTensor
from torch2jax import j2t, t2j

config.update("jax_enable_x64", True)

M, N, D = 1000, 2000, 3
x = torch.randn(M, D, requires_grad=True).cuda() 
y = torch.randn(N, D).cuda()  

def test_func(x,y):
    x_i = LazyTensor(x.view(M, 1, D))  
    y_j = LazyTensor(y.view(1, N, D)) 

    # We can now perform large-scale computations, without memory overflows:
    D_ij = ((x_i - y_j)**2).sum(dim=2) 
    K_ij = (- D_ij).exp() 
    a_i = K_ij.sum(dim=1)
    return a_i

a_i = test_func(x,y)

key = jr.PRNGKey(0)

x_jax = jr.normal(key, (M, D))
y_jax = jr.normal(key, (N, D))
jax_test_func = t2j(test_func)
a_i = jax_test_func(x_jax, y_jax)

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions