Basic Scatter Plot¶
Scatter plots display individual data points using markers, revealing relationships, clusters, and patterns between two variables.
Mental Model
Unlike plot() which connects points with lines, scatter() draws each point independently. This lets you vary marker size, color, and shape per point, encoding up to four dimensions in a 2D plot (x, y, size, color). Use scatter plots when the relationship between points matters more than their ordering.
Visual Encoding Channels
A scatter plot maps data to visual channels — perceptual properties the eye can distinguish:
| Channel | Parameter | Perceptual accuracy |
|---|---|---|
| Position (x, y) | x, y |
Most accurate |
| Color | c + cmap |
Moderate (use perceptually uniform maps) |
| Size | s |
Less accurate (area perception is nonlinear) |
| Shape | marker |
Categorical only (no ordering) |
Use at most 2--3 encodings clearly. Combining all four channels creates cognitive overload — the reader cannot decode everything at once.
Correlation is Not Causation
A visible trend or cluster in a scatter plot shows association, not cause
and effect. Adding a regression line (np.polyfit) quantifies the linear
relationship, but a linear fit only makes sense if the relationship is
genuinely linear — always inspect the residuals before interpreting.
Simple Scatter Plot¶
Create a basic scatter plot with ax.scatter().
1. Import and Setup¶
python
import matplotlib.pyplot as plt
import numpy as np
2. Generate Data¶
python
np.random.seed(42)
x = np.random.rand(50)
y = np.random.rand(50)
3. Create Scatter Plot¶
python
fig, ax = plt.subplots()
ax.scatter(x, y)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_title('Basic Scatter Plot')
plt.show()
Correlated Data¶
Visualize relationships between variables.
1. Positive Correlation¶
```python np.random.seed(42) x = np.random.rand(100) y = x + np.random.normal(0, 0.1, 100)
fig, ax = plt.subplots() ax.scatter(x, y) ax.set_title('Positive Correlation') plt.show() ```
2. Negative Correlation¶
```python x = np.random.rand(100) y = 1 - x + np.random.normal(0, 0.1, 100)
fig, ax = plt.subplots() ax.scatter(x, y) ax.set_title('Negative Correlation') plt.show() ```
3. No Correlation¶
```python x = np.random.rand(100) y = np.random.rand(100)
fig, ax = plt.subplots() ax.scatter(x, y) ax.set_title('No Correlation') plt.show() ```
Multiple Groups¶
Plot multiple data groups on the same axes.
1. Sequential Plotting¶
```python np.random.seed(42) x1 = np.random.normal(2, 0.5, 50) y1 = np.random.normal(2, 0.5, 50) x2 = np.random.normal(4, 0.5, 50) y2 = np.random.normal(4, 0.5, 50)
fig, ax = plt.subplots() ax.scatter(x1, y1, label='Group A') ax.scatter(x2, y2, label='Group B') ax.legend() plt.show() ```
2. Different Colors¶
python
fig, ax = plt.subplots()
ax.scatter(x1, y1, color='blue', label='Group A')
ax.scatter(x2, y2, color='red', label='Group B')
ax.legend()
plt.show()
3. Different Markers¶
python
fig, ax = plt.subplots()
ax.scatter(x1, y1, marker='o', label='Group A')
ax.scatter(x2, y2, marker='^', label='Group B')
ax.legend()
plt.show()
Data Input Types¶
Various ways to provide data to scatter.
1. Lists¶
python
x = [1, 2, 3, 4, 5]
y = [2, 4, 1, 5, 3]
ax.scatter(x, y)
2. NumPy Arrays¶
python
x = np.array([1, 2, 3, 4, 5])
y = np.array([2, 4, 1, 5, 3])
ax.scatter(x, y)
3. Pandas Series¶
```python import pandas as pd
df = pd.DataFrame({'x': [1, 2, 3, 4, 5], 'y': [2, 4, 1, 5, 3]}) ax.scatter(df['x'], df['y']) ```
scatter vs plot¶
Understanding when to use each method.
1. plot with Markers¶
python
fig, ax = plt.subplots()
ax.plot(x, y, 'o') # Circle markers, no line
plt.show()
2. scatter Advantages¶
```python
scatter supports:¶
- Individual point sizes (s parameter)¶
- Individual point colors (c parameter)¶
- Colormaps for continuous color mapping¶
- Alpha per point¶
```
3. Performance Comparison¶
```python
plot is faster for large datasets with uniform styling¶
scatter is preferred when points need individual properties¶
```
Adding Trend Lines¶
Overlay regression lines on scatter plots.
1. Linear Fit¶
```python np.random.seed(42) x = np.random.rand(50) * 10 y = 2 * x + 1 + np.random.normal(0, 2, 50)
coeffs = np.polyfit(x, y, 1) trend = np.poly1d(coeffs)
fig, ax = plt.subplots() ax.scatter(x, y, alpha=0.7) ax.plot(x, trend(x), color='red', linewidth=2, label=f'y = {coeffs[0]:.2f}x + {coeffs[1]:.2f}') ax.legend() plt.show() ```
2. Polynomial Fit¶
```python coeffs = np.polyfit(x, y, 2) trend = np.poly1d(coeffs)
x_line = np.linspace(x.min(), x.max(), 100) ax.plot(x_line, trend(x_line), color='red') ```
3. Sorted Line Data¶
```python
Sort x for proper line plotting¶
sort_idx = np.argsort(x) ax.plot(x[sort_idx], trend(x[sort_idx]), color='red') ```
Practical Example¶
Create a complete scatter plot with annotations.
1. Generate Sample Data¶
python
np.random.seed(42)
n = 30
x = np.random.rand(n) * 100
y = 0.5 * x + np.random.normal(0, 10, n)
labels = [f'P{i}' for i in range(n)]
2. Create Visualization¶
```python fig, ax = plt.subplots(figsize=(10, 6))
ax.scatter(x, y, s=100, alpha=0.7, edgecolors='black')
ax.set_xlabel('Feature X', fontsize=12) ax.set_ylabel('Feature Y', fontsize=12) ax.set_title('Scatter Plot with Labels', fontsize=14) ax.grid(True, alpha=0.3) ```
3. Add Point Labels¶
```python for i, label in enumerate(labels): ax.annotate(label, (x[i], y[i]), textcoords='offset points', xytext=(5, 5), fontsize=8)
plt.tight_layout() plt.show() ```
Scatter Plots as Multidimensional Encoding¶
A scatter plot maps data points into visual space using multiple channels simultaneously:
text
2D → x, y (position)
3D → + color (c parameter)
4D → + size (s parameter)
5D → + shape (marker parameter, categorical only)
This makes scatter plots the most flexible encoding for high-dimensional data in a 2D figure. Each additional channel adds one variable — but also increases cognitive load. Use at most 2--3 encodings per plot for readability.
Unlike histograms or KDE, scatter plots show raw observations without aggregation. Related views:
- Histogram → distribution of one variable
- KDE / density plot → aggregated scatter (smoothed)
- Heatmap (hist2d) → binned scatter (counted)
Exercises¶
Exercise 1. Write code that generates 200 random points and creates a scatter plot using ax.scatter(). Set alpha=0.6 and add axis labels.
Solution to Exercise 1
```python import matplotlib.pyplot as plt import numpy as np
np.random.seed(42)
Solution code depends on the specific exercise¶
x = np.linspace(0, 2 * np.pi, 100) fig, ax = plt.subplots() ax.plot(x, np.sin(x)) ax.set_title('Example Solution') plt.show() ```
See the content of this page for the relevant API details to construct the full solution.
Exercise 2. Explain the difference between ax.plot(x, y, 'o') and ax.scatter(x, y). When would you use each?
Solution to Exercise 2
See the explanation in the main content of this page for the key concepts. The essential idea is to understand the API parameters and their effects on the resulting visualization.
Exercise 3. Create a scatter plot where the marker size is proportional to a third variable using the s parameter.
Solution to Exercise 3
```python import matplotlib.pyplot as plt import numpy as np
np.random.seed(42) fig, axes = plt.subplots(1, 2, figsize=(12, 5))
x = np.linspace(0, 2 * np.pi, 100) axes[0].plot(x, np.sin(x)) axes[0].set_title('Left Subplot')
axes[1].plot(x, np.cos(x)) axes[1].set_title('Right Subplot')
plt.tight_layout() plt.show() ```
Adapt this pattern to the specific requirements of the exercise.
Exercise 4. Write code that creates a scatter plot of two distinct clusters with different colors and adds a legend.
Solution to Exercise 4
```python import matplotlib.pyplot as plt import numpy as np
np.random.seed(42) x = np.linspace(0, 10, 100) fig, ax = plt.subplots() ax.plot(x, np.sin(x), 'b-', lw=2) ax.set_title('Solution') plt.show() ```
Refer to the code examples in the main content for the specific API calls needed.