Skip to content

Argument mismatch and hard-coded return_all_eval #195

Open
@cantabile-kwok

Description

There are mismatched arguments in problems.ODEProblem.odeint
My torchdyn version is 1.0.3
Step to Reproduce
I want to see how many steps did the adaptive dopri5 solver take, so I sought for return_all_eval argument according to issue #131. Then I found the NeuralODE class does not provide such a keyword argument here, so after a little bit diving into the source code I decided to put args={'return_all_eval': True}. However, this still does not give the desired result. The code snippet is:

from torchdyn.core import NeuralODE
import torch
import torch.nn as nn


class VectorField(nn.Module):
    def __init__(self):
        super(VectorField, self).__init__()
        self.net = nn.Linear(2, 2)

    def forward(self, t, x):
        print(f"In VectorField, t is fed as {t}")
        return self.net(t+x)


vf = VectorField()
ode = NeuralODE(vf, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)
time = torch.linspace(0, 1, 10)
initial = torch.randn(16, 20, 2)
eval_time, sol = ode(initial, time, args={'return_all_eval': True})
print(sol.shape)

Then, I found the return_all_eval keyword is not actually passed into the numerics.odeint.odeint function. The signature of that function is

def odeint(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:Union[str, nn.Module], atol:float=1e-3, rtol:float=1e-3,
		   t_stops:Union[List, Tensor, None]=None, verbose:bool=False, interpolator:Union[str, Callable, None]=None, return_all_eval:bool=False,
		   save_at:Union[List, Tensor]=(), args:Dict={}, seminorm:Tuple[bool, Union[int, None]]=(False, None)) -> Tuple[Tensor, Tensor]:

so you can see return_all_eval is explicitly passed, but in numerics.sensitivity._gather_odefunc_adjoint._ODEProblemFunc.forward it is hard-coded as False:

def forward(ctx, vf_params, x, t_span, B=None, save_at=()):
            t_sol, sol = generic_odeint(problem_type, vf, x, t_span, solver, atol, rtol, interpolator, B, 
                                        False, maxiter, fine_steps, save_at)
            ctx.save_for_backward(sol, t_sol)
            return t_sol, sol

So, basically I don't have any chance to switch it on except changing the source code.

Another thing is the argument mismatch issue of the numerics.sensitivity._gather_odefunc_adjoint._ODEProblemFunc.forward function. When it is called from odeint like

return self._autograd_func()(self.vf_params, x, t_span, save_at, args)
, the arguments are mismatched from the signature of that forward function. This means the save_at argument will actually be overwritten by a dict and the B (which I do not understand) argument is actually the true save_at. This so far has not caused any problems in my code but I don't believe this is an expected behavior. I suggest someone take a deep debug into the code to have a look.

Screenshots
There is a traceback that shows the problem.
image

Expected behavior

The return_all_eval option should be handled by user and control whether the ODE solver produces all the evaluation time slots.
Also, there is a huge lack of documentation on the meaning of these arguments and the provided functionalities, e.g. it is not until I found that github issue did I realize that there is a way to return all the evaluation time stamps.

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions