The package markovDP (Hahsler 2024) implements episodic semi-gradient Sarsa with linear state-value function approximation following Sutton and Barto (2018). The state-action value construction uses the approach described in Geramifard et al. (2013). First, we will use the state features directly and then we will use a Fourier basis (see “Value Function Approximation in Reinforcement Learning Using the Fourier Basis” (2011)).
The state-action value function is approximated by
where is a weight vector and is a feature function that maps each state-action pair to a feature vector. The gradient of the state-action function is
For a small number of actions, we can follow the construction described by Geramifard et al. (2013) which uses a state feature function to construct the complete state-action feature vector. Here, we also add an intercept term. The state-action feature vector has length . It has the intercept and then one component for each action. All these components are set to zero and only the active action component is set to , where is the current state. For example, for the state feature vector and action out of three possible actions , the complete state-action feature vector is . The leading 1 is for the intercept and the zeros represent the two not chosen actions.
This construction is implemented in
add_linear_approx_Q_function()
.
The following helper functions for using approximation are available:
approx_Q_value()
calculates approximate Q values given
the weights in the model or specified weights.approx_greedy_action()
uses approximate Q values given
the weights in the model or specified weights to find the the greedy
action for a state.approx_greedy_policy()
calculates the greedy-policy for
the approximate Q values given the weights in the model or specified
weights.The implementation follows the algorithm given in Sutton and Barto (2018).
The step cost is 1. The start is top-left and the goal (+100 reward) is bottom-right. This is the ideal problem for a linear approximation of the Q-function using the x/y location as state features.
We start with defining an MDP for a small maze without walls.
We construct state features as the x/y coordinates in the gridworld.
state_features <- gw_s2rc(S(m))
state_features
#> [,1] [,2]
#> s(1,1) 1 1
#> s(2,1) 2 1
#> s(3,1) 3 1
#> s(4,1) 4 1
#> s(5,1) 5 1
#> s(1,2) 1 2
#> s(2,2) 2 2
#> s(3,2) 3 2
#> s(4,2) 4 2
#> s(5,2) 5 2
#> s(1,3) 1 3
#> s(2,3) 2 3
#> s(3,3) 3 3
#> s(4,3) 4 3
#> s(5,3) 5 3
#> s(1,4) 1 4
#> s(2,4) 2 4
#> s(3,4) 3 4
#> s(4,4) 4 4
#> s(5,4) 5 4
#> s(1,5) 1 5
#> s(2,5) 2 5
#> s(3,5) 3 5
#> s(4,5) 4 5
#> s(5,5) 5 5
We add the state features with the linear approximation function to
the model. add_linear_approx_Q_function()
constructs
state-action features and adds them with an approximate Q function and a
gradient function to the model. The state-action features are
constructed as a vector with weights for an intercept and the state
features for each action.
Below, we see that the initial weight vector has names organized by
action. x0
represents the intercept and x1
and
x2
represent the x and y coordinate in the maze,
respectively.
m <- add_linear_approx_Q_function(m, state_features)
m$approx_Q_function
#> $x
#> function (s, a)
#> {
#> s <- transformation(s)
#> a <- normalize_action_id(a, model)
#> x <- numeric(n_A * dim_s)
#> a_pos <- 1L + (a - 1L) * dim_s
#> x[a_pos:(a_pos + dim_s - 1L)] <- s
#> x
#> }
#> <bytecode: 0x56235f99bdc8>
#> <environment: 0x56235f98cae8>
#>
#> $f
#> function (s, a, w)
#> sum(w * x(s, a))
#> <bytecode: 0x56235f98eea0>
#> <environment: 0x56235f98cae8>
#>
#> $gradient
#> function (s, a, w)
#> x(s, a)
#> <bytecode: 0x56235f98b498>
#> <environment: 0x56235f98cae8>
#>
#> $transformation
#> function (x)
#> {
#> x <- (x - min)/(max - min)
#> if (intercept)
#> x <- c(x0 = 1, x)
#> x
#> }
#> <bytecode: 0x56235f984308>
#> <environment: 0x56235f984b58>
#>
#> $w_init
#> up.x0 up.x1 up.x2 right.x0 right.x1 right.x2 down.x0 down.x1
#> 0 0 0 0 0 0 0 0
#> down.x2 left.x0 left.x1 left.x2
#> 0 0 0 0
Now, we can solve the model. Both, alpha and epsilon follow an exponential schedule. Epsilon used for the -greedy behavior policy starts at 1 and decays by a factor of 0.1 per episode. The step size alpha starts at 0.2 and also decays bu 0.1 per episode.
set.seed(1000)
sol <- solve_MDP_APPROX(m, horizon = 1000, n = 100,
alpha = schedule_exp(0.2, 0.1),
epsilon = schedule_exp(1, 0.1))
gw_plot(sol)
gw_matrix(sol, what = "value")
#> [,1] [,2] [,3] [,4] [,5]
#> [1,] 35.55476 45.87542 56.19608 66.51674 76.83740
#> [2,] 42.17727 51.80411 62.12477 72.44543 82.76609
#> [3,] 55.72449 59.97481 68.05346 78.37411 88.69477
#> [4,] 69.27171 73.52203 77.77235 84.30280 94.62346
#> [5,] 82.81893 87.06925 91.31957 95.56989 100.55215
Here are the learned weights.
sol$solution$w
#> up.x0 up.x1 up.x2 right.x0 right.x1 right.x2 down.x0 down.x1
#> 14.65528 12.30324 11.06122 28.63005 54.18888 17.00129 35.55476 23.71475
#> down.x2 left.x0 left.x1 left.x2
#> 41.28264 13.77898 12.14905 9.10107
The approximate value function is continuous and can also be displayed using matrix shading and contours
The wall and the -1 absorbing state make linear approximation using just the position more difficult.
Adding the linear approximation translates state names of the format
s(feature list)
automatically.
Maze_approx <- add_linear_approx_Q_function(Maze)
sol <- solve_MDP_APPROX(Maze_approx, horizon = 100, n = 100,
alpha = schedule_exp(0.2, 0.01),
epsilon = schedule_exp(1, 0.1))
gw_plot(sol)
gw_matrix(sol, what = "value")
#> [,1] [,2] [,3] [,4]
#> [1,] 0.6526606 0.7443948 0.8361289 0.92786308
#> [2,] 0.5362074 NA 0.2450085 0.29572528
#> [3,] 0.5155252 0.3699258 0.2243263 0.07872689
The linear approximation cannot deal with the center wall and the -1 absorbing state.
Maze_approx <- add_linear_approx_Q_function(Maze, transformation = transformation_polynomial_basis, order = 1)
set.seed(2000)
sol <- solve_MDP_APPROX(Maze_approx, horizon = 100, n = 100,
alpha = schedule_exp(0.2, 0.01),
epsilon = schedule_exp(1, 0.1))
gw_matrix(sol, what = "value")
#> [,1] [,2] [,3] [,4]
#> [1,] 0.7362799 0.8628479 0.9894160 1.1159840
#> [2,] 0.6413143 NA 0.5437256 0.4949312
#> [3,] 0.6301629 0.5506522 0.4711415 0.3916308
set.seed(2000)
sol <- solve_MDP_APPROX(Maze_approx, horizon = 100, n = 100,
alpha = schedule_exp(0.2, 0.01),
epsilon = schedule_exp(1, 0.1))
gw_matrix(sol, what = "value")
#> [,1] [,2] [,3] [,4]
#> [1,] 0.8277573 0.9706198 0.9379159 0.7476378
#> [2,] 0.8072311 NA 0.8633556 0.6662076
#> [3,] 0.6912557 0.7816229 0.7252808 0.5532505
Maze_approx <- add_linear_approx_Q_function(Maze, transformation = transformation_fourier_basis, order = 1)
set.seed(2000)
sol <- solve_MDP_APPROX(Maze_approx, horizon = 100, n = 500,
alpha = schedule_exp(0.2, 0.01),
epsilon = schedule_exp(1, 0.1))
gw_matrix(sol, what = "value")
#> [,1] [,2] [,3] [,4]
#> [1,] 0.7991024 0.8382023 0.9164022 0.9555022
#> [2,] 0.7336637 NA 0.5667932 0.1794023
#> [3,] 0.6698697 0.5077707 0.1998846 0.0459415
Maze <- gw_read_maze(
textConnection(c("XXXXXXXXXXXX",
"X X",
"X S X",
"X X",
"XXXXXXXXX X",
"X X",
"X G X",
"X X",
"XXXXXXXXXXXX"
)))
We use a Fourier basis.
Maze_approx <- add_linear_approx_Q_function(Maze, transformation = transformation_fourier_basis, order = 2)
set.seed(2000)
sol <- solve_MDP_APPROX(Maze_approx, horizon = 100, n = 100,
alpha = schedule_exp(0.2, 0.01),
epsilon = schedule_exp(1, 0.01))
gw_matrix(sol, what = "value")
#> [,1] [,2] [,3] [,4] [,5] [,6] [,7]
#> [1,] NA NA NA NA NA NA NA
#> [2,] NA -34.67819 -36.56455 -40.93651 -44.58366 -43.81634 -32.723621
#> [3,] NA -39.11605 -40.55373 -44.80256 -35.15548 -27.92717 -23.070036
#> [4,] NA -40.05288 -26.51104 -17.23813 -12.38626 -10.04010 -7.617928
#> [5,] NA NA NA NA NA NA NA
#> [6,] NA 35.07558 39.35036 39.95272 39.66480 45.03300 54.405511
#> [7,] NA 73.08151 75.25035 67.83979 70.78166 70.53766 70.061263
#> [8,] NA 97.46271 91.42248 76.07045 75.61095 71.66126 63.360543
#> [9,] NA NA NA NA NA NA NA
#> [,8] [,9] [,10] [,11] [,12]
#> [1,] NA NA NA NA NA
#> [2,] -24.526856 -11.347759 -0.2685837 3.947177 NA
#> [3,] -19.230784 -7.257869 6.0132145 11.058506 NA
#> [4,] -3.523448 14.378799 24.7563633 22.869067 NA
#> [5,] NA NA 38.0529495 20.978329 NA
#> [6,] 57.058410 51.366899 44.7776043 27.681048 NA
#> [7,] 68.406845 59.674567 44.8505262 27.681288 NA
#> [8,] 51.709973 39.452029 30.1139787 26.623795 NA
#> [9,] NA NA NA NA NA