Softmax Function and Probability Simplex¶
Definition¶
The softmax function maps a vector of \(C\) real-valued logits \(\mathbf{z}=(z_1,\ldots,z_C)\) to a probability distribution:
Properties¶
- Non-negative: Every output is strictly positive.
- Sums to one: \(\sum_c\operatorname{softmax}(\mathbf{z})_c = 1\).
- Monotone: A larger logit \(z_c\) produces a larger probability.
- Shift invariance: \(\operatorname{softmax}(\mathbf{z}+\alpha\mathbf{1}) = \operatorname{softmax}(\mathbf{z})\) for any scalar \(\alpha\).
Property 4 is exploited for numerical stability: before exponentiation we subtract \(\max_c z_c\).
The Probability Simplex¶
The output of softmax lies on the probability simplex
For \(C=3\) this is a triangle in 3D space; for \(C=10\) (MNIST) it is a 9-dimensional simplex.
Softmax as a Generalization of the Sigmoid¶
When \(C=2\) with logits \((z_1,z_2)\):
Thus the binary softmax is exactly the sigmoid applied to the difference of the two logits.
Temperature Scaling¶
A common variant introduces a temperature parameter \(\tau>0\):
As \(\tau\to 0\) the distribution collapses to a point mass on \(\arg\max_c z_c\) (hard decision); as \(\tau\to\infty\) it approaches the uniform distribution. Temperature scaling is used in model calibration and in generative models (e.g. controlling the "creativity" of language models).
NumPy Implementation¶
import numpy as np
def softmax(z):
"""Numerically stable softmax."""
z_shifted = z - np.max(z, axis=1, keepdims=True)
exp_z = np.exp(z_shifted)
return exp_z / np.sum(exp_z, axis=1, keepdims=True)
The np.max subtraction prevents overflow in np.exp without
changing the result (shift invariance).