์ํ ๋ชจ๋ธ(Recurrent Models)
๊ธฐ์ต ์ธํฌ(Recurrent Cells, ์ํ ์ , ๋๋ฅผ ๋ชจ๋ฐฉํ ๊ฑฐ)
๋จ์ํ ํผ๋ํฌ์๋(feedforward, ์ฌ์ดํด(cycle)์ด๋ ๋ฃจํ(loop)๊ฐ ์๋ ๋คํธ์ํฌ) ๊ฒฝ์ฐ, ๋ชจ๋ธ m
์ ์ฌ๋ฌ ๊ฐ์ ์
๋ ฅ xแตข
์ ๋ํ yแตข
๋ฅผ ์์ธกํ๋ ๊ฐ๋จํ ํจ์๋ค. (์๋ฅผ ๋ค์ด, x
๋ฅผ MNIST ์ซ์๋ผ ์น๋ฉด y
๋ ๊ทธ๊ฒ์ ๋ถ๋ฅํ ์ซ์.) ์์ธก์ ์๋ก ์์ ํ ๋
๋ฆฝ์ ์ด๋ฉฐ x
๊ฐ ๊ฐ์ผ๋ฉด y
๋ ์ธ์ ๋ ๋์ผํ๋ค.
yโ = f(xโ)
yโ = f(xโ)
yโ = f(xโ)
# ...
์ํ ๋คํธ์ํฌ๋ ํ๋ ์ํ(hidden state, ์จ๊ฒจ๋
ผ ์ํ)๊ฐ ์กด์ฌํ๋ฉฐ ๋ชจ๋ธ์ ๋๋ฆด ๋ ๋งค๋ฒ ๊ทธ ์ํ๋ฅผ ๋ค์์ผ๋ก ๋๊ธด๋ค. ๊ทธ๋์ ์ด ๋ชจ๋ธ์ ๊ทธ๋๋ง๋ค ์ด์ h
๋ฅผ ์
๋ ฅ์ผ๋ก ๋ฐ๊ณ , ์๋ก์ด h
๋ฅผ ์ถ๋ ฅ์ผ๋ก ๋ด ๋๋๋ค.
h = # ... ์ด๊ธฐ ์ํ ...
h, yโ = f(h, xโ)
h, yโ = f(h, xโ)
h, yโ = f(h, xโ)
# ...
h
์ ์ ์ฅํ ์ ๋ณด๋ ๋ค์๋ฒ ์์ธก์ ์ํด ์ ์งํ๊ณ , ๊ทธ๋์ ๋ง์น ํจ์์ ๋ฉ๋ชจ๋ฆฌ ๊ฐ์ ์ญํ ์ ํ๊ฒ ํ๋ค. ์ด๊ฒ์ x
์ ๋ํ ์์ธก์ด ์ด์ ๊น์ง ๋ชจ๋ธ์ ์ฃผ์
ํ ๋ชจ๋ ์
๋ ฅ์ผ๋ก๋ถํฐ ์ํฅ์ ๋ฐ๋ ๊ฒ์ ์๋ฏธํ๋ค.
(์๊ฑฐ๋ ์ค์ํ ๊ฑฐ๋๊น ์๋ฅผ ๋ค์ด, x
๋ฅผ ๋ฌธ์ฅ์์์ ํ ๋จ์ด๋ผ ๋ณด์; ๋ง์ฝ์ "bank"๋ผ๋ ์์ด ๋จ์ด๊ฐ ์ฃผ์ด์ง๋ฉด ๋ชจ๋ธ์ ์ด์ ์
๋ ฅ์ด "๊ฐ river" ์ด๋ฉด ๊ฐ๋์ผ๋ก, "ํฌ์ investment"๋ฉด ์ํ์ผ๋ก ํด์ํด์ผ ํ๋ค.)
Flux์ RNN ์ง์์ ์ํ์ ๊ด์ ์ ์ง๋๊ณ ์๋ค. ๊ฐ์ฅ ๊ธฐ๋ณธ์ด ๋๋ RNN์ ํ์ค "Dense" ๋ ์ด์ด๋ฅผ ๋ฐ๋ฅด๊ณ , ๊ทธ ์ถ๋ ฅ์ ํ๋ ์ํ์ด๋ค.
julia> using Flux
julia> Wxh = randn(5, 10)
5ร10 Array{Float64,2}:
-0.197167 0.0931036 -1.13283 โฆ 0.426711 1.5678 0.488363
-1.19948 -1.05618 1.057 -1.85708 2.05188 -0.732148
-0.848823 0.147774 1.66139 -0.777346 -0.0650354 0.36015
-0.380701 0.737349 0.426964 0.694122 -1.46597 -1.00572
-0.789044 -0.374745 -0.996698 0.505453 -0.117276 1.35148
julia> Whh = randn(5, 5)
5ร5 Array{Float64,2}:
-1.12946 -0.523065 0.0547692 -0.305124 -0.105809
-0.195351 0.588007 0.616959 0.779213 -0.145329
-0.265139 -0.535485 -0.300887 2.13263 -1.53089
-0.0537235 -1.47912 -0.883858 0.993426 -0.354738
0.486817 0.170843 0.0440353 0.177502 0.730423
julia> b = randn(5)
5-element Array{Float64,1}:
0.982592
-0.724775
0.118081
0.140369
-1.07578
julia> function rnn(h, x)
h = tanh.(Wxh * x .+ Whh * h .+ b)
return h, h
end
rnn (generic function with 1 method)
julia> x = rand(10) # ๋๋ฏธ ๋ฐ์ดํฐ
10-element Array{Float64,1}:
0.312436
0.384043
0.972045
0.194086
0.496317
0.654925
0.0311892
0.494105
0.338846
0.204689
julia> h = rand(5) # ์ด๊ธฐ ํ๋ ์ํ
5-element Array{Float64,1}:
0.861124
0.994686
0.560054
0.371721
0.159454
julia> h, y = rnn(h, x)
([-0.963817, -0.198195, 0.903936, -0.686608, -0.839093], [-0.963817, -0.198195, 0.903936, -0.686608, -0.839093])
๋ง์ง๋ง rnn
์ ์ข ๋ ๋๋ ค๋ณด๋ฉด, ์ถ๋ ฅ y
๋ ์
๋ ฅ x
๊ฐ ๊ฐ์๋ฐ๋ ์กฐ๊ธ์ฉ ๋ฐ๋๋ ๊ฒ์ ์ ์ ์๋ค.
julia> h, y = rnn(h, x)
([0.812906, -0.767065, 0.945139, 0.0198447, -0.996763], [0.812906, -0.767065, 0.945139, 0.0198447, -0.996763])
julia> h, y = rnn(h, x)
([-0.647084, -0.799032, 0.997557, 0.902798, -0.984697], [-0.647084, -0.799032, 0.997557, 0.902798, -0.984697])
์์ ์ธ๊ธํ rnn
ํจ์๋ ๋ช
์์ ์ผ๋ก ์ํ๋ฅผ ๊ด๋ฆฌํ๋ ๊ธฐ์ต ์ธํฌ(cells) ์ด๋ค. ๋ค์ํ ๊ธฐ์ต ์ธํฌ๊ฐ ์กด์ฌํ๋ฉฐ ๋ ์ด์ด ์ฐธ์กฐ์ ๊ด๋ จ ๋ด์ฉ์ด ์๋ค. ์์ ์์ ๋ ๋ค์๊ณผ ๊ฐ์ด ๋ฐ๊ฟ ์ ์๋ค:
julia> using Flux
julia> rnn2 = Flux.RNNCell(10, 5)
RNNCell(10, 5, tanh)
julia> x = rand(10) # ๋๋ฏธ ๋ฐ์ดํฐ
10-element Array{Float64,1}:
0.142406
0.944597
0.973233
0.434782
0.715639
0.763562
0.280661
0.293604
0.496457
0.173372
julia> h = rand(5) # ์ด๊ธฐ ํ๋ ์ํ
5-element Array{Float64,1}:
0.602545
0.998396
0.558707
0.637564
0.0313308
julia> h, y = rnn2(h, x)
(param([-0.160217, -0.741263, 0.048164, 0.963063, 0.0301785]), param([-0.160217, -0.741263, 0.048164, 0.963063, 0.0301785]))
์ํ๋ฅผ ๊ฐ๋ ๋ชจ๋ธ
๋๋ถ๋ถ์ ๊ฒฝ์ฐ, ํ๋ ์ํ๋ฅผ ์ง์ ๊ด๋ฆฌํ๋ ๊ฑฐ๋ ๊ท์ฐฎ์ผ๋๊น ๋ชจ๋ธ์ด ์ํ๋ฅผ ๊ฐ๊ฒ๋ ์ฒ๋ฆฌํ ์ ์๋ค. Flux๋ Recur
๋ํผ๋ฅผ ์ ๊ณตํ๋ค.
julia> x = rand(10)
10-element Array{Float64,1}:
0.165593
0.502313
0.120926
0.505827
0.917068
0.557163
0.688472
0.791826
0.0838632
0.709302
julia> h = rand(5)
5-element Array{Float64,1}:
0.40008
0.48858
0.551568
0.0688404
0.0583865
julia> m = Flux.Recur(rnn, h)
Recur(rnn)
julia> y = m(x)
5-element Array{Float64,1}:
0.963414
-0.999974
0.739107
0.976241
0.986023
Recur
๋ํผ๋ m.state
ํ๋์ ์ํ๋ฅผ ์ ์ฅํ๋ค.
RNN(10, 5)
์์ฑ์๋ฅผ ์ฌ์ฉํ๋ฉด - RNNCell
๊ณผ ๋์ํ๋ - ๋ค์๊ณผ ๊ฐ์ด ์ด๊ฑฐ๋ ๋จ์ํ ๋ํผ ์
์ด๋ค.
julia> RNN(10, 5)
Recur(RNNCell(10, 5, tanh))
์ํ์ค(Sequences, ์ฐ์๋๋ ๊ฐ)
์ข
์ข
๊ฐ๋ณ์ ์ธ x
๋ณด๋ค๋ ์ฐ์๋๋ ์
๋ ฅ์ ๋ค๋ฃจ๊ธธ ์ํ๋ค.
julia> seq = [rand(10) for i = 1:10]
10-element Array{Array{Float64,1},1}:
[0.443911, 0.955247, 0.980153, 0.313181, 0.0426581, 0.354755, 0.113961, 0.222873, 0.865114, 0.14094]
[0.50466, 0.0204917, 0.890547, 0.574102, 0.301098, 0.944295, 0.95414, 0.36809, 0.341546, 0.474998]
[0.474114, 0.152628, 0.364967, 0.601978, 0.212361, 0.66016, 0.12101, 0.944988, 0.417781, 0.715282]
[0.0776375, 0.843099, 0.000618674, 0.352273, 0.977611, 0.801756, 0.550702, 0.311638, 0.285711, 0.0856441]
[0.603498, 0.863035, 0.89494, 0.506224, 0.840984, 0.13453, 0.43549, 0.216554, 0.361081, 0.0965758]
[0.236062, 0.407028, 0.357854, 0.875694, 0.0468227, 0.786622, 0.616748, 0.791976, 0.800668, 0.147169]
[0.739452, 0.38329, 0.961215, 0.113691, 0.381309, 0.57526, 0.0170709, 0.403656, 0.445509, 0.051497]
[0.956629, 0.624735, 0.14811, 0.202354, 0.484018, 0.250409, 0.0352729, 0.809209, 0.831828, 0.826355]
[0.388553, 0.42596, 0.736068, 0.454156, 0.626974, 0.641246, 0.444018, 0.768584, 0.118879, 0.416568]
[0.307721, 0.176393, 0.371934, 0.714272, 0.886859, 0.333667, 0.721609, 0.975586, 0.59609, 0.771424]
Recur
๋ก ๋ชจ๋ธ์ ์ํ์ค์ ๊ฐ ํญ๋ชฉ๋ง๋ค ์ฝ๊ฒ ์ ์ฉํ ์ ์๋ค:
julia> m.(seq) # 5-์๋ฌ๋จผํธ ๋ฒกํฐ์ ๋ฆฌ์คํธ๋ฅผ ๋๋ ค์ค๋ค
10-element Array{Array{Float64,1},1}:
[0.958516, -0.996974, 0.640934, -0.440203, 0.991754]
[0.998417, -0.998238, 0.988128, 0.924522, 0.999099]
[0.943455, -0.999939, 0.94332, 0.638572, 0.999795]
[0.997841, -0.999912, 0.414106, 0.705974, 0.999871]
[0.9896, -0.96634, 0.903348, 0.805409, 0.949429]
[0.990047, -0.999849, 0.991448, 0.950895, 0.999938]
[0.980617, -0.988072, 0.978565, -0.785643, 0.985682]
[0.98617, -0.99938, -0.791134, 0.603178, 0.0937938]
[0.946547, -0.893022, 0.914559, 0.999905, 0.984556]
[0.989439, -0.999979, 0.964896, 0.978421, 0.999834]
๋ ์ปค๋ค๋ ๋ชจ๋ธ์ ์ํ ๋ ์ด์ด(recurrent layers)๋ฅผ ์ฐ์์ (chain)์ผ๋ก ์ฐ๊ฒฐ ํ ์ ์๋ค.
julia> m = Chain(LSTM(10, 15), Dense(15, 5))
Chain(Recur(LSTMCell(10, 60)), Dense(15, 5))
julia> m.(seq)
10-element Array{TrackedArray{โฆ,Array{Float64,1}},1}:
param([0.0779735, 0.0534096, -0.0245852, -0.0699291, -0.00650743])
param([0.203825, -0.0307184, -0.0940759, -0.100437, 0.0523315])
param([0.21071, -0.19635, -0.106985, -0.185204, 0.132647])
param([0.314643, -0.205525, -0.00144219, -0.165195, 0.197256])
param([0.351024, -0.116196, 0.00489051, -0.255343, 0.209503])
param([0.370406, -0.125797, -0.0506301, -0.253045, 0.179001])
param([0.349787, -0.091392, -0.0699977, -0.249944, 0.197391])
param([0.370064, -0.21158, -0.00144108, -0.337597, 0.24153])
param([0.396285, -0.240793, -0.0263459, -0.358695, 0.260678])
param([0.464372, -0.316526, -0.0295575, -0.352548, 0.251627])
๊ธฐ์ธ๊ธฐ ๊ณ์ฐ ๊ธฐ๋ก ์๋ผ๋ด๊ธฐ(Truncating Gradients)
๊ธฐ๋ณธ์ ์ผ๋ก, ์ํ ๋ ์ด์ด์ ๊ธฐ์ธ๊ธฐ๋ฅผ ๊ณ์ฐํ๋ ๊ฒ์ ์ ์ฒด ๊ธฐ๋ก(history)์ ๋ดํฌํ๋ค. ์๋ฅผ ๋ค์ด 100๊ฐ์ ์
๋ ฅ์ ๊ฐ์ง ๋ชจ๋ธ์ ์คํํ ๋, back!
์ ํ๋ฉด 100๊ฐ์ ๋ํ ๊ธฐ์ธ๊ธฐ๋ฅผ ๊ณ์ฐํ๋ค. ๊ทธ๋ฌ๊ณ ๋ค๋ฅธ 10๊ฐ์ ์
๋ ฅ์ ๋ ๊ณ์ฐํ๋ค๋ฉด 110๊ฐ์ ๊ธฐ์ธ๊ธฐ๋ฅผ ๊ณ์ฐํด์ผ ํ๋ค - ์ด๊ฑฐ๋ ๋์ ๋๋ฏ๋ก ๋น ๋ฅด๊ฒ ์ฐ์ฐ ๋น์ฉ์ด ์ฆ๊ฐํ๋ค.
์ด๊ฑฐ๋ฅผ ๋ง๋ ๋ฐฉ๋ฒ์ ๊ธฐ์ธ๊ธฐ ๊ณ์ฐ ๊ธฐ๋ก์ ์๋ผ๋ด์ด(truncate) ์ง์์ฃผ๋ ๊ฒ์ด๋ค.
julia> Flux.truncate!(m)
truncate!
์ ํธ์ถํ๋ฉด ๊น๋์ด ์ฒญ์ํด ์ค๋ค. ๊ทธ๋์ ๋ ๋ง์ ์
๋ ฅ์ ๋ชจ๋ธ์ ์คํํด๋ ๋น์ผ ๊ธฐ์ธ๊ธฐ ์ฐ์ฐ ์์ด ํด๋ผ ์ ์๋ค.
truncate!
๋ ์ฌ๋ฌ ๊ฐ์ ์ปค๋ค๋ ์ํ์ค ๋ฉ์ด๋ฆฌ๋ฅผ ๋ค๋ฃฐ ๋ ์ ์ฉํ์ง๋ง, ์๋ก ๋
๋ฆฝ์ ์ธ ์ํ์ค๋ค์ ๋ค๋ฃจ๊ณ ์ถ์ ๋๋ ์๋ค. ๊ทธ ๊ฒฝ์ฐ ํ๋ ์ํ๋ ์๋ ๊ฐ์ผ๋ก ์์ ํ ์ด๊ธฐํ ๋์ด ๋์ ๋ ์ ๋ณด๋ฅผ ๋ฒ๋ฆฐ๋ค. ๊ทธ๋ ๊ฒ ํ๊ณ ์ถ์ผ๋ฉด reset!
์ ํด ์ฃผ์.
julia> Flux.reset!(m)