Splitting Arrays¶
NumPy provides functions to split arrays into multiple sub-arrays.
Mental Model
Splitting is the reverse of concatenation: it divides one array into a list of smaller arrays along a specified axis. np.split requires equal-sized chunks (or explicit cut points), while np.array_split handles remainders gracefully by distributing extra elements across the first few chunks.
Core Concept
Splitting is partitioning data along an axis — the inverse of concatenation. If you concatenated arrays A, B, C along axis 0, splitting the result along axis 0 gives you A, B, C back.
np.split¶
1. Equal Parts¶
```python import numpy as np
def main(): a = np.arange(12)
# Split into 3 equal parts
parts = np.split(a, 3)
print(f"Original: {a}")
print(f"Number of parts: {len(parts)}")
for i, part in enumerate(parts):
print(f"Part {i}: {part}")
if name == "main": main() ```
Output:
Original: [ 0 1 2 3 4 5 6 7 8 9 10 11]
Number of parts: 3
Part 0: [0 1 2 3]
Part 1: [4 5 6 7]
Part 2: [ 8 9 10 11]
2. Split at Indices¶
```python import numpy as np
def main(): a = np.arange(10)
# Split at indices 3 and 7
parts = np.split(a, [3, 7])
print(f"Original: {a}")
print(f"Split at [3, 7]:")
for i, part in enumerate(parts):
print(f" Part {i}: {part}")
if name == "main": main() ```
3. 2D Array Split¶
```python import numpy as np
def main(): a = np.arange(12).reshape(3, 4)
print("Original:")
print(a)
print()
# Split along axis=0 (rows)
parts = np.split(a, 3, axis=0)
print("Split along axis=0:")
for i, part in enumerate(parts):
print(f"Part {i}: {part}")
if name == "main": main() ```
np.array_split¶
1. Unequal Splits¶
```python import numpy as np
def main(): a = np.arange(10)
# np.split would fail (10 not divisible by 3)
# np.array_split handles unequal splits
parts = np.array_split(a, 3)
print(f"Original: {a}")
print(f"Split into 3 parts:")
for i, part in enumerate(parts):
print(f" Part {i}: {part} (length {len(part)})")
if name == "main": main() ```
2. Comparison with split¶
```python import numpy as np
def main(): a = np.arange(10)
# np.split requires equal division
try:
parts = np.split(a, 3)
except ValueError as e:
print(f"np.split error: {e}")
# np.array_split allows unequal
parts = np.array_split(a, 3)
print(f"np.array_split works: {[len(p) for p in parts]}")
if name == "main": main() ```
3. Distribution of Elements¶
```python import numpy as np
def main(): a = np.arange(17)
for n in [2, 3, 4, 5]:
parts = np.array_split(a, n)
sizes = [len(p) for p in parts]
print(f"Split 17 into {n}: sizes = {sizes}")
if name == "main": main() ```
np.hsplit¶
1. Horizontal Split¶
```python import numpy as np
def main(): a = np.arange(12).reshape(3, 4)
print("Original:")
print(a)
print()
# Split into 2 parts horizontally
left, right = np.hsplit(a, 2)
print("Left:")
print(left)
print()
print("Right:")
print(right)
if name == "main": main() ```
2. Split at Column Indices¶
```python import numpy as np
def main(): a = np.arange(20).reshape(4, 5)
print("Original:")
print(a)
print()
# Split at columns 1 and 3
parts = np.hsplit(a, [1, 3])
for i, part in enumerate(parts):
print(f"Part {i}:")
print(part)
print()
if name == "main": main() ```
3. Equivalent to split axis=1¶
```python import numpy as np
def main(): a = np.arange(12).reshape(3, 4)
parts1 = np.hsplit(a, 2)
parts2 = np.split(a, 2, axis=1)
print(f"hsplit equal to split axis=1: {np.array_equal(parts1[0], parts2[0])}")
if name == "main": main() ```
np.vsplit¶
1. Vertical Split¶
```python import numpy as np
def main(): a = np.arange(12).reshape(4, 3)
print("Original:")
print(a)
print()
# Split into 2 parts vertically
top, bottom = np.vsplit(a, 2)
print("Top:")
print(top)
print()
print("Bottom:")
print(bottom)
if name == "main": main() ```
2. Split at Row Indices¶
```python import numpy as np
def main(): a = np.arange(20).reshape(5, 4)
print("Original:")
print(a)
print()
# Split at rows 1 and 3
parts = np.vsplit(a, [1, 3])
for i, part in enumerate(parts):
print(f"Part {i} (shape {part.shape}):")
print(part)
print()
if name == "main": main() ```
3. Equivalent to split axis=0¶
```python import numpy as np
def main(): a = np.arange(12).reshape(4, 3)
parts1 = np.vsplit(a, 2)
parts2 = np.split(a, 2, axis=0)
print(f"vsplit equal to split axis=0: {np.array_equal(parts1[0], parts2[0])}")
if name == "main": main() ```
np.dsplit¶
1. Depth Split¶
```python import numpy as np
def main(): a = np.arange(24).reshape(2, 3, 4)
print(f"Original shape: {a.shape}")
# Split along axis=2
parts = np.dsplit(a, 2)
print(f"Number of parts: {len(parts)}")
for i, part in enumerate(parts):
print(f"Part {i} shape: {part.shape}")
if name == "main": main() ```
2. Split RGB Channels¶
```python import numpy as np
def main(): # Simulated RGB image image = np.random.randint(0, 256, (100, 100, 3))
print(f"Image shape: {image.shape}")
# Split into channels
r, g, b = np.dsplit(image, 3)
print(f"Red channel shape: {r.shape}")
# Squeeze to get 2D
r_2d = r.squeeze()
print(f"Squeezed shape: {r_2d.shape}")
if name == "main": main() ```
Applications¶
1. Train/Test Split¶
```python import numpy as np
def main(): np.random.seed(42)
# Dataset
X = np.random.randn(100, 5)
y = np.random.randint(0, 2, 100)
# 80/20 split
split_idx = int(0.8 * len(X))
X_train, X_test = np.split(X, [split_idx])
y_train, y_test = np.split(y, [split_idx])
print(f"X_train shape: {X_train.shape}")
print(f"X_test shape: {X_test.shape}")
if name == "main": main() ```
2. Batch Processing¶
```python import numpy as np
def main(): data = np.arange(1000) batch_size = 32
n_batches = len(data) // batch_size
batches = np.split(data[:n_batches * batch_size], n_batches)
print(f"Number of batches: {len(batches)}")
print(f"Batch shape: {batches[0].shape}")
if name == "main": main() ```
3. K-Fold Cross-Validation¶
```python import numpy as np
def main(): data = np.arange(100) n_folds = 5
folds = np.array_split(data, n_folds)
for i in range(n_folds):
test = folds[i]
train = np.concatenate([folds[j] for j in range(n_folds) if j != i])
print(f"Fold {i}: train={len(train)}, test={len(test)}")
if name == "main": main() ```
Summary Table¶
1. Split Functions¶
| Function | Description |
|---|---|
np.split(a, n) |
Split into n equal parts |
np.split(a, [i,j]) |
Split at indices i, j |
np.array_split(a, n) |
Split into n parts (unequal OK) |
np.hsplit(a, n) |
Split horizontally (axis=1) |
np.vsplit(a, n) |
Split vertically (axis=0) |
np.dsplit(a, n) |
Split depth-wise (axis=2) |
2. Equivalences¶
| Function | Equivalent |
|---|---|
np.hsplit(a, n) |
np.split(a, n, axis=1) |
np.vsplit(a, n) |
np.split(a, n, axis=0) |
np.dsplit(a, n) |
np.split(a, n, axis=2) |
Exercises¶
Exercise 1.
Create a = np.arange(12). Split it into 3 equal parts using np.split and into parts of sizes 2, 5, 5 using np.split with explicit indices.
Solution to Exercise 1
import numpy as np
a = np.arange(12)
parts_equal = np.split(a, 3)
for i, p in enumerate(parts_equal):
print(f"Part {i}: {p}")
parts_custom = np.split(a, [2, 7])
for i, p in enumerate(parts_custom):
print(f"Custom part {i}: {p}")
Exercise 2.
Create a 4x6 matrix. Use np.hsplit to split it into 3 equal column blocks and np.vsplit to split it into 2 equal row blocks. Print the shapes of each resulting block.
Solution to Exercise 2
import numpy as np
M = np.arange(24).reshape(4, 6)
col_blocks = np.hsplit(M, 3)
for i, b in enumerate(col_blocks):
print(f"Col block {i}: shape {b.shape}")
row_blocks = np.vsplit(M, 2)
for i, b in enumerate(row_blocks):
print(f"Row block {i}: shape {b.shape}")
Exercise 3.
Create a = np.arange(10). Use np.array_split to split it into 3 parts (which handles uneven division). Print the lengths of each part and verify they sum to 10.
Solution to Exercise 3
import numpy as np
a = np.arange(10)
parts = np.array_split(a, 3)
lengths = [len(p) for p in parts]
print(f"Lengths: {lengths}") # [4, 3, 3]
print(f"Sum: {sum(lengths)}") # 10