pytorch中使用cuda扩展的实现示例

 更新时间:2020年02月12日 11:16:17   作者:outthinker  
这篇文章主要介绍了pytorch中使用cuda扩展的实现示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧

以下面这个例子作为教程,实现功能是element-wise add;

(pytorch中想调用cuda模块,还是用另外使用C编写接口脚本)

第一步:cuda编程的源文件和头文件

// mathutil_cuda_kernel.cu
// 头文件,最后一个是cuda特有的
#include <curand.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include "mathutil_cuda_kernel.h"

// 获取GPU线程通道信息
dim3 cuda_gridsize(int n)
{
  int k = (n - 1) / BLOCK + 1;
  int x = k;
  int y = 1;
  if(x > 65535) {
    x = ceil(sqrt(k));
    y = (n - 1) / (x * BLOCK) + 1;
  }
  dim3 d(x, y, 1);
  return d;
}
// 这个函数是cuda执行函数,可以看到细化到了每一个元素
__global__ void broadcast_sum_kernel(float *a, float *b, int x, int y, int size)
{
  int i = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
  if(i >= size) return;
  int j = i % x; i = i / x;
  int k = i % y;
  a[IDX2D(j, k, y)] += b[k];
}


// 这个函数是与c语言函数链接的接口函数
void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream)
{
  int size = x * y;
  cudaError_t err;
  
  // 上面定义的函数
  broadcast_sum_kernel<<<cuda_gridsize(size), BLOCK, 0, stream>>>(a, b, x, y, size);

  err = cudaGetLastError();
  if (cudaSuccess != err)
  {
    fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
    exit(-1);
  }
}
#ifndef _MATHUTIL_CUDA_KERNEL
#define _MATHUTIL_CUDA_KERNEL

#define IDX2D(i, j, dj) (dj * i + j)
#define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk))

#define BLOCK 512
#define MAX_STREAMS 512

#ifdef __cplusplus
extern "C" {
#endif

void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream);

#ifdef __cplusplus
}
#endif

#endif

第二步:C编程的源文件和头文件(接口函数)

// mathutil_cuda.c
// THC是pytorch底层GPU库
#include <THC/THC.h>
#include "mathutil_cuda_kernel.h"

extern THCState *state;

int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y)
{
  float *a = THCudaTensor_data(state, a_tensor);
  float *b = THCudaTensor_data(state, b_tensor);
  cudaStream_t stream = THCState_getCurrentStream(state);

  // 这里调用之前在cuda中编写的接口函数
  broadcast_sum_cuda(a, b, x, y, stream);

  return 1;
}

int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y);

第三步:编译,先编译cuda模块,再编译接口函数模块(不能放在一起同时编译)

nvcc -c -o mathutil_cuda_kernel.cu.o mathutil_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
import os
import torch
from torch.utils.ffi import create_extension

this_file = os.path.dirname(__file__)

sources = []
headers = []
defines = []
with_cuda = False

if torch.cuda.is_available():
  print('Including CUDA code.')
  sources += ['src/mathutil_cuda.c']
  headers += ['src/mathutil_cuda.h']
  defines += [('WITH_CUDA', None)]
  with_cuda = True

this_file = os.path.dirname(os.path.realpath(__file__))

extra_objects = ['src/mathutil_cuda_kernel.cu.o']  # 这里是编译好后的.o文件位置
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]


ffi = create_extension(
  '_ext.cuda_util',
  headers=headers,
  sources=sources,
  define_macros=defines,
  relative_to=__file__,
  with_cuda=with_cuda,
  extra_objects=extra_objects
)

if __name__ == '__main__':
  ffi.build()

第四步:调用cuda模块

from _ext import cuda_util #从对应路径中调用编译好的模块

a = torch.randn(3, 5).cuda()
b = torch.randn(3, 1).cuda()
mathutil.broadcast_sum(a, b, *map(int, a.size()))

# 上面等价于下面的效果:

a = torch.randn(3, 5)
b = torch.randn(3, 1)
a += b

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • windows上安装Anaconda和python的教程详解

    windows上安装Anaconda和python的教程详解

    本文主要给大家介绍windows上安装Anaconda和python的教程详解,非常不错,具有参考借鉴价值,需要的朋友参考下
    2017-03-03
  • Python批量实现Word/EXCEL/PPT转PDF

    Python批量实现Word/EXCEL/PPT转PDF

    在日常办公和文档处理中,有时我们需要将多个Word文档、Excel表格或PPT演示文稿转换为PDF文件,本文将介绍如何使用Python编程语言批量实现将多个Word、Excel和PPT文件转换为PDF文件,需要的可以参考下
    2023-09-09
  • Python实现批量图片去重

    Python实现批量图片去重

    在日常办公的时候,我们经常需要对图片进行去重后保存,如果一张张进行寻找将会非常的耗时,下面我们就来看看如何使用Python实现批量图片去重吧
    2024-11-11
  • python在前端页面使用 MySQLdb 连接数据

    python在前端页面使用 MySQLdb 连接数据

    这篇文章主要介绍了MySQLdb 连接数据的使用,文章主要介绍的相关内容又插入数据,删除数据,更新数据,搜索数据,需要的小伙伴可以参考一下
    2022-03-03
  • Python利用Flask动态生成汉字头像

    Python利用Flask动态生成汉字头像

    这篇文章主要为大家详细介绍了Python如何利用Flask动态生成一个汉字头像,文中的示例代码讲解详细,对我们学习Python有一定的帮助,需要的可以参考一下
    2023-01-01
  • 详解如何用OpenCV + Python 实现人脸识别

    详解如何用OpenCV + Python 实现人脸识别

    这篇文章主要介绍了详解如何用OpenCV + Python 实现人脸识别,非常具有实用价值,需要的朋友可以参考下
    2017-10-10
  • Python OpenCV读取中文路径图像的方法

    Python OpenCV读取中文路径图像的方法

    这篇文章主要介绍了Python OpenCV读取中文路径图像的方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下
    2020-07-07
  • Python爬虫常用库的安装及其环境配置

    Python爬虫常用库的安装及其环境配置

    今天小编就为大家分享一篇关于python爬虫常用库的安装及其环境配置的文章,小编觉得内容挺不错的,现在分享给大家,具有很好的参考价值,需要的朋友一起跟随小编来看看吧
    2018-09-09
  • Python开发之os与os.path的使用小结

    Python开发之os与os.path的使用小结

    这篇文章主要介绍了Python开发之os与os.path的使用小结,本文通过实例代码给大家介绍的非常详细,感兴趣的朋友一起看看吧
    2024-05-05
  • Python中PyExecJS(执行JS代码库)的具体使用

    Python中PyExecJS(执行JS代码库)的具体使用

    pyexecjs是一个用Python来执行JavaScript代码的工具库,本文主要介绍了Python中PyExecJS(执行JS代码库)的具体使用,具有一定的参考价值,感兴趣的可以了解一下
    2024-02-02

最新评论