SAM Algorithm Explained

Step 1: Perturb in the “bad” direction

Add perturbation \(\varepsilon = \rho \cdot \frac{\nabla L(\theta)}{\|\nabla L(\theta)\|}\) to current parameters:

  • This moves you uphill toward higher loss
  • Like “undoing” part of your last optimization step

Step 2: Compute gradient from this worse position

Calculate \(\nabla L(\theta + \varepsilon)\):

  • This shows how to descend from the perturbed (worse) point

Step 3: Apply that gradient to original \(\theta\)

Update rule:

\[\theta \leftarrow \theta - \eta \cdot \nabla L(\theta + \varepsilon)\]

where \(\eta\) is the learning rate.

Why This Works Intuitively

Think of it as asking: “If I accidentally took a step in the wrong direction (uphill), what gradient would get me back on track?”

  • In a sharp valley: Small uphill perturbation puts you on a steep slope; the gradient there points strongly away from the sharpness
  • In a flat basin: Small perturbation barely changes the gradient, so you get similar direction to normal SGD

Complete SAM Algorithm

\[\begin{align} g &= \nabla L(\theta) \\ \varepsilon &= \rho \cdot \frac{g}{\|g\|} \\ g_{\text{SAM}} &= \nabla L(\theta + \varepsilon) \\ \theta &\leftarrow \theta - \eta \cdot g_{\text{SAM}} \end{align}\]

Analogy

It’s like training for worst-case scenarios:

  • Don’t just optimize from where you are
  • Optimize assuming you might be slightly off (perturbed)
  • This makes you robust to being near sharp regions

Summary: Perturb uphill (make loss worse), then descend from there. This two-step process implicitly biases you toward flatter minima.