Description
There are mismatched arguments in problems.ODEProblem.odeint
My torchdyn version is 1.0.3
Step to Reproduce
I want to see how many steps did the adaptive dopri5 solver take, so I sought for return_all_eval
argument according to issue #131. Then I found the NeuralODE class does not provide such a keyword argument here, so after a little bit diving into the source code I decided to put args={'return_all_eval': True}
. However, this still does not give the desired result. The code snippet is:
from torchdyn.core import NeuralODE
import torch
import torch.nn as nn
class VectorField(nn.Module):
def __init__(self):
super(VectorField, self).__init__()
self.net = nn.Linear(2, 2)
def forward(self, t, x):
print(f"In VectorField, t is fed as {t}")
return self.net(t+x)
vf = VectorField()
ode = NeuralODE(vf, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)
time = torch.linspace(0, 1, 10)
initial = torch.randn(16, 20, 2)
eval_time, sol = ode(initial, time, args={'return_all_eval': True})
print(sol.shape)
Then, I found the return_all_eval
keyword is not actually passed into the numerics.odeint.odeint
function. The signature of that function is
def odeint(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:Union[str, nn.Module], atol:float=1e-3, rtol:float=1e-3,
t_stops:Union[List, Tensor, None]=None, verbose:bool=False, interpolator:Union[str, Callable, None]=None, return_all_eval:bool=False,
save_at:Union[List, Tensor]=(), args:Dict={}, seminorm:Tuple[bool, Union[int, None]]=(False, None)) -> Tuple[Tensor, Tensor]:
so you can see return_all_eval
is explicitly passed, but in numerics.sensitivity._gather_odefunc_adjoint._ODEProblemFunc.forward
it is hard-coded as False
:
def forward(ctx, vf_params, x, t_span, B=None, save_at=()):
t_sol, sol = generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator, B,
False, maxiter, fine_steps, save_at)
ctx.save_for_backward(sol, t_sol)
return t_sol, sol
So, basically I don't have any chance to switch it on except changing the source code.
Another thing is the argument mismatch issue of the numerics.sensitivity._gather_odefunc_adjoint._ODEProblemFunc.forward
function. When it is called from odeint
like
torchdyn/torchdyn/core/problems.py
Line 85 in a0d0fc5
save_at
argument will actually be overwritten by a dict and the B
(which I do not understand) argument is actually the true save_at
. This so far has not caused any problems in my code but I don't believe this is an expected behavior. I suggest someone take a deep debug into the code to have a look.
Screenshots
There is a traceback that shows the problem.
Expected behavior
The return_all_eval
option should be handled by user and control whether the ODE solver produces all the evaluation time slots.
Also, there is a huge lack of documentation on the meaning of these arguments and the provided functionalities, e.g. it is not until I found that github issue did I realize that there is a way to return all the evaluation time stamps.
Activity