์ˆœํ™˜(Recurrence)

์ˆœํ™˜ ๋ชจ๋ธ(Recurrent Models)

๊ธฐ์–ต ์„ธํฌ(Recurrent Cells, ์ˆœํ™˜ ์…€, ๋‡Œ๋ฅผ ๋ชจ๋ฐฉํ•œ ๊ฑฐ)

๋‹จ์ˆœํ•œ ํ”ผ๋“œํฌ์›Œ๋“œ(feedforward, ์‚ฌ์ดํด(cycle)์ด๋‚˜ ๋ฃจํ”„(loop)๊ฐ€ ์—†๋Š” ๋„คํŠธ์›Œํฌ) ๊ฒฝ์šฐ, ๋ชจ๋ธ m์€ ์—ฌ๋Ÿฌ ๊ฐœ์˜ ์ž…๋ ฅ xแตข์— ๋Œ€ํ•œ yแตข๋ฅผ ์˜ˆ์ธกํ•˜๋Š” ๊ฐ„๋‹จํ•œ ํ•จ์ˆ˜๋‹ค. (์˜ˆ๋ฅผ ๋“ค์–ด, x๋ฅผ MNIST ์ˆซ์ž๋ผ ์น˜๋ฉด y๋Š” ๊ทธ๊ฒƒ์„ ๋ถ„๋ฅ˜ํ•œ ์ˆซ์ž.) ์˜ˆ์ธก์€ ์„œ๋กœ ์™„์ „ํžˆ ๋…๋ฆฝ์ ์ด๋ฉฐ x๊ฐ€ ๊ฐ™์œผ๋ฉด y๋„ ์–ธ์ œ๋‚˜ ๋™์ผํ•˜๋‹ค.

yโ‚ = f(xโ‚)
yโ‚‚ = f(xโ‚‚)
yโ‚ƒ = f(xโ‚ƒ)
# ...

์ˆœํ™˜ ๋„คํŠธ์›Œํฌ๋Š” ํžˆ๋“  ์ƒํƒœ(hidden state, ์ˆจ๊ฒจ๋…ผ ์ƒํƒœ)๊ฐ€ ์กด์žฌํ•˜๋ฉฐ ๋ชจ๋ธ์„ ๋Œ๋ฆด ๋•Œ ๋งค๋ฒˆ ๊ทธ ์ƒํƒœ๋ฅผ ๋‹ค์Œ์œผ๋กœ ๋„˜๊ธด๋‹ค. ๊ทธ๋ž˜์„œ ์ด ๋ชจ๋ธ์€ ๊ทธ๋•Œ๋งˆ๋‹ค ์ด์ „ h๋ฅผ ์ž…๋ ฅ์œผ๋กœ ๋ฐ›๊ณ , ์ƒˆ๋กœ์šด h๋ฅผ ์ถœ๋ ฅ์œผ๋กœ ๋‚ด ๋†“๋Š”๋‹ค.

h = # ... ์ดˆ๊ธฐ ์ƒํƒœ ...
h, yโ‚ = f(h, xโ‚)
h, yโ‚‚ = f(h, xโ‚‚)
h, yโ‚ƒ = f(h, xโ‚ƒ)
# ...

h์— ์ €์žฅํ•œ ์ •๋ณด๋Š” ๋‹ค์Œ๋ฒˆ ์˜ˆ์ธก์„ ์œ„ํ•ด ์œ ์ง€ํ•˜๊ณ , ๊ทธ๋ž˜์„œ ๋งˆ์น˜ ํ•จ์ˆ˜์˜ ๋ฉ”๋ชจ๋ฆฌ ๊ฐ™์€ ์—ญํ• ์„ ํ•˜๊ฒŒ ํ•œ๋‹ค. ์ด๊ฒƒ์€ x์— ๋Œ€ํ•œ ์˜ˆ์ธก์ด ์ด์ „๊นŒ์ง€ ๋ชจ๋ธ์— ์ฃผ์ž…ํ•œ ๋ชจ๋“  ์ž…๋ ฅ์œผ๋กœ๋ถ€ํ„ฐ ์˜ํ–ฅ์„ ๋ฐ›๋Š” ๊ฒƒ์„ ์˜๋ฏธํ•œ๋‹ค.

(์š”๊ฑฐ๋Š” ์ค‘์š”ํ•œ ๊ฑฐ๋‹ˆ๊นŒ ์˜ˆ๋ฅผ ๋“ค์–ด, x๋ฅผ ๋ฌธ์žฅ์—์„œ์˜ ํ•œ ๋‹จ์–ด๋ผ ๋ณด์ž; ๋งŒ์•ฝ์— "bank"๋ผ๋Š” ์˜์–ด ๋‹จ์–ด๊ฐ€ ์ฃผ์–ด์ง€๋ฉด ๋ชจ๋ธ์€ ์ด์ „ ์ž…๋ ฅ์ด "๊ฐ• river" ์ด๋ฉด ๊ฐ•๋‘‘์œผ๋กœ, "ํˆฌ์ž investment"๋ฉด ์€ํ–‰์œผ๋กœ ํ•ด์„ํ•ด์•ผ ํ•œ๋‹ค.)

Flux์˜ RNN ์ง€์›์€ ์ˆ˜ํ•™์  ๊ด€์ ์„ ์ง€๋‹ˆ๊ณ  ์žˆ๋‹ค. ๊ฐ€์žฅ ๊ธฐ๋ณธ์ด ๋˜๋Š” RNN์€ ํ‘œ์ค€ "Dense" ๋ ˆ์ด์–ด๋ฅผ ๋”ฐ๋ฅด๊ณ , ๊ทธ ์ถœ๋ ฅ์€ ํžˆ๋“  ์ƒํƒœ์ด๋‹ค.

julia> using Flux

julia> Wxh = randn(5, 10)
5ร—10 Array{Float64,2}:
 -0.197167   0.0931036  -1.13283   โ€ฆ   0.426711   1.5678      0.488363
 -1.19948   -1.05618     1.057        -1.85708    2.05188    -0.732148
 -0.848823   0.147774    1.66139      -0.777346  -0.0650354   0.36015
 -0.380701   0.737349    0.426964      0.694122  -1.46597    -1.00572
 -0.789044  -0.374745   -0.996698      0.505453  -0.117276    1.35148

julia> Whh = randn(5, 5)
5ร—5 Array{Float64,2}:
 -1.12946    -0.523065   0.0547692  -0.305124  -0.105809
 -0.195351    0.588007   0.616959    0.779213  -0.145329
 -0.265139   -0.535485  -0.300887    2.13263   -1.53089
 -0.0537235  -1.47912   -0.883858    0.993426  -0.354738
  0.486817    0.170843   0.0440353   0.177502   0.730423

julia> b   = randn(5)
5-element Array{Float64,1}:
  0.982592
 -0.724775
  0.118081
  0.140369
 -1.07578

julia> function rnn(h, x)
         h = tanh.(Wxh * x .+ Whh * h .+ b)
         return h, h
       end
rnn (generic function with 1 method)

julia> x = rand(10) # ๋”๋ฏธ ๋ฐ์ดํ„ฐ
10-element Array{Float64,1}:
 0.312436
 0.384043
 0.972045
 0.194086
 0.496317
 0.654925
 0.0311892
 0.494105
 0.338846
 0.204689

julia> h = rand(5)  # ์ดˆ๊ธฐ ํžˆ๋“  ์ƒํƒœ
5-element Array{Float64,1}:
 0.861124
 0.994686
 0.560054
 0.371721
 0.159454

julia> h, y = rnn(h, x)
([-0.963817, -0.198195, 0.903936, -0.686608, -0.839093], [-0.963817, -0.198195, 0.903936, -0.686608, -0.839093])

๋งˆ์ง€๋ง‰ rnn์„ ์ข€ ๋” ๋Œ๋ ค๋ณด๋ฉด, ์ถœ๋ ฅ y๋Š” ์ž…๋ ฅ x๊ฐ€ ๊ฐ™์€๋ฐ๋„ ์กฐ๊ธˆ์”ฉ ๋ฐ”๋€Œ๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ๋‹ค.

julia> h, y = rnn(h, x)
([0.812906, -0.767065, 0.945139, 0.0198447, -0.996763], [0.812906, -0.767065, 0.945139, 0.0198447, -0.996763])

julia> h, y = rnn(h, x)
([-0.647084, -0.799032, 0.997557, 0.902798, -0.984697], [-0.647084, -0.799032, 0.997557, 0.902798, -0.984697])

์•ž์„œ ์–ธ๊ธ‰ํ•œ rnn ํ•จ์ˆ˜๋Š” ๋ช…์‹œ์ ์œผ๋กœ ์ƒํƒœ๋ฅผ ๊ด€๋ฆฌํ•˜๋Š” ๊ธฐ์–ต ์„ธํฌ(cells) ์ด๋‹ค. ๋‹ค์–‘ํ•œ ๊ธฐ์–ต ์„ธํฌ๊ฐ€ ์กด์žฌํ•˜๋ฉฐ ๋ ˆ์ด์–ด ์ฐธ์กฐ์— ๊ด€๋ จ ๋‚ด์šฉ์ด ์žˆ๋‹ค. ์œ„์˜ ์˜ˆ์ œ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋ฐ”๊ฟ€ ์ˆ˜ ์žˆ๋‹ค:

julia> using Flux

julia> rnn2 = Flux.RNNCell(10, 5)
RNNCell(10, 5, tanh)

julia> x = rand(10) # ๋”๋ฏธ ๋ฐ์ดํ„ฐ
10-element Array{Float64,1}:
 0.142406
 0.944597
 0.973233
 0.434782
 0.715639
 0.763562
 0.280661
 0.293604
 0.496457
 0.173372

julia> h = rand(5)  # ์ดˆ๊ธฐ ํžˆ๋“  ์ƒํƒœ
5-element Array{Float64,1}:
 0.602545
 0.998396
 0.558707
 0.637564
 0.0313308

julia> h, y = rnn2(h, x)
(param([-0.160217, -0.741263, 0.048164, 0.963063, 0.0301785]), param([-0.160217, -0.741263, 0.048164, 0.963063, 0.0301785]))

์ƒํƒœ๋ฅผ ๊ฐ–๋Š” ๋ชจ๋ธ

๋Œ€๋ถ€๋ถ„์˜ ๊ฒฝ์šฐ, ํžˆ๋“  ์ƒํƒœ๋ฅผ ์ง์ ‘ ๊ด€๋ฆฌํ•˜๋Š” ๊ฑฐ๋Š” ๊ท€์ฐฎ์œผ๋‹ˆ๊นŒ ๋ชจ๋ธ์ด ์ƒํƒœ๋ฅผ ๊ฐ–๊ฒŒ๋” ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ๋‹ค. Flux๋Š” Recur ๋ž˜ํผ๋ฅผ ์ œ๊ณตํ•œ๋‹ค.

julia> x = rand(10)
10-element Array{Float64,1}:
 0.165593
 0.502313
 0.120926
 0.505827
 0.917068
 0.557163
 0.688472
 0.791826
 0.0838632
 0.709302

julia> h = rand(5)
5-element Array{Float64,1}:
 0.40008
 0.48858
 0.551568
 0.0688404
 0.0583865

julia> m = Flux.Recur(rnn, h)
Recur(rnn)

julia> y = m(x)
5-element Array{Float64,1}:
  0.963414
 -0.999974
  0.739107
  0.976241
  0.986023

Recur ๋ž˜ํผ๋Š” m.state ํ•„๋“œ์— ์ƒํƒœ๋ฅผ ์ €์žฅํ•œ๋‹ค.

RNN(10, 5) ์ƒ์„ฑ์ž๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด - RNNCell๊ณผ ๋Œ€์‘ํ•˜๋Š” - ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ด๊ฑฐ๋Š” ๋‹จ์ˆœํžˆ ๋ž˜ํผ ์…€์ด๋‹ค.

julia> RNN(10, 5)
Recur(RNNCell(10, 5, tanh))

์‹œํ€€์Šค(Sequences, ์—ฐ์†๋˜๋Š” ๊ฐ’)

์ข…์ข… ๊ฐœ๋ณ„์ ์ธ x ๋ณด๋‹ค๋Š” ์—ฐ์†๋˜๋Š” ์ž…๋ ฅ์„ ๋‹ค๋ฃจ๊ธธ ์›ํ•œ๋‹ค.

julia> seq = [rand(10) for i = 1:10]
10-element Array{Array{Float64,1},1}:
 [0.443911, 0.955247, 0.980153, 0.313181, 0.0426581, 0.354755, 0.113961, 0.222873, 0.865114, 0.14094]
 [0.50466, 0.0204917, 0.890547, 0.574102, 0.301098, 0.944295, 0.95414, 0.36809, 0.341546, 0.474998]
 [0.474114, 0.152628, 0.364967, 0.601978, 0.212361, 0.66016, 0.12101, 0.944988, 0.417781, 0.715282]
 [0.0776375, 0.843099, 0.000618674, 0.352273, 0.977611, 0.801756, 0.550702, 0.311638, 0.285711, 0.0856441]
 [0.603498, 0.863035, 0.89494, 0.506224, 0.840984, 0.13453, 0.43549, 0.216554, 0.361081, 0.0965758]
 [0.236062, 0.407028, 0.357854, 0.875694, 0.0468227, 0.786622, 0.616748, 0.791976, 0.800668, 0.147169]
 [0.739452, 0.38329, 0.961215, 0.113691, 0.381309, 0.57526, 0.0170709, 0.403656, 0.445509, 0.051497]
 [0.956629, 0.624735, 0.14811, 0.202354, 0.484018, 0.250409, 0.0352729, 0.809209, 0.831828, 0.826355]
 [0.388553, 0.42596, 0.736068, 0.454156, 0.626974, 0.641246, 0.444018, 0.768584, 0.118879, 0.416568]
 [0.307721, 0.176393, 0.371934, 0.714272, 0.886859, 0.333667, 0.721609, 0.975586, 0.59609, 0.771424]

Recur๋กœ ๋ชจ๋ธ์„ ์‹œํ€€์Šค์˜ ๊ฐ ํ•ญ๋ชฉ๋งˆ๋‹ค ์‰ฝ๊ฒŒ ์ ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค:

julia> m.(seq) # 5-์—˜๋Ÿฌ๋จผํŠธ ๋ฒกํ„ฐ์˜ ๋ฆฌ์ŠคํŠธ๋ฅผ ๋Œ๋ ค์ค€๋‹ค
10-element Array{Array{Float64,1},1}:
 [0.958516, -0.996974, 0.640934, -0.440203, 0.991754]
 [0.998417, -0.998238, 0.988128, 0.924522, 0.999099]
 [0.943455, -0.999939, 0.94332, 0.638572, 0.999795]
 [0.997841, -0.999912, 0.414106, 0.705974, 0.999871]
 [0.9896, -0.96634, 0.903348, 0.805409, 0.949429]
 [0.990047, -0.999849, 0.991448, 0.950895, 0.999938]
 [0.980617, -0.988072, 0.978565, -0.785643, 0.985682]
 [0.98617, -0.99938, -0.791134, 0.603178, 0.0937938]
 [0.946547, -0.893022, 0.914559, 0.999905, 0.984556]
 [0.989439, -0.999979, 0.964896, 0.978421, 0.999834]

๋” ์ปค๋‹ค๋ž€ ๋ชจ๋ธ์— ์ˆœํ™˜ ๋ ˆ์ด์–ด(recurrent layers)๋ฅผ ์—ฐ์‡„์ (chain)์œผ๋กœ ์—ฐ๊ฒฐ ํ•  ์ˆ˜ ์žˆ๋‹ค.

julia> m = Chain(LSTM(10, 15), Dense(15, 5))
Chain(Recur(LSTMCell(10, 60)), Dense(15, 5))

julia> m.(seq)
10-element Array{TrackedArray{โ€ฆ,Array{Float64,1}},1}:
 param([0.0779735, 0.0534096, -0.0245852, -0.0699291, -0.00650743])
 param([0.203825, -0.0307184, -0.0940759, -0.100437, 0.0523315])
 param([0.21071, -0.19635, -0.106985, -0.185204, 0.132647])
 param([0.314643, -0.205525, -0.00144219, -0.165195, 0.197256])
 param([0.351024, -0.116196, 0.00489051, -0.255343, 0.209503])
 param([0.370406, -0.125797, -0.0506301, -0.253045, 0.179001])
 param([0.349787, -0.091392, -0.0699977, -0.249944, 0.197391])
 param([0.370064, -0.21158, -0.00144108, -0.337597, 0.24153])
 param([0.396285, -0.240793, -0.0263459, -0.358695, 0.260678])
 param([0.464372, -0.316526, -0.0295575, -0.352548, 0.251627])

๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐ ๊ธฐ๋ก ์ž˜๋ผ๋‚ด๊ธฐ(Truncating Gradients)

๊ธฐ๋ณธ์ ์œผ๋กœ, ์ˆœํ™˜ ๋ ˆ์ด์–ด์˜ ๊ธฐ์šธ๊ธฐ๋ฅผ ๊ณ„์‚ฐํ•˜๋Š” ๊ฒƒ์€ ์ „์ฒด ๊ธฐ๋ก(history)์„ ๋‚ดํฌํ•œ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด 100๊ฐœ์˜ ์ž…๋ ฅ์„ ๊ฐ€์ง„ ๋ชจ๋ธ์„ ์‹คํ–‰ํ•  ๋•Œ, back!์„ ํ•˜๋ฉด 100๊ฐœ์— ๋Œ€ํ•œ ๊ธฐ์šธ๊ธฐ๋ฅผ ๊ณ„์‚ฐํ•œ๋‹ค. ๊ทธ๋Ÿฌ๊ณ  ๋‹ค๋ฅธ 10๊ฐœ์˜ ์ž…๋ ฅ์„ ๋” ๊ณ„์‚ฐํ•œ๋‹ค๋ฉด 110๊ฐœ์˜ ๊ธฐ์šธ๊ธฐ๋ฅผ ๊ณ„์‚ฐํ•ด์•ผ ํ•œ๋‹ค - ์ด๊ฑฐ๋Š” ๋ˆ„์ ๋˜๋ฏ€๋กœ ๋น ๋ฅด๊ฒŒ ์—ฐ์‚ฐ ๋น„์šฉ์ด ์ฆ๊ฐ€ํ•œ๋‹ค.

์ด๊ฑฐ๋ฅผ ๋ง‰๋Š” ๋ฐฉ๋ฒ•์€ ๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐ ๊ธฐ๋ก์„ ์ž˜๋ผ๋‚ด์–ด(truncate) ์ง€์›Œ์ฃผ๋Š” ๊ฒƒ์ด๋‹ค.

julia> Flux.truncate!(m)

truncate!์„ ํ˜ธ์ถœํ•˜๋ฉด ๊น”๋”์ด ์ฒญ์†Œํ•ด ์ค€๋‹ค. ๊ทธ๋ž˜์„œ ๋” ๋งŽ์€ ์ž…๋ ฅ์˜ ๋ชจ๋ธ์„ ์‹คํ–‰ํ•ด๋„ ๋น„์‹ผ ๊ธฐ์šธ๊ธฐ ์—ฐ์‚ฐ ์—†์ด ํ•ด๋‚ผ ์ˆ˜ ์žˆ๋‹ค.

truncate!๋Š” ์—ฌ๋Ÿฌ ๊ฐœ์˜ ์ปค๋‹ค๋ž€ ์‹œํ€€์Šค ๋ฉ์–ด๋ฆฌ๋ฅผ ๋‹ค๋ฃฐ ๋•Œ ์œ ์šฉํ•˜์ง€๋งŒ, ์„œ๋กœ ๋…๋ฆฝ์ ์ธ ์‹œํ€€์Šค๋“ค์„ ๋‹ค๋ฃจ๊ณ  ์‹ถ์„ ๋•Œ๋„ ์žˆ๋‹ค. ๊ทธ ๊ฒฝ์šฐ ํžˆ๋“  ์ƒํƒœ๋Š” ์›๋ž˜ ๊ฐ’์œผ๋กœ ์™„์ „ํžˆ ์ดˆ๊ธฐํ™” ๋˜์–ด ๋ˆ„์ ๋œ ์ •๋ณด๋ฅผ ๋ฒ„๋ฆฐ๋‹ค. ๊ทธ๋ ‡๊ฒŒ ํ•˜๊ณ  ์‹ถ์œผ๋ฉด reset!์„ ํ•ด ์ฃผ์ž.

julia> Flux.reset!(m)