Splitting Arrays¶
NumPy provides functions to split arrays into multiple sub-arrays.
np.split¶
1. Equal Parts¶
import numpy as np
def main():
a = np.arange(12)
# Split into 3 equal parts
parts = np.split(a, 3)
print(f"Original: {a}")
print(f"Number of parts: {len(parts)}")
for i, part in enumerate(parts):
print(f"Part {i}: {part}")
if __name__ == "__main__":
main()
Output:
Original: [ 0 1 2 3 4 5 6 7 8 9 10 11]
Number of parts: 3
Part 0: [0 1 2 3]
Part 1: [4 5 6 7]
Part 2: [ 8 9 10 11]
2. Split at Indices¶
import numpy as np
def main():
a = np.arange(10)
# Split at indices 3 and 7
parts = np.split(a, [3, 7])
print(f"Original: {a}")
print(f"Split at [3, 7]:")
for i, part in enumerate(parts):
print(f" Part {i}: {part}")
if __name__ == "__main__":
main()
3. 2D Array Split¶
import numpy as np
def main():
a = np.arange(12).reshape(3, 4)
print("Original:")
print(a)
print()
# Split along axis=0 (rows)
parts = np.split(a, 3, axis=0)
print("Split along axis=0:")
for i, part in enumerate(parts):
print(f"Part {i}: {part}")
if __name__ == "__main__":
main()
np.array_split¶
1. Unequal Splits¶
import numpy as np
def main():
a = np.arange(10)
# np.split would fail (10 not divisible by 3)
# np.array_split handles unequal splits
parts = np.array_split(a, 3)
print(f"Original: {a}")
print(f"Split into 3 parts:")
for i, part in enumerate(parts):
print(f" Part {i}: {part} (length {len(part)})")
if __name__ == "__main__":
main()
2. Comparison with split¶
import numpy as np
def main():
a = np.arange(10)
# np.split requires equal division
try:
parts = np.split(a, 3)
except ValueError as e:
print(f"np.split error: {e}")
# np.array_split allows unequal
parts = np.array_split(a, 3)
print(f"np.array_split works: {[len(p) for p in parts]}")
if __name__ == "__main__":
main()
3. Distribution of Elements¶
import numpy as np
def main():
a = np.arange(17)
for n in [2, 3, 4, 5]:
parts = np.array_split(a, n)
sizes = [len(p) for p in parts]
print(f"Split 17 into {n}: sizes = {sizes}")
if __name__ == "__main__":
main()
np.hsplit¶
1. Horizontal Split¶
import numpy as np
def main():
a = np.arange(12).reshape(3, 4)
print("Original:")
print(a)
print()
# Split into 2 parts horizontally
left, right = np.hsplit(a, 2)
print("Left:")
print(left)
print()
print("Right:")
print(right)
if __name__ == "__main__":
main()
2. Split at Column Indices¶
import numpy as np
def main():
a = np.arange(20).reshape(4, 5)
print("Original:")
print(a)
print()
# Split at columns 1 and 3
parts = np.hsplit(a, [1, 3])
for i, part in enumerate(parts):
print(f"Part {i}:")
print(part)
print()
if __name__ == "__main__":
main()
3. Equivalent to split axis=1¶
import numpy as np
def main():
a = np.arange(12).reshape(3, 4)
parts1 = np.hsplit(a, 2)
parts2 = np.split(a, 2, axis=1)
print(f"hsplit equal to split axis=1: {np.array_equal(parts1[0], parts2[0])}")
if __name__ == "__main__":
main()
np.vsplit¶
1. Vertical Split¶
import numpy as np
def main():
a = np.arange(12).reshape(4, 3)
print("Original:")
print(a)
print()
# Split into 2 parts vertically
top, bottom = np.vsplit(a, 2)
print("Top:")
print(top)
print()
print("Bottom:")
print(bottom)
if __name__ == "__main__":
main()
2. Split at Row Indices¶
import numpy as np
def main():
a = np.arange(20).reshape(5, 4)
print("Original:")
print(a)
print()
# Split at rows 1 and 3
parts = np.vsplit(a, [1, 3])
for i, part in enumerate(parts):
print(f"Part {i} (shape {part.shape}):")
print(part)
print()
if __name__ == "__main__":
main()
3. Equivalent to split axis=0¶
import numpy as np
def main():
a = np.arange(12).reshape(4, 3)
parts1 = np.vsplit(a, 2)
parts2 = np.split(a, 2, axis=0)
print(f"vsplit equal to split axis=0: {np.array_equal(parts1[0], parts2[0])}")
if __name__ == "__main__":
main()
np.dsplit¶
1. Depth Split¶
import numpy as np
def main():
a = np.arange(24).reshape(2, 3, 4)
print(f"Original shape: {a.shape}")
# Split along axis=2
parts = np.dsplit(a, 2)
print(f"Number of parts: {len(parts)}")
for i, part in enumerate(parts):
print(f"Part {i} shape: {part.shape}")
if __name__ == "__main__":
main()
2. Split RGB Channels¶
import numpy as np
def main():
# Simulated RGB image
image = np.random.randint(0, 256, (100, 100, 3))
print(f"Image shape: {image.shape}")
# Split into channels
r, g, b = np.dsplit(image, 3)
print(f"Red channel shape: {r.shape}")
# Squeeze to get 2D
r_2d = r.squeeze()
print(f"Squeezed shape: {r_2d.shape}")
if __name__ == "__main__":
main()
Applications¶
1. Train/Test Split¶
import numpy as np
def main():
np.random.seed(42)
# Dataset
X = np.random.randn(100, 5)
y = np.random.randint(0, 2, 100)
# 80/20 split
split_idx = int(0.8 * len(X))
X_train, X_test = np.split(X, [split_idx])
y_train, y_test = np.split(y, [split_idx])
print(f"X_train shape: {X_train.shape}")
print(f"X_test shape: {X_test.shape}")
if __name__ == "__main__":
main()
2. Batch Processing¶
import numpy as np
def main():
data = np.arange(1000)
batch_size = 32
n_batches = len(data) // batch_size
batches = np.split(data[:n_batches * batch_size], n_batches)
print(f"Number of batches: {len(batches)}")
print(f"Batch shape: {batches[0].shape}")
if __name__ == "__main__":
main()
3. K-Fold Cross-Validation¶
import numpy as np
def main():
data = np.arange(100)
n_folds = 5
folds = np.array_split(data, n_folds)
for i in range(n_folds):
test = folds[i]
train = np.concatenate([folds[j] for j in range(n_folds) if j != i])
print(f"Fold {i}: train={len(train)}, test={len(test)}")
if __name__ == "__main__":
main()
Summary Table¶
1. Split Functions¶
| Function | Description |
|---|---|
np.split(a, n) |
Split into n equal parts |
np.split(a, [i,j]) |
Split at indices i, j |
np.array_split(a, n) |
Split into n parts (unequal OK) |
np.hsplit(a, n) |
Split horizontally (axis=1) |
np.vsplit(a, n) |
Split vertically (axis=0) |
np.dsplit(a, n) |
Split depth-wise (axis=2) |
2. Equivalences¶
| Function | Equivalent |
|---|---|
np.hsplit(a, n) |
np.split(a, n, axis=1) |
np.vsplit(a, n) |
np.split(a, n, axis=0) |
np.dsplit(a, n) |
np.split(a, n, axis=2) |