Description
Hello,
I'm quite new to this package (I think this is really cool!) and wondering if I can copy the Flux model parameters to SimpleChains model. For your reference, I'm developing a machine-learned advection solver and I'm using neural network to estimate numerical flux in each time stepping. Here is the code I'm working on as an example
Flux.jl model
flux_Flux = Chain( Flux.Conv((3,1), 2 => 10, pad = 0, Flux.relu),
Flux.Conv((3,1), 10 => 10, pad = 0, Flux.relu),
Flux.Conv((3,1), 10 => 2, pad = 0, Flux.identity))
loss(x, y) = Flux.Losses.mae(flux_Flux(x), y)
ps = Flux.params(flux_Flux)
SimpleChains.jl model
flux_estimator = SimpleChain(
SimpleChains.Conv(SimpleChains.relu, (3, 1), 10),
SimpleChains.Conv(SimpleChains.relu, (3, 1), 10),
SimpleChains.Conv(SimpleChains.identity, (3, 1), 2)
)
p = SimpleChains.init_params(flux_estimator, size(input))
I confirmed that both models have the same structure with same number of parameters.
Chain(
Conv((3, 1), 2 => 10, relu), # 70 parameters
Conv((3, 1), 10 => 10, relu), # 310 parameters
Conv((3, 1), 10 => 2), # 62 parameters
) # Total: 6 arrays, 442 parameters, 2.781 KiB.
442-element StrideArray{Float32, 1, (1,), Tuple{Int64}, Tuple{Nothing}, Tuple{Static.StaticInt{1}}, Vector{Float32}}:
After training the Flux model I passed the model parameters to SimpleChains parameter as below
p[1:60] = Flux.params(flux_Flux)[1][1:60]
p[61:70] = Flux.params(flux_Flux)[2][1:10]
p[71:370] = Flux.params(flux_Flux)[3][1:300]
p[371:380] = Flux.params(flux_Flux)[4][1:10]
p[381:440] = Flux.params(flux_Flux)[5][1:60]
p[441:442] = Flux.params(flux_Flux)[6][1:2]
However, they gave me totally different results when I feed the same input dataset.
Flux estimation in several time steps using Flux.jl model
Flux estimation in several time steps using SimpleChains.jl model with the same parameters as the Flux.jl model
Do you have any idea why SimpleChains.jl model gave me very different results? I originally tried to train my model with SimpleChains.jl but likewise the model training was not successful so I chose to pass the parameters from Flux.jl and here I saw it's not very helpful so far. Any comments will help me out. Thank you so much!
Activity