using MLPGradientFlow, Random, Optimisers

Random.seed!(1)

input = randn(2, 10_000)
teacher = TeacherNet(; layers = ((5, softplus, false), (1, identity, false)), input)
target = teacher(input)

net = Net(; layers = ((4, softplus, false), (1, identity, false)),
            input, target)
p = random_params(net)

res_gradientflow = train(net, p, maxT = 30,
                         maxiterations_optim = 0,
                         n_samples_trajectory = 10^4)
Dict{String, Any} with 19 entries:
  "gnorm"             => 1.19526e-6
  "init"              => Dict("w1"=>[-1.10008 0.418848; -0.278256 0.234034; -1.…
  "x"                 => Dict("w1"=>[-0.552286 0.648202; 0.833958 0.0197933; -0…
  "loss_curve"        => [1.70629, 0.038037, 0.0345656, 0.032486, 0.030712, 0.0…
  "target"            => [-0.533436 -0.33788 … 0.00158328 -0.354491]
  "optim_iterations"  => 0
  "ode_stopped_by"    => ""
  "ode_iterations"    => 135
  "optim_time_run"    => 0
  "converged"         => false
  "ode_time_run"      => 2.87843
  "loss"              => 1.15865e-6
  "input"             => [0.0619327 -0.595824 … 0.388148 -0.713525; 0.278406 0.…
  "trajectory"        => OrderedDict(0.0=>Dict("w1"=>[-1.10008 0.418848; -0.278…
  "ode_x"             => Dict("w1"=>[-0.552286 0.648202; 0.833958 0.0197933; -0…
  "total_time"        => 14.7717
  "ode_loss"          => 1.15865e-6
  "layerspec"         => ((4, "softplus", false), (1, "identity", false))
  "gnorm_regularized" => 1.19526e-6

Let us compare to training without and with minibatches and gradient descent.

res_descent_fullbatch = train(net, p, alg = Descent(eta = 1e-1),
                              maxiterations_ode = 10^8,
                              maxtime_ode = 30, maxiterations_optim = 0,
                              n_samples_trajectory = 10^3)
Dict{String, Any} with 19 entries:
  "gnorm"             => 2.34935e-5
  "init"              => Dict("w1"=>[-1.10008 0.418848; -0.278256 0.234034; -1.…
  "x"                 => Dict("w1"=>[-0.83211 0.590325; 1.11798 0.174144; -0.62…
  "loss_curve"        => [1.70629, 0.0340358, 0.0301425, 0.0265193, 0.0219961, …
  "target"            => [-0.533436 -0.33788 … 0.00158328 -0.354491]
  "optim_iterations"  => 0
  "ode_stopped_by"    => "maxtime_ode"
  "ode_iterations"    => 74232
  "optim_time_run"    => 0
  "converged"         => false
  "ode_time_run"      => 30.0
  "loss"              => 1.42086e-5
  "input"             => [0.0619327 -0.595824 … 0.388148 -0.713525; 0.278406 0.…
  "trajectory"        => OrderedDict(0=>Dict("w1"=>[-1.10008 0.418848; -0.27825…
  "ode_x"             => Dict("w1"=>[-0.832109 0.590324; 1.11798 0.174144; -0.6…
  "total_time"        => 30.0685
  "ode_loss"          => 1.42084e-5
  "layerspec"         => ((4, "softplus", false), (1, "identity", false))
  "gnorm_regularized" => 2.34935e-5
res_descent = train(net, p, alg = Descent(eta = 1e-1), batchsize = 100,
                    maxiterations_ode = 10^8,
                    maxtime_ode = 20, maxiterations_optim = 0,
                    n_samples_trajectory = 10^3)
Dict{String, Any} with 19 entries:
  "gnorm"             => 5.33634e-5
  "init"              => Dict("w1"=>[-1.10008 0.418848; -0.278256 0.234034; -1.…
  "x"                 => Dict("w1"=>[-0.652126 0.597922; 0.883571 0.0688886; -0…
  "loss_curve"        => [1.70629, 0.000493841, 0.000355521, 0.000272728, 0.000…
  "target"            => [-0.533436 -0.33788 … 0.00158328 -0.354491]
  "optim_iterations"  => 0
  "ode_stopped_by"    => "maxtime_ode"
  "ode_iterations"    => 1574619
  "optim_time_run"    => 0
  "converged"         => false
  "ode_time_run"      => 20.0
  "loss"              => 1.9186e-6
  "input"             => [0.0619327 -0.595824 … 0.388148 -0.713525; 0.278406 0.…
  "trajectory"        => OrderedDict(0=>Dict("w1"=>[-1.10008 0.418848; -0.27825…
  "ode_x"             => Dict("w1"=>[-0.652044 0.597949; 0.883681 0.0690056; -0…
  "total_time"        => 20.0016
  "ode_loss"          => 1.96983e-6
  "layerspec"         => ((4, "softplus", false), (1, "identity", false))
  "gnorm_regularized" => 5.33634e-5

Not surprisingly, gradient descent takes more time than gradient flow (which uses second order information), and therefore does not find a point of equally low loss and gradient as gradient flow.

tdb, ttb, _ = MLPGradientFlow.trajectory_distance(res_descent_fullbatch, res_gradientflow)
td, tt, _ = MLPGradientFlow.trajectory_distance(res_descent, res_gradientflow)

using CairoMakie

f = Figure()
ax = Axis(f[1, 1], ylabel = "distance", yscale = Makie.pseudolog10, xscale = Makie.pseudolog10, xlabel = "time")
lines!(ax, ttb, tdb, label = "full batch")
lines!(ax, tt, td, label = "batchsize = 100")
axislegend(ax)
f

Gradient descent stays close to gradient flow, both in full batch mode and with minibatches of size 100.

res_adam_fullbatch = train(net, p, alg = Adam(),
                           maxtime_ode = 20, maxiterations_optim = 0,
                           n_samples_trajectory = 10^3)
Dict{String, Any} with 19 entries:
  "gnorm"             => 7.60045e-6
  "init"              => Dict("w1"=>[-1.10008 0.418848; -0.278256 0.234034; -1.…
  "x"                 => Dict("w1"=>[0.234266 -0.263646; -0.539106 0.509594; -0…
  "loss_curve"        => [1.70629, 1.29335, 0.991448, 0.770121, 0.61352, 0.4979…
  "target"            => [-0.533436 -0.33788 … 0.00158328 -0.354491]
  "optim_iterations"  => 0
  "ode_stopped_by"    => "maxtime_ode"
  "ode_iterations"    => 48297
  "optim_time_run"    => 0
  "converged"         => false
  "ode_time_run"      => 20.0001
  "loss"              => 8.43275e-7
  "input"             => [0.0619327 -0.595824 … 0.388148 -0.713525; 0.278406 0.…
  "trajectory"        => OrderedDict(0=>Dict("w1"=>[-1.10008 0.418848; -0.27825…
  "ode_x"             => Dict("w1"=>[0.234275 -0.263645; -0.539112 0.509596; -0…
  "total_time"        => 20.103
  "ode_loss"          => 8.43425e-7
  "layerspec"         => ((4, "softplus", false), (1, "identity", false))
  "gnorm_regularized" => 7.60045e-6
res_adam = train(net, p, alg = Adam(), batchsize = 100, maxiterations_ode = 10^8,
                 maxtime_ode = 20, maxiterations_optim = 0,
                 n_samples_trajectory = 10^3)
Dict{String, Any} with 19 entries:
  "gnorm"             => 0.000676705
  "init"              => Dict("w1"=>[-1.10008 0.418848; -0.278256 0.234034; -1.…
  "x"                 => Dict("w1"=>[0.198439 -0.261251; -0.519234 0.502803; -0…
  "loss_curve"        => [1.70629, 0.0259082, 0.0003525, 0.000235537, 0.0001839…
  "target"            => [-0.533436 -0.33788 … 0.00158328 -0.354491]
  "optim_iterations"  => 0
  "ode_stopped_by"    => "maxtime_ode"
  "ode_iterations"    => 1827508
  "optim_time_run"    => 0
  "converged"         => false
  "ode_time_run"      => 20.0
  "loss"              => 1.08468e-6
  "input"             => [0.0619327 -0.595824 … 0.388148 -0.713525; 0.278406 0.…
  "trajectory"        => OrderedDict(0=>Dict("w1"=>[-1.10008 0.418848; -0.27825…
  "ode_x"             => Dict("w1"=>[0.394122 -0.243279; -0.525369 0.505516; -0…
  "total_time"        => 20.0017
  "ode_loss"          => 8.52197e-7
  "layerspec"         => ((4, "softplus", false), (1, "identity", false))
  "gnorm_regularized" => 0.000676705
tdb, ttb, _ = MLPGradientFlow.trajectory_distance(res_adam_fullbatch, res_gradientflow)
td, tt, _ = MLPGradientFlow.trajectory_distance(res_adam, res_gradientflow)

using CairoMakie

f = Figure()
ax = Axis(f[1, 1], ylabel = "distance", yscale = Makie.pseudolog10, xlabel = "trajectory steps")
lines!(ax, 1:length(tdb), tdb, label = "full batch")
lines!(ax, 1:length(td), td, label = "batchsize = 100")
axislegend(ax, position = :rb)
f

This is not the case for Adam which uses effectively different timescales for the different parameters.