1 분 소요

Drop out 코드 작성

1. Introduce

과적합을 방지하기 위해 몇가지 방식이 있다. 규제를 적용하거나, 데이터의 양을 늘리거나 등등 이번에 배워볼 방식은 Drop Out 방식이다.
Drop Out 특정 feature 과도한 학습을 막아내기 위해 나온 방식이다.

2.forward

def dropout_forward(x, dropout_param):

    p, mode = dropout_param["p"], dropout_param["mode"]
    if "seed" in dropout_param:
        np.random.seed(dropout_param["seed"])

    mask = None
    out = None

    if mode == "train":

        mask = (np.random.rand(*x.shape) < p) / p
        out = x * mask

    elif mode == "test":
        out = x

    cache = (dropout_param, mask)
    out = out.astype(x.dtype, copy=False)

    return out, cache
  1. x.shape 모양만큼 정규분포를 만들고 해당 p 확률 만큼 뉴런을 꺼준다.
  2. “/p” 이유는 테스트 단계에서 스케일을 맞추기 위해 “P” 해야한다. 하지만 테스트 단계에서 “P” 한다는것은 성능에 좋지 않다. 그로므로 스케일을 애초에 train 단계에서 나누어서 간다면 test 단계에서 P를 곱할일이 없다.

3. backward

def dropout_backward(dout, cache):
    
  dropout_param, mask = cache
    
    mode = dropout_param["mode"]
  
  
    dx = None
    if mode == "train":
        #######################################################################
        # TODO: Implement training phase backward pass for inverted dropout   #
        #######################################################################
        # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
        dx = dout * mask
        # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
        #######################################################################
        #                          END OF YOUR CODE                           #
        #######################################################################
    elif mode == "test":
        dx = dout
    return dx

4. Question

4.1 Question 1

What happens if we do not divide the values being passed through inverse dropout by p in the dropout layer? Why does that happen?

4.1.1 Answer

테스트 단계에서 스케일을 맞추기 위해 “P” 해야한다. 하지만 테스트 단계에서 “P” 한다는것은 성능에 좋지 않다. 그로므로 스케일을 애초에 train 단계에서 나누어서 간다면 test 단계에서 P를 곱할일이 없다.

  1. 나누지 않을 경우 \(E[\hat y]=p\hat y\)

  2. 나눈 경우 \(E[\hat y]=\frac{p}{p}\hat y = \hat y\)

Additional references

  • https://cs231n.github.io