From 93ed7065c582393384e958ea96f3ad6221b09626 Mon Sep 17 00:00:00 2001 From: Samuel Ainsworth Date: Fri, 22 Nov 2024 13:41:05 -0500 Subject: [PATCH] Add fix and test for the unsqueeze issue in https://github.com/samuela/torch2jax/issues/7 --- tests/test_all_the_things.py | 5 +++++ torch2jax/__init__.py | 12 +++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/test_all_the_things.py b/tests/test_all_the_things.py index 1a344d0..cd8de31 100644 --- a/tests/test_all_the_things.py +++ b/tests/test_all_the_things.py @@ -26,6 +26,11 @@ def poop(self): return rng_key +def test_t2j_array(): + # See https://github.com/samuela/torch2jax/issues/7 + aac(t2j(torch.eye(3).unsqueeze(0)), jnp.eye(3)[jnp.newaxis, ...]) + + def t2j_function_test(f, input_shapes, rng=random.PRNGKey(123), num_tests=5, **assert_kwargs): for test_rng in random.split(rng, num_tests): inputs = [random.normal(rng, shape) for rng, shape in zip(random.split(test_rng, len(input_shapes)), input_shapes)] diff --git a/torch2jax/__init__.py b/torch2jax/__init__.py index a0a5e1b..7c81167 100644 --- a/torch2jax/__init__.py +++ b/torch2jax/__init__.py @@ -14,9 +14,19 @@ def t2j_array(torch_array): # `torch.func.functionalize` in `t2j_function`. For now, we're avoiding `torch.func.functionalize`, but something to # be wary of in the future. + # RuntimeError: Can't export tensors that require gradient, use tensor.detach() + torch_array = torch_array.detach() + # See https://github.com/google/jax/issues/8082. torch_array = torch_array.contiguous() - return jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(torch_array)) + + # At some point between 0.4.28 and 0.4.33 from_dlpack introduced a new + # deprecation notice: + # + # DeprecationWarning: Calling from_dlpack with a DLPack tensor is deprecated. The argument to from_dlpack should be an array from another framework that implements the __dlpack__ protocol. + # + # Very well, PyTorch arrays implement the __dlpack__ protocol, so no need to convert them to dlpack first. + return jax.dlpack.from_dlpack(torch_array) # Alternative, but copying implementation: # Note FunctionalTensor.numpy() returns incorrect results, preventing us from using torch.func.functionalize.