using MLPGradientFlow, Random

Random.seed!(71)

teacher = TeacherNet(layers = ((5, g, false), (1, identity, false)), Din = 2,
                     p = params((randn(5, 2), nothing), (randn(1, 5), nothing)))

input = randn(2, 10_000)
target = teacher(input)

student = Net(; layers = ((5, g, false), (1, identity, false)),
                input, target)

p = random_params(student)

res = train(student, p,
            maxtime_ode = 10, maxiterations_optim = 0,
            n_samples_trajectory = 10^3)
Dict{String, Any} with 19 entries:
  "gnorm"             => 1.99552e-12
  "init"              => Dict("w1"=>[-0.398123 0.531146; -0.550592 0.0325326; ……
  "x"                 => Dict("w1"=>[1.58006 1.36655; -0.448955 -0.242849; … ; …
  "loss_curve"        => [18.3373, 8.62024e-5, 1.51932e-7, 3.16209e-9, 1.1967e-…
  "target"            => [-2.8022 -2.71996 … -6.5585 -2.46518]
  "optim_iterations"  => 0
  "ode_stopped_by"    => "maxtime > 10.0s"
  "ode_iterations"    => 934
  "optim_time_run"    => 0
  "converged"         => false
  "ode_time_run"      => 10.0144
  "loss"              => 1.23711e-24
  "input"             => [-1.49619 -0.342379 … 0.569387 0.138405; -0.2845 0.075…
  "trajectory"        => OrderedDict(0.0=>Dict("w1"=>[-0.398123 0.531146; -0.55…
  "ode_x"             => Dict("w1"=>[1.58006 1.36655; -0.448955 -0.242849; … ; …
  "total_time"        => 23.4217
  "ode_loss"          => 1.23711e-24
  "layerspec"         => ((5, "g", false), (1, "identity", false))
  "gnorm_regularized" => 1.99552e-12

Let us compare the solution found by the student to the teacher parameters:

p_res = params(res["x"])
p_res.w1
5×2 reshape(view(::Vector{Float64}, 1:10), 5, 2) with eltype Float64:
  1.58006    1.36655
 -0.448955  -0.242849
 -1.59789    0.358099
  0.199348   0.0451142
 -0.404374  -1.80821
teacher.p.w1
5×2 reshape(view(::Vector{Float64}, 1:10), 5, 2) with eltype Float64:
  0.199348   0.0451142
  1.58006    1.36655
 -0.448955  -0.242849
 -0.404374  -1.80821
 -1.59789    0.358099
p_res.w2
1×5 reshape(view(::Vector{Float64}, 11:15), 1, 5) with eltype Float64:
 -1.35351  0.0760491  -0.661599  -0.502888  -0.103821
teacher.p.w2
1×5 reshape(view(::Vector{Float64}, 11:15), 1, 5) with eltype Float64:
 -0.502888  -1.35351  0.0760491  -0.103821  -0.661599

We see that the student perfectly reproduces the teacher up to permutation of the hidden neurons.