To experiment with different timescales for different parameters, we can use separate tauinv for each parameter. In the following example the biases change on a much faster timescale than the weights.

using MLPGradientFlow, Random

Random.seed!(14)

input = randn(2, 5_000)
teacher = TeacherNet(; layers = ((5, sigmoid2, true), (1, identity, true)),
                       p = params((randn(5, 2), randn(5)), (randn(1, 5), randn(1))),
                       input)

student = Net(; layers = ((4, sigmoid2, true), (1, identity, true)),
                input, target = teacher(input))

p = random_params(student)

neti = NetI(teacher, student)

res_standard = train(neti, p; maxtime_ode = 10, maxiterations_optim = 0)
Dict{String, Any} with 17 entries:
  "gnorm"             => 4.3992e-15
  "init"              => Dict("w1"=>[0.13725 1.29647 0.0; -0.644227 0.487307 0.…
  "x"                 => Dict("w1"=>[-0.0933173 1.18814 0.226924; -0.425007 0.5…
  "loss_curve"        => [1.12621, 0.00147223, 2.27038e-6, 2.27038e-6, 2.27038e…
  "optim_iterations"  => 0
  "ode_stopped_by"    => "t > 1e300"
  "ode_iterations"    => 2707
  "optim_time_run"    => 0
  "converged"         => false
  "ode_time_run"      => 1.12226
  "loss"              => 2.27038e-6
  "trajectory"        => OrderedDict(0.0=>Dict("w1"=>[0.13725 1.29647 0.0; -0.6…
  "ode_x"             => Dict("w1"=>[-0.0933173 1.18814 0.226924; -0.425007 0.5…
  "total_time"        => 6.91658
  "ode_loss"          => 2.27038e-6
  "layerspec"         => ((4, "sigmoid2", true), (1, "identity", true))
  "gnorm_regularized" => 4.3992e-15
tauinv = zero(p) .+ 1
tauinv.w1[:, end] .= 1e-4
tauinv.w2[:, end] .= 1e-4

res = train(neti, p; tauinv, maxtime_ode = 20, maxiterations_optim = 0)
Dict{String, Any} with 17 entries:
  "gnorm"             => 1.08463e-8
  "init"              => Dict("w1"=>[0.13725 1.29647 0.0; -0.644227 0.487307 0.…
  "x"                 => Dict("w1"=>[-0.264197 1.17991 0.180489; 0.0029566 -0.0…
  "loss_curve"        => [1.12621, 1.11164, 1.10364, 1.09962, 1.09768, 1.09668,…
  "optim_iterations"  => 0
  "ode_stopped_by"    => "maxtime > 20.0s"
  "ode_iterations"    => 52552
  "optim_time_run"    => 0
  "converged"         => false
  "ode_time_run"      => 20.0001
  "loss"              => 0.000334693
  "trajectory"        => OrderedDict(0.0=>Dict("w1"=>[0.13725 1.29647 0.0; -0.6…
  "ode_x"             => Dict("w1"=>[-0.264197 1.17991 0.180489; 0.0029566 -0.0…
  "total_time"        => 22.5813
  "ode_loss"          => 0.000334693
  "layerspec"         => ((4, "sigmoid2", true), (1, "identity", true))
  "gnorm_regularized" => 7.60673e-5

We see that the two dynamics converge to different solutions.