Skip to content

Allow ReverseDiff Propagation #281

Open
@taylormcd

Description

This package is currently compatible with ForwardDiff, but not ReverseDiff.

using NLsolve, ForwardDiff, ReverseDiff

function residual!(r, y, x)
    r[1] = (y[1] + x[1])*(y[2]^3-x[2])+x[3]
    r[2] = sin(y[2]*exp(y[1])-1)*x[4]
end

function solve(x)
    TF = eltype(x)
    rwrap(r, y) = residual!(r, y, x[1:4])
    res = nlsolve(rwrap, TF[0.1; 1.2], autodiff=:forward)
    return res.zero
end

function program(x)
    z = 2.0*x
    w = z + x.^2
    y = solve(w)
    return y[1] .+ w*y[2]
end

x = [1.0, 2.0, 3.0, 4.0, 5.0]

ForwardDiff.jacobian(program, x)
# 5×5 Matrix{Float64}:
#   8.05247   1.94271   -0.95879   2.90746e-25   0.0
#   8.55572  14.3307    -3.0819    8.60063e-26   0.0
#  16.8073   12.2672     4.72726  -2.0063e-25    0.0
#  27.4165   20.0105    -9.87583  13.4769        0.0
#  40.3833   29.4746   -14.5467   -1.01959e-24  16.1723

ReverseDiff.jacobian(program, x)
# ERROR: UndefVarError: rT not defined

The reverse pass fails because the default constructor for SolverResults can't figure out the right type for rT. The fix is to define a more reliable constructor. For example:

mutable struct SolverResults{rT<:Real,T<:Union{rT,Complex{rT}},I<:AbstractArray{T},Z<:AbstractArray{T}}
    method::String
    initial_x::I
    zero::Z
    residual_norm::rT
    iterations::Int
    x_converged::Bool
    xtol::rT
    f_converged::Bool
    ftol::rT
    trace::SolverTrace
    f_calls::Int
    g_calls::Int
    # provide inner constructor (default inner constructor doesn't work for all cases)
    function SolverResults(method, initial_x, zero, residual_norm, iterations, x_converged, 
        xtol, f_converged, ftol, trace, f_calls, g_calls)

        # real type
        rT = promote_type(real(eltype(initial_x)), real(eltype(zero)), typeof(residual_norm), typeof(xtol), typeof(ftol))
        
        # real/complex type
        if promote_type(eltype(initial_x), eltype(zero)) <: Complex
            T = Complex{rT}
        else
            T = rT
        end

        # correct initial guess type
        if !(eltype(initial_x) <: T)
            initial_x = T.(initial_x)
        end

        # correct zero element type (if necessary)
        if !(eltype(zero) <: T)
            zero = T.(zero)
        end

        # initial guess type
        I = typeof(initial_x)

        # zero type
        Z = typeof(initial_x)

        return new{rT,T,I,Z}(method, initial_x, zero, residual_norm, iterations, 
            x_converged, xtol, f_converged, ftol, trace, f_calls, g_calls)
    end
end

Then the ReverseDiff derivatives propagate as expected.

# with modified implementation
ReverseDiff.jacobian(program, x)
# 5×5 Matrix{Float64}:
#   8.05247   1.94271   -0.95879  -4.91066e-28   0.0
#   8.55572  14.3307    -3.0819   -2.04315e-27   0.0
#  16.8073   12.2672     4.72726  -2.00371e-27   0.0
#  27.4165   20.0105    -9.87583  13.4769        0.0
#  40.3833   29.4746   -14.5467    0.0          16.1723

Note that this issue involves passing derivatives through the nonlinear solve, rather than defining a custom pullback for the nonlinear solve (as discussed in #205).

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions