scatter_matrix¶
The pd.plotting.scatter_matrix() function creates a grid of scatter plots showing pairwise relationships between numeric columns. Diagonal plots show the distribution of each variable.
Basic Usage¶
import pandas as pd
import matplotlib.pyplot as plt
url = 'https://raw.githubusercontent.com/justmarkham/DAT8/master/data/drinks.csv'
df = pd.read_csv(url)
# Create scatter matrix
pd.plotting.scatter_matrix(df[['beer_servings', 'spirit_servings', 'wine_servings']],
figsize=(10, 10))
plt.tight_layout()
plt.show()
Understanding the Output¶
For n numeric columns, the scatter matrix creates an n×n grid:
- Off-diagonal cells: Scatter plots showing relationship between two variables
- Diagonal cells: Distribution of each variable (histogram by default)
beer spirit wine
┌────────┬────────┬────────┐
beer │ hist │ scatter│ scatter│
├────────┼────────┼────────┤
spirit │ scatter│ hist │ scatter│
├────────┼────────┼────────┤
wine │ scatter│ scatter│ hist │
└────────┴────────┴────────┘
Key Parameters¶
figsize - Figure Size¶
pd.plotting.scatter_matrix(df, figsize=(12, 12))
diagonal - Distribution Plot Type¶
# Histogram (default)
pd.plotting.scatter_matrix(df, diagonal='hist')
# Kernel density estimate
pd.plotting.scatter_matrix(df, diagonal='kde')
alpha - Point Transparency¶
pd.plotting.scatter_matrix(df, alpha=0.5)
marker - Point Style¶
pd.plotting.scatter_matrix(df, marker='o')
pd.plotting.scatter_matrix(df, marker='.')
pd.plotting.scatter_matrix(df, marker='+')
s - Point Size¶
pd.plotting.scatter_matrix(df, s=50) # Larger points
pd.plotting.scatter_matrix(df, s=10) # Smaller points
hist_kwds - Histogram Customization¶
pd.plotting.scatter_matrix(
df,
diagonal='hist',
hist_kwds={'bins': 20, 'edgecolor': 'black'}
)
density_kwds - KDE Customization¶
pd.plotting.scatter_matrix(
df,
diagonal='kde',
density_kwds={'linewidth': 2}
)
ax - Specify Axes Array¶
fig, axes = plt.subplots(3, 3, figsize=(12, 12))
pd.plotting.scatter_matrix(df[['col1', 'col2', 'col3']], ax=axes)
plt.show()
Practical Example: Iris Dataset¶
from sklearn.datasets import load_iris
import pandas as pd
import matplotlib.pyplot as plt
# Load iris data
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
# Create scatter matrix
fig = pd.plotting.scatter_matrix(
df,
figsize=(12, 12),
diagonal='kde',
alpha=0.6,
marker='o',
s=30
)
plt.suptitle('Iris Dataset - Pairwise Relationships', y=1.02)
plt.tight_layout()
plt.show()
Practical Example: Financial Data¶
import yfinance as yf
import pandas as pd
import matplotlib.pyplot as plt
# Download multiple stock prices
tickers = ['AAPL', 'MSFT', 'GOOGL', 'AMZN']
df = yf.download(tickers, start='2023-01-01', end='2024-01-01')['Close']
# Calculate returns
returns = df.pct_change().dropna()
# Scatter matrix of returns
pd.plotting.scatter_matrix(
returns,
figsize=(10, 10),
diagonal='kde',
alpha=0.3,
marker='.'
)
plt.suptitle('Stock Return Correlations', y=1.02)
plt.tight_layout()
plt.show()
Practical Example: Drinks Dataset¶
url = 'https://raw.githubusercontent.com/justmarkham/DAT8/master/data/drinks.csv'
df = pd.read_csv(url)
# Select numeric columns for alcohol consumption
alcohol_cols = ['beer_servings', 'spirit_servings', 'wine_servings', 'total_litres_of_pure_alcohol']
fig, ax = plt.subplots(figsize=(12, 12))
pd.plotting.scatter_matrix(
df[alcohol_cols],
diagonal='hist',
alpha=0.5,
hist_kwds={'bins': 15, 'edgecolor': 'black'}
)
plt.tight_layout()
plt.show()
Color by Category¶
To color points by a categorical variable, use matplotlib directly:
from sklearn.datasets import load_iris
import pandas as pd
import matplotlib.pyplot as plt
iris = load_iris()
df = pd.DataFrame(iris.data, columns=['sepal_l', 'sepal_w', 'petal_l', 'petal_w'])
df['species'] = iris.target
# Create figure
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
colors = ['red', 'green', 'blue']
columns = ['sepal_l', 'sepal_w', 'petal_l', 'petal_w']
for i, col1 in enumerate(columns):
for j, col2 in enumerate(columns):
ax = axes[i, j]
if i == j:
# Diagonal: histogram
for species in range(3):
mask = df['species'] == species
ax.hist(df.loc[mask, col1], alpha=0.5, color=colors[species])
else:
# Off-diagonal: scatter
for species in range(3):
mask = df['species'] == species
ax.scatter(df.loc[mask, col2], df.loc[mask, col1],
alpha=0.5, color=colors[species], s=10)
if j == 0:
ax.set_ylabel(col1)
if i == 3:
ax.set_xlabel(col2)
plt.tight_layout()
plt.show()
Interpreting Scatter Matrices¶
What to Look For¶
- Linear relationships: Points forming a line indicate correlation
- Clusters: Groups of points may indicate categories
- Outliers: Isolated points away from the main cluster
- Distribution shape: Diagonal plots show if data is normal, skewed, etc.
Correlation Patterns¶
| Pattern | Interpretation |
|---|---|
| Points along diagonal (↗) | Positive correlation |
| Points along anti-diagonal (↘) | Negative correlation |
| Circular cloud | No correlation |
| Distinct clusters | Possible categorical grouping |
Method Signature¶
pandas.plotting.scatter_matrix(
frame, # DataFrame
alpha=0.5, # Point transparency
figsize=None, # Figure size
ax=None, # Axes array
grid=False, # Show grid
diagonal='hist', # 'hist' or 'kde'
marker='.', # Point marker
density_kwds=None, # KDE kwargs
hist_kwds=None, # Histogram kwargs
range_padding=0.05, # Padding around axis limits
**kwargs # Additional scatter kwargs
)
Summary¶
# Basic scatter matrix
pd.plotting.scatter_matrix(df)
# With KDE on diagonal
pd.plotting.scatter_matrix(df, diagonal='kde')
# Customized
pd.plotting.scatter_matrix(
df[['col1', 'col2', 'col3']],
figsize=(10, 10),
diagonal='hist',
alpha=0.5,
marker='o',
hist_kwds={'bins': 20}
)
Alternatives¶
For more advanced pairwise plots with categorical coloring:
# Using seaborn (if available)
import seaborn as sns
sns.pairplot(df, hue='category_column')