Many to One Structure
- RNN의 Many to One 구조는 중간 출력들을 사용하지 않고, 마지막 출력만 사용하는 구조입니다.
Many to Many Structure
Data Strcture
- Many to One 에서는 Output이 2D 였지만, Many to Many 에서는 3D가 됩니다.
Many to Many
- 중간 출력이 나오는 구조입니다.
- 학습단계 시 중간 출력마다 Back propagation이 수행됩니다.
- 추론단계 시 맨 마지막 값만 사용합니다.
BPTT (Back Propagation Thorugh Time)
- RNN에서는 Input layer의 Weight인 W_x, Output layer의 Weight인 W_o를 제외하고, Hidden layer의 W_h에 대해서만 정리하겠습니다. 물론, 이어서 알아보려는 구조의 경우는 Many to Many 입니다.
- RNN에서는 Recurrent하게 Data Flow가 작동하듯이, Back propagation도 time axis를 따라 recurrent하게 업데이트 됩니다.
- 위 그림은 t=3에서의 Back propgation 예시입니다. t=3에서 첫 번째로 W_h update, t=2에서 그 다음으로 update , t=1 에서 update가 됩니다. 즉, W_h가 과거의 시퀀스로 지나가면서 t번 Update 됩니다.
- 일반적인 2개의 layer로 구성된 Network의 Back propgation 수식을 먼저 관찰해봅니다. w_2와 w_1은 위 처럼 수식이 표현될 수 있습니다.
- 그런데 Weight(W_h)를 공유하는 RNN이라고 생각해보면, w_1 == w_2로 보고, 최종적인 Weight update 수식을 정리해볼 수 있습니다.
- 즉, w_2일때와 w_1일 때의 수식이 합해져야 총 2번의 Update 에 대한 수식이 정리될 수 있습니다. 이 과정을 통해 이해한 것은 RNN 처럼 Backpropagation 과정 중에 여러 Update 가 발생하는 수식이 이와 비슷할 것이라는 것입니다.
- Unfold 상태의 RNN 구조를 이 전에 언급한 2개의 layer 처럼 Sequence하게 구조를 변경해서 생각해봤습니다.
- t=3일 때의 BPTT입니다. 진행 중에 처음 t=3일 때 W_h update, t=2일 때 W_h 까지만 보겠습니다. 이렇게 되면 이전 예실처럼 2개 layer 입니다. 여기서는 정확하게는 layer의 명칭이 RNN이기 때문에 hidden layer 입니다.
- 빨간색으로 표시한 부분이 BPTT 과정중에 t=3의 weight update, 파란색이 t=2의 weight update가 됩니다. 이는 이전에 봤던 예시와 매우 똑같은 형태입니다.
- 이를 통해 알 수 있는 것은 RNN의 Hidden layer's Weight W_h가 특정 시점에서의 여러번의 Update가 저런 식으로 표현될 수 있다는 것입니다.
- 위는 Loss에서 부터의 BPTT를 알아본 수식입니다.
- 먼저, Loss와 t 시점에서의 hidden layer인 h_t 의 정의를 첫 번째, 두 번째 line의 수식으로 나타낼 수 있습니다.
- Loss는 각 시점 마다의 Loss 들의 평균입니다.
- 세 번째 line 을 봤을 때, h_t는 RNN이기 때문에 이전 time step을 포함한 채로 표현되기 때문에 h_t-1에 W_h가 포함되어있습니다. 이 때문에 미분 시, 이전에 수식처럼 Chain Rule이 적용되면서 h_t의 BPTT에 대한 수식이 작성됩니다.
- 맨 마지막 line을 봅시다. RNN의 임의 시점에 대한 W_h Update 의 괄호 부분과 같은 형태가 나타납니다. 의미적으로는 맨 앞이 첫 번째 Update, 그 뒤가 t-1 번의 Update가 됩니다. 총 t 번의 Update 에 대한 식이 마지막 line 입니다.
- t-1 번의 Update는 앞에서 말한 것 처럼 h_t-1에 W_h가 있듯이 과거의 h 들이 W_h를 갖는 것으로 수식이 전개됩니다. 이는 Recursive를 만들기 때문에 결국 전개 후에는 t-1 번의 Update 식이 될 것입니다.
- 해당 Update에 대한 점화식을 나타내면 첫 번째 전개한 Line의 수식으로 나타낼 수 있습니다.
- 점화식 꼴대로 W_h에 대한 h_t 미분식이 위 처럼 정리가 될 수 있습니다.
- 이렇게 정리한 식에서 몇가지 의미를 추출해낼 수 있습니다.
- 지금 까지 전개하며 정리해본 W_h에 대한 h_t 미분식은 Chain Rule에 의해 Loss에 대한 W_h의 미분식에 영향을 끼칩니다.
- 마지막 빨간색 밑줄친 부분이 현재 time step에 대한 과거 time step에대한 미분들의 곱입니다.
- 만약, W_h의 값이 1보다 크다면 그 값이 1.1이 더라도 time step (t)가 길면, Gradient가 Exploding되는 문제가 발생합니다.
- 도함수 f_h 나 W_h의 값이 1보다 작을 경우 time step (t)가 길면 Gradient Vanishing 문제 또한 발생합니다.
- Simple RNN은 t가 길면 길수록 나타나는 문제들이 있어서 학습에 부정적인 영향이 분명히 존재합니다. 특정 시점 이후 gradient를 버리는 truncated BPTT와 같은 것들이 소개될 정도입니다.
'RNN Part' 카테고리의 다른 글
RNN (1) - Simple RNN (0) | 2024.02.16 |
---|