终于明白tf.reduce_sum()函数和tf.reduce_mean()函数用法

 更新时间:2022年11月28日 10:13:10   作者:不想秃顶还想当程序猿  
这篇文章主要介绍了终于明白tf.reduce_sum()函数和tf.reduce_mean()函数用法,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

解读tf.reduce_sum()函数和tf.reduce_mean()函数

在学习搭建神经网络的时候,照着敲别人的代码,有一句代码一直搞不清楚,就是下面这句了

loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction), reduction_indices=[1]))

刚开始照着up主写的代码是这样滴:

loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction)))

然后就出现了这样的结果:

709758.1
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan

怎么肥事,对于萌新小白首先想到的就是找度娘,结果找到的方法都不行,然后开始查函数,终于发现了原因,问题就出在reduce_sum()函数上,哈哈哈,然后小白又叕叕叕开始找博客学习reduce_sum()顺带学下reduce_mean(),结果看了好几篇,还是脑袋一片浆糊,为啥用reduction_indices=[1],不用reduction_indices=[0]或者干脆不用,费了九牛二虎之力终于让我给弄懂了,赶紧记录下来!!

-------------------分割线-------------------

1.tf.reduce_mean 函数

用于计算张量tensor沿着指定的数轴(tensor的某一维度)上的的平均值,主要用作降维或者计算tensor(图像)的平均值。

reduce_mean(input_tensor,
                axis=None,
                keep_dims=False,
                name=None,
                reduction_indices=None)
  • 第一个参数input_tensor: 输入的待降维的tensor;
  • 第二个参数axis: 指定的轴,如果不指定,则计算所有元素的均值;
  • 第三个参数keep_dims:是否降维度,设置为True,输出的结果保持输入tensor的形状,设置为False,输出结果会降低维度;
  • 第四个参数name: 操作的名称;
  • 第五个参数 reduction_indices:在以前版本中用来指定轴,已弃用;

2.tf.reduce_sum函数

计算一个张量的各个维度上元素的总和,一般只需设置两个参数

reduce_sum ( 
    input_tensor , 
    axis = None , 
    keep_dims = False , 
    name = None , 
    reduction_indices = None
 )
  • 第一个参数input_tensor: 输入的tensor
  • 第二个参数 reduction_indices:指定沿哪个维度计算元素的总和

最难的就是维度问题,反正本小白看了好几个博客都没弄太懂,最后还是按自己的理解,直接上例子

  • reduce_sum()
tf.reduce_sum
matrix1 = [[1.,2.,3.],            #二维,元素为列表
          [4.,5.,6.]]
matrix2 = [[[1.,2.],[3.,4.]],      #三维,元素为矩阵
           [[5.,6.],[7.,8.]]]

res_2 = tf.reduce_sum(matrix1)
res_3 = tf.reduce_sum(matrix2)
res1_2 = tf.reduce_sum(matrix1,reduction_indices=[0])
res1_3 = tf.reduce_sum(matrix2,reduction_indices=[0])
res2_2 = tf.reduce_sum(matrix1,reduction_indices=[1])
res2_3 = tf.reduce_sum(matrix2,reduction_indices=[1])

sess = tf.Session()
print("reduction_indices=None:res_2={},res_3={}".format(sess.run(res_2),sess.run(res_3)))
print("reduction_indices=[0]:res1_2={},res1_3={}".format(sess.run(res1_2),sess.run(res1_3)))
print("reduction_indices=[1]:res2_2={},res2_3={}".format(sess.run(res2_2),sess.run(res2_3)))

结果如下:

axis=None:res_2=21.0,res_3=36.0
axis=[0]:res1_2=[5. 7. 9.],res1_3=[[ 6.  8.]
                                    [10. 12.]]
axis=[1]:res2_2=[ 6. 15.],res2_3=[[ 4.  6.]
                                   [12. 14.]]

  • tf.reduce_mean

只需要把上面代码的reduce_sum部分换成renduce_mean即可

res_2 = tf.reduce_mean(matrix1)
res_3 = tf.reduce_mean(matrix2)
res1_2 = tf.reduce_mean(matrix1,axis=[0])
res1_3 = tf.reduce_mean(matrix2,axis=[0])
res2_2 = tf.reduce_mean(matrix1,axis=[1])
res2_3 = tf.reduce_mean(matrix2,axis=[1])

结果如下:

axis=None:res_2=3.5,res_3=4.5
axis=[0]:res1_2=[2.5 3.5 4.5],res1_3=[[3. 4.]
                                       [5. 6.]]
axis=[1]:res2_2=[2. 5.],res2_3=[[2. 3.]
                                 [6. 7.]]

可以看到,reduction_indices和axis其实都是代表维度,当为None时,reduce_sum和reduce_mean对所有元素进行操作,当为[0]时,其实就是按行操作,当为[1]时,就是按列操作,对于三维情况,把最里面的括号当成是一个数,这样就可以用二维的情况代替,最后得到的结果都是在原来的基础上降一维,下面按专业的方法讲解:

对于一个多维的array,最外层的括号里的元素的axis为0,然后每减一层括号,axis就加1,直到最后的元素为单个数字

如上例中的matrix1 = [[1., 2., 3.], [4., 5., 6.]]:

  • axis=0时,所包含的元素有:[1., 2., 3.]、[4., 5., 6.]
  • axis=1时,所包含的元素有:1.、2.、3.、4.、5.、6.

所以当reduction_indices/axis=[0],应对axis=0上的元素进行操作,故reduce_sum()得到的结果为[5. 7. 9.],即把两个数组对应元素相加;当reduction_indices/axis=[1],应对axis=1上的元素进行操作,故reduce_sum()得到的结果为[ 6. 15.],即把每个数组里的元素相加。reduce_mean()同理。

不难看出对于三维情况也是同样的思路,如上例中的matrix2 = [[[1,2],[3,4]], [[5,6],[7,8]]]:

  • axis=0时,所包含的元素有:[[1., 2.],[3., 4.]]、[[5., 6.],[7., 8.]]
  • axis=1时,所包含的元素有:[1., 2.]、[3., 4.]、[5., 6.]、[7., 8.]
  • axis=2时,所包含的的元素有:1.、2.、3.、4.、5.、6.、7.、8.

当reduction_indices/axis=[0],reduce_sum()得到的结果应为[[ 6. 8.], [10. 12.]],即把两个矩阵对应位置元素相加;当reduction_indices/axis=[1],reduce_sum()得到的结果应为[[ 4. 6.], [12. 14.]],即把数组对应元素相加。reduce_mean()同理。

一句话就是对哪一维操作,计算完后外面的括号就去掉,相当于降维。

那么问题来了,当reduction_indices/axis=[2]时呢???

  • 对于二维情况,当然是报错了,因为axis最大为1

ValueError: Invalid reduction dimension 2 for input with 2 dimensions. for 'Sum_4' (op: 'Sum') with input shapes: [2,3], [1] and with computed input tensors: input[1] = <2>.

  • 对于三维情况,reduce_sum()得到的结果为:[[ 3. 7.], [11. 15.]],即对最内层括号里的元素求和。

-------------------分割线-------------------

回到最开始自己的问题,为什么只有设置参数reduction_indices=[1],loss才不为Nan

loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))

本程序构建的是一个3层神经网络,输入层只有1个神经元,输入数据为100个样本点,即shape为(100,1)的列向量,隐藏层有10个神经元,输出层同样只有1个神经元,故最后输出数据的shape也为(100,1)的列向量,那么reduce_sum的参数即为一个二维数组。

  • 若reduction_indices=[0],最后得到的是只有一个元素的数组,即[n]
  • 若reduction_indices=[1],最后得到的是有100个元素的数组,即[n1,n2…n100]
  • 若reduction_indices=None,最后得到的则是一个数

那么再使用reduce_mean()求平均时,想要得到的结果是sum/100,这时就只有reduce_sum()传入参数reduction_indices=[1],才能实现想要的效果了。

完美解决!!!

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Qt Quick QML-500行代码实现合成大西瓜游戏

    Qt Quick QML-500行代码实现合成大西瓜游戏

    合成大西瓜游戏是前段时间比较火的小游戏,最近小编闲来无事,通过研究小球碰撞原理亲自写碰撞算法实现一个合成大西瓜游戏,下面小编把我的实现思路及核心代码分析出来,供大家参考
    2021-05-05
  • python 检查数据中是否有缺失值,删除缺失值的方式

    python 检查数据中是否有缺失值,删除缺失值的方式

    今天小编就为大家分享一篇python 检查数据中是否有缺失值,删除缺失值的方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12
  • Python collections模块的使用技巧

    Python collections模块的使用技巧

    Python的最大优势之一是其广泛的模块和软件包。这将Python的功能扩展到许多受欢迎的领域,包括机器学习、数据科学和Web开发等, 其中最好的模块之一是Python的内置collections 模块。
    2021-04-04
  • python爬虫模拟登录之图片验证码实现详解

    python爬虫模拟登录之图片验证码实现详解

    众所周知python是一个很强大的语言,它拥有众多的库,今天我尝试了使用python进行验证码的识别,下面这篇文章主要给大家介绍了关于python爬虫模拟登录之图片验证码实现的相关资料,需要的朋友可以参考下
    2022-08-08
  • Python3去除头尾指定字符的函数strip()、lstrip()、rstrip()用法详解

    Python3去除头尾指定字符的函数strip()、lstrip()、rstrip()用法详解

    这篇文章主要介绍了Python3去除头尾指定字符的函数strip()、lstrip()、rstrip()用法详解,需要的朋友可以参考下
    2021-04-04
  • python创建ArcGIS shape文件的实现

    python创建ArcGIS shape文件的实现

    今天小编就为大家分享一篇python创建ArcGIS shape文件的实现,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-12-12
  • Visual Studio code 配置Python开发环境

    Visual Studio code 配置Python开发环境

    这篇文章主要介绍了Visual Studio code 配置Python开发环境,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-09-09
  • python pandas分组聚合详细

    python pandas分组聚合详细

    分组聚合是数据处理中常见的场景,在pandas中用groupby方法实现分组操作,用agg方法实现聚合操作,在这篇文章里有主要介绍,感兴趣的朋友请参考下文
    2021-09-09
  • Python中set方法的使用教程详解

    Python中set方法的使用教程详解

    在Python中,set是一种集合数据类型,表示一个无序且不重复的集合。本文主要为大家详细介绍了Python中set方法的使用,需要的可以参考一下
    2023-04-04
  • Django框架表单操作实例分析

    Django框架表单操作实例分析

    这篇文章主要介绍了Django框架表单操作,结合实例形式分析了Django框架表单数据发送、请求相关操作技巧与注意事项,需要的朋友可以参考下
    2019-11-11

最新评论