SAM Algorithm Explained
SAM Algorithm Explained
Step 1: Perturb in the “bad” direction
Add perturbation \(\varepsilon = \rho \cdot \frac{\nabla L(\theta)}{\|\nabla L(\theta)\|}\) to current parameters:
- This moves you uphill toward higher loss
- Like “undoing” part of your last optimization step
Step 2: Compute gradient from this worse position
Calculate \(\nabla L(\theta + \varepsilon)\):
- This shows how to descend from the perturbed (worse) point
Step 3: Apply that gradient to original \(\theta\)
Update rule:
\[\theta \leftarrow \theta - \eta \cdot \nabla L(\theta + \varepsilon)\]where \(\eta\) is the learning rate.
Why This Works Intuitively
Think of it as asking: “If I accidentally took a step in the wrong direction (uphill), what gradient would get me back on track?”
- In a sharp valley: Small uphill perturbation puts you on a steep slope; the gradient there points strongly away from the sharpness
- In a flat basin: Small perturbation barely changes the gradient, so you get similar direction to normal SGD
Complete SAM Algorithm
\[\begin{align} g &= \nabla L(\theta) \\ \varepsilon &= \rho \cdot \frac{g}{\|g\|} \\ g_{\text{SAM}} &= \nabla L(\theta + \varepsilon) \\ \theta &\leftarrow \theta - \eta \cdot g_{\text{SAM}} \end{align}\]Analogy
It’s like training for worst-case scenarios:
- Don’t just optimize from where you are
- Optimize assuming you might be slightly off (perturbed)
- This makes you robust to being near sharp regions
Summary: Perturb uphill (make loss worse), then descend from there. This two-step process implicitly biases you toward flatter minima.
Enjoy Reading This Article?
Here are some more articles you might like to read next: