์ต์ ํ ํจ์(Optimisers)
๊ฐ๋จํ ๋ฆฌ๋์ด ๋ฆฌ๊ทธ๋ ์
์์ ์ฐ๋ฆฌ๋ ๋๋ฏธ ๋ฐ์ดํฐ๋ฅผ ๋ง๋ ํ, ์์ค(loss)์ ๊ณ์ฐํ๊ณ ์ญ์ ํ(backpropagate) ํ์ฌ ํ๋ผ๋ฏธํฐ W
์ b
์ ๊ธฐ์ธ๊ธฐ๋ฅผ ๊ณ์ฐํ์๋ค.
julia> using Flux
julia> W = param(rand(2, 5))
Tracked 2ร5 Array{Float64,2}:
0.215021 0.22422 0.352664 0.11115 0.040711
0.180933 0.769257 0.361652 0.783197 0.545495
julia> b = param(rand(2))
Tracked 2-element Array{Float64,1}:
0.205216
0.150938
julia> predict(x) = W*x .+ b
predict (generic function with 1 method)
julia> loss(x, y) = sum((predict(x) .- y).^2)
loss (generic function with 1 method)
julia> x, y = rand(5), rand(2) # ๋๋ฏธ ๋ฐ์ดํฐ
([0.153473, 0.927019, 0.40597, 0.783872, 0.392236], [0.261727, 0.00917161])
julia> l = loss(x, y) # ~ 3
3.6352060699201565 (tracked)
julia> Flux.back!(l)
๊ธฐ์ธ๊ธฐ๋ฅผ ์ฌ์ฉํ์ฌ ํ๋ผ๋ฏธํฐ๋ฅผ ์ ๋ฐ์ดํธ ํ๊ณ ์ ํ๋ค. ์์ค์ ์ค์ด๋ ค๊ณ ๋ง์ด๋ค. ์ฌ๊ธฐ์ ํ๊ฐ์ง ๋ฐฉ๋ฒ์:
function update()
ฮท = 0.1 # ํ์ตํ๋ ์๋(Learning Rate)
for p in (W, b)
p.data .-= ฮท .* p.grad # ์
๋ฐ์ดํธ ์ ์ฉ
p.grad .= 0 # ๊ธฐ์ธ๊ธฐ 0์ผ๋ก clear
end
end
update
๋ฅผ ํธ์ถํ๋ฉด ํ๋ผ๋ฏธํฐ W
์ b
๋ ๋ฐ๋๊ณ ์์ค(loss)์ ๋ด๋ ค๊ฐ๋ค.
๋๊ฐ์ง๋ ์ง๊ณ ๋์ด๊ฐ์: ๋ชจ๋ธ์์ ํ๋ จํ ํ๋ผ๋ฏธํฐ์ ๋ชฉ๋ก (์ฌ๊ธฐ์๋ [W, b]
), ๊ทธ๋ฆฌ๊ณ ์
๋ฐ์ดํธ ์งํ ์๋. ์ฌ๊ธฐ์์ ์
๋ฐ์ดํธ๋ ๊ฐ๋จํ gradient descent(๊ฒฝ์ฌ ํ๊ฐ, x .-= ฮท .* ฮ
) ์์ง๋ง, ๋ชจ๋ฉํ
(momentum)์ ์ถ๊ฐํ๋ ๊ฒ์ฒ๋ผ ๋ณด๋ค ์ด๋ ค์ด ๊ฒ๋ ํด๋ณด๊ณ ์ถ์ ๊ฒ์ด๋ค.
์ฌ๊ธฐ์ ๋ณ์๋ฅผ ์ป๋ ๊ฒ์ ์๋ฌด๊ฒ๋ ์๋์ง๋ง, ๋ ์ด์ด๋ฅผ ๋ณต์กํ๊ฒ ์๋๋ค๋ฉด ๊ณจ์น ์ข ์ํ ๊ฒ์ด๋ค.
julia> m = Chain(
Dense(10, 5, ฯ),
Dense(5, 2), softmax)
Chain(Dense(10, 5, NNlib.ฯ), Dense(5, 2), NNlib.softmax)
[m[1].W, m[1].b, ...]
์ด๋ ๊ฒ ์์ฑํ๋ ๊ฒ ๋์ , Flux์์ ์ ๊ณตํ๋ params(m)
ํจ์๋ฅผ ์ด์ฉํด ๋ชจ๋ธ์ ๋ชจ๋ ํ๋ผ๋ฏธํฐ์ ๋ชฉ๋ก์ ๊ตฌํ ๊ฒ์ด๋ค.
julia> opt = SGD([W, b], 0.1) # Gradient descent(๊ฒฝ์ฌ ํ๊ฐ)์ learning rate(ํ์ต ์๋) 0.1 ์ผ๋ก ํ๋ค
(::#71) (generic function with 1 method)
julia> opt() # `W`์ `b`๋ฅผ ๋ณ๊ฒฝํ๋ฉฐ ์
๋ฐ์ดํธ๋ฅผ ์ํํ๋ค
์ต์ ํ ํจ์๋ ํ๋ผ๋ฏธํฐ ๋ชฉ๋ก์ ๋ฐ์ ์์ update
์ ๊ฐ์ ํจ์๋ฅผ ๋๋ ค์ค๋ค. opt
๋ update
๋ฅผ ํ๋ จ ๋ฃจํ(training loop)์ ๋๊ฒจ์ค ์ ์๋๋ฐ, ๋งค๋ฒ ๋ฐ์ดํฐ์ ๋ฏธ๋-๋ฐฐ์น(mini-batch)๋ฅผ ํ ํ์ ์ต์ ํ๋ฅผ ์ํํ ๊ฒ์ด๋ค.
์ต์ ํ ํจ์ ์ฐธ๊ณ
๋ชจ๋ ์ต์ ํ ํจ์๋ ๋๊ฒจ๋ฐ์ ํ๋ผ๋ฏธํฐ๋ฅผ ์ ๋ฐ์ดํธ ํ๋ ํจ์๋ฅผ ๋๋ ค์ค๋ค.
Flux.Optimise.Descent
โ Type.Descent(ฮท)
Classic gradient descent optimiser with learning rate ฮท
. For each parameter p
and its gradient ฮดp
, this runs p -= ฮท*ฮดp
.
Flux.Optimise.Momentum
โ Type.Momentum(params, ฮท = 0.01; ฯ = 0.9)
Gradient descent with learning rate ฮท
and momentum ฯ
.
Flux.Optimise.Nesterov
โ Type.Nesterov(eta, ฯ = 0.9)
Gradient descent with learning rate ฮท
and Nesterov momentum ฯ
.
Flux.Optimise.RMSProp
โ Type.RMSProp(ฮท = 0.001, ฯ = 0.9)
RMSProp optimiser. Parameters other than learning rate don't need tuning. Often a good choice for recurrent networks.
Flux.Optimise.ADAM
โ Type.ADAM(ฮท = 0.001, ฮฒ = (0.9, 0.999))
ADAM optimiser.
Flux.Optimise.AdaMax
โ Type.AdaMax(params, ฮท = 0.001; ฮฒ1 = 0.9, ฮฒ2 = 0.999, ฯต = 1e-08)
AdaMax optimiser. Variant of ADAM based on the โ-norm.
Flux.Optimise.ADAGrad
โ Type.ADAGrad(ฮท = 0.1; ฯต = 1e-8)
ADAGrad optimiser. Parameters don't need tuning.
Flux.Optimise.ADADelta
โ Type.ADADelta(ฯ = 0.9, ฯต = 1e-8)
ADADelta optimiser. Parameters don't need tuning.
Flux.Optimise.AMSGrad
โ Type.AMSGrad(ฮท = 0.001, ฮฒ = (0.9, 0.999))
AMSGrad optimiser. Parameters don't need tuning.
Flux.Optimise.NADAM
โ Type.NADAM(ฮท = 0.001, ฮฒ = (0.9, 0.999))
NADAM optimiser. Parameters don't need tuning.
Flux.Optimise.ADAMW
โ Function.ADAMW((ฮท = 0.001, ฮฒ = (0.9, 0.999), decay = 0)
ADAMW fixing weight decay regularization in Adam.