Skip to content

Einops support #5

Open
Open
@jrichterpowell

Description

Hey guys,

Love this tool!

Extremely useful for me as someone much more comfortable with JAX than torch. I've been using this for a project at the moment and extended some of the api surface coverage (have another issue I might open re: type promotion things), but the application I'm targeting also uses einops extensively. I hacked together preliminary support for this by subclassing the einops array backend as below, then sticking this class at the bottom of the __init__.py for your library. This seems more or less like the correct approach (I need to fix the layers implementation ofc), but I'd be curious to hear feedback if you guys have a better idea to accomplish this. It's almost identical to the pytorch backend from einops, just changed in a few places to use the Torchish class where needed.

If it's acceptable, I'll try to open a PR soon :)

from einops._backends import AbstractBackend
class TorchishBackend(AbstractBackend):
    framework_name = "torch2jax"

    def __init__(self):
        import torch
        import torch2jax

        self.torch = torch
        self.t2j = torch2jax

    def is_appropriate_type(self, tensor):
        return type(tensor).__name__ ==  'Torchish'

    def from_numpy(self, x):
        return self.t2j.Torchish(x)

    def to_numpy(self, x):
        return x.value.cpu().numpy()

    def arange(self, start, stop):
        return self.torch.arange(start, stop)

    def reduce(self, x, operation, reduced_axes):
        if operation == "min":
            return x.amin(dim=reduced_axes)
        elif operation == "max":
            return x.amax(dim=reduced_axes)
        elif operation == "sum":
            return x.sum(dim=reduced_axes)
        elif operation == "mean":
            return x.mean(dim=reduced_axes)
        elif operation in ("any", "all", "prod"):
            # pytorch supports reducing only one operation at a time
            for i in list(sorted(reduced_axes))[::-1]:
                x = getattr(x, operation)(dim=i)
            return x
        else:
            raise NotImplementedError("Unknown reduction ", operation)

    def transpose(self, x, axes):
        return x.permute(*axes)

    def stack_on_zeroth_dimension(self, tensors: list):
        return self.torch.stack(tensors)

    def add_axes(self, x, n_axes, pos2len):
        repeats = [-1] * n_axes
        for axis_position, axis_length in pos2len.items():
            x = self.add_axis(x, axis_position)
            repeats[axis_position] = axis_length
        return x.expand(repeats)

    def tile(self, x, repeats):
        return x.repeat(repeats)

    def concat(self, tensors, axis: int):
        return self.torch.cat(tensors, dim=axis)

    def add_axis(self, x, new_position):
        return self.torch.unsqueeze(x, new_position)

    def is_float_type(self, x):
        return x.dtype in [self.torch.float16, self.torch.float32, self.torch.float64, self.torch.bfloat16]

    # def layers(self):
    #     from .layers import torch

    #     return torch

    def einsum(self, pattern, *x):
        return self.torch.einsum(pattern, *x)


Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions