using MLPGradientFlow, Random
Random.seed!(71)
teacher = TeacherNet(layers = ((5, g, false), (1, identity, false)), Din = 2,
p = params((randn(5, 2), nothing), (randn(1, 5), nothing)))
input = randn(2, 10_000)
target = teacher(input)
student = Net(; layers = ((5, g, false), (1, identity, false)),
input, target)
p = random_params(student)
res = train(student, p,
maxtime_ode = 10, maxiterations_optim = 0,
n_samples_trajectory = 10^3)
Dict{String, Any} with 19 entries:
"gnorm" => 1.99552e-12
"init" => Dict("w1"=>[-0.398123 0.531146; -0.550592 0.0325326; ……
"x" => Dict("w1"=>[1.58006 1.36655; -0.448955 -0.242849; … ; …
"loss_curve" => [18.3373, 8.62024e-5, 1.51932e-7, 3.16209e-9, 1.1967e-…
"target" => [-2.8022 -2.71996 … -6.5585 -2.46518]
"optim_iterations" => 0
"ode_stopped_by" => "maxtime > 10.0s"
"ode_iterations" => 934
"optim_time_run" => 0
"converged" => false
"ode_time_run" => 10.0144
"loss" => 1.23711e-24
"input" => [-1.49619 -0.342379 … 0.569387 0.138405; -0.2845 0.075…
"trajectory" => OrderedDict(0.0=>Dict("w1"=>[-0.398123 0.531146; -0.55…
"ode_x" => Dict("w1"=>[1.58006 1.36655; -0.448955 -0.242849; … ; …
"total_time" => 23.4217
"ode_loss" => 1.23711e-24
"layerspec" => ((5, "g", false), (1, "identity", false))
"gnorm_regularized" => 1.99552e-12
Let us compare the solution found by the student to the teacher parameters:
p_res = params(res["x"])
p_res.w1
5×2 reshape(view(::Vector{Float64}, 1:10), 5, 2) with eltype Float64:
1.58006 1.36655
-0.448955 -0.242849
-1.59789 0.358099
0.199348 0.0451142
-0.404374 -1.80821
teacher.p.w1
5×2 reshape(view(::Vector{Float64}, 1:10), 5, 2) with eltype Float64:
0.199348 0.0451142
1.58006 1.36655
-0.448955 -0.242849
-0.404374 -1.80821
-1.59789 0.358099
p_res.w2
1×5 reshape(view(::Vector{Float64}, 11:15), 1, 5) with eltype Float64:
-1.35351 0.0760491 -0.661599 -0.502888 -0.103821
teacher.p.w2
1×5 reshape(view(::Vector{Float64}, 11:15), 1, 5) with eltype Float64:
-0.502888 -1.35351 0.0760491 -0.103821 -0.661599
We see that the student perfectly reproduces the teacher up to permutation of the hidden neurons.