Last active
May 28, 2024 10:27
-
-
Save flacle/a93ff64b80e85d3ee715d201f5f7b7b6 to your computer and use it in GitHub Desktop.
Manim Gradient Descent Intuition (in Papiamento)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Manim Gradient Descent Video Intuition | |
# Author: Francis Laclé | |
# Video: https://www.youtube.com/watch?v=1cCS6uK_NH8 | |
# Github: https://github.com/flacle | |
# Date: 29 Oct, 2020 | |
from manim import * | |
import math | |
class Intro(Scene): | |
def construct(self): | |
introText = PangoText('Gradient Descent', gradient=(BLUE, GREEN)).scale(2) | |
self.wait(1) | |
self.add(introText) | |
self.play(Write(introText)) | |
self.wait(11) | |
self.play(FadeOut(introText)) | |
self.wait(1) | |
class ThreeDSurface(ParametricSurface): | |
def __init__(self, **kwargs): | |
kwargs = { | |
"u_min": -2, | |
"u_max": 2, | |
"v_min": -2, | |
"v_max": 2, | |
"checkerboard_colors": [BLUE_D] | |
} | |
ParametricSurface.__init__(self, self.func, **kwargs) | |
def func(self, x, y): | |
return np.array([x,y,x**2 - y**2]) | |
class ConYPakicoTraha(ThreeDScene): | |
def construct(self): | |
axes = ThreeDAxes(animate=True) | |
surface = ThreeDSurface() | |
self.set_camera_orientation(phi=75 * DEGREES, theta=30 * DEGREES, distance=30) | |
self.begin_ambient_camera_rotation(rate=0.1) | |
self.wait(1) | |
self.play(ShowCreation(axes)) | |
self.play(ShowCreation(surface)) | |
self.wait(4) | |
self.move_camera(0.4*np.pi/1, -0.45*np.pi) | |
self.wait(4) | |
self.stop_ambient_camera_rotation() | |
self.play(FadeOut(surface)) | |
self.play(FadeOut(axes)) | |
class KostFunctie(Scene): | |
def construct(self): | |
J = Tex(r'$J\left(\cdot\cdot\cdot\right)$').scale(3) | |
Jmin = Tex(r'$\min{J\left(\cdot\cdot\cdot\right)}$').scale(3) | |
self.wait(1) | |
self.play(Write(J)) | |
self.wait(4) | |
self.play(ReplacementTransform(J, Jmin)) | |
self.wait(6) | |
self.play(FadeOut(Jmin)) | |
class OnderzoekAruba(GraphScene): | |
CONFIG = { | |
"y_axis_label": r"Poblacion di Aruba", | |
"x_axis_label": "Aña", | |
"y_max": 7, | |
"y_min": 0, | |
"y_tick_frequency" : 1, | |
"x_max": 9, | |
"x_min": 0, | |
"axes_color" : BLUE | |
} | |
def construct(self): | |
data = [1,1.243735763,1.673120729,2.258542141,2.940774487,3.641230068,4.287015945,4.891799544,5.454441913,6] | |
self.setup_axes() | |
line = self.get_graph(lambda x : (5/9*x)+1, | |
color = RED, | |
x_min = 0, | |
x_max = 9, | |
label="$J(x)$") | |
dot_collection = VGroup() | |
for time, dat in enumerate(data): | |
dot = Dot(color=YELLOW).move_to(self.coords_to_point(time, dat)) | |
dot_collection.add(dot) | |
self.play(FadeIn(dot), rate_func=rush_into) | |
self.wait(1) | |
self.play(ShowCreation(line),run_time = 2) | |
self.wait(1) | |
error_collection = VGroup() | |
for time, dat in enumerate(data): | |
error = Line( | |
self.coords_to_point(time, (5/9*time)+1), dot_collection[time].get_center(), | |
color=GREEN) | |
error_collection.add(error) | |
self.play(ShowCreation(error),run_time = 1) | |
self.wait(2) | |
self.play( | |
FadeOut(error_collection), | |
FadeOut(dot_collection), | |
FadeOut(line), | |
FadeOut(self.axes), | |
FadeOut(self.x_axis_labels), | |
FadeOut(self.y_axis_labels)) | |
self.play() | |
def setup_axes(self): | |
GraphScene.setup_axes(self) | |
self.x_axis.label_direction = UP | |
self.y_axis.label_direction = UP | |
values_x = [ | |
(0,"'09"), | |
(1,"'10"), | |
(2,"'11"), | |
(3,"'12"), | |
(4,"'13"), | |
(5,"'14"), | |
(6,"'15"), | |
(7,"'16"), | |
(8,"'17"), | |
(9,"'18") | |
] | |
values_y = [ | |
(0,"100.000"), | |
(1,"101.000"), | |
(2,"102.000"), | |
(3,"103.000"), | |
(4,"104.000"), | |
(5,"105.000"), | |
(6,"106.000") | |
] | |
self.x_axis_labels = VGroup() | |
self.y_axis_labels = VGroup() | |
# pos. tex. | |
for x_val, x_tex in values_x: | |
tex = PangoText(x_tex).scale(0.6) | |
tex.next_to(self.coords_to_point(x_val, 0), DOWN) #Put tex on the position | |
self.x_axis_labels.add(tex) #Add tex in graph | |
for y_val, y_tex in values_y: | |
tex = PangoText(y_tex).scale(0.6) | |
tex.next_to(self.coords_to_point(0, y_val), LEFT) #Put tex on the position | |
self.y_axis_labels.add(tex) #Add tex in graph | |
self.play( | |
Write(self.x_axis_labels), | |
Write(self.x_axis), | |
Write(self.y_axis_labels), | |
Write(self.y_axis), | |
) | |
class Hypothese(Scene): | |
def construct(self): | |
# 0 , 1 , 2 , 3 , 4 , 5 | |
h = MathTex("h_\\theta\\left(x\\right)","=","\\theta_0","+","{\\theta_1}","x").scale(2) | |
h3= MathTex("h_\\theta\\left(3\\right)","=","\\theta_0","+","{\\theta_1}","3").scale(2) | |
h9= MathTex("9","=","\\theta_0","+","{\\theta_1}","3").scale(2) | |
self.wait(1) | |
self.play(Write(h)) | |
self.wait(6) | |
framebox1 = SurroundingRectangle(h[4], buff = .1) # theta_1 | |
framebox2 = SurroundingRectangle(h[2], buff = .1) # theta_0 | |
framebox3 = SurroundingRectangle(h[0], buff = .1) # left-side | |
self.play( | |
ShowCreation(framebox1), | |
) | |
self.wait(2) | |
self.play( | |
ReplacementTransform(framebox1,framebox2), | |
) | |
self.wait(1) | |
self.play( | |
ReplacementTransform(framebox2,framebox3), | |
) | |
self.wait(3) | |
self.play(FadeOut(framebox3)) | |
self.wait(1) | |
self.play(ReplacementTransform(h, h3)) | |
self.wait(3) | |
self.play(ReplacementTransform(h3, h9)) | |
self.wait(1) | |
self.play(FadeOut(h9)) | |
class CombinacionJmin(Scene): | |
def construct(self): | |
Jmin1 = Tex(r'$\min{J\left(\theta_0, \theta_1)}\to{9}$').scale(2) | |
Jmin2 = Tex(r'$\min{J\left(0, 7)}\to{9}$').scale(2) | |
Jmin3 = Tex(r'$\min{J\left(2, 5)}\to{9}$').scale(2) | |
Jmin4 = Tex(r'$\min{J\left(-3, 1)}\to{9}$').scale(2) | |
Jmin5 = Tex(r'$\min{J\left(-2, 2)}\to{9}$').scale(2) | |
Jmin6 = Tex(r'$\min{J\left(-1, 3)}\to{9}$').scale(2) | |
self.wait(1) | |
self.play(Write(Jmin1)) | |
self.wait(3) | |
self.play(ReplacementTransform(Jmin1, Jmin2)) | |
self.play(ReplacementTransform(Jmin2, Jmin3)) | |
self.play(ReplacementTransform(Jmin3, Jmin4)) | |
self.play(ReplacementTransform(Jmin4, Jmin5)) | |
self.play(ReplacementTransform(Jmin5, Jmin6)) | |
self.wait(8) | |
self.play(FadeOut(Jmin6)) | |
class SomDifferencia(Scene): | |
def construct(self): | |
# 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 | |
Jmin = MathTex("J(\\theta_{0}, \\theta_{1})", "=", "\\frac{1}{2m}", "\\sum\\limits_{i=1}^m", "(", "h_{\\theta}(x^{(i)})", "-", "y^{(i)}", ")^2").scale(1) | |
framebox1 = SurroundingRectangle(Jmin[3], buff = .1) # sum | |
framebox2 = SurroundingRectangle(Jmin[5], buff = .1) # h | |
framebox3 = SurroundingRectangle(Jmin[7], buff = .1) # y | |
self.wait(1) | |
self.play(Write(Jmin)) | |
self.wait(5) | |
self.play( | |
ShowCreation(framebox1), | |
) | |
self.wait(1) | |
self.play( | |
ReplacementTransform(framebox1,framebox2), | |
) | |
self.play( | |
ReplacementTransform(framebox2,framebox3), | |
) | |
self.wait(2) | |
self.play(FadeOut(framebox3)) | |
self.wait(9) | |
self.play(FadeOut(Jmin)) | |
class GradientDescentDilanti(MovingCameraScene): | |
def construct(self): | |
gd = PangoText('Gradient Descent', gradient=(BLUE, GREEN)).scale(2) | |
self.wait(3) | |
self.play(Write(gd)) | |
self.wait(1) | |
self.play(self.camera_frame.set_width, gd.get_width() * 1.2) | |
self.wait(2) | |
self.play(FadeOut(gd)) | |
self.wait(1) | |
Jmin1 = Tex(r'$\min{J\left(\theta_0, \theta_1)}$').scale(1) | |
Jmin2 = Tex(r'$\min{J\left(\theta_0, \theta_1, \theta_2)}$').scale(1) | |
Jmin3 = Tex(r'$\min{J\left(\theta_0, \theta_1, \theta_2, \theta_3)}$').scale(1) | |
Jmin4 = Tex(r'$\min{J\left(\theta_0, \theta_1, \theta_2, \theta_3, \theta_4)}$').scale(1) | |
Jmin5 = Tex(r'$\min{J\left(\theta_0, \theta_1, \theta_2, \theta_3, \theta_4, \theta_5)}$').scale(1) | |
Jmin6 = Tex(r'$\min{J\left(\theta_0, \theta_1)}$').scale(1) | |
self.play(Write(Jmin1)) | |
self.wait(1) | |
self.play(ReplacementTransform(Jmin1, Jmin2)) | |
self.play(ReplacementTransform(Jmin2, Jmin3)) | |
self.play(ReplacementTransform(Jmin3, Jmin4)) | |
self.play(ReplacementTransform(Jmin4, Jmin5)) | |
self.wait(1) | |
self.play(ReplacementTransform(Jmin5, Jmin6)) | |
self.wait(3) | |
self.play(FadeOut(Jmin6)) | |
class KedaRipitiYGana(MovingCameraScene): | |
def construct(self): | |
thetaJ = Tex(r'$\theta_j := $', r'$\theta_j - \alpha \frac{\partial}{\partial\theta_j}J\left(\theta_0, \theta_1\right)$').scale(1) | |
simul = Tex(r'(update parew pa $j=0$ y $j=1$ !)').scale(0.75).move_to(2 * DOWN) | |
self.wait(1) | |
self.play(Write(thetaJ), Write(simul)) | |
brace1 = Brace(thetaJ[1], UP, buff=SMALL_BUFF) | |
t1 = brace1.get_text("ripiti te ora e converge") | |
self.play( | |
GrowFromCenter(brace1), | |
FadeIn(t1), | |
) | |
self.wait(9) | |
self.play(FadeOut(t1), FadeOut(brace1), FadeOut(simul)) | |
self.play(self.camera_frame.set_width, thetaJ.get_width() * 1.6) | |
self.play(FadeOut(thetaJ)) | |
class CordaCalculus(GraphScene): | |
CONFIG = { | |
"y_axis_label": r"$y$", | |
"x_axis_label": r"$x$", | |
"y_max": 10, | |
"y_min": 0, | |
"y_tick_frequency" : 1, | |
"x_max": 10, | |
"x_min": 0, | |
"axes_color" : BLUE | |
} | |
def construct(self): | |
self.wait(1) | |
deriv = Tex(r'$\frac{dy}{dx}$').scale(3) | |
self.play(Write(deriv)) | |
self.wait(5) | |
self.play(FadeOut(deriv)) | |
self.setup_axes(animate=True) | |
def graph_to_be_drawn(x): | |
return (x-5)**2 | |
def dx(x): | |
return 2*(x-5) | |
parabola = self.get_graph( | |
lambda x: graph_to_be_drawn(x), | |
x_min=2, | |
x_max=8, | |
color=YELLOW, | |
stroke_opacity=0.5) | |
vt = ValueTracker(0) | |
def moving_dot(): | |
x = vt.get_value() | |
d = Dot(color=WHITE).move_to(self.coords_to_point(x, graph_to_be_drawn(x))) | |
return d | |
md = always_redraw(moving_dot) | |
def get_w_line(): | |
t = TangentLine(md, 1.0, length=2, stroke_opacity=1, color=RED) | |
x = vt.get_value() | |
t.move_to(self.coords_to_point(x, graph_to_be_drawn(x))) | |
# seems to be some rounding error? dx(x) is correct (manshrug) | |
inter = match_interpolate(0.6, -0.6, 3, 7, x) | |
t.rotate(math.atan2(-1,dx(x+inter))) | |
return t | |
vt.set_value(3) | |
line = always_redraw(get_w_line) | |
self.play(ShowCreation(parabola), FadeIn(md), FadeIn(line)) | |
self.wait(1) | |
self.play(vt.set_value, 7, rate_func=there_and_back, run_time=4) | |
self.wait(1) | |
self.play(vt.set_value, 5, rate_func=slow_into, run_time=4) | |
self.wait(6) | |
self.play(vt.set_value, 3, rate_func=slow_into, run_time=1) | |
self.play(vt.set_value, 5, rate_func=slow_into, run_time=6) | |
self.wait(1) | |
self.play(FadeOut(parabola), FadeOut(md), FadeOut(line), FadeOut(self.axes)) | |
class MinTekenMeiMei(MovingCameraScene): | |
def construct(self): | |
thetaJ = Tex(r'$\theta_j := $', r'$\theta_j - \alpha \frac{\partial}{\partial\theta_j}J\left(\theta_0, \theta_1\right)$').scale(1) | |
simul = Tex(r'(update parew pa $j=0$ y $j=1$ !)').scale(0.75).move_to(2 * DOWN) | |
brace1 = Brace(thetaJ[1], UP, buff=SMALL_BUFF) | |
framebox1 = SurroundingRectangle(thetaJ[1], buff = .1) # theta_1 | |
t1 = brace1.get_text("ripiti te ora e converge") | |
self.wait(1) | |
self.play( | |
Write(thetaJ), | |
Write(simul), | |
GrowFromCenter(brace1), | |
FadeIn(t1)) | |
self.play(FadeIn(framebox1)) | |
self.wait(3) | |
thetaPlus = Tex(r'$\frac{\partial}{\partial\theta_j}J\left(\theta_0, \theta_1\right) > 0 \to$ descent').scale(1) | |
self.play( | |
FadeOut(brace1), | |
FadeOut(t1), | |
FadeOut(simul), | |
FadeOut(framebox1), | |
ReplacementTransform(thetaJ, thetaPlus)) | |
self.wait(4) | |
thetaMin = Tex(r'$\frac{\partial}{\partial\theta_j}J\left(\theta_0, \theta_1\right) < 0 \to$ ascent').scale(1) | |
self.play(ReplacementTransform(thetaPlus, thetaMin)) | |
self.wait(4) | |
thetaJ2 = Tex(r'$\theta_j := $', r'$\theta_j - \alpha \frac{\partial}{\partial\theta_j}J\left(\theta_0, \theta_1\right)$').scale(1) | |
self.play(ReplacementTransform(thetaMin, thetaJ2)) | |
self.wait(4) | |
self.play(FadeOut(thetaJ2)) | |
class Alpha(GraphScene): | |
CONFIG = { | |
"y_axis_label": r"$y$", | |
"x_axis_label": r"$x$", | |
"y_max": 10, | |
"y_min": 0, | |
"y_tick_frequency" : 1, | |
"x_max": 10, | |
"x_min": 0, | |
"axes_color" : BLUE | |
} | |
def construct(self): | |
self.wait(1) | |
alpha = Tex(r'$\alpha$').scale(3).shift(0) | |
self.play(Write(alpha)) | |
self.play(ApplyMethod(alpha.shift, (UP+RIGHT)*PI)) | |
self.setup_axes(animate=True) | |
def graph_to_be_drawn(x): | |
return (x-5)**2 | |
def dx(x): | |
return 2*(x-5) | |
parabola = self.get_graph( | |
lambda x: graph_to_be_drawn(x), | |
x_min=2, | |
x_max=8, | |
color=YELLOW, | |
stroke_opacity=0.5) | |
vt = ValueTracker(0) | |
def moving_dot(): | |
x = vt.get_value() | |
d = Dot(color=WHITE).move_to(self.coords_to_point(x, graph_to_be_drawn(x))) | |
return d | |
md = always_redraw(moving_dot) | |
def get_w_line(): | |
t = TangentLine(md, 1.0, length=2, stroke_opacity=1, color=RED) | |
x = vt.get_value() | |
t.move_to(self.coords_to_point(x, graph_to_be_drawn(x))) | |
# seems to be some rounding error? dx(x) is correct (manshrug) | |
inter = match_interpolate(0.6, -0.6, 3, 7, x) | |
t.rotate(math.atan2(-1,dx(x+inter))) | |
return t | |
vt.set_value(3) | |
line = always_redraw(get_w_line) | |
self.play( | |
ShowCreation(parabola), | |
FadeIn(md), | |
FadeIn(line), | |
ApplyMethod(alpha.scale, (1/2))) | |
alpha2 = Tex(r'$\alpha = 2.0$') | |
self.play(ReplacementTransform(alpha, alpha2), run_time=0.5) | |
self.play(vt.set_value, 6, rate_func=there_and_back, run_time=2) | |
alpha3 = Tex(r'$\alpha = 1.1$') | |
self.play(ReplacementTransform(alpha2, alpha3), run_time=0.5) | |
self.play(vt.set_value, 4, rate_func=there_and_back, run_time=2) | |
self.play(vt.set_value, 5, rate_func=slow_into, run_time=12) | |
alpha4 = Tex(r'$\alpha = 0.05$') | |
self.play(ReplacementTransform(alpha3, alpha4), run_time=0.5) | |
self.play(vt.set_value, 4, rate_func=there_and_back, run_time=6) | |
alpha5 = Tex(r'$\alpha = 2.2$') | |
self.play(ReplacementTransform(alpha4, alpha5), run_time=0.5) | |
self.play(vt.set_value, 3, rate_func=rush_into, run_time=2) | |
self.play(vt.set_value, 7, rate_func=rush_into, run_time=3) | |
self.wait(3) | |
self.play( | |
FadeOut(parabola), | |
FadeOut(md), | |
FadeOut(line), | |
FadeOut(self.axes), | |
FadeOut(alpha5)) | |
class SaddlePoint(ThreeDScene): | |
def construct(self): | |
axes = ThreeDAxes(animate=True) | |
surface = ThreeDSurface() | |
self.set_camera_orientation(phi=75 * DEGREES, theta=30 * DEGREES, distance=30) | |
self.begin_ambient_camera_rotation(rate=0.1) | |
self.wait(1) | |
self.play(ShowCreation(axes)) | |
self.play(ShowCreation(surface)) | |
self.wait(50) | |
self.play(FadeOut(surface)) | |
self.play(FadeOut(axes)) | |
class TipoDiGradientDescent(Scene): | |
def construct(self): | |
grad1 = Tex(r'SGD').scale(2).shift(0) | |
grad2 = Tex(r'RMSprop').scale(2).shift(0) | |
grad3 = Tex(r'Adam').scale(2).shift(0) | |
grad4 = Tex(r'Adadelta').scale(2).shift(0) | |
grad5 = Tex(r'Adagrad').scale(2).shift(0) | |
grad6 = Tex(r'Adamax').scale(2).shift(0) | |
grad7 = Tex(r'Nadam').scale(2).shift(0) | |
grad8 = Tex(r'Ftrl').scale(2).shift(0) | |
grad9 = Tex(r'BGD').scale(2).shift(0) | |
grad2.next_to(grad1, DOWN*1.5) | |
grad3.next_to(grad1, UP*1.5) | |
grad4.next_to(grad1, LEFT*1.5) | |
grad5.next_to(grad1, RIGHT*1.5) | |
grad6.next_to(grad4, UP*2) | |
grad7.next_to(grad5, UP*2) | |
grad8.next_to(grad5, DOWN*2) | |
grad9.next_to(grad4, DOWN*2) | |
self.wait(1) | |
self.play(Write(grad1)) | |
self.play(Write(grad2)) | |
self.play(Write(grad3)) | |
self.play(Write(grad4)) | |
self.play(Write(grad5)) | |
self.play(Write(grad6)) | |
self.play(Write(grad7)) | |
self.play(Write(grad8)) | |
self.play(Write(grad9)) | |
self.wait(12) | |
self.play( | |
FadeOut(grad9), | |
FadeOut(grad8), | |
FadeOut(grad7), | |
FadeOut(grad6), | |
FadeOut(grad5), | |
FadeOut(grad4), | |
FadeOut(grad3), | |
FadeOut(grad2), | |
FadeOut(grad1) | |
) | |
class Outro(Scene): | |
def construct(self): | |
outroText = PangoText('Masha Danki!', gradient=(BLUE, GREEN)).scale(2) | |
self.wait(1) | |
self.add(outroText) | |
self.play(Write(outroText)) | |
self.wait(3) | |
self.play(FadeOut(outroText)) | |
self.wait(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment