Skip to content

Creating 3D Axes

Create 3D plotting axes in Matplotlib using projection parameter.

Mental Model

A 3D Axes is a regular Axes with projection='3d' that adds a z-axis and perspective rendering. The same Figure/Axes pattern applies -- you just pass the projection keyword when creating the Axes. Once you have a 3D Axes, methods like plot(), scatter(), and plot_surface() accept an extra z argument.

Method 1: plt.subplots with subplot_kw

Basic 3D Subplot

```python import matplotlib.pyplot as plt

def main(): fig, ax = plt.subplots(subplot_kw={'projection': '3d'}) plt.show()

if name == "main": main() ```

Multiple 3D Subplots

```python import matplotlib.pyplot as plt

def main(): fig, (ax0, ax1) = plt.subplots(1, 2, subplot_kw={'projection': '3d'}) plt.show()

if name == "main": main() ```

Grid of 3D Subplots

```python import matplotlib.pyplot as plt

def main(): fig, axes = plt.subplots(2, 2, subplot_kw={'projection': '3d'}, figsize=(10, 10)) plt.tight_layout() plt.show()

if name == "main": main() ```

Method 2: fig.add_subplot

Mixed 2D and 3D

```python import matplotlib.pyplot as plt

def main(): fig = plt.figure() ax0 = fig.add_subplot(1, 2, 1, projection='3d') # 3D ax1 = fig.add_subplot(1, 2, 2) # 2D plt.show()

if name == "main": main() ```

Multiple Mixed Subplots

```python import matplotlib.pyplot as plt

def main(): fig = plt.figure(figsize=(12, 8))

# Row 1: 3D plots
ax1 = fig.add_subplot(2, 3, 1, projection='3d')
ax2 = fig.add_subplot(2, 3, 2, projection='3d')
ax3 = fig.add_subplot(2, 3, 3, projection='3d')

# Row 2: 2D plots
ax4 = fig.add_subplot(2, 3, 4)
ax5 = fig.add_subplot(2, 3, 5)
ax6 = fig.add_subplot(2, 3, 6)

plt.tight_layout()
plt.show()

if name == "main": main() ```

Comparison

When to Use Each Method

Method Use Case
plt.subplots(subplot_kw={'projection': '3d'}) All subplots are 3D
fig.add_subplot(projection='3d') Mixed 2D and 3D subplots

Example Comparison

```python import matplotlib.pyplot as plt

Method 1: All 3D (subplot_kw)

fig1, axes1 = plt.subplots(1, 2, subplot_kw={'projection': '3d'}, figsize=(10, 4)) fig1.suptitle('Method 1: subplot_kw (all 3D)')

Method 2: Mixed (add_subplot)

fig2 = plt.figure(figsize=(10, 4)) ax2_3d = fig2.add_subplot(1, 2, 1, projection='3d') ax2_2d = fig2.add_subplot(1, 2, 2) fig2.suptitle('Method 2: add_subplot (mixed)')

plt.show() ```

3D Axes Properties

Axes3D Object

```python import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D

fig, ax = plt.subplots(subplot_kw={'projection': '3d'})

print(f"Type: {type(ax)}") print(f"Is 3D: {hasattr(ax, 'zaxis')}") ```

Available 3D Methods

Method Description
ax.plot3D() 3D line plot
ax.scatter3D() 3D scatter plot
ax.plot_surface() Surface plot
ax.plot_wireframe() Wireframe plot
ax.contour3D() 3D contour plot
ax.bar3d() 3D bar chart

Basic 3D Plotting

3D Line Plot

```python import matplotlib.pyplot as plt import numpy as np

def main(): fig, ax = plt.subplots(subplot_kw={'projection': '3d'})

t = np.linspace(0, 10, 100)
x = np.sin(t)
y = np.cos(t)
z = t

ax.plot3D(x, y, z)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('3D Helix')

plt.show()

if name == "main": main() ```

3D Scatter Plot

```python import matplotlib.pyplot as plt import numpy as np

def main(): fig, ax = plt.subplots(subplot_kw={'projection': '3d'})

n = 100
x = np.random.randn(n)
y = np.random.randn(n)
z = np.random.randn(n)

ax.scatter3D(x, y, z, c=z, cmap='viridis')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('3D Scatter')

plt.show()

if name == "main": main() ```

3D Surface Plot

```python import matplotlib.pyplot as plt import numpy as np

def main(): fig, ax = plt.subplots(subplot_kw={'projection': '3d'})

x = np.linspace(-5, 5, 50)
y = np.linspace(-5, 5, 50)
X, Y = np.meshgrid(x, y)
Z = np.sin(np.sqrt(X**2 + Y**2))

ax.plot_surface(X, Y, Z, cmap='viridis')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('3D Surface')

plt.show()

if name == "main": main() ```

View Angle

Setting View Angle

```python import matplotlib.pyplot as plt import numpy as np

def main(): x = np.linspace(-5, 5, 50) y = np.linspace(-5, 5, 50) X, Y = np.meshgrid(x, y) Z = np.sin(np.sqrt(X2 + Y2))

fig, axes = plt.subplots(1, 3, subplot_kw={'projection': '3d'}, 
                         figsize=(15, 5))

views = [(30, 45), (60, 45), (30, 135)]

for ax, (elev, azim) in zip(axes, views):
    ax.plot_surface(X, Y, Z, cmap='viridis', alpha=0.8)
    ax.view_init(elev=elev, azim=azim)
    ax.set_title(f'elev={elev}, azim={azim}')

plt.tight_layout()
plt.show()

if name == "main": main() ```

Figure Size

Adjusting 3D Figure Size

```python import matplotlib.pyplot as plt import numpy as np

def main(): # Single 3D plot with custom size fig, ax = plt.subplots(subplot_kw={'projection': '3d'}, figsize=(10, 8))

x = np.linspace(-5, 5, 50)
y = np.linspace(-5, 5, 50)
X, Y = np.meshgrid(x, y)
Z = np.sin(np.sqrt(X**2 + Y**2))

ax.plot_surface(X, Y, Z, cmap='coolwarm')
ax.set_title('3D Surface Plot')

plt.tight_layout()
plt.show()

if name == "main": main() ```

Practical Example

Dashboard with 2D and 3D

```python import matplotlib.pyplot as plt import numpy as np

def main(): fig = plt.figure(figsize=(14, 10))

# 3D Surface (top left, spans 2 columns)
ax1 = fig.add_subplot(2, 2, 1, projection='3d')
x = np.linspace(-5, 5, 50)
y = np.linspace(-5, 5, 50)
X, Y = np.meshgrid(x, y)
Z = np.sin(np.sqrt(X**2 + Y**2))
ax1.plot_surface(X, Y, Z, cmap='viridis', alpha=0.8)
ax1.set_title('3D Surface')

# 3D Scatter (top right)
ax2 = fig.add_subplot(2, 2, 2, projection='3d')
n = 100
xs = np.random.randn(n)
ys = np.random.randn(n)
zs = np.random.randn(n)
ax2.scatter3D(xs, ys, zs, c=zs, cmap='plasma')
ax2.set_title('3D Scatter')

# 2D Contour (bottom left)
ax3 = fig.add_subplot(2, 2, 3)
ax3.contourf(X, Y, Z, levels=20, cmap='viridis')
ax3.set_title('2D Contour (Top View)')
ax3.set_aspect('equal')

# 2D Line plot (bottom right)
ax4 = fig.add_subplot(2, 2, 4)
ax4.plot(x, np.sin(x), label='sin(x)')
ax4.plot(x, np.cos(x), label='cos(x)')
ax4.legend()
ax4.set_title('2D Line Plot')
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

if name == "main": main() ```

When to Use 3D (and When Not To)

Limitations of Matplotlib 3D

Matplotlib's 3D is for exploration, not high-end rendering. It uses painter's algorithm (not true depth buffering), has no lighting model, and perspective can obscure structure. For publication, often a well-chosen 2D projection is clearer.

Use 3D when:

  • Relationships depend on three variables
  • Structure is inherently geometric (curves, surfaces)
  • Spatial intuition is needed (loss landscapes, trajectories)

Avoid 3D when:

  • Data is easily understood in 2D (use contour instead)
  • Perspective hides more than it reveals
  • Readers cannot interact with the plot (static images in papers)

Exercises

Exercise 1. Create a figure with a single 3D axes using fig.add_subplot(111, projection='3d'). Plot a helix defined by x = cos(t), y = sin(t), z = t for t in \([0, 4\pi]\). Add axis labels and a title.

Solution to Exercise 1
import matplotlib.pyplot as plt
import numpy as np

t = np.linspace(0, 4 * np.pi, 500)
x = np.cos(t)
y = np.sin(t)
z = t

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')
ax.plot(x, y, z)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title('Helix')
plt.show()

Exercise 2. Create a 1x3 grid of 3D subplots using plt.subplots(1, 3, subplot_kw={'projection': '3d'}). In each subplot, plot the surface \(z = \sin(\sqrt{x^2 + y^2})\) but from three different viewing angles using view_init. Use elevations of 10, 45, and 80 degrees, all with azimuth 45. Title each subplot with its elevation angle.

Solution to Exercise 2
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x, y)
Z = np.sin(np.sqrt(X**2 + Y**2))

fig, axes = plt.subplots(1, 3, figsize=(15, 5),
                          subplot_kw={'projection': '3d'})

for ax, elev in zip(axes, [10, 45, 80]):
    ax.plot_surface(X, Y, Z, cmap='viridis', alpha=0.9)
    ax.view_init(elev=elev, azim=45)
    ax.set_title(f'Elevation = {elev}°')

plt.tight_layout()
plt.show()

Exercise 3. Use fig.add_subplot with a 2x2 grid layout to create four 3D axes. In each axes, plot a different 3D surface: a paraboloid (\(z = x^2 + y^2\)), a saddle (\(z = x^2 - y^2\)), a plane (\(z = x + y\)), and a cone (\(z = \sqrt{x^2 + y^2}\)). Use a different colormap for each and add titles.

Solution to Exercise 3
import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(-3, 3, 100)
y = np.linspace(-3, 3, 100)
X, Y = np.meshgrid(x, y)

surfaces = [
    (X**2 + Y**2, 'Paraboloid', 'viridis'),
    (X**2 - Y**2, 'Saddle', 'coolwarm'),
    (X + Y, 'Plane', 'plasma'),
    (np.sqrt(X**2 + Y**2), 'Cone', 'inferno'),
]

fig = plt.figure(figsize=(12, 10))
for i, (Z, title, cmap) in enumerate(surfaces, 1):
    ax = fig.add_subplot(2, 2, i, projection='3d')
    ax.plot_surface(X, Y, Z, cmap=cmap, alpha=0.9)
    ax.set_title(title)
    ax.set_xlabel('X')
    ax.set_ylabel('Y')

plt.tight_layout()
plt.show()