import tensorflow as tf
import numpy as np
import gym
%matplotlib tk
def mlp(x, hidden_sizes=(32,32), activation=tf.tanh):
for size in hidden_sizes:
x = tf.layers.dense(x, units=size, activation=activation)
return x
def discount_cumsum(x, gamma):
n = len(x)
x = np.array(x)
y = gamma**np.arange(n)
z = np.zeros_like(x, dtype=np.float32)
for j in range(n):
z[j] = sum(x[j:] * y[:n-j])
return z
class Agent:
def __init__(self,
env_name='CartPole-v0',
hidden_dim=32,
n_layers=1,
lr=1e-2,
beta=1):
# create environent
self.env = gym.make(env_name)
obs_dim = self.env.observation_space.shape[0]
n_acts = self.env.action_space.n
# make model
with tf.variable_scope('model'):
self.obs_ph = tf.placeholder(shape=(None, obs_dim),
dtype=tf.float32)
self.net = mlp(self.obs_ph,
hidden_sizes=[hidden_dim]*n_layers,
activation=tf.nn.relu)
self.logits = tf.layers.dense(self.net,
units=n_acts,
activation=None)
self.actions = tf.squeeze(tf.multinomial(logits=self.logits,num_samples=1),
axis=1)
# make loss
self.adv_ph = tf.placeholder(shape=(None,), dtype=tf.float32)
self.act_ph = tf.placeholder(shape=(None,), dtype=tf.int32)
self.action_one_hots = tf.one_hot(self.act_ph, n_acts)
self.log_probs = tf.reduce_sum(self.action_one_hots * tf.nn.log_softmax(beta * self.logits), axis=1)
self.loss = -tf.reduce_mean(self.adv_ph * self.log_probs)
# make train op
self.train_op = tf.train.AdamOptimizer(learning_rate=lr).minimize(self.loss)
self.sess = tf.InteractiveSession()
self.sess.run(tf.global_variables_initializer())
def train_one_iteration(self, gamma=0.99, batch_size=5000, render=False):
batch_obs, batch_acts, batch_rtgs, batch_rets, batch_lens = [], [], [], [], []
obs, rew, done, ep_rews = self.env.reset(), 0, False, []
#cpt = 0
while True:
if render:
self.env.render()
batch_obs.append(obs.copy())
act = self.sess.run(self.actions, {self.obs_ph: obs.reshape(1,-1)})[0]
obs, rew, done, _ = self.env.step(act)
batch_acts.append(act)
ep_rews.append(rew)
#cpt += 1
if done: #cpt > 5000 or obs[0] > 1 : #done:
batch_rets.append(sum(ep_rews))
batch_lens.append(len(ep_rews))
batch_rtgs += list(discount_cumsum(ep_rews, gamma))
obs, rew, done, ep_rews = self.env.reset(), 0, False, []
#cpt = 0
if len(batch_obs) > batch_size:
break
# normalize advs trick:
batch_advs = np.array(batch_rtgs)
batch_advs = (batch_advs - np.mean(batch_advs))/(np.std(batch_advs) + 1e-8)
#print(batch_advs)
batch_loss, _ = self.sess.run([self.loss, self.train_op], feed_dict={self.obs_ph: np.array(batch_obs),
self.act_ph: np.array(batch_acts),
self.adv_ph: batch_advs})
return batch_loss, batch_rets, batch_lens
def train(self, gamma=0.99, n_iters=50, batch_size=5000):
for i in range(n_iters):
batch_loss, batch_rets, batch_lens = agent.train_one_iteration(gamma=gamma, batch_size=batch_size)
print('itr: %d \t loss: %.3f \t return: %.3f \t ep_len: %.3f'%
(i, batch_loss, np.mean(batch_rets), np.mean(batch_lens)))
agent = Agent(env_name='CartPole-v1', lr=1e-2)
agent.train(gamma=1, n_iters=50, batch_size = 5000)
WARNING: Logging before flag parsing goes to stderr. W0123 00:29:07.697847 139945602238272 deprecation.py:323] From <ipython-input-3-af4f9e7bfb85>:3: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version. Instructions for updating: Use keras.layers.dense instead. W0123 00:29:07.700635 139945602238272 deprecation.py:506] From /usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor W0123 00:29:07.993772 139945602238272 deprecation.py:323] From <ipython-input-5-0277a5173c3c>:25: multinomial (from tensorflow.python.ops.random_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.random.categorical` instead.
itr: 0 loss: -0.007 return: 25.851 ep_len: 25.851 itr: 1 loss: -0.009 return: 28.096 ep_len: 28.096 itr: 2 loss: -0.012 return: 31.344 ep_len: 31.344 itr: 3 loss: -0.014 return: 34.122 ep_len: 34.122 itr: 4 loss: -0.011 return: 38.397 ep_len: 38.397 itr: 5 loss: -0.015 return: 43.921 ep_len: 43.921 itr: 6 loss: -0.016 return: 42.585 ep_len: 42.585 itr: 7 loss: -0.010 return: 50.100 ep_len: 50.100 itr: 8 loss: -0.009 return: 51.755 ep_len: 51.755 itr: 9 loss: -0.004 return: 52.021 ep_len: 52.021 itr: 10 loss: -0.008 return: 57.602 ep_len: 57.602 itr: 11 loss: -0.007 return: 60.928 ep_len: 60.928 itr: 12 loss: -0.006 return: 56.955 ep_len: 56.955 itr: 13 loss: -0.009 return: 61.975 ep_len: 61.975 itr: 14 loss: -0.004 return: 58.802 ep_len: 58.802 itr: 15 loss: -0.004 return: 64.500 ep_len: 64.500 itr: 16 loss: -0.006 return: 59.353 ep_len: 59.353 itr: 17 loss: -0.005 return: 64.551 ep_len: 64.551 itr: 18 loss: -0.007 return: 65.078 ep_len: 65.078 itr: 19 loss: -0.005 return: 69.493 ep_len: 69.493 itr: 20 loss: -0.009 return: 78.828 ep_len: 78.828 itr: 21 loss: -0.006 return: 74.721 ep_len: 74.721 itr: 22 loss: -0.007 return: 77.769 ep_len: 77.769 itr: 23 loss: -0.010 return: 87.414 ep_len: 87.414 itr: 24 loss: -0.007 return: 95.057 ep_len: 95.057 itr: 25 loss: -0.013 return: 98.412 ep_len: 98.412 itr: 26 loss: -0.010 return: 110.043 ep_len: 110.043 itr: 27 loss: -0.011 return: 127.200 ep_len: 127.200 itr: 28 loss: -0.009 return: 125.900 ep_len: 125.900 itr: 29 loss: -0.018 return: 152.364 ep_len: 152.364 itr: 30 loss: -0.021 return: 163.484 ep_len: 163.484 itr: 31 loss: -0.016 return: 169.067 ep_len: 169.067 itr: 32 loss: -0.013 return: 185.519 ep_len: 185.519 itr: 33 loss: -0.010 return: 186.370 ep_len: 186.370 itr: 34 loss: -0.016 return: 209.200 ep_len: 209.200 itr: 35 loss: -0.014 return: 234.864 ep_len: 234.864 itr: 36 loss: -0.019 return: 200.320 ep_len: 200.320 itr: 37 loss: -0.014 return: 298.471 ep_len: 298.471 itr: 38 loss: -0.009 return: 306.176 ep_len: 306.176 itr: 39 loss: -0.011 return: 383.357 ep_len: 383.357 itr: 40 loss: -0.013 return: 418.154 ep_len: 418.154 itr: 41 loss: -0.009 return: 447.750 ep_len: 447.750 itr: 42 loss: -0.011 return: 419.308 ep_len: 419.308 itr: 43 loss: -0.008 return: 417.083 ep_len: 417.083 itr: 44 loss: -0.003 return: 489.091 ep_len: 489.091 itr: 45 loss: -0.008 return: 432.750 ep_len: 432.750 itr: 46 loss: -0.008 return: 476.636 ep_len: 476.636 itr: 47 loss: -0.005 return: 497.636 ep_len: 497.636 itr: 48 loss: -0.010 return: 435.417 ep_len: 435.417 itr: 49 loss: -0.011 return: 447.000 ep_len: 447.000
agent.train_one_iteration(render=True)