카테고리 없음
[강화학습] A3C 코드 공부
yennle
2023. 7. 25. 21:19
728x90
A3C의 실행하는 과정의 전반적인 흐름을 코드로 공부해보자.
① 환경 초기화
# 상태변수, 행동, 보상을 저장할 배치는 초기화한다.
batch_state, batch_action, batch_reward = [], [], []
# 환경을 초기화하고 환경으로부터 첫번째 상태변수 x0를 측정한다.
step, episode_reward, done = 0, 0, False
state = self.env_reset()
②-1. 행동 선택
# 워커의 액터 신경망을 이용해 행동을 샘플링한다.
action = self.get_action(tf.convert_to_tensor([state], dtype=tf.float32))
# 행동이 범위 [-2, 2]를 벗어나지 않도록 제한
action = np.clip(action, -self.action_bound, self.action_bound)
def get_action(self, state):
mu_a, std_a = self.worker_actor(state)
mu_a = mu_a.numpy()[0]
std_a = std_a.numpy()[0]
std_a = np.clip(std_a, self.std_bound[0], self.std_bound[1])
action = np.random.normal(mu_a, std_a, size=self.action_dim)
return action
②-2. 다음 상태변수 측정
# 행동 u0를 실행해 보상 r(x0, u0)와 다음 상태변수 x1을 얻는다.
# 여기서 done=1이면 에피소드 종료
next_state, reward, done, _ = self.env.step(action)
②-3. 샘플 저장
# Gym 환경과 학습환경에서 사용하는 변수의 배열 모양이 다름을 고려하여 상태변수, 행동, 보상, 다음 상태변수 등의 배열 모양을 바꿔준다.
state = np.reshape(state, [1, self.state_dim])
action = np.reshape(action, [1, self.action_dim])
reward = np.reshape(reward, [1, 1])
# 학습용으로 사용할 보상의 범위를 식 rt = (r+8)/8을 이용해 [-16, 0]에서 [-1, 1]로 조정
train_reward = (reward + 8) / 8
# 상태변수, 행동, 보상을 배치에 저장
batch_state.append(state)
batch_action.append(action)
batch_reward.append(train_reward)
③ 2번 과정 t_MAX 시간동안 반복 후, 학습 준비
# 다시 상태변수 xi를 이용해 행동 ui를 계산하는 과정을 되풀이한다.
state = next_state
episode_reward += reward[0]
step += 1
# 배치가 t_max개만큼 쌓이거나 에피소드가 종료되면 학습을 시작한다.
if len(batch_state) == self.t_MAX or done:
pass
# 학습이 시작되면 배치에서 각각t_MAX개의 상태변수, 행동, 보상을 추출한다. 그리고 배치를 비운다.
states = self.unpack_batch(batch_state)
actions = self.unpack_batch(batch_action)
rewards = self.unpack_batch(batch_reward)
# 배치비움
batch_state, batch_action, batch_reward = [], [], []
# 넘파이 어레이를 요소로 하는 파이썬 리스트로 구성된 배치 데이터를 array
def unpack_batch(self, batch):
unpack = batch[0]
for idx in range(len(batch)-1):
unpack = np.append(unpack, batch[idx+1], axis=0)
return unpack
④ 학습
n-스텝 시간차 타깃, 어드밴티지 계산 & 글로벌 신경망 업데이트
# 워커의 크리틱 신경망을 이용해 n-스텝 시간차 타깃과 어드밴티지를 계산
next_state = np.reshape(next_state, [1, self.state_dim])
next_v_value = self.worker_critic(tf.convert_to_tensor(next_state, dtype=tf.float32))
n_step_td_targets = self.n_step_td_target(rewards, next_v_value.numpy(), done)
v_values = self.worker_critic(tf.convert_to_tensor(states, dtype=tf.float32))
advantages = n_step_td_targets - v_values
# 워커의 그래디언트를 계산해 글로벌 신경망을 업데이트 한다.
self.critic_learn(states, n_step_td_targets)
self.actor_learn(states, actions, advantages)
# n_step_td_target은 워커의 시간차 타깃
def n_step_td_target(self, rewards, next_v_value, done):
y_i = np.zeros(rewards.shape)
cumulative = 0
if not done:
cumulative = next_v_value
for k in reversed(range(0, len(rewards))):
cumulative = self.GAMMA * cumulative + rewards[k]
y_i[k] = cumulative
return y_i
# 크리틱 신경망 학습
def critic_learn(self, states, n_step_td_targets):
with tf.GradientTape() as tape:
# 워커의 손실함수 계산
td_hat = self.worker_critic(states, training=True)
loss = tf.reduce_mean(tf.square(n_step_td_targets-td_hat))
# 워커의 그래디언트 계산
grads = tape.gradient(loss, self.worker_critic.trainable_variables)
# 그래디언트 클리핑
grads, _ = tf.clip_by_global_norm(grads, 20)
# 워커의 그래디언트를 이용해 글로벌 신경망 업데이트
self.actor_opt.apply_gradients(zip(grads, self.global_critic.trainable_variables))
def actor_learn(self, states, actions, advantages):
with tf.GradientTape() as tape:
# 정책 확률 밀도 함수
mu_a, std_a = self.worker_actor(states, training=True)
log_policy_pdf = self.log_pdf(mu_a, std_a, actions)
# 워커의 손실함수 계산
loss_policy = log_policy_pdf * advantages
loss = tf.reduce_sum(-loss_policy)
# 워커의 그래디언트 계산
grads = tape.gradient(loss, self.worker_actor.trainable_variables)
# 그래디언트 클리핑
grads, _ = tf.clip_by_global_norm(grads, 20)
# 워커의 그래디언트를 이용해 글로벌 신경망 업데이트
self.actor_opt.apply_gradients(zip(grads, self.global_action.trainable_variables))
def log_pdf(self, mu, std, action):
std = tf.clip_by_value(std, self.std_bound[0], self.std_bound[1])
var = std ** 2
log_policy_pdf = -0.5 * (action - mu) ** 2 / var -0.5 * tf.math.of(var*2*np.pi)
return tf.reduce_sum(log_policy_pdf, 1, keppdims=True)
시간차 타깃과 어드밴티지 식은 아래 링크 참고
⑤ 글로벌 신경망을 워커 신경망으로 복사
# 글로벌 신경망 파라미터를 워커 신경망으로 복사한다.
self.worker_actor.set_weights(self.global_actor.get_weights())
self.worker_critic.set_weights(self.global_critic.get_weights())
[참고]
박성수, 「수학으로 풀어보는 강화학습 원리와 알고리즘」, 위키북스(2020)
https://github.com/pasus/Reinforcement-Learning-Book/tree/master/Chap5/A3CData
728x90