Open
Description
This package is currently compatible with ForwardDiff, but not ReverseDiff.
using NLsolve, ForwardDiff, ReverseDiff
function residual!(r, y, x)
r[1] = (y[1] + x[1])*(y[2]^3-x[2])+x[3]
r[2] = sin(y[2]*exp(y[1])-1)*x[4]
end
function solve(x)
TF = eltype(x)
rwrap(r, y) = residual!(r, y, x[1:4])
res = nlsolve(rwrap, TF[0.1; 1.2], autodiff=:forward)
return res.zero
end
function program(x)
z = 2.0*x
w = z + x.^2
y = solve(w)
return y[1] .+ w*y[2]
end
x = [1.0, 2.0, 3.0, 4.0, 5.0]
ForwardDiff.jacobian(program, x)
# 5×5 Matrix{Float64}:
# 8.05247 1.94271 -0.95879 2.90746e-25 0.0
# 8.55572 14.3307 -3.0819 8.60063e-26 0.0
# 16.8073 12.2672 4.72726 -2.0063e-25 0.0
# 27.4165 20.0105 -9.87583 13.4769 0.0
# 40.3833 29.4746 -14.5467 -1.01959e-24 16.1723
ReverseDiff.jacobian(program, x)
# ERROR: UndefVarError: rT not defined
The reverse pass fails because the default constructor for SolverResults
can't figure out the right type for rT
. The fix is to define a more reliable constructor. For example:
mutable struct SolverResults{rT<:Real,T<:Union{rT,Complex{rT}},I<:AbstractArray{T},Z<:AbstractArray{T}}
method::String
initial_x::I
zero::Z
residual_norm::rT
iterations::Int
x_converged::Bool
xtol::rT
f_converged::Bool
ftol::rT
trace::SolverTrace
f_calls::Int
g_calls::Int
# provide inner constructor (default inner constructor doesn't work for all cases)
function SolverResults(method, initial_x, zero, residual_norm, iterations, x_converged,
xtol, f_converged, ftol, trace, f_calls, g_calls)
# real type
rT = promote_type(real(eltype(initial_x)), real(eltype(zero)), typeof(residual_norm), typeof(xtol), typeof(ftol))
# real/complex type
if promote_type(eltype(initial_x), eltype(zero)) <: Complex
T = Complex{rT}
else
T = rT
end
# correct initial guess type
if !(eltype(initial_x) <: T)
initial_x = T.(initial_x)
end
# correct zero element type (if necessary)
if !(eltype(zero) <: T)
zero = T.(zero)
end
# initial guess type
I = typeof(initial_x)
# zero type
Z = typeof(initial_x)
return new{rT,T,I,Z}(method, initial_x, zero, residual_norm, iterations,
x_converged, xtol, f_converged, ftol, trace, f_calls, g_calls)
end
end
Then the ReverseDiff derivatives propagate as expected.
# with modified implementation
ReverseDiff.jacobian(program, x)
# 5×5 Matrix{Float64}:
# 8.05247 1.94271 -0.95879 -4.91066e-28 0.0
# 8.55572 14.3307 -3.0819 -2.04315e-27 0.0
# 16.8073 12.2672 4.72726 -2.00371e-27 0.0
# 27.4165 20.0105 -9.87583 13.4769 0.0
# 40.3833 29.4746 -14.5467 0.0 16.1723
Note that this issue involves passing derivatives through the nonlinear solve, rather than defining a custom pullback for the nonlinear solve (as discussed in #205).
Metadata
Assignees
Labels
No labels
Activity