-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathflake.nix
60 lines (56 loc) · 1.26 KB
/
flake.nix
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
{
inputs = {
nixpkgs.url = "github:nixos/nixpkgs";
flake-utils.url = "github:numtide/flake-utils";
};
outputs =
{
self,
nixpkgs,
flake-utils,
}:
flake-utils.lib.eachDefaultSystem (
system:
let
pkgs = nixpkgs.legacyPackages.${system};
in
with pkgs.python3.pkgs;
let
jaxlib' = if jaxlib.meta.broken then jaxlib-bin else jaxlib;
in
{
defaultPackage = buildPythonPackage {
pname = "torch2jax";
# Don't forget to also update the version in pyproject.toml!
version = "0.1.0";
pyproject = true;
src = ./.;
dependencies = [
jax
torch
];
nativeCheckInputs = [
jaxlib'
pytestCheckHook
torchvision
pkgs.writableTmpDirAsHomeHook # torchvision downloads models into HOME.
];
pythonImportsCheck = [ "torch2jax" ];
};
devShell = pkgs.mkShell {
buildInputs = [
pkgs.act
pkgs.ruff
build
ipython
jax
jaxlib'
pytest
torch
torchvision
twine
];
};
}
);
}