Numpy Broadcast

Table of Contents

1. Numpy Broadcast

1.1. Overview

当需要计算 f(a,b) 且需要 a.shape==b.shape 时, 若 a 与 b 的 shape 不同, numpy 会根据一定的规则尝试把 a,b 转换成相同的 shape

例如:

import numpy as np

print((np.ones((1, 2, 3)) + np.ones((3,))).shape)
print((np.ones((1, 2, 3)) + np.ones((1, 3))).shape)
print((np.ones((1, 2, 3)) + np.ones((2, 3))).shape)
print((np.ones((1, 1, 3)) + np.ones((2, 3))).shape)

(1, 2, 3) (1, 2, 3) (1, 2, 3) (1, 2, 3)

1.2. broadcasting rules

https://numpy.org/doc/stable/user/basics.broadcasting.html

  1. 若两者的 rank (dim.size) 不同, 在小的左边添加上 1, 使两者的 rank 相同

    if a.rank == b.rank:
        return
    
    out_rank = max(a.rank, b.rank)
    
    while a.rank != out_rank:
        a.dim.insert(0, 1)
    
    while b.rank != out_rank:
        b.dim.insert(0, 1)
    

    例如

    • a=(1,1,2,3), rank 为 3
    • b=(1,3), rank 为 2

    给 b 左边添加新的 dim 后:

    • a=(1,1,2,3)
    • b=(1,1,1,3)
  2. 两者 rank 相同后, 从右向左扫描

    for i in reversed(range(a.rank)):
        if a.dim[i] == b.dim[i]:
            continue
    
        if a.dim[i] == 1:
            a.dim[i] = b.dim[i]
        elif b.dim[i] == 1:
            b.dim[i] = a.dim[i]
        else:
            rainse(error)
    

Author: [email protected]
Date: 2021-09-27 Mon 00:00
Last updated: 2022-01-24 Mon 19:34

知识共享许可协议