Open
Description
In torchdyn -> numerics -> sensitivity.py
function _gather_odefunc_adjoint(), line 71:
dμ = torch.cat([el.flatten() if el is not None else torch.zeros(1) for el in dμ], dim=-1)
should be fixed by
param_shapes = [p.shape for p in vf.parameters()]
dμ = torch.cat([el.flatten() if el is not None else torch.zeros(param_shapes[i]).to(t.device).flatten() for i, el in enumerate(dμ)], dim=-1)
otherwise, the shape of gradient (torch.zeros(1)
) does not match the parameter shape in vector field.
Activity