Python中的Broadcast机制

 更新时间:2023年06月14日 09:15:51   作者:Hayz  
这篇文章主要介绍了Python中的Broadcast机制,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

Python Broadcast机制

最近在用numpy的时候,里面的矩阵和向量之间各种乘法加法搞的我头昏脑胀,整理下总结出来的规则

首先说明array型数据结构有两种类型,一种是一维的向量,比如用np.linspace(1,2,num=2)创建出的对象,shape为(2,);另外一种就是多维的矩阵,如np.zeros(1,2)创建出的对象,其shape为(1,2),这两种类型是不一样的。

矩阵之间的矩阵乘法

不必多说,就是按照正常的矩阵乘法规则来做

(N,M) (M,P) = (N,P)

矩阵之间按元素相乘、相加

这里开始就涉及到广播(broadcast)的问题了。

其实也比较简单,两个矩阵broadcast后的结果每一维都是两个矩阵中最大的。

但broadcast必须满足两个规则,即要么相对应的维数相等,要么其中有一个矩阵的维数是1。

那么问题来了,哪两个维度是相对应的维数呢?规则就是将矩阵的shape写出来,然后按右对齐逐维对比。

通过以上方法,可以得出两矩阵broadcast结果的维数,而最后结果的计算方法就是先将两个矩阵都broadcast到结果的维数,然后再按照相同维度的矩阵对应元素相乘、相加。

例子如下:

A      (4d array):  8 x 1 x 6 x 1
B      (3d array):      7 x 1 x 5
Result (4d array):  8 x 7 x 6 x 5
A      (2d array):  5 x 4
B      (1d array):      1
Result (2d array):  5 x 4
A      (2d array):  15 x 3 x 5
B      (1d array):  15 x 1 x 5
Result (2d array):  15 x 3 x 5

矩阵和向量之间的矩阵乘法

这里也很简单,规则是

作左乘数的向量是行向量,作右乘数的向量是列向量。

这样做的好处就是,结果矩阵一定也是个向量。这个规则也说明了向量不一定是行向量(虽然print出来看见的是一个行向量)

矩阵和向量之间的按元素乘法、加法

规则其实和“二”中说的是一样的,只不过这里要注意的是,向量在这里永远当作(1,N)来看,也就是是行向量,按照“二”中所说的broadcast的规则,向量的维度永远从右对齐,也就是只有最右边有数,也就说明和他进行broadcast的矩阵,其最低维(也就是最右侧的维度)要么是一维,要么就和向量的维度相同。

举例子如下:

矩阵 (3d array)   : 256 x 256 x 3
向量 (1d array)   :             3
结果 (3d array)   : 256 x 256 x 3

python broadcast机制的模拟实现

tensorflow的算术操作:mul/add/sub等op都支持broadcast机制,该机制支持不同维度的计算,但是在对维度进行逆向比较时需要满足以下要求:

  • 1)二者维度相同
  • 2)二者维度有一个为1
  • 3)如果维度大小不一致,需要用1来对维度小的数据进行扩展,在进行上述判断;

如:a:[256,256,3]、b:[3]这样的维度,需要先将b扩展至与a一致,将b扩展至[1,1,3],再对a、b数据进行mul/add/sub等计算,最后输出维度[256,256,3]

如果为了实现broadcast,可以进行以下操作进行模拟:

  • 1)对维度大小不一致的数组进行维度扩展
  • 2)获取输出维度,即broadcast的维度
  • 3)进行数据广播

粗略代码如下(这里以四维数据为例,进行扩展):

import tensorflow as tf
import numpy as np
if __name__ == "__main__":
	input0_shape = [1,1,3,1]
	input1_shape = [3]
	#维度扩展
	input_len = len(input0_shape) - len(input1_shape)
	for i in range(input_len):
		input1_shape.insert(0,1)
	print input1_shape
	#获取broadcast shape
	broadcast_shape = [0] * len(input0_shape)
	for i in range(len(input0_shape)):
		broadcast_shape[i] = max(input0_shape[i],input1_shape[i])
	print broadcast_shape
	data_a = np.random.random(input0_shape)	#hwcn
	data_b = np.random.random(input1_shape) #h,w,c_out,c_in
	a = tf.placeholder("float")
	b = tf.placeholder("float")
	c = tf.add(a,b)
	with tf.Session() as sess:
		sess.run(tf.global_variables_initializer())
		out = sess.run(c, feed_dict={a: data_a,b:data_b})
		#print data_a
		print data_b
		print out.shape
		#print out - data_a
	res_pre = out - data_b    	#获取input0的扩展结果,用于验证实际值
	out_tf = res_pre.reshape(broadcast_shape[0]*broadcast_shape[1]*broadcast_shape[2]*broadcast_shape[3])
	data_b_tmp = data_a.reshape(input0_shape[0]*input0_shape[1]*input0_shape[2]*input0_shape[3])
	print "out_tf"
	print out_tf
	f_dets = open("pre_data.dat", "w")
	for k in out_tf:
		b = float(k)
		a = '{:.10f}'.format(b)
		f_dets.write(str(a) + '\n')
	f_dets.close()
	out_res = [0]*broadcast_shape[0]*broadcast_shape[1]*broadcast_shape[2]*broadcast_shape[3]
	#进行数据扩展
	for i in range(broadcast_shape[0]):
		for j in range(broadcast_shape[1]):
			for k in range(broadcast_shape[2]):
				for m in range(broadcast_shape[3]):
					tmp_idx0 = i*broadcast_shape[1]*broadcast_shape[2]*broadcast_shape[3]  \
					          + j*broadcast_shape[2]*broadcast_shape[3] + k*broadcast_shape[3] + m
					ii = 0
					jj = 0
					kk = 0
					mm = 0
					if i >= input0_shape[0]:
						ii = input0_shape[0] -1
					else:
						ii = i
					if j >= input0_shape[1]:
						jj = input0_shape[1] -1
					else:
						jj = j
					if k >= input0_shape[2]:
						kk = input0_shape[2] -1
					else:
						kk = k
					if m >= input0_shape[3]:
						mm = input0_shape[3] -1
					else:
						mm = m
					tmp_idx1 = ii*input0_shape[1]*input0_shape[2]*input0_shape[3] \
								+ jj*input0_shape[2]*input0_shape[3] + kk*input0_shape[3] + mm
					#print mm
					out_res[tmp_idx0] = data_b_tmp[tmp_idx1]
	f_dets = open("aft_data.dat", "w")
	for k in out_res:
		b = float(k)
		a = '{:.10f}'.format(b)
		f_dets.write(str(a) + '\n')
	f_dets.close()
	#对比
	print "compare"
	print out_res - out_tf

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • django解决订单并发问题【推荐】

    django解决订单并发问题【推荐】

    这篇文章主要介绍了django解决订单并发问题,本文给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-07-07
  • TensorFlow教程Softmax逻辑回归识别手写数字MNIST数据集

    TensorFlow教程Softmax逻辑回归识别手写数字MNIST数据集

    这篇文章主要为大家介绍了python神经网络的TensorFlow教程基于Softmax逻辑回归识别手写数字的MNIST数据集示例详解,有需要的朋友可以借鉴参考下,希望能够有所帮助
    2021-11-11
  • python实现监控linux性能及进程消耗性能的方法

    python实现监控linux性能及进程消耗性能的方法

    这篇文章主要介绍了python实现监控linux性能及进程消耗性能的方法,需要的朋友可以参考下
    2014-07-07
  • python解决字符串倒序输出的问题

    python解决字符串倒序输出的问题

    今天小编就为大家分享一篇python解决字符串倒序输出的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-06-06
  • Python如何使用input函数获取输入

    Python如何使用input函数获取输入

    这篇文章主要介绍了Python如何使用input函数获取输入,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-08-08
  • Python用字典构建多级菜单功能

    Python用字典构建多级菜单功能

    这篇文章主要介绍了Python用字典构建多级菜单功能,本文通过实例代码给大家介绍的非常详细,具有一定的参考借鉴价值,需要的朋友可以参考下
    2019-07-07
  • 利用keras加载训练好的.H5文件,并实现预测图片

    利用keras加载训练好的.H5文件,并实现预测图片

    今天小编就为大家分享一篇利用keras加载训练好的.H5文件,并实现预测图片,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2020-01-01
  • Python中pickle模块的使用详解

    Python中pickle模块的使用详解

    这篇文章主要介绍了Python中pickle模块的使用详解,python的pickle模块提供了一个简答的持久化功能,可以将对象以文件的形式存放在磁盘上,pickle模块实现了基本的数据序列化和反序列化,需要的朋友可以参考下
    2023-08-08
  • Python Opencv中基础的知识点

    Python Opencv中基础的知识点

    这篇文章主要介绍了Python Opencv中基础的知识点,主要包括创建窗口、保存图片、采集视频、鼠标控制的代码,代码简单易懂,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2022-07-07
  • Python实现经典算法拓扑排序、字符串匹配算法和最小生成树实例

    Python实现经典算法拓扑排序、字符串匹配算法和最小生成树实例

    这篇文章主要介绍了Python实现经典算法拓扑排序、字符串匹配算法和最小生成树实例,拓扑排序、字符串匹配算法和最小生成树是计算机科学中常用的数据结构和算法,它们在解决各种实际问题中具有重要的应用价值,需要的朋友可以参考下
    2023-08-08

最新评论