Python深度学习:基于TensorFlow(第2版)
上QQ阅读APP看本书,新人免费读10天
设备和账号都新为新人

1.6 广播机制

NumPy的通用函数(ufunc)中要求输入的数组shape是一致的,当数组的shape不一致时,则会用到广播机制。不过,调整数组使得shape一样时需满足一定规则,否则将出错。广播机制中的这些规则可归结为以下四条。

1)让所有输入数组都向其中shape最长的数组看齐,shape中不足的部分都通过在前面加1补齐;如对于数组a(2×3×2)和数组b(3×2),则b向a看齐,在b的前面加1,变为1×3×2。

2)输出数组的shape是输入数组shape的各个轴上的最大值。

3)如果输入数组的某个轴和输出数组的对应轴的长度相同或者长度为1时,则可以调整,否则将出错。

4)当输入数组的某个轴的长度为1时,沿着此轴运算时都用(或复制)此轴上的第一组值。

广播机制在整个NumPy中用于决定如何处理形状迥异的数组,涉及的算术运算包括+、-、*、/。这些规则虽然很严谨,但不直观。下面我们结合图形与代码做进一步说明。

目的:A+B。其中A为4×1矩阵,B为一维向量(3,)。要相加,需要做如下处理。

1)根据规则1,B需要向A看齐,把B变为(1, 3)。

2)根据规则2,输出的结果为各个轴上的最大值,即输出结果应该为(4, 3)矩阵。那么A如何由(4, 1)变为(4, 3)矩阵?B如何由(1, 3)变为(4, 3)矩阵?

3)根据规则4,用此轴上的第一组值(主要区分是哪个轴)进行复制即可。(但在实际处理中不是真正复制,而是采用其他对象,如ogrid对象,进行网格处理,否则太耗内存。)如图1-10所示。

图1-10 NumPy广播机制示意图

具体实现如下:

运行结果如下: