์ตœ์ ํ™”

์ตœ์ ํ™” ํ•จ์ˆ˜(Optimisers)

๊ฐ„๋‹จํ•œ ๋ฆฌ๋‹ˆ์–ด ๋ฆฌ๊ทธ๋ ˆ์…˜์—์„œ ์šฐ๋ฆฌ๋Š” ๋”๋ฏธ ๋ฐ์ดํ„ฐ๋ฅผ ๋งŒ๋“  ํ›„, ์†์‹ค(loss)์„ ๊ณ„์‚ฐํ•˜๊ณ  ์—ญ์ „ํŒŒ(backpropagate) ํ•˜์—ฌ ํŒŒ๋ผ๋ฏธํ„ฐ W์™€ b์˜ ๊ธฐ์šธ๊ธฐ๋ฅผ ๊ณ„์‚ฐํ•˜์˜€๋‹ค.

julia> using Flux

julia> W = param(rand(2, 5))
Tracked 2ร—5 Array{Float64,2}:
 0.215021  0.22422   0.352664  0.11115   0.040711
 0.180933  0.769257  0.361652  0.783197  0.545495

julia> b = param(rand(2))
Tracked 2-element Array{Float64,1}:
 0.205216
 0.150938

julia> predict(x) = W*x .+ b
predict (generic function with 1 method)

julia> loss(x, y) = sum((predict(x) .- y).^2)
loss (generic function with 1 method)

julia> x, y = rand(5), rand(2) # ๋”๋ฏธ ๋ฐ์ดํ„ฐ
([0.153473, 0.927019, 0.40597, 0.783872, 0.392236], [0.261727, 0.00917161])

julia> l = loss(x, y) # ~ 3
3.6352060699201565 (tracked)

julia> Flux.back!(l)

๊ธฐ์šธ๊ธฐ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์—…๋ฐ์ดํŠธ ํ•˜๊ณ ์ž ํ•œ๋‹ค. ์†์‹ค์„ ์ค„์ด๋ ค๊ณ  ๋ง์ด๋‹ค. ์—ฌ๊ธฐ์„œ ํ•œ๊ฐ€์ง€ ๋ฐฉ๋ฒ•์€:

function update()
  ฮท = 0.1 # ํ•™์Šตํ•˜๋Š” ์†๋„(Learning Rate)
  for p in (W, b)
    p.data .-= ฮท .* p.grad # ์—…๋ฐ์ดํŠธ ์ ์šฉ
    p.grad .= 0            # ๊ธฐ์šธ๊ธฐ 0์œผ๋กœ clear
  end
end

update๋ฅผ ํ˜ธ์ถœํ•˜๋ฉด ํŒŒ๋ผ๋ฏธํ„ฐ W์™€ b๋Š” ๋ฐ”๋€Œ๊ณ  ์†์‹ค(loss)์€ ๋‚ด๋ ค๊ฐ„๋‹ค.

๋‘๊ฐ€์ง€๋Š” ์งš๊ณ  ๋„˜์–ด๊ฐ€์ž: ๋ชจ๋ธ์—์„œ ํ›ˆ๋ จํ•  ํŒŒ๋ผ๋ฏธํ„ฐ์˜ ๋ชฉ๋ก (์—ฌ๊ธฐ์„œ๋Š” [W, b]), ๊ทธ๋ฆฌ๊ณ  ์—…๋ฐ์ดํŠธ ์ง„ํ–‰ ์†๋„. ์—ฌ๊ธฐ์„œ์˜ ์—…๋ฐ์ดํŠธ๋Š” ๊ฐ„๋‹จํ•œ gradient descent(๊ฒฝ์‚ฌ ํ•˜๊ฐ•, x .-= ฮท .* ฮ”) ์˜€์ง€๋งŒ, ๋ชจ๋ฉ˜ํ…€(momentum)์„ ์ถ”๊ฐ€ํ•˜๋Š” ๊ฒƒ์ฒ˜๋Ÿผ ๋ณด๋‹ค ์–ด๋ ค์šด ๊ฒƒ๋„ ํ•ด๋ณด๊ณ  ์‹ถ์„ ๊ฒƒ์ด๋‹ค.

์—ฌ๊ธฐ์„œ ๋ณ€์ˆ˜๋ฅผ ์–ป๋Š” ๊ฒƒ์€ ์•„๋ฌด๊ฒƒ๋„ ์•„๋‹ˆ์ง€๋งŒ, ๋ ˆ์ด์–ด๋ฅผ ๋ณต์žกํ•˜๊ฒŒ ์Œ“๋Š”๋‹ค๋ฉด ๊ณจ์น˜ ์ข€ ์•„ํ”Œ ๊ฒƒ์ด๋‹ค.

julia> m = Chain(
         Dense(10, 5, ฯƒ),
         Dense(5, 2), softmax)
Chain(Dense(10, 5, NNlib.ฯƒ), Dense(5, 2), NNlib.softmax)

[m[1].W, m[1].b, ...] ์ด๋ ‡๊ฒŒ ์ž‘์„ฑํ•˜๋Š” ๊ฒƒ ๋Œ€์‹ , Flux์—์„œ ์ œ๊ณตํ•˜๋Š” params(m) ํ•จ์ˆ˜๋ฅผ ์ด์šฉํ•ด ๋ชจ๋ธ์˜ ๋ชจ๋“  ํŒŒ๋ผ๋ฏธํ„ฐ์˜ ๋ชฉ๋ก์„ ๊ตฌํ•  ๊ฒƒ์ด๋‹ค.

julia> opt = SGD([W, b], 0.1) # Gradient descent(๊ฒฝ์‚ฌ ํ•˜๊ฐ•)์„ learning rate(ํ•™์Šต ์†๋„) 0.1 ์œผ๋กœ ํ•œ๋‹ค
(::#71) (generic function with 1 method)

julia> opt() # `W`์™€ `b`๋ฅผ ๋ณ€๊ฒฝํ•˜๋ฉฐ ์—…๋ฐ์ดํŠธ๋ฅผ ์ˆ˜ํ–‰ํ•œ๋‹ค

์ตœ์ ํ™” ํ•จ์ˆ˜๋Š” ํŒŒ๋ผ๋ฏธํ„ฐ ๋ชฉ๋ก์„ ๋ฐ›์•„ ์œ„์˜ update์™€ ๊ฐ™์€ ํ•จ์ˆ˜๋ฅผ ๋Œ๋ ค์ค€๋‹ค. opt๋‚˜ update๋ฅผ ํ›ˆ๋ จ ๋ฃจํ”„(training loop)์— ๋„˜๊ฒจ์ค„ ์ˆ˜ ์žˆ๋Š”๋ฐ, ๋งค๋ฒˆ ๋ฐ์ดํ„ฐ์˜ ๋ฏธ๋‹ˆ-๋ฐฐ์น˜(mini-batch)๋ฅผ ํ•œ ํ›„์— ์ตœ์ ํ™”๋ฅผ ์ˆ˜ํ–‰ํ•  ๊ฒƒ์ด๋‹ค.

์ตœ์ ํ™” ํ•จ์ˆ˜ ์ฐธ๊ณ 

๋ชจ๋“  ์ตœ์ ํ™” ํ•จ์ˆ˜๋Š” ๋„˜๊ฒจ๋ฐ›์€ ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ์—…๋ฐ์ดํŠธ ํ•˜๋Š” ํ•จ์ˆ˜๋ฅผ ๋Œ๋ ค์ค€๋‹ค.

Flux.Optimise.Descent โ€” Type.
Descent(ฮท)

Classic gradient descent optimiser with learning rate ฮท. For each parameter p and its gradient ฮดp, this runs p -= ฮท*ฮดp.

source
Momentum(params, ฮท = 0.01; ฯ = 0.9)

Gradient descent with learning rate ฮท and momentum ฯ.

source
Nesterov(eta, ฯ = 0.9)

Gradient descent with learning rate ฮท and Nesterov momentum ฯ.

source
Flux.Optimise.RMSProp โ€” Type.
RMSProp(ฮท = 0.001, ฯ = 0.9)

RMSProp optimiser. Parameters other than learning rate don't need tuning. Often a good choice for recurrent networks.

source
Flux.Optimise.ADAM โ€” Type.
ADAM(ฮท = 0.001, ฮฒ = (0.9, 0.999))

ADAM optimiser.

source
Flux.Optimise.AdaMax โ€” Type.
AdaMax(params, ฮท = 0.001; ฮฒ1 = 0.9, ฮฒ2 = 0.999, ฯต = 1e-08)

AdaMax optimiser. Variant of ADAM based on the โˆž-norm.

source
Flux.Optimise.ADAGrad โ€” Type.
ADAGrad(ฮท = 0.1; ฯต = 1e-8)

ADAGrad optimiser. Parameters don't need tuning.

source
ADADelta(ฯ = 0.9, ฯต = 1e-8)

ADADelta optimiser. Parameters don't need tuning.

source
Flux.Optimise.AMSGrad โ€” Type.
AMSGrad(ฮท = 0.001, ฮฒ = (0.9, 0.999))

AMSGrad optimiser. Parameters don't need tuning.

source
Flux.Optimise.NADAM โ€” Type.
NADAM(ฮท = 0.001, ฮฒ = (0.9, 0.999))

NADAM optimiser. Parameters don't need tuning.

source
Flux.Optimise.ADAMW โ€” Function.
ADAMW((ฮท = 0.001, ฮฒ = (0.9, 0.999), decay = 0)

ADAMW fixing weight decay regularization in Adam.

source