본문 바로가기

강화 학습

강화 학습 - DQN

https://wikibook.co.kr/reinforcement-learning/

 

파이썬과 케라스로 배우는 강화학습: 내 손으로 직접 구현하는 게임 인공지능

“강화학습을 쉽게 이해하고 코드로 구현하기” 강화학습의 기초부터 최근 알고리즘까지 친절하게 설명한다! ‘알파고’로부터 받은 신선한 충격으로 많은 사람들이 강화학습에 관심을 가지

wikibook.co.kr

해당 포스팅은 위의 책을 보고 정리한 내용이 주를 이룹니다.

또한, 이 포스팅에서 Neural Network(Deep Learning)에 대해서는 따로 설명하지 않겠습니다.


이번에 소개할 DQN(Deep Q-Network)은 딥마인드(구글)이 "Playing Atari with Deep Reinforcement Learning"라는 제목으로 2013년에 NIPS에 발표한 논문입니다. 개인적으로 강화학습의 시대를 알리는 신호탄 정도라고 생각합니다. 알파고는 대중들에게 알리는 신호탄정도 아닐까요? 

 

그런거 치고는 DQN 논문이 엄청 혁신적인 방법론이 제시된 것은 아닙니다.

앞서 포스팅한 TD, Q-Learning, Deep SARSA 정도 이해하셨으면 쉽게 따라올 수 있습니다. DQN은 쉽게보면 Q-Learning방식을 Neural Network를 사용한 것이죠. Deep SARSA에서 SARSA가 Q-Learning으로 바뀐 것입니다.

2022.01.09 - [강화 학습] - 강화 학습 기본 - 시간차 학습(Temporal-Difference Learning) part 1. 살사(SARSA)

2022.01.10 - [강화 학습] - 강화 학습 기본 - 시간차 학습(Temporal-Difference Learning) part 2. 큐 학습(Q-Leaning)

2022.01.10 - [강화 학습] - 강화 학습 - 딥 살사(Deep SARSA) 알고리즘

위 포스팅에 이미 언급한 부분은 생략되거나 짦게 언급할 예정입니다.

 

 

하지만 당연히 그게 다는 아니고, 학습에서 DQN이 좋은 퍼포먼스를 낼 수 있는 두 포인트가 있습니다.

  • 리플레이 메모리(Replay Memory)를 이용한 경험 리플레이(Experience Replay)
  • 학습이 되는 네트워크와 정답의 기준이되는 타겟 네트워크(Target Network)

따라서, 이번 포스팅은

-전체적인 개괄

-DQN의 중요한 두 포인트

-그리고 예시

순으로 진행됩니다.


 

  • Q-Learning with Neural Network

먼저 Q-Learning에 대해 짧게 다시보겠습니다. 우선 [S_t, A_t, R_t+1, S_t+1, A_t+1]를 고려하는 SARSA(On-Policy)와 달리, Q-Learning(Off-Policy)은 [S_t, A_t, R_t+1, S_t+1]를 하나의 샘플로 사용합니다. 이때 아래 업데이트 Q-함수 수식에서 a'은 S_t+1에서 가장 높은(max) 행동가치함수 값을 가지는 행동을 뜻하죠. 따라서, SARSA는 안 좋은 상황에 빠져버리면 그 상황에 빠져버리는데 비해 Q-Learning은 a'를 보고 학습을 하기 때문에 그 문제에서 자유로워 질 수 있다는 장점이 있죠.

Q-Leaning 방식의 Q함수 업데이트 수식

그리고 Q-함수를 계산하는 방식은 Deep SARSA와 동일합니다. (*아래 그림은 편의상 그림이고 책에 나온 코드는 상태 S_t을 입력으로 받고 모든 행동에 대한 Q-함수 값을 출력합니다.) Q-함수의 값을 Neural Network을 이용해 근사하는 것이죠. 이때 Decaying ε-greedy 정책을 사용합니다.

Deep Neural Network를 이용한 Q 함수 추정


 

  • 리플레이 메모리(Replay Memory)를 이용한 경험 리플레이(Experience Replay)

다음으로는, DQN에서 제안된 리플레이 메모리(Replay Memory)의 사용입니다. 어려운 개념은 아니고 단지 한 에피소드 중 상태 S_t에서 행동 A_t를 취했을 때 수집할 수 있는 [S_t, A_t, R_t+1, S_t+1] 또는 [s, a, R_t+1, s'] 에 해당하는 샘플들을 저장하는 것이죠. 이 샘플들이 저장된 리플레이 메모리(Replay Memory)에서 몇가지 샘플을 무작위로 추출해서 배치(Batch)단위로 학습에 이용하는 것이 경험 리플레이(Experience Replay)가 되는 것이죠. 

 

이렇게 저장을 해놓고 추출해서 방식을 취했을 때의 장점은 샘플 간 Correlation을 줄일 수 있다는 것인데요. Neural Network를 사용하기 전에는 각 상태에 대한 Q-함수를 따로 업데이트 했지만 Neural Network는 모든 상태(입력)에 대한 Q-함수를 학습해야 합니다.

 

에피소드의 과정 중에 학습하게 되면(On-Policy) 비슷한 상태(입력) 그리고 비슷한 큐 함수 값(출력)이 학습되기 때문에 편향(Biased)될 여지가 있습니다. 따라서 모여있는 샘플이 아닌 다양한 상태에서의 추출된 샘플을 학습하게 되면 이런 문제를 완화할 수 있죠. 아래 그림을 보면 이해에 도움이 되실 겁니다.(https://mathmakeworld.tistory.com/70?category=423585 해당 블로그 포스팅을 참고했습니다. 감사합니다.)

무작위로 추출 했을 때(오른쪽)가 전체(왼쪽)를 근사하는데 긍정적이다.


 

  • 학습이 되는 네트워크와 정답의 기준이되는 타겟 네트워크(Target Network)와 학습 네트워크 분리

다음으로 DQN은 타겟 네트워크(Target Network)와 학습 네트워크를 분리하였습니다. 여기서 타겟 네트워크란 정답의 기준이 되는 값을 출력하는 네트워크입니다. 이렇게 분리해서 사용하는 이유는 DQN에서 사용하는 Loss Function 때문입니다. Deep SARSA와 같이 Q-Learning의 업데이트 식에서 정답과 예측항을 나눠 사용합니다.

Q-Leaning 방식의 Q함수 업데이트 수식2: Neural Network을 사용했기 때문에 파라미터 𝜃가 붙었다. 
Q-Learning으로 Neural Network 학습 시 고안된 Loss Function

하지만 위의 수식으로 학습을 진행하면 생기는 문제점이 있습니다. 바로 정답(Target)에 해당하는 부분과 예측에 해당하는 부분이 하나의 Network 파라미터 θ에 의해 결정되기 때문에죠. θ는 Loss Function에서 나온 Error에 대한 Gradient를 받아 매 스텝 학습을 진행하여 업데이트 됩니다. 그렇게 되면 θ를 통해 나온 정답(Target)은 매 스텝마다 바뀌게 되는 것이죠. 이는 학습의 불안정을 유발합니다.

 

이 문제를 해결하기 위해서는 DQN에서 학습(예측)을 하는 네트워크와 정답이 되는 Target Network를 분리합니다. 그 수식은 아래와 같은데 학습(예측)을 하는 네트워크 파라미터는 θ 그리고 Target Network의 파라미터는 θ^-로 구분합니다. 그리고 일정 스텝 동안은 Target Network에 변화를 주지 않고 학습을 진행하고, 일정 스텝마다 학습이 진행된 네트워크로 Target Network를 업데이트합니다.

DQN에 사용되는 Loss Function

Target Network의 값이 진동하지 않아 학습 Network가 어느정도 안정된다는 것은 이해가 됐습니다만, 그 안정되게 수렴하는게 Target Network로 수렴하는 거라서 약간의 의문이 듭니다.. 학습 초반에는 탐험을 많이 하기 때문에 또 괜찮나 싶기도 하고 Reward가 있기 때문에 가능한가도 생각이 드는데 지금 시점에서는 좀 아리송한 부분입니다.


 

  • 카트폴(Cartpole)

카트폴은 강화학습에서 유명한 예제 중 하나인데, 움직이는 블록이 그 위에 긴 막대기가 쓰러지지 않게 중심을 잡는 것이 목표입니다. 책에서는 이 카트폴을 코드로 구현하여 설명합니다. 이때, 카트에 일정한 크기의 힘을 가해 움직일 수 있는데, 이 힘의 크기는 정해져 있습니다. 목표는 막대(폴)을 5초 동안 넘어지지 않게 세우는 것입니다. 따라서 폴이 일정 각도 이상 떨어지지 않도록 그리고 화면 밖으로 벗어나지 않도록 학습해야 할 것입니다.

카트폴(Cartpole)

먼저 MDP를 설정 해야겠죠. 이 문제를 풀기 위해 책에서는 에이전트 상태 S를 아래 수식같이 정의합니다.

x는 카트의 위치, x'은 카트의 속도, θ는 카트의 수직선으로 부터 기운 각도 그리고 θ'는 각속도를 나타냅니다. 그리고 이 성분은 모두 float 자료형입니다.(책에서 OpenAI gym을 사용해서 환경에 대한 정보는 그냥 받아오기만 합니다.)

에이전트 상태 S

책에서 카트폴에 대한 보상 설정은 쓰러지지않는 것이 핵심이기 때문에, 매 타임 스텝마다 +0.1이 주어지는 것은 물론이고 , 목표인 500 타임스텝(5초)를 채우지 못했을 때(넘어지거나 화면 밖으로 나갈 때) -1의 보상을 주었다고 합니다.

 

또한 책에 카트폴에 강화학습 적용 시 상호 작용은 아래와 같이 나와있습니다. 여기서 구분되어야 할 것은 에피소드의 진행과 학습이 반드시 동시에 진행하는 것이 아니라는 겁니다. Off-policy 방식이기 때문에 상관이 없는 것이죠.  

1.상태에 따른 행동 선택
2.선택한 행동으로 환경에서 한 타임스텝을 진행
3.환경으로부터 다음 상태와 보상을 받음
4.샘플(s,a,r,s')을 리플레이 메모리에 저장
5.리플레이 메모리에서 무작위 추출한 샘플로 학습
6.에피소드마다 타겟 모델 업데이트

마지막으로 이해를 돕기 위해 전체적인 학습도를 그려보았습니다.

전체 DQN 학습도


  • 브레이크 아웃(Additional Technics)

또한 책에서는 브레이크 아웃이라는 DQN을 적용시켰을때에 대해서도 설명합니다, 브레이크 아웃은 밑에 그림과 같이 바로 공을 쳐내 위의 블록을 부수는 게임인데요. 전체적인 학습 구조는 위의 예시를 통해 설명했으니 브레이크 아웃이라는 게임을 DQN에서 어떻게 풀었는지, 그 실제 구현 상의 특징들만 언급하려고 합니다.

브레이크 아웃 게임의 한 장면 (출처: DQN 논문)

우선 브레이크 아웃은 공의 위치, 바의 위치를 상태 S로 정의해 통해 문제를 푼것이 아니라 사람이 볼 수 있는 이미지 자체를 입력으로 취급했습니다. 즉 이미지가 입력인 셈이죠. 그렇기 때문에 이미지 처리에 효율적인 네트워크인 Convolution Neural Network(이하 CNN)을 사용합니다. 상태 S를 정의하는 특징 추출 단계까지 Neural Network을 이용하여 End-to-End 학습이 되도록 한 것이죠. 

 

또한 입력 이미지를 RGB 3채널 이미지가 아닌 Gray 1채널 이미지를 사용하였습니다. 색상이 주는 정보가 딱히 없기 때문에 입력의 사이즈를 효율적으로 줄인 것입니다. 또한 같은 이유로 입력이미지의 사이즈도 84x84로 리사이즈 합니다.

 

그리고 연속된 프레임은 중첩된 정보가 너무 많으므로(같은 이미지를 사용하는거 같음) 4프레임당 1개의 이미지를 선택하였고 이 선택된 이미지 4개를 묶어 CNN의 입력으로 주었습니다. 따라서 입력 사이즈는 Batch x 4 x 84 x 84(BCHW)가 되는 것이죠.

 

 또 다른 특징으로는 후버로스(Huber Loss)입니다. 후버로스는 [-1, 1] 구간에서는 2차함수지만, 나머지 구간은 1차함수로 표현되는 오류 함수입니다. 아래 그림의 초록색 선을 보면 알 수 있죠. 2차 함수로 표현된 파란 선과 달리 [-1, 1] 구간 밖에서는 그 큰 값이 크더라도 제곱으로 가중되지 않습니다. 후버로스를 MSE 오류함수 대신 사용해 큰 오류가 났을 때에도 Network의 파라미터가 크게 바뀌어 학습이 불안정해 지는 것을 방지합니다. 

후버로스(Huber Loss)

마지막으로 그레디언트 클리핑(Gradient Clipping)입니다. 말 그대로 학습시 일정 그레디언트가 넘어가지 못하게 막는 것을 의미합니다. 이 또한 그레디언트가 너무 크면 모델의 파라미터가 크게 변하면서 학습이 불안정해질 수 있기 때문이죠.

 

이처럼 단지 강화 학습의 수식만 아는 것이 아니라 여러가지 테크닉들이 같이 동원된다는 것을 알면 좋을거 같습니다.


이렇게 DQN의 개념을 살펴 보았습니다. DQN은 많은 강화 학습 논문의 초석인데 리플레이 메모리와 타겟 네트워크를 통한 학습 안정화가 특징이지 않나 싶습니다.

 

포스팅이 길어지고 지저분해질까봐 코드 설명은 피하려고 하는데,

또 책에서 코드 상 나와있는 디테일이 있어 풀어나가기가 쉽지는 않네요.

아무쪼록, 도움이 되셨으면 좋겠습니다.

감사합니다.