PyTorch 中mm和bmm函数的使用示例详解

 更新时间:2025年06月18日 16:46:15   作者:点云SLAM  
PyTorch中torch.mm用于二维矩阵乘法,torch.bmm处理批量矩阵乘法(三维张量),均不支持广播,两者适用于神经网络权重计算、点云变换等场景,区别在于维度要求和批量处理效率,本文给大家介绍PyTorch 中mm和bmm函数的使用,感兴趣的朋友一起看看吧

torch.mm 是 PyTorch 中用于 二维矩阵乘法(matrix-matrix multiplication) 的函数,等价于数学中的 A × B 矩阵乘积。

一、函数定义

torch.mm(input, mat2) → Tensor

执行的是两个 2D Tensor(矩阵)的标准矩阵乘法。

  • input: 第一个二维张量,形状为 (n × m)
  • mat2: 第二个二维张量,形状为 (m × p)
  • 返回:形状为 (n × p) 的张量

二、使用条件和注意事项

条件说明
仅支持 2D 张量一维或三维以上使用 torch.matmul 或 @ 操作符
维度要匹配即 input.shape[1] == mat2.shape[0]
不支持广播两个矩阵维度不匹配会直接报错
结果是普通矩阵乘积不是逐元素乘法(Hadamard),即不是 * 或 torch.mul()

三、示例代码

示例 1:基本矩阵乘法

import torch
A = torch.tensor([[1., 2.], [3., 4.]])   # 2x2
B = torch.tensor([[5., 6.], [7., 8.]])   # 2x2
C = torch.mm(A, B)
print(C)

输出:

tensor([[19., 22.],
        [43., 50.]])

计算步骤:

C[0][0] = 1*5 + 2*7 = 19
C[0][1] = 1*6 + 2*8 = 22
...

示例 2:不匹配维度导致报错

A = torch.rand(2, 3)
B = torch.rand(4, 2)
C = torch.mm(A, B)  # ❌ 会报错

报错:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x3 and 4x2)

示例 3:推荐写法(推荐使用 @ 或 matmul)

A = torch.rand(3, 4)
B = torch.rand(4, 5)
C1 = torch.mm(A, B)
C2 = A @ B                # 推荐用法
C3 = torch.matmul(A, B)   # 推荐用法

四、与其他乘法函数的比较

函数名支持维度运算类型支持广播
torch.mm仅限二维矩阵乘法❌ 不支持
torch.matmul1D, 2D, ND自动判断点乘 / 矩阵乘✅ 支持
torch.bmm批量二维乘法3D Tensor batch × batch❌ 不支持
torch.mul任意维度元素乘(Hadamard)✅ 支持
* 运算符任意维度元素乘✅ 支持
@ 运算符ND(推荐用)矩阵乘法(和 matmul 一样)

五、典型应用场景

  • 神经网络权重乘法:output = torch.mm(W, x)
  • 点云 / 图像变换:x' = torch.mm(R, x) + t
  • 多层感知机中的矩阵计算
  • 注意力机制中 QK^T 乘积

六、总结:什么时候用 mm?

使用场景用什么
仅二维矩阵乘法torch.mm
高维或支持广播乘法torch.matmul / @
批量矩阵乘法 (如 batch_size×3×3)torch.bmm
元素乘torch.mul or *

在 PyTorch 中,torch.bmm 是 批量矩阵乘法(batch matrix multiplication) 的操作,专用于处理三维张量(batch of matrices)。它的主要作用是对一组矩阵成对进行乘法,效率远高于手动循环计算。

一、torch.bmm 语法

torch.bmm(input, mat2, *, out=None) → Tensor
  • inputTensor,形状为 (B, N, M)
  • mat2Tensor,形状为 (B, M, P)
  • 返回结果形状为 (B, N, P)

这表示对 B 对 N×M 和 M×P 的矩阵进行成对相乘。

二、示例演示

示例 1:基础用法

import torch
# 定义两个 batch 矩阵
A = torch.randn(4, 2, 3)  # shape: (B=4, N=2, M=3)
B = torch.randn(4, 3, 5)  # shape: (B=4, M=3, P=5)
# 批量矩阵乘法
C = torch.bmm(A, B)       # shape: (4, 2, 5)
print(C.shape)  # 输出: torch.Size([4, 2, 5])

示例 2:手动循环 vs bmm 效率对比

# 慢速手动方式
C_manual = torch.stack([A[i] @ B[i] for i in range(A.size(0))])
# 等效于 bmm
C_bmm = torch.bmm(A, B)
print(torch.allclose(C_manual, C_bmm))  # True

三、注意事项

1. 维度必须是三维张量

  • 否则会报错:
RuntimeError: batch1 must be a 3D tensor

你可以通过 .unsqueeze() 手动调整维度:

a = torch.randn(2, 3)
b = torch.randn(3, 4)
# 升维
a_batch = a.unsqueeze(0)  # (1, 2, 3)
b_batch = b.unsqueeze(0)  # (1, 3, 4)
c = torch.bmm(a_batch, b_batch)  # (1, 2, 4)

2. 维度必须满足矩阵乘法规则

  • (B, N, M) × (B, M, P) → (B, N, P)
  • 若 M 不一致会报错:
RuntimeError: Expected size for the second dimension of batch2 tensor to match the first dimension of batch1 tensor

3. bmm 不支持广播(broadcasting)

  • 必须显式提供相同的 batch size。
  • 如果只有一个矩阵固定,可以使用 .expand()
A = torch.randn(1, 2, 3)  # 单个矩阵
B = torch.randn(4, 3, 5)  # 4 个矩阵
# 扩展 A 以进行 batch 乘法
A_expand = A.expand(4, -1, -1)
C = torch.bmm(A_expand, B)  # (4, 2, 5)

四、在实际应用中的例子

在点云变换中:批量乘旋转矩阵

# 假设有 B 个旋转矩阵和点坐标
R = torch.randn(B, 3, 3)       # 旋转矩阵
points = torch.randn(B, 3, N)  # 点云
# 先转置点坐标为 (B, N, 3)
points_T = points.transpose(1, 2)  # (B, N, 3)
# 用 bmm 做点变换:每组点乘旋转
transformed = torch.bmm(points_T, R.transpose(1, 2))  # (B, N, 3)

五、总结

特性torch.bmm
操作对象三维张量(batch of matrices)
核心规则(B, N, M) x (B, M, P) = (B, N, P)
是否支持广播❌ 不支持,需要手动 .expand()
与 matmul 区别matmul 支持更多广播,bmm 更高效用于纯批量矩阵乘法
应用场景批量线性变换、点云配准、神经网络前向传播等

到此这篇关于PyTorch 中mm和bmm函数的使用详解的文章就介绍到这了,更多相关PyTorch mm和bmm函数内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Pycharm中切换pytorch的环境和配置的教程详解

    Pycharm中切换pytorch的环境和配置的教程详解

    这篇文章主要介绍了Pycharm中切换pytorch的环境和配置,本文给大家介绍的非常详细,对大家的工作或学习具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-03-03
  • python实现对任意大小图片均匀切割的示例

    python实现对任意大小图片均匀切割的示例

    今天小编就为大家分享一篇python实现对任意大小图片均匀切割的示例,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2018-12-12
  • Python闭包及装饰器运行原理解析

    Python闭包及装饰器运行原理解析

    这篇文章主要介绍了python闭包及装饰器运行原理解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2020-06-06
  • Python异步编程中asyncio.gather的并发控制详解

    Python异步编程中asyncio.gather的并发控制详解

    在Python异步编程生态中,asyncio.gather是并发任务调度的核心工具,本文将通过实际场景和代码示例,展示如何结合信号量机制实现精准并发控制,希望对大家有所帮助
    2025-03-03
  • pycharm远程开发项目的实现步骤

    pycharm远程开发项目的实现步骤

    这篇文章主要介绍了pycharm远程开发项目的实现步骤,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2019-01-01
  • Python自动化实战之接口请求的实现

    Python自动化实战之接口请求的实现

    本文为大家重点介绍如何通过 python 编码来实现我们的接口测试以及通过Pycharm的实际应用编写一个简单接口测试,感兴趣的可以了解一下
    2022-05-05
  • Python列表切片操作实例探究(提取复制反转)

    Python列表切片操作实例探究(提取复制反转)

    在Python中,列表切片是处理列表数据非常强大且灵活的方法,本文将全面探讨Python中列表切片的多种用法,包括提取子列表、复制列表、反转列表等操作,结合丰富的示例代码进行详细讲解
    2024-01-01
  • Python文件操作之合并文本文件内容示例代码

    Python文件操作之合并文本文件内容示例代码

    众所周知Python文件处理操作方便快捷,下面这篇文章主要给大家介绍了关于Python文件操作之合并文本文件内容的相关资料,文中通过示例代码介绍的非常详细,需要的朋友可以参考借鉴,下面随着小编来一起学习学习吧。
    2017-09-09
  • 基于Python开发云主机类型管理脚本分享

    基于Python开发云主机类型管理脚本分享

    这篇文章主要为大家详细介绍了如何基于Python开发一个云主机类型管理脚本,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下
    2023-02-02
  • Python的deque双端队列详解

    Python的deque双端队列详解

    这篇文章主要介绍了Python的deque双端队列详解,deque(双端队列)是一种数据结构,允许使用O(1)时间复杂度从两端添加和删除元素, Python的deque类实现了此数据结构,需要的朋友可以参考下
    2023-09-09

最新评论