Computation Graph - why forward pass and backward pass

We use computation graph to explain why we have forward pass (forward propagation) follow by a backward pass (back-propagation)

Example: J(a, b, c) = 3(a + bc) where a=5, b=3, and c=2

To calculate J, we actually have to do 3 steps:
Step 1: calculate u = bc = 3x2 = 6
Step 2: calculate v = a + u = 5 + 6 = 11
Step 3 calculate J = 3(a + bc) = 3(a+u) = 3v = 3 x 11 = 33
This left to right computation is the forward pass.


When computing derivative, we do backward pass from right to left

If v = 11.001, then J = 33.003, so \(\frac{\partial J} {\partial v} = \frac{33.003 - 33}{11.001 - 11} = 3\)

If a = 5.001, then v = 5.001 + 6 = 11.001, then J = 33.001,
so \(\frac{\partial J} {\partial a} = \frac{33.003 - 33}{11.001 - 11} = 3 = 3 * 1 = \frac{\partial J} {\partial v} * \frac{\partial v} {\partial a} \) . This is chain rule of calculus
Similarly, \(\frac{\partial J} {\partial u} = \frac{\partial J} {\partial v} * \frac{\partial v} {\partial u}  = 3 * 1 = 3 \)

If b = 3.001, then u = 3.001 x 2 = 6.002, so \(\frac{\partial u} {\partial b} = \frac{6.002 - 6}{3.001 - 3} = 2\). Therefore,
\(\frac{\partial J} {\partial b} = \frac{\partial J} {\partial u} * \frac{\partial u} {\partial b} = 3 * 2 = 6 = \frac{\partial J} {\partial v} * \frac{\partial v} {\partial u} * \frac{\partial u} {\partial b} = 3 * 1 * 2 \)
\(\frac{\partial J} {\partial c} = \frac{\partial J} {\partial v} * \frac{\partial v} {\partial u} * \frac{\partial u} {\partial c} = 3 * 1 * 3 = 9 \)

Comments