using MLPGradientFlow, Random, Optimisers
Random.seed!(1)
input = randn(2, 10_000)
teacher = TeacherNet(; layers = ((5, softplus, false), (1, identity, false)), input)
target = teacher(input)
net = Net(; layers = ((4, softplus, false), (1, identity, false)),
input, target)
p = random_params(net)
res_gradientflow = train(net, p, maxT = 30,
maxiterations_optim = 0,
n_samples_trajectory = 10^4)
Dict{String, Any} with 19 entries:
"gnorm" => 1.19526e-6
"init" => Dict("w1"=>[-1.10008 0.418848; -0.278256 0.234034; -1.…
"x" => Dict("w1"=>[-0.552286 0.648202; 0.833958 0.0197933; -0…
"loss_curve" => [1.70629, 0.038037, 0.0345656, 0.032486, 0.030712, 0.0…
"target" => [-0.533436 -0.33788 … 0.00158328 -0.354491]
"optim_iterations" => 0
"ode_stopped_by" => ""
"ode_iterations" => 135
"optim_time_run" => 0
"converged" => false
"ode_time_run" => 2.87843
"loss" => 1.15865e-6
"input" => [0.0619327 -0.595824 … 0.388148 -0.713525; 0.278406 0.…
"trajectory" => OrderedDict(0.0=>Dict("w1"=>[-1.10008 0.418848; -0.278…
"ode_x" => Dict("w1"=>[-0.552286 0.648202; 0.833958 0.0197933; -0…
"total_time" => 14.7717
"ode_loss" => 1.15865e-6
"layerspec" => ((4, "softplus", false), (1, "identity", false))
"gnorm_regularized" => 1.19526e-6
Let us compare to training without and with minibatches and gradient descent.
res_descent_fullbatch = train(net, p, alg = Descent(eta = 1e-1),
maxiterations_ode = 10^8,
maxtime_ode = 30, maxiterations_optim = 0,
n_samples_trajectory = 10^3)
Dict{String, Any} with 19 entries:
"gnorm" => 2.34935e-5
"init" => Dict("w1"=>[-1.10008 0.418848; -0.278256 0.234034; -1.…
"x" => Dict("w1"=>[-0.83211 0.590325; 1.11798 0.174144; -0.62…
"loss_curve" => [1.70629, 0.0340358, 0.0301425, 0.0265193, 0.0219961, …
"target" => [-0.533436 -0.33788 … 0.00158328 -0.354491]
"optim_iterations" => 0
"ode_stopped_by" => "maxtime_ode"
"ode_iterations" => 74232
"optim_time_run" => 0
"converged" => false
"ode_time_run" => 30.0
"loss" => 1.42086e-5
"input" => [0.0619327 -0.595824 … 0.388148 -0.713525; 0.278406 0.…
"trajectory" => OrderedDict(0=>Dict("w1"=>[-1.10008 0.418848; -0.27825…
"ode_x" => Dict("w1"=>[-0.832109 0.590324; 1.11798 0.174144; -0.6…
"total_time" => 30.0685
"ode_loss" => 1.42084e-5
"layerspec" => ((4, "softplus", false), (1, "identity", false))
"gnorm_regularized" => 2.34935e-5
res_descent = train(net, p, alg = Descent(eta = 1e-1), batchsize = 100,
maxiterations_ode = 10^8,
maxtime_ode = 20, maxiterations_optim = 0,
n_samples_trajectory = 10^3)
Dict{String, Any} with 19 entries:
"gnorm" => 5.33634e-5
"init" => Dict("w1"=>[-1.10008 0.418848; -0.278256 0.234034; -1.…
"x" => Dict("w1"=>[-0.652126 0.597922; 0.883571 0.0688886; -0…
"loss_curve" => [1.70629, 0.000493841, 0.000355521, 0.000272728, 0.000…
"target" => [-0.533436 -0.33788 … 0.00158328 -0.354491]
"optim_iterations" => 0
"ode_stopped_by" => "maxtime_ode"
"ode_iterations" => 1574619
"optim_time_run" => 0
"converged" => false
"ode_time_run" => 20.0
"loss" => 1.9186e-6
"input" => [0.0619327 -0.595824 … 0.388148 -0.713525; 0.278406 0.…
"trajectory" => OrderedDict(0=>Dict("w1"=>[-1.10008 0.418848; -0.27825…
"ode_x" => Dict("w1"=>[-0.652044 0.597949; 0.883681 0.0690056; -0…
"total_time" => 20.0016
"ode_loss" => 1.96983e-6
"layerspec" => ((4, "softplus", false), (1, "identity", false))
"gnorm_regularized" => 5.33634e-5
Not surprisingly, gradient descent takes more time than gradient flow (which uses second order information), and therefore does not find a point of equally low loss and gradient as gradient flow.
tdb, ttb, _ = MLPGradientFlow.trajectory_distance(res_descent_fullbatch, res_gradientflow)
td, tt, _ = MLPGradientFlow.trajectory_distance(res_descent, res_gradientflow)
using CairoMakie
f = Figure()
ax = Axis(f[1, 1], ylabel = "distance", yscale = Makie.pseudolog10, xscale = Makie.pseudolog10, xlabel = "time")
lines!(ax, ttb, tdb, label = "full batch")
lines!(ax, tt, td, label = "batchsize = 100")
axislegend(ax)
f
Gradient descent stays close to gradient flow, both in full batch mode and with minibatches of size 100.
res_adam_fullbatch = train(net, p, alg = Adam(),
maxtime_ode = 20, maxiterations_optim = 0,
n_samples_trajectory = 10^3)
Dict{String, Any} with 19 entries:
"gnorm" => 7.60045e-6
"init" => Dict("w1"=>[-1.10008 0.418848; -0.278256 0.234034; -1.…
"x" => Dict("w1"=>[0.234266 -0.263646; -0.539106 0.509594; -0…
"loss_curve" => [1.70629, 1.29335, 0.991448, 0.770121, 0.61352, 0.4979…
"target" => [-0.533436 -0.33788 … 0.00158328 -0.354491]
"optim_iterations" => 0
"ode_stopped_by" => "maxtime_ode"
"ode_iterations" => 48297
"optim_time_run" => 0
"converged" => false
"ode_time_run" => 20.0001
"loss" => 8.43275e-7
"input" => [0.0619327 -0.595824 … 0.388148 -0.713525; 0.278406 0.…
"trajectory" => OrderedDict(0=>Dict("w1"=>[-1.10008 0.418848; -0.27825…
"ode_x" => Dict("w1"=>[0.234275 -0.263645; -0.539112 0.509596; -0…
"total_time" => 20.103
"ode_loss" => 8.43425e-7
"layerspec" => ((4, "softplus", false), (1, "identity", false))
"gnorm_regularized" => 7.60045e-6
res_adam = train(net, p, alg = Adam(), batchsize = 100, maxiterations_ode = 10^8,
maxtime_ode = 20, maxiterations_optim = 0,
n_samples_trajectory = 10^3)
Dict{String, Any} with 19 entries:
"gnorm" => 0.000676705
"init" => Dict("w1"=>[-1.10008 0.418848; -0.278256 0.234034; -1.…
"x" => Dict("w1"=>[0.198439 -0.261251; -0.519234 0.502803; -0…
"loss_curve" => [1.70629, 0.0259082, 0.0003525, 0.000235537, 0.0001839…
"target" => [-0.533436 -0.33788 … 0.00158328 -0.354491]
"optim_iterations" => 0
"ode_stopped_by" => "maxtime_ode"
"ode_iterations" => 1827508
"optim_time_run" => 0
"converged" => false
"ode_time_run" => 20.0
"loss" => 1.08468e-6
"input" => [0.0619327 -0.595824 … 0.388148 -0.713525; 0.278406 0.…
"trajectory" => OrderedDict(0=>Dict("w1"=>[-1.10008 0.418848; -0.27825…
"ode_x" => Dict("w1"=>[0.394122 -0.243279; -0.525369 0.505516; -0…
"total_time" => 20.0017
"ode_loss" => 8.52197e-7
"layerspec" => ((4, "softplus", false), (1, "identity", false))
"gnorm_regularized" => 0.000676705
tdb, ttb, _ = MLPGradientFlow.trajectory_distance(res_adam_fullbatch, res_gradientflow)
td, tt, _ = MLPGradientFlow.trajectory_distance(res_adam, res_gradientflow)
using CairoMakie
f = Figure()
ax = Axis(f[1, 1], ylabel = "distance", yscale = Makie.pseudolog10, xlabel = "trajectory steps")
lines!(ax, 1:length(tdb), tdb, label = "full batch")
lines!(ax, 1:length(td), td, label = "batchsize = 100")
axislegend(ax, position = :rb)
f
This is not the case for Adam
which uses effectively different timescales for the different parameters.