본문 바로가기

WIL

[WIL] Transformer 구현

https://wikidocs.net/31379

 

16-01 트랜스포머(Transformer)

* 이번 챕터는 앞서 설명한 어텐션 메커니즘 챕터에 대한 사전 이해가 필요합니다. 트랜스포머(Transformer)는 2017년 구글이 발표한 논문인 Attention i…

wikidocs.net

다음 사이트의 내용을 참고하여 transformer를 구현해보았다.


import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    def __init__(self, position, d_model):
        super(PositionalEncoding, self).__init__()
        self.register_buffer('pos_encoding', self.positional_encoding(position, d_model))

    def get_angles(self, position, i, d_model):
        angles = 1 / torch.pow(10000, (2 * (i // 2)) / d_model)
        return position * angles

    def positional_encoding(self, position, d_model):
        angle_rads = self.get_angles(
            position=torch.arange(position, dtype=torch.float)[:, None],
            i=torch.arange(d_model, dtype=torch.float)[None, :],
            d_model=d_model)

        # 배열의 짝수 인덱스(2i)에는 사인 함수 적용
        sines = torch.sin(angle_rads[:, 0::2])

        # 배열의 홀수 인덱스(2i+1)에는 코사인 함수 적용
        cosines = torch.cos(angle_rads[:, 1::2])

        pos_encoding = torch.zeros(angle_rads.shape)
        pos_encoding[:, 0::2] = sines
        pos_encoding[:, 1::2] = cosines
        pos_encoding = pos_encoding.unsqueeze(0)

        print(pos_encoding.shape)
        return pos_encoding.float()

    def forward(self, inputs):
        return inputs + self.pos_encoding[:, :inputs.shape[1], :]

각 임베딩 벡터에 positional encoding의 값을 더하면 문장의 위치에 따른 임베딩 벡터의 값이 달라진다.

 

 

def __init__(self, position, d_model)  : PositionalEncoding class의 생성자를 정의한다.

* 여기서 생성자(__init__)란, class로부터 객체를 만들 때, 자동으로 호출되며 객체의 초기 상태를 설정한다.

  예를들어, class가 사람을 나타낸다고 하면 해당 class는 사람의 속성(이름, 나이 등)을 정의할 수 있는데, 생성자는 이런 속성을 초기화하기 때문에, 객체가 만들어질 때마다 이러한 속성들이 초기화된다.

# 생성자 예시 코드

class Person:
    def __init__(self, name, age):
        self.name = name   # 이름 속성 초기화
        self.age = age     # 나이 속성 초기화

# Person 클래스의 객체 생성
person1 = Person("Alice", 30)

super(PositionalEncoding, self).__init__() : 상속 관계에 있는 부모 class의 생성자를 호출한다. 여기서는 nn.Module의 생성자를 호출한다. super()함수는 현재 class의 부모 class를 나타내는 super 객체를 반환한다.

즉, PositionalEncoding class가 nn.Module class를 상속하고 있으므로, super(PositionalEncoding, self)nn.Module을 가리킨다. 따라서 nn.Module class의 생성자를 호출하여 해당 기능을 PositionalEncoding에서도 사용할 수 있게한다.

self.register_buffer('pos_encoding', self.positional_encoding(position, d_model)) : 모델의 산채를 관리하기 위한 버퍼를 생성한다.

* 버퍼(buffer)란, 모델의 매개변수는 아니지만, 학습 중에 순전파 및 역전파를 수행하면서 변하지 않는 상태를 저장하는데 사용하며, 모델과 함께 저장되고 로드된다. 버퍼를 사용하는 주요 이유는 모델의 상태를 관리하고 모델의 유연성과 재사용성을 향상시키는 데 있다. 즉, 모델을 저장하고 다시 불러올 때, 버퍼에 저장된 상태도 함께 저장되므로 모델을 다시 불러올 때 동일한 상태를 사용할 수 있도록 한다.

self.register_buffer() 함수는 모델에 버퍼를 등록한다. 여기서 'pos_encoding'은 버퍼의 이름이고, self.positional_encoding(position, d_model)은 버퍼에 저장할 데이터이다.

 

def get_angles(self, position, i, d_model) : positional embedding을 위해 필요한, 각 토큰의 위치와 각도를 계산하는 함수로, self.position으로 문장의 길이(토큰의 개수)값을 받으며 i로 텐서의 크기(positional encoding의 차원)를 나타낸다. 또한, d_model은 transformer 모델의 hidden state의 차원 수(모델의 임베팅 차원 수)를 나타낸다.

여기서는 각도 계산에 가낭 널리 사용되는 수식인,

다음을 사용하여 angle값을 계산한다.

여기서, 위치 정보와 각도를 결합하여 더 다양한 패턴을 생성하기 위해 position * angles 값을 사용한다.

 

def positional_encoding(self, position, d_model) : Positional Encoding을 계산하는 함수이다.

position=torch.arange(position, dtype=torch.float)[:, None] : 0부터 position - 1까지의 위치 값을 생성하며, 생성된 시퀀스의 데이터 타입을 실수형(float)으로 설정한다. 이는 결과적으로 0부터 시작하여 position - 1까지의 정수값을 포함하는 텐서이다. [:, None]는 생성된 텐서를 열 벡터(column vector) 형태로 변환한다.

i=torch.arange(d_model, dtype=torch.float)[None, :] : 0부터 d_modle - 1까지의 값을 생성하며, [None, :]를 통해 열 벡터 형태로 변환한다.

pos_encoding = torch.zeros(angle_rads.shape) : Positional Encoding을 저장할 텐서를 초기화한다.

pos_encoding[:, 0::2] = sines : 생성된 텐서의 짝수 인덱스 열에 사인 값을 할당한다.

pos_encoding[:, 1::2] = cosines : 생성된 텐서의 홀수 인덱스 열에 코사인 값을 할당한다.

pos_encoding = pos_encoding.unsqueeze(0) : 최종적으로 생성된 Positional Encoding 텐서에 차원을 추가한다. 이를 통해 생성된 Positional Encoding은 모델의 batch dimension을 추가하여 [batch_size, sequence_length, d_model] 형태를 갖게 된다.

 

def forward(self, inputs) : Transformer의 Positional Encoding을 입력 데이터에 추가하는 forward 메서드를 나타낸다.

return inputs + self.pos_encoding[:, :inputs.shape[1], :] : inputs는 Transformer 모델에 입력되는 데이터로, 예를 들어 토큰의 임베딩 행렬이 될 수 있다. inputs.shape[1]은 입력 데이터의 두 번째 차원의 크기를 나타내는 것으로, 이 값은 입력 데이터의 시퀀스 길이를 나타낸다. 즉,  입력 데이터의 시퀀스 길이를 나타내므로, 이 값을 사용하여 Positional Encoding을 입력 데이터와 동일한 길이로 자르거나 인덱싱하는 등의 작업을 수행하며,  입력 시퀀스의 길이를 동적으로 반영하여 모델에 필요한 처리를 한다.