Skip to content

How to use the model parameters trained from Flux.jl in the SimpleChains.jl model?? #159

Closed
@manozzing

Description

Hello,

I'm quite new to this package (I think this is really cool!) and wondering if I can copy the Flux model parameters to SimpleChains model. For your reference, I'm developing a machine-learned advection solver and I'm using neural network to estimate numerical flux in each time stepping. Here is the code I'm working on as an example

Flux.jl model

flux_Flux = Chain( Flux.Conv((3,1), 2 => 10, pad = 0, Flux.relu),
        Flux.Conv((3,1), 10 => 10, pad = 0, Flux.relu),
        Flux.Conv((3,1), 10 => 2, pad = 0, Flux.identity)) 
loss(x, y) = Flux.Losses.mae(flux_Flux(x), y)
ps = Flux.params(flux_Flux)

SimpleChains.jl model

flux_estimator = SimpleChain(
    SimpleChains.Conv(SimpleChains.relu, (3, 1), 10),
    SimpleChains.Conv(SimpleChains.relu, (3, 1), 10),
    SimpleChains.Conv(SimpleChains.identity, (3, 1), 2)
)
p = SimpleChains.init_params(flux_estimator, size(input))

I confirmed that both models have the same structure with same number of parameters.

Chain(
  Conv((3, 1), 2 => 10, relu),          # 70 parameters
  Conv((3, 1), 10 => 10, relu),         # 310 parameters
  Conv((3, 1), 10 => 2),                # 62 parameters
)                   # Total: 6 arrays, 442 parameters, 2.781 KiB.
442-element StrideArray{Float32, 1, (1,), Tuple{Int64}, Tuple{Nothing}, Tuple{Static.StaticInt{1}}, Vector{Float32}}:

After training the Flux model I passed the model parameters to SimpleChains parameter as below

p[1:60] = Flux.params(flux_Flux)[1][1:60]
p[61:70] = Flux.params(flux_Flux)[2][1:10]
p[71:370] = Flux.params(flux_Flux)[3][1:300]
p[371:380] = Flux.params(flux_Flux)[4][1:10]
p[381:440] = Flux.params(flux_Flux)[5][1:60]
p[441:442] = Flux.params(flux_Flux)[6][1:2]

However, they gave me totally different results when I feed the same input dataset.

Flux estimation in several time steps using Flux.jl model
flux_scattered_Flux

Flux estimation in several time steps using SimpleChains.jl model with the same parameters as the Flux.jl model
flux_scattered_SC

Do you have any idea why SimpleChains.jl model gave me very different results? I originally tried to train my model with SimpleChains.jl but likewise the model training was not successful so I chose to pass the parameters from Flux.jl and here I saw it's not very helpful so far. Any comments will help me out. Thank you so much!

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions