Open
Description
We are developing a package that uses ComponentArrays
and Lux
to train Neural ODEs in a simple to use front end. We found the following bug after one of our devs updated his packages. Upon investigation, it seems to be an issue of ComponentArrays
. The following MWE:
using ComponentArrays, Lux, Random, OrdinaryDiffEq, Zygote, SciMLSensitivity
df = rand(2,10)
rng = Random.default_rng()
NN = Lux.Chain(Lux.Dense(2,10,tanh), Lux.Dense(10,2))
rng = Random.default_rng()
parameters, states = Lux.setup(rng,NN)
parameters = (NN = parameters, )
function derivs!(du,u,parameters,t)
du .= NN(u,parameters.NN,states)[1]
return du
end
u0 = zeros(2); tspan = (0.0,0.5)
IVP = ODEProblem(derivs!, u0, tspan, parameters)
function predict(u,t,dt,parameters)
tspan = (t,t+dt)
sol = solve(IVP, Tsit5(), u0 = u, p=parameters,tspan = tspan, saveat = (t,t+dt))
X = Array(sol)
return X[:,end]
end
function loss(parameters)
sum(abs2,predict(df[:,1],0.0,0.05,parameters) .- df[:,2])
end
gradient(loss,ComponentArray(parameters))
fails with the following ]status
:
⌃ [b0b7db55] ComponentArrays v0.15.14
⌃ [b2108857] Lux v0.5.61
[9a3f8284] Random
but works with the following ]status
:
⌃ [b0b7db55] ComponentArrays v0.15.13
⌃ [b2108857] Lux v0.5.61
[9a3f8284] Random
I'm testing on Julia 1.10.4. When it fails, it throws the following stacktrace:
ERROR: type Array has no field NN
Stacktrace:
[1] getproperty
@ .\Base.jl:37 [inlined]
[2] derivs!(du::Vector{ForwardDiff.Dual{…}}, u::Vector{ForwardDiff.Dual{…}}, parameters::Vector{ForwardDiff.Dual{…}}, t::Float64)
@ Main .\REPL[72]:2
[3] (::ODEFunction{…})(::Vector{…}, ::Vararg{…})
@ SciMLBase C:\Users\JArroyo-Esquivel\.julia\packages\SciMLBase\YE50s\src\scimlfunctions.jl:2335
[4] (::ODEFunction{…})(::Vector{…}, ::Vararg{…})
@ SciMLBase C:\Users\JArroyo-Esquivel\.julia\packages\SciMLBase\YE50s\src\scimlfunctions.jl:2335
[5] initialize!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…}, cache::OrdinaryDiffEqTsit5.Tsit5Cache{…})
@ OrdinaryDiffEqTsit5 C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEqTsit5\DHYtz\src\tsit_perform_step.jl:175
[6] __init(prob::ODEProblem{…}, alg::Tsit5{…}, timeseries_init::Tuple{}, ts_init::Tuple{}, ks_init::Tuple{}, recompile::Type{…}; saveat::Vector{…}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_on::Bool, save_start::Bool, save_end::Nothing, callback::Nothing, dense::Bool, calck::Bool, dt::Float64, dtmin::Float64, dtmax::Float64, force_dtmin::Bool, adaptive::Bool, gamma::Rational{…}, abstol::Nothing, reltol::Nothing, qmin::Rational{…}, qmax::Int64, qsteady_min::Int64, qsteady_max::Int64, beta1::Nothing, beta2::Nothing, qoldinit::Rational{…}, controller::Nothing, fullnormalize::Bool, failfactor::Int64, maxiters::Int64, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), internalopnorm::typeof(LinearAlgebra.opnorm), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), progress_id::Symbol, userdata::Nothing, allow_extrapolation::Bool, initialize_integrator::Bool,
alias_u0::Bool, alias_du0::Bool, initializealg::OrdinaryDiffEqCore.DefaultInit, kwargs::@Kwargs{…})
@ OrdinaryDiffEqCore C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEqCore\4A2vD\src\solve.jl:525
[7] __init (repeats 4 times)
@ C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEqCore\4A2vD\src\solve.jl:11 [inlined]
[8] #__solve#61
@ C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEqCore\4A2vD\src\solve.jl:6 [inlined]
[9] __solve
@ C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEqCore\4A2vD\src\solve.jl:1 [inlined]
[10] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
@ DiffEqBase C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\sCsah\src\solve.jl:612
[11] solve_call
@ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\sCsah\src\solve.jl:569 [inlined]
[12] #solve_up#53
@ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\sCsah\src\solve.jl:1080 [inlined]
[13] solve_up
@ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\sCsah\src\solve.jl:1066 [inlined]
[14] #solve#51
@ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\sCsah\src\solve.jl:1003 [inlined]
[15] (::SciMLSensitivity.var"#327#336"{0, ODESolution{…}, Tuple{…}, @Kwargs{…}, ODEProblem{…}, Tsit5{…}, ForwardDiffSensitivity{…}, Vector{…}, ComponentVector{…}, Tuple{}, Vector{…}})()
@ SciMLSensitivity C:\Users\JArroyo-Esquivel\.julia\packages\SciMLSensitivity\se3y4\src\concrete_solve.jl:894
[16] unthunk
@ C:\Users\JArroyo-Esquivel\.julia\packages\ChainRulesCore\I1EbV\src\tangent_types\thunks.jl:204 [inlined]
[17] wrap_chainrules_output
@ C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\chainrules.jl:110 [inlined]
[18] map
@ .\tuple.jl:293 [inlined]
[19] map (repeats 3 times)
@ .\tuple.jl:294 [inlined]
[20] wrap_chainrules_output
@ C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\chainrules.jl:111 [inlined]
[21] ZBack
@ C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\chainrules.jl:211 [inlined]
[22] (::Zygote.var"#kw_zpullback#53"{…})(dy::ODESolution{…})
@ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\chainrules.jl:237
[23] #291
@ C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
[24] (::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…})
@ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
[25] #solve#51
@ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\sCsah\src\solve.jl:1003 [inlined]
[26] (::Zygote.Pullback{…})(Δ::ODESolution{…})
@ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
[27] #291
@ C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
[28] (::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…})
@ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
[29] solve
@ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\sCsah\src\solve.jl:993 [inlined]
[30] (::Zygote.Pullback{…})(Δ::ODESolution{…})
@ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
[31] predict
@ .\REPL[75]:3 [inlined]
[32] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Vector{Float64})
@ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
[33] loss
@ .\REPL[76]:2 [inlined]
[34] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
[35] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\interface.jl:91
[36] gradient(f::Function, args::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}})
@ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\interface.jl:148
[37] top-level scope
@ REPL[77]:1
Metadata
Assignees
Labels
No labels
Activity