티스토리 뷰

딥러닝

오차역전파법 - Affine, Softmax 구현

4567은 소수 2021. 1. 14. 19:03

1. Affine 구현

 

Affine 함수는 신경망의 순전파 연산에서 행렬연산에 해당하는 것이다. 

입력 값이 행렬임을 이용하면 역전파 식은 다음과 같다.

 

( X, W->dot -> X*W, B -> + -> Y 인 경우, 자세한 내용 책 참고)

 

역전파, dL/dX
역전파, dL/dW

(행렬에 대해 미분한 걸 생각하면 어렵지 않게 구할 수 있다.)

 

구현 코드

class Affine:
    def __init__(self, W, b):
        self.W = W
        self.b = b
        
        self.x = None
        self.original_x_shape = None

        self.dW = None
        self.db = None

    def forward(self, x):
        self.original_x_shape = x.shape # 일반적인 행렬 계산 위해
        x = x.reshape(x.shape[0], -1)
        self.x = x

        out = np.dot(self.x, self.W) + self.b

        return out

    def backward(self, dout):
        dx = np.dot(dout, self.W.T)
        self.dW = np.dot(self.x.T, dout)
        self.db = np.sum(dout, axis=0) # 미분 연산 결과에 꼴 맞춤
        
        dx = dx.reshape(*self.original_x_shape)  # 입력 데이터 모양 변경
        return dx

 

2. Softmax 구현

 

여기서는 마지막 출력으로 softmax 함수를 이용하고, 손실함수를 교차 엔트로피 (cross entropy error)를 이용한다.

그림은 복잡하여 책을 참고하길 바란다.

 

전체 결과를 간소화하면 다음과 같다.

softmax, cross entropy error 역전파

softmax의 손실 함수로 cross entropy error를 사용한 것은 결과가 y-t 꼴로 깔끔하게 나오기 때문이다. 우연히 이렇게 된 것이 아닌 cross entropy error가 그렇게 설계된 것이다.

마찬가지로 항등함수의 손실하수로 오차제곱합을 이용하면 같은 결과를 얻는다. (y-t 꼴)

 

softmax와 cross entropy error를 합친 코드는 다음과 같다.

class SoftmaxWithLoss:
    def __init__(self):
        self.loss = None # 손실함수
        self.y = None    # softmax의 출력
        self.t = None    # 정답 레이블(원-핫 인코딩 형태)
        
    def forward(self, x, t):
        self.t = t
        self.y = softmax(x)
        self.loss = cross_entropy_error(self.y, self.t)
        
        return self.loss

    def backward(self, dout=1):
        batch_size = self.t.shape[0]
        if self.t.size == self.y.size: # 정답 레이블이 원-핫 인코딩 형태일 때
            dx = (self.y - self.t) / batch_size
        else:
            dx = self.y.copy()
            dx[np.arange(batch_size), self.t] -= 1
            dx = dx / batch_size
        
        return dx

 

'딥러닝' 카테고리의 다른 글

매개변수 갱신  (0) 2021.01.15
오차역전파법 구현  (0) 2021.01.14
오차역전파법 - 활성화 함수 계층 구현하기  (0) 2021.01.14
오차역전파법 - 계산 그래프  (0) 2021.01.13
신경망 학습 알고리즘 구현  (0) 2021.01.12
댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
TAG
more
«   2024/12   »
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
29 30 31
글 보관함