Fitting a Q-Learner
We define a Q-Learner that explores an environment with 10 states and 3 actions with softmax policy.
using LaplacianExpectationMaximization
using ConcreteStructs
import LogExpFunctions: softplus, logistic
@concrete struct QLearner
q
end
QLearner() = QLearner(zeros(10, 3))
function LaplacianExpectationMaximization.initialize!(m::QLearner, parameters)
m.q .= parameters.q₀
end
LaplacianExpectationMaximization.parameters(::QLearner) = (; q₀ = zeros(10, 3), η = 0., β_real = 1., γ_logit = 10.)
function LaplacianExpectationMaximization.logp(data, m::QLearner, parameters)
initialize!(m, parameters)
(; η, β_real, γ_logit) = parameters
β = softplus(β_real)
γ = logistic(γ_logit)
q = m.q
logp = 0.
for (; s, a, s′, r, done) in data
logp += logsoftmax(β, q, s, a)
td_error = r + γ * findmax(view(q, s′, :))[1] - q[s, a]
q[s, a] += η * td_error
end
logp
end
function LaplacianExpectationMaximization.sample(rng, data, m::QLearner, parameters; environment)
(; η, β_real) = parameters
q = m.q
(; s′, done) = data[end]
if done
s′ = initial_state(rng, environment)
end
a′ = randsoftmax(rng, softplus(β_real), q, s′)
s′′ = transition(rng, environment, s′, a′)
r′ = reward(rng, environment, s′, a′, s′′)
(s = s′, a = a′, s′ = s′′, r = r′, done = isdone(rng, environment, s′′))
endLet us define the helper functions logsoftmax, randsoftmax and the environment functions initial_state, transition, reward, isdone.
using Distributions, LinearAlgebra
logsoftmax(β, q, s, a) = β * q[s, a] - logsumexp(β, view(q, s, :))
function logsumexp(β, v)
m = β * maximum(v)
sumexp = zero(eltype(v))
for vᵢ in v
sumexp += exp(β * vᵢ - m)
end
m + log(sumexp)
end
function randsoftmax(rng, β, q, s)
m = maximum(view(q, s, :))
p = exp.(q[s, :] .- m)
p ./= sum(p)
rand(rng, Categorical(p))
end
@concrete struct Environment
t
r
end
function Environment(; rng = Random.default_rng())
Environment([normalize(rand(rng, 10), 1) for s in 1:10, a in 1:3],
randn(rng, 10, 3))
end
initial_state(rng, ::Environment) = rand(rng, 1:10)
transition(rng, e::Environment, s, a) = rand(rng, Categorical(e.t[s, a]))
reward(rng, e::Environment, s, a, s′) = e.r[s, a]
isdone(rng, e::Environment, s) = s == 10isdone (generic function with 1 method)With this we can generate some artificial data and fit it.
using ComponentArrays, Random
model = QLearner()
p = ComponentArray(parameters(model))
p.η = .15
p.β_real = 1.6
p.γ_logit = 1.8
rng = Xoshiro(17)
tmp = [simulate(model, p;
n_steps = 200, rng,
init = [(s = 1, a = 1, s′ = 1, r = 0., done = false)],
environment = Environment(; rng))
for _ in 1:100]
data = first.(tmp)
data_logp = sum(last.(tmp))-21206.661361122406This is the probability with which the data was generated. Let us check the data probability under the default parameters.
population_model = PopulationModel(model, shared = (:q₀, :η, :β_real, :γ_logit))
p0 = ComponentArray(parameters(population_model))
mc_marginal_logp(data, population_model, p0)1-element Vector{Float64}:
-22082.10700222894We see that the data probability under the default parameters is higher than under the parameter with which the data was generated. Now we maximize the log-likelihood, to find the optimal parameters for this data.
result = maximize_logp(data, population_model)
result.logp-20473.929115426105The resulting data probability is indeed the highest. The fitted parameters are, however, not super close to p:
result.population_parametersComponentVector{Float64}(q₀ = [1.2995145048085743 0.9846083140546489 1.0258508448414068; 1.0225508041218634 1.0606936792848884 1.0361697126420548; … ; 1.14261369870024 1.2753002345956057 1.1095232835440607; 1.0402172471341526 1.1870825594265575 1.2897445609067633], η = 0.14377293880177086, β_real = 0.5318712596281107, γ_logit = 2.4971664587164897, population_parameters = (μ = Float64[], σ = Float64[]))We can try fixing q₀:
result2 = maximize_logp(data, population_model, fixed = (; q₀ = zeros(10, 3)))
result2.logp-20494.891628022964The data probability is a bit lower, as expected.
result2.population_parametersComponentVector{Float64}(η = 0.13525533745553608, β_real = 0.5730574455049349, γ_logit = 2.1517662240627065, q₀ = [0.0 0.0 0.0; 0.0 0.0 0.0; … ; 0.0 0.0 0.0; 0.0 0.0 0.0], population_parameters = (μ = Float64[], σ = Float64[]))The fitted parameters, however, are closer to p.