準備

Googleドライブのマウント

In [1]:
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive

sys.pathの設定

以下では,Googleドライブのマイドライブ直下にDNN_codeフォルダを置くことを仮定しています.必要に応じて,パスを変更してください.

In [4]:
import sys
sys.path.append('/content/drive/My Drive/DNN_code_colab_ver200425')

simple RNN

バイナリ加算

In [49]:
import numpy as np
from common import functions
import matplotlib.pyplot as plt

# def d_tanh(x):
#def tanh(x):
# return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x))

def d_tanh(x):
  return 4/ (np.exp(x) + np.exp(-x)) ** 2
  #return 1/ (np.cosh(x) ** 2)

# データを用意
# 2進数の桁数
binary_dim = 8
# 最大値 + 1
largest_number = pow(2, binary_dim)
# largest_numberまで2進数を用意
binary = np.unpackbits(np.array([range(largest_number)],dtype=np.uint8).T,axis=1)

input_layer_size = 2
hidden_layer_size = 32
output_layer_size = 1

weight_init_std = 1
learning_rate = 0.2

iters_num = 10000
plot_interval = 100

# ウェイト初期化 (バイアスは簡単のため省略)
#W_in = weight_init_std * np.random.randn(input_layer_size, hidden_layer_size)
#W_out = weight_init_std * np.random.randn(hidden_layer_size, output_layer_size)
#W = weight_init_std * np.random.randn(hidden_layer_size, hidden_layer_size)


# Xavier
W_in = np.random.randn(input_layer_size, hidden_layer_size) / np.sqrt(input_layer_size)
W_out = np.random.randn(hidden_layer_size, output_layer_size, ) / np.sqrt(hidden_layer_size)
W =np.random.randn(hidden_layer_size, hidden_layer_size) / np.sqrt(hidden_layer_size)

# He
#W_in = np.random.randn(input_layer_size, hidden_layer_size) / np.sqrt(input_layer_size) * np.sqrt(2)
#W_out = np.random.randn(hidden_layer_size, output_layer_size) / np.sqrt(hidden_layer_size) * np.sqrt(2)
#W = np.random.randn(hidden_layer_size, hidden_layer_size) / np.sqrt(hidden_layer_size) * np.sqrt(2)

# 勾配
W_in_grad = np.zeros_like(W_in)
W_out_grad = np.zeros_like(W_out)
W_grad = np.zeros_like(W)

u = np.zeros((hidden_layer_size, binary_dim + 1))
z = np.zeros((hidden_layer_size, binary_dim + 1))
y = np.zeros((output_layer_size, binary_dim))

delta_out = np.zeros((output_layer_size, binary_dim))
delta = np.zeros((hidden_layer_size, binary_dim + 1))

all_losses = []

for i in range(iters_num):
    
    # A, B初期化 (a + b = d)
    a_int = np.random.randint(largest_number/2)
    a_bin = binary[a_int] # binary encoding
    b_int = np.random.randint(largest_number/2)
    b_bin = binary[b_int] # binary encoding
    
    # 正解データ
    d_int = a_int + b_int
    d_bin = binary[d_int]
    
    # 出力バイナリ
    out_bin = np.zeros_like(d_bin)
    
    # 時系列全体の誤差
    all_loss = 0    
    
    # 時系列ループ
    for t in range(binary_dim):
        # 入力値
        X = np.array([a_bin[ - t - 1], b_bin[ - t - 1]]).reshape(1, -1)
        # 時刻tにおける正解データ
        dd = np.array([d_bin[binary_dim - t - 1]])
        
        u[:,t+1] = np.dot(X, W_in) + np.dot(z[:,t].reshape(1, -1), W)
        z[:,t+1] = np.tanh(u[:,t+1])

        y[:,t] = functions.sigmoid(np.dot(z[:,t+1].reshape(1, -1), W_out))


        #誤差
        loss = functions.mean_squared_error(dd, y[:,t])
        
        delta_out[:,t] = functions.d_mean_squared_error(dd, y[:,t]) * d_tanh(y[:,t])        
        
        all_loss += loss

        out_bin[binary_dim - t - 1] = np.round(y[:,t])
    
    
    for t in range(binary_dim)[::-1]:
        X = np.array([a_bin[-t-1],b_bin[-t-1]]).reshape(1, -1)        

        delta[:,t] = (np.dot(delta[:,t+1].T, W.T) + np.dot(delta_out[:,t].T, W_out.T)) * functions.d_sigmoid(u[:,t+1])

        # 勾配更新
        W_out_grad += np.dot(z[:,t+1].reshape(-1,1), delta_out[:,t].reshape(-1,1))
        W_grad += np.dot(z[:,t].reshape(-1,1), delta[:,t].reshape(1,-1))
        W_in_grad += np.dot(X.T, delta[:,t].reshape(1,-1))
    
    # 勾配適用
    W_in -= learning_rate * W_in_grad
    W_out -= learning_rate * W_out_grad
    W -= learning_rate * W_grad
    
    W_in_grad *= 0
    W_out_grad *= 0
    W_grad *= 0
    

    if(i % plot_interval == 0):
        all_losses.append(all_loss)        
        print("iters:" + str(i))
        print("Loss:" + str(all_loss))
        print("Pred:" + str(out_bin))
        print("True:" + str(d_bin))
        out_int = 0
        for index,x in enumerate(reversed(out_bin)):
            out_int += x * pow(2, index)
        print(str(a_int) + " + " + str(b_int) + " = " + str(out_int))
        print("------------")

lists = range(0, iters_num, plot_interval)
plt.plot(lists, all_losses, label="loss")
plt.show()
iters:0
Loss:1.1876101161706205
Pred:[1 0 0 1 0 0 0 1]
True:[0 1 0 0 1 0 0 0]
63 + 9 = 145
------------
iters:100
Loss:1.0418099028916483
Pred:[0 1 1 1 0 1 0 1]
True:[0 1 1 0 0 1 1 0]
37 + 65 = 117
------------
iters:200
Loss:0.9897606116921079
Pred:[0 1 0 0 0 1 0 0]
True:[0 0 1 0 0 1 0 0]
16 + 20 = 68
------------
iters:300
Loss:0.6538371827402377
Pred:[1 1 1 1 1 1 0 0]
True:[0 1 1 1 0 1 0 0]
31 + 85 = 252
------------
iters:400
Loss:0.3265224815060553
Pred:[1 0 1 1 1 1 1 0]
True:[1 0 1 1 0 1 1 0]
85 + 97 = 190
------------
iters:500
Loss:0.13501759865164847
Pred:[1 1 1 0 0 0 1 1]
True:[1 1 1 0 0 0 1 1]
124 + 103 = 227
------------
iters:600
Loss:0.014809136911444238
Pred:[0 1 0 1 1 1 0 0]
True:[0 1 0 1 1 1 0 0]
87 + 5 = 92
------------
iters:700
Loss:0.12799111185218962
Pred:[1 0 0 1 0 0 1 0]
True:[1 0 0 1 0 0 1 0]
114 + 32 = 146
------------
iters:800
Loss:0.005987316547127775
Pred:[1 0 0 1 0 1 0 1]
True:[1 0 0 1 0 1 0 1]
58 + 91 = 149
------------
iters:900
Loss:0.00034008606150162033
Pred:[0 1 1 1 1 1 0 0]
True:[0 1 1 1 1 1 0 0]
3 + 121 = 124
------------
iters:1000
Loss:0.0002653015699991977
Pred:[0 1 1 1 1 0 1 1]
True:[0 1 1 1 1 0 1 1]
98 + 25 = 123
------------
iters:1100
Loss:0.00017755475638348402
Pred:[1 0 0 0 1 0 0 1]
True:[1 0 0 0 1 0 0 1]
49 + 88 = 137
------------
iters:1200
Loss:0.0001393110864816808
Pred:[1 1 0 0 1 1 1 0]
True:[1 1 0 0 1 1 1 0]
103 + 103 = 206
------------
iters:1300
Loss:0.1250497505699516
Pred:[0 1 1 1 0 0 1 0]
True:[0 1 1 1 0 0 1 0]
102 + 12 = 114
------------
iters:1400
Loss:4.120040061669113e-05
Pred:[1 0 1 0 0 0 0 1]
True:[1 0 1 0 0 0 0 1]
58 + 103 = 161
------------
iters:1500
Loss:3.799639275028743e-05
Pred:[0 1 0 0 0 0 1 1]
True:[0 1 0 0 0 0 1 1]
59 + 8 = 67
------------
iters:1600
Loss:0.12504642004006833
Pred:[0 1 1 1 0 1 1 0]
True:[0 1 1 1 0 1 1 0]
50 + 68 = 118
------------
iters:1700
Loss:2.38780854581337e-05
Pred:[0 1 1 0 1 0 0 1]
True:[0 1 1 0 1 0 0 1]
62 + 43 = 105
------------
iters:1800
Loss:4.518885192256246e-05
Pred:[0 1 0 1 1 0 1 1]
True:[0 1 0 1 1 0 1 1]
61 + 30 = 91
------------
iters:1900
Loss:6.570231148081363e-06
Pred:[0 0 1 1 1 1 0 0]
True:[0 0 1 1 1 1 0 0]
47 + 13 = 60
------------
iters:2000
Loss:4.134562582002625e-05
Pred:[0 1 1 0 0 0 1 1]
True:[0 1 1 0 0 0 1 1]
18 + 81 = 99
------------
iters:2100
Loss:2.302464815353576e-05
Pred:[1 0 1 1 0 1 0 1]
True:[1 0 1 1 0 1 0 1]
124 + 57 = 181
------------
iters:2200
Loss:1.5532811633129886e-05
Pred:[0 0 1 0 1 0 0 0]
True:[0 0 1 0 1 0 0 0]
17 + 23 = 40
------------
iters:2300
Loss:2.0794708962102368e-05
Pred:[1 1 0 1 0 1 0 1]
True:[1 1 0 1 0 1 0 1]
100 + 113 = 213
------------
iters:2400
Loss:1.0589983334470734e-05
Pred:[1 0 0 0 0 0 0 1]
True:[1 0 0 0 0 0 0 1]
19 + 110 = 129
------------
iters:2500
Loss:5.122755777299482e-06
Pred:[0 1 0 0 0 0 0 0]
True:[0 1 0 0 0 0 0 0]
59 + 5 = 64
------------
iters:2600
Loss:7.938882625799165e-06
Pred:[1 0 0 1 0 0 0 0]
True:[1 0 0 1 0 0 0 0]
117 + 27 = 144
------------
iters:2700
Loss:1.5826808086557878e-05
Pred:[0 1 0 1 1 1 0 1]
True:[0 1 0 1 1 1 0 1]
86 + 7 = 93
------------
iters:2800
Loss:0.1250062690255983
Pred:[0 1 0 0 0 0 1 0]
True:[0 1 0 0 0 0 1 0]
14 + 52 = 66
------------
iters:2900
Loss:6.835181672652392e-06
Pred:[1 0 0 1 1 1 1 0]
True:[1 0 0 1 1 1 1 0]
83 + 75 = 158
------------
iters:3000
Loss:3.7607857108030784e-06
Pred:[1 0 1 1 0 0 0 0]
True:[1 0 1 1 0 0 0 0]
61 + 115 = 176
------------
iters:3100
Loss:7.686732119092815e-06
Pred:[1 0 0 0 0 0 0 1]
True:[1 0 0 0 0 0 0 1]
90 + 39 = 129
------------
iters:3200
Loss:5.644553443309038e-06
Pred:[0 1 1 0 0 0 1 1]
True:[0 1 1 0 0 0 1 1]
37 + 62 = 99
------------
iters:3300
Loss:7.763048728189088e-06
Pred:[1 0 0 1 0 0 1 0]
True:[1 0 0 1 0 0 1 0]
49 + 97 = 146
------------
iters:3400
Loss:2.318000870543377e-06
Pred:[0 1 1 1 0 0 1 0]
True:[0 1 1 1 0 0 1 0]
83 + 31 = 114
------------
iters:3500
Loss:0.1250010701618784
Pred:[1 1 0 0 0 0 0 0]
True:[1 1 0 0 0 0 0 0]
94 + 98 = 192
------------
iters:3600
Loss:0.12500546170223667
Pred:[0 0 1 0 0 1 1 0]
True:[0 0 1 0 0 1 1 0]
0 + 38 = 38
------------
iters:3700
Loss:1.1758504839469965e-05
Pred:[1 0 0 1 1 1 0 1]
True:[1 0 0 1 1 1 0 1]
113 + 44 = 157
------------
iters:3800
Loss:0.25000237679272863
Pred:[1 0 1 1 0 0 0 0]
True:[1 0 1 1 0 0 0 0]
92 + 84 = 176
------------
iters:3900
Loss:9.877730608751023e-06
Pred:[1 0 0 1 1 1 1 1]
True:[1 0 0 1 1 1 1 1]
114 + 45 = 159
------------
iters:4000
Loss:4.3076533804446094e-06
Pred:[0 1 0 1 0 1 0 1]
True:[0 1 0 1 0 1 0 1]
26 + 59 = 85
------------
iters:4100
Loss:0.12500391907865466
Pred:[0 1 1 0 0 0 1 0]
True:[0 1 1 0 0 0 1 0]
82 + 16 = 98
------------
iters:4200
Loss:0.1250051474634441
Pred:[1 1 0 1 0 0 1 0]
True:[1 1 0 1 0 0 1 0]
98 + 112 = 210
------------
iters:4300
Loss:2.410645545528776e-06
Pred:[1 0 1 1 0 0 0 1]
True:[1 0 1 1 0 0 0 1]
63 + 114 = 177
------------
iters:4400
Loss:3.770010592015752e-06
Pred:[0 1 1 0 0 1 1 1]
True:[0 1 1 0 0 1 1 1]
80 + 23 = 103
------------
iters:4500
Loss:0.1250027056144264
Pred:[0 1 1 1 0 1 1 0]
True:[0 1 1 1 0 1 1 0]
90 + 28 = 118
------------
iters:4600
Loss:0.2500012973723249
Pred:[0 0 0 1 1 0 0 0]
True:[0 0 0 1 1 0 0 0]
20 + 4 = 24
------------
iters:4700
Loss:2.2784349975037117e-06
Pred:[1 0 1 0 1 1 1 0]
True:[1 0 1 0 1 1 1 0]
49 + 125 = 174
------------
iters:4800
Loss:4.8651669168865295e-06
Pred:[1 0 0 0 1 1 1 1]
True:[1 0 0 0 1 1 1 1]
90 + 53 = 143
------------
iters:4900
Loss:2.39029009947123e-06
Pred:[1 1 0 0 1 1 1 1]
True:[1 1 0 0 1 1 1 1]
101 + 106 = 207
------------
iters:5000
Loss:8.988769532898243e-07
Pred:[0 0 1 0 0 0 0 0]
True:[0 0 1 0 0 0 0 0]
13 + 19 = 32
------------
iters:5100
Loss:0.1250034164835524
Pred:[0 1 1 1 0 1 1 0]
True:[0 1 1 1 0 1 1 0]
38 + 80 = 118
------------
iters:5200
Loss:5.274135455457528e-06
Pred:[1 0 0 1 1 1 1 0]
True:[1 0 0 1 1 1 1 0]
31 + 127 = 158
------------
iters:5300
Loss:3.7756005289317355e-06
Pred:[1 0 1 0 1 0 0 1]
True:[1 0 1 0 1 0 0 1]
73 + 96 = 169
------------
iters:5400
Loss:2.8118526253705584e-06
Pred:[1 0 1 0 1 1 1 1]
True:[1 0 1 0 1 1 1 1]
126 + 49 = 175
------------
iters:5500
Loss:9.482410296635913e-07
Pred:[0 0 0 1 1 0 0 0]
True:[0 0 0 1 1 0 0 0]
3 + 21 = 24
------------
iters:5600
Loss:0.12500230266001075
Pred:[1 0 0 0 1 1 1 0]
True:[1 0 0 0 1 1 1 0]
28 + 114 = 142
------------
iters:5700
Loss:0.3750004749367717
Pred:[1 0 1 0 0 0 0 0]
True:[1 0 1 0 0 0 0 0]
120 + 40 = 160
------------
iters:5800
Loss:0.12500196530199495
Pred:[0 1 1 1 1 0 1 0]
True:[0 1 1 1 1 0 1 0]
14 + 108 = 122
------------
iters:5900
Loss:2.528495913696765e-06
Pred:[0 1 0 1 1 1 0 1]
True:[0 1 0 1 1 1 0 1]
90 + 3 = 93
------------
iters:6000
Loss:0.12500332281770538
Pred:[1 0 0 1 1 0 1 0]
True:[1 0 0 1 1 0 1 0]
114 + 40 = 154
------------
iters:6100
Loss:2.5035477292971917e-06
Pred:[1 0 1 1 0 1 0 0]
True:[1 0 1 1 0 1 0 0]
99 + 81 = 180
------------
iters:6200
Loss:2.0483042358204807e-06
Pred:[0 1 1 0 0 1 1 1]
True:[0 1 1 0 0 1 1 1]
18 + 85 = 103
------------
iters:6300
Loss:1.9219285273527308e-06
Pred:[1 0 0 0 0 0 0 1]
True:[1 0 0 0 0 0 0 1]
66 + 63 = 129
------------
iters:6400
Loss:0.1250022975404902
Pred:[0 1 1 0 1 0 1 0]
True:[0 1 1 0 1 0 1 0]
88 + 18 = 106
------------
iters:6500
Loss:4.956262735254183e-07
Pred:[1 0 0 0 1 0 0 0]
True:[1 0 0 0 1 0 0 0]
31 + 105 = 136
------------
iters:6600
Loss:1.928456500920741e-07
Pred:[0 0 0 1 1 1 0 0]
True:[0 0 0 1 1 1 0 0]
15 + 13 = 28
------------
iters:6700
Loss:3.241571531665714e-06
Pred:[0 0 0 1 1 1 0 1]
True:[0 0 0 1 1 1 0 1]
13 + 16 = 29
------------
iters:6800
Loss:4.627317133987496e-06
Pred:[0 1 1 1 1 1 1 0]
True:[0 1 1 1 1 1 1 0]
95 + 31 = 126
------------
iters:6900
Loss:1.0592229276210797e-06
Pred:[1 1 0 1 1 0 0 0]
True:[1 1 0 1 1 0 0 0]
113 + 103 = 216
------------
iters:7000
Loss:1.6820397734208046e-06
Pred:[0 0 1 0 1 0 1 1]
True:[0 0 1 0 1 0 1 1]
25 + 18 = 43
------------
iters:7100
Loss:0.12500049884678102
Pred:[0 1 1 1 0 0 0 0]
True:[0 1 1 1 0 0 0 0]
2 + 110 = 112
------------
iters:7200
Loss:9.501157990652455e-07
Pred:[0 1 1 0 0 0 0 1]
True:[0 1 1 0 0 0 0 1]
23 + 74 = 97
------------
iters:7300
Loss:0.12500028290976414
Pred:[0 1 1 1 0 1 0 0]
True:[0 1 1 1 0 1 0 0]
22 + 94 = 116
------------
iters:7400
Loss:1.6004819114081656e-06
Pred:[1 0 1 0 0 1 0 1]
True:[1 0 1 0 0 1 0 1]
86 + 79 = 165
------------
iters:7500
Loss:0.12500284941372178
Pred:[0 1 1 1 1 0 1 0]
True:[0 1 1 1 1 0 1 0]
8 + 114 = 122
------------
iters:7600
Loss:1.3378838077969413e-06
Pred:[0 1 0 1 1 0 0 1]
True:[0 1 0 1 1 0 0 1]
58 + 31 = 89
------------
iters:7700
Loss:1.5634943666136195e-06
Pred:[1 0 0 0 1 1 1 1]
True:[1 0 0 0 1 1 1 1]
26 + 117 = 143
------------
iters:7800
Loss:8.298712160398821e-07
Pred:[1 0 0 0 1 0 0 1]
True:[1 0 0 0 1 0 0 1]
111 + 26 = 137
------------
iters:7900
Loss:9.721998296969126e-07
Pred:[1 0 1 1 0 0 0 1]
True:[1 0 1 1 0 0 0 1]
67 + 110 = 177
------------
iters:8000
Loss:1.1089418668371164e-06
Pred:[1 0 1 0 1 0 1 0]
True:[1 0 1 0 1 0 1 0]
69 + 101 = 170
------------
iters:8100
Loss:1.2011066190571116e-06
Pred:[1 0 0 1 1 0 0 1]
True:[1 0 0 1 1 0 0 1]
119 + 34 = 153
------------
iters:8200
Loss:1.1020820008897288e-06
Pred:[0 1 1 0 1 0 1 1]
True:[0 1 1 0 1 0 1 1]
103 + 4 = 107
------------
iters:8300
Loss:9.360294663849577e-07
Pred:[1 1 0 0 0 1 1 1]
True:[1 1 0 0 0 1 1 1]
118 + 81 = 199
------------
iters:8400
Loss:0.125000703826329
Pred:[1 0 1 0 1 1 1 0]
True:[1 0 1 0 1 1 1 0]
126 + 48 = 174
------------
iters:8500
Loss:8.921121456912583e-07
Pred:[0 1 1 1 1 0 1 0]
True:[0 1 1 1 1 0 1 0]
87 + 35 = 122
------------
iters:8600
Loss:1.30893257410703e-06
Pred:[1 1 0 0 1 0 0 1]
True:[1 1 0 0 1 0 0 1]
84 + 117 = 201
------------
iters:8700
Loss:6.220847321166628e-07
Pred:[1 0 1 0 0 1 0 0]
True:[1 0 1 0 0 1 0 0]
81 + 83 = 164
------------
iters:8800
Loss:1.83261700057198e-06
Pred:[0 1 0 1 1 0 1 1]
True:[0 1 0 1 1 0 1 1]
29 + 62 = 91
------------
iters:8900
Loss:9.979971340586316e-07
Pred:[0 1 0 0 0 1 1 1]
True:[0 1 0 0 0 1 1 1]
46 + 25 = 71
------------
iters:9000
Loss:5.960091082607881e-07
Pred:[1 0 1 1 1 0 0 0]
True:[1 0 1 1 1 0 0 0]
119 + 65 = 184
------------
iters:9100
Loss:0.12500087863901013
Pred:[0 0 0 1 1 1 1 0]
True:[0 0 0 1 1 1 1 0]
8 + 22 = 30
------------
iters:9200
Loss:0.12500032820664525
Pred:[0 1 0 0 0 0 0 0]
True:[0 1 0 0 0 0 0 0]
2 + 62 = 64
------------
iters:9300
Loss:6.736849891227599e-07
Pred:[1 0 0 0 0 0 0 1]
True:[1 0 0 0 0 0 0 1]
38 + 91 = 129
------------
iters:9400
Loss:0.12500075162686697
Pred:[1 0 0 0 0 1 1 0]
True:[1 0 0 0 0 1 1 0]
8 + 126 = 134
------------
iters:9500
Loss:0.12500059009366546
Pred:[0 0 1 1 0 1 1 0]
True:[0 0 1 1 0 1 1 0]
10 + 44 = 54
------------
iters:9600
Loss:8.74558182666507e-07
Pred:[0 0 0 1 0 0 1 1]
True:[0 0 0 1 0 0 1 1]
11 + 8 = 19
------------
iters:9700
Loss:5.148342447295406e-07
Pred:[1 0 0 1 1 0 0 0]
True:[1 0 0 1 1 0 0 0]
49 + 103 = 152
------------
iters:9800
Loss:0.12500066096627008
Pred:[1 0 0 1 1 0 1 0]
True:[1 0 0 1 1 0 1 0]
54 + 100 = 154
------------
iters:9900
Loss:2.6674907257287893e-07
Pred:[1 0 1 0 1 0 1 0]
True:[1 0 1 0 1 0 1 0]
91 + 79 = 170
------------

[try] weight_init_stdやlearning_rate, hidden_layer_sizeを変更してみよう

[try] 重みの初期化方法を変更してみよう

Xavier, He

[try] 中間層の活性化関数を変更してみよう

ReLU(勾配爆発を確認しよう)
tanh(numpyにtanhが用意されている。導関数をd_tanhとして作成しよう)


In [ ]: