Description
Hey guys,
Love this tool!
Extremely useful for me as someone much more comfortable with JAX than torch. I've been using this for a project at the moment and extended some of the api surface coverage (have another issue I might open re: type promotion things), but the application I'm targeting also uses einops extensively. I hacked together preliminary support for this by subclassing the einops array backend as below, then sticking this class at the bottom of the __init__.py
for your library. This seems more or less like the correct approach (I need to fix the layers
implementation ofc), but I'd be curious to hear feedback if you guys have a better idea to accomplish this. It's almost identical to the pytorch
backend from einops, just changed in a few places to use the Torchish
class where needed.
If it's acceptable, I'll try to open a PR soon :)
from einops._backends import AbstractBackend
class TorchishBackend(AbstractBackend):
framework_name = "torch2jax"
def __init__(self):
import torch
import torch2jax
self.torch = torch
self.t2j = torch2jax
def is_appropriate_type(self, tensor):
return type(tensor).__name__ == 'Torchish'
def from_numpy(self, x):
return self.t2j.Torchish(x)
def to_numpy(self, x):
return x.value.cpu().numpy()
def arange(self, start, stop):
return self.torch.arange(start, stop)
def reduce(self, x, operation, reduced_axes):
if operation == "min":
return x.amin(dim=reduced_axes)
elif operation == "max":
return x.amax(dim=reduced_axes)
elif operation == "sum":
return x.sum(dim=reduced_axes)
elif operation == "mean":
return x.mean(dim=reduced_axes)
elif operation in ("any", "all", "prod"):
# pytorch supports reducing only one operation at a time
for i in list(sorted(reduced_axes))[::-1]:
x = getattr(x, operation)(dim=i)
return x
else:
raise NotImplementedError("Unknown reduction ", operation)
def transpose(self, x, axes):
return x.permute(*axes)
def stack_on_zeroth_dimension(self, tensors: list):
return self.torch.stack(tensors)
def add_axes(self, x, n_axes, pos2len):
repeats = [-1] * n_axes
for axis_position, axis_length in pos2len.items():
x = self.add_axis(x, axis_position)
repeats[axis_position] = axis_length
return x.expand(repeats)
def tile(self, x, repeats):
return x.repeat(repeats)
def concat(self, tensors, axis: int):
return self.torch.cat(tensors, dim=axis)
def add_axis(self, x, new_position):
return self.torch.unsqueeze(x, new_position)
def is_float_type(self, x):
return x.dtype in [self.torch.float16, self.torch.float32, self.torch.float64, self.torch.bfloat16]
# def layers(self):
# from .layers import torch
# return torch
def einsum(self, pattern, *x):
return self.torch.einsum(pattern, *x)