Numpy Broadcast
Table of Contents
1. Numpy Broadcast
1.1. Overview
当需要计算 f(a,b) 且需要 a.shape==b.shape 时, 若 a 与 b 的 shape 不同, numpy 会根据一定的规则尝试把 a,b 转换成相同的 shape
例如:
import numpy as np print((np.ones((1, 2, 3)) + np.ones((3,))).shape) print((np.ones((1, 2, 3)) + np.ones((1, 3))).shape) print((np.ones((1, 2, 3)) + np.ones((2, 3))).shape) print((np.ones((1, 1, 3)) + np.ones((2, 3))).shape)
(1, 2, 3) (1, 2, 3) (1, 2, 3) (1, 2, 3)
1.2. broadcasting rules
https://numpy.org/doc/stable/user/basics.broadcasting.html
若两者的 rank (dim.size) 不同, 在小的左边添加上 1, 使两者的 rank 相同
if a.rank == b.rank: return out_rank = max(a.rank, b.rank) while a.rank != out_rank: a.dim.insert(0, 1) while b.rank != out_rank: b.dim.insert(0, 1)
例如
- a=(1,1,2,3), rank 为 3
- b=(1,3), rank 为 2
给 b 左边添加新的 dim 后:
- a=(1,1,2,3)
- b=(1,1,1,3)
两者 rank 相同后, 从右向左扫描
for i in reversed(range(a.rank)): if a.dim[i] == b.dim[i]: continue if a.dim[i] == 1: a.dim[i] = b.dim[i] elif b.dim[i] == 1: b.dim[i] = a.dim[i] else: rainse(error)