Python去除图片背景的两种方式介绍

 更新时间:2025年06月18日 15:07:33   作者:培根芝士  
这篇文章主要为大家详细介绍了Python去除图片背景的两种方式,一个是使用rembg,一个是使用U2-NET,文中的示例代码讲解详细,感兴趣的小伙伴可以了解下

1.使用rembg去除图像背景

rembg 是一个开源的 Python 库,专门用于去除图像背景,它利用深度神经网络能够准确地识别并去除图像背景,使用户无需进行复杂的手动编辑,只需几行代码即可获得专业效果。Rembg 基于 U2-Net 架构,有多种架构修改和经过测试的方法以提供最佳结果,还提供对 GPU 安装的访问以实现更快的处理。

安装依赖:

pip install rembg

示例代码:

from rembg import remove
from PIL import Image
import numpy as np
import cv2
 
def remove_background(source_image_path):
    # 打开要处理的图片
    input_image = Image.open(source_image_path)
 
    # 使用 rembg 去除背景
    output_image = remove(input_image, 
        alpha_matting=False
    )
 
    # 将 PIL.Image 转换为 OpenCV 格式
    output_image_array = np.array(output_image)
 
    # 如果是 RGBA 图像,转换为 BGRA(OpenCV 使用 BGRA 格式)
    if output_image_array.shape[2] == 4:
        output_image_bgra = cv2.cvtColor(output_image_array, cv2.COLOR_RGBA2BGRA)
    else:
        output_image_bgra = cv2.cvtColor(output_image_array, cv2.COLOR_RGB2BGR)
    return output_image_bgra

2.使用U2-NET去除图像背景

U2-Net 是一种用于显著目标检测的深度学习模型,在 CVPR2020 开源后备受关注。其具有两层嵌套的 U 型结构,底层是带有新颖的 ReSidual U-Block(RSU)模块,可在不降低特征图分辨率的情况下提取多尺度特征;顶层则类似于 U-Net 结构,每个阶段都由 RSU 填充,这种结构使网络能够在不显著增加内存和计算成本的情况下,加深网络并获得高分辨率特征图,从而有效捕捉显著目标的多尺度信息。

示例代码:

import torch
import os
import cv2
import numpy as np
from torchvision import transforms
from PIL import Image
import torch.nn as nn
from torch.nn import BatchNorm2d
from u2net import U2NET
from base_util import get_images_in_dir
 
# 定义U-2-Net模型结构
class REBNCONV(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, dirate=1):
        super(REBNCONV, self).__init__()
        self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate)
        self.bn_s1 = BatchNorm2d(out_ch)
        self.relu_s1 = nn.ReLU(inplace=True)
 
    def forward(self, x):
        hx = x
        xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
        return xout
 
# 定义U-2-Net模型的其他组件...
def normPRED(d):
    ma = torch.max(d)
    mi = torch.min(d)
    dn = (d - mi) / (ma - mi)
    return dn
 
def remove_background(input_path, output_path, model_path='models/u2net.pth'):
    # 加载模型
    net = U2NET(3,1)
    net.load_state_dict(torch.load(model_path))
    if torch.cuda.is_available():
        net.cuda()
    net.eval()
    
    # 读取图像
    img = Image.open(input_path).convert('RGB')
 
    # 图像预处理
    transforms_list = []
    transforms_list.append(transforms.Resize((320, 320)))
    transforms_list.append(transforms.ToTensor())
    transforms_list.append(transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]))
    transform = transforms.Compose(transforms_list)
    
    image = transform(img)
    image = image.unsqueeze(0)
    
    # 模型推理
    if torch.cuda.is_available():
        image = image.cuda()
    
    with torch.no_grad():
        d1, d2, d3, d4, d5, d6, d7 = net(image)
    
    # 获取预测结果
    pred = d1[:, 0, :, :]
    pred = normPRED(pred)
    
    # 将预测结果转换为掩码
    predict = pred.squeeze()
    predict_np = predict.cpu().data.numpy()
    mask = Image.fromarray((predict_np * 255).astype(np.uint8))
    mask = mask.resize(img.size, resample=Image.BILINEAR)
    
    # 应用阈值处理掩码
    mask_threshold = 0.5  # 阈值可以根据实际情况调整
    mask_array = np.array(mask)
    mask_array = np.where(mask_array > mask_threshold * 255, 255, 0).astype(np.uint8)
 
    kernel = np.ones((3, 3), np.uint8)  # 定义膨胀核
    mask_array = cv2.dilate(mask_array, kernel, iterations=1)
 
    # 应用掩码到原始图像
    img_array = np.array(img)
    
    # 创建透明背景的图像
    transparent_img = np.zeros((img_array.shape[0], img_array.shape[1], 4), dtype=np.uint8)
    transparent_img[:, :, :3] = img_array
    transparent_img[:, :, 3] = mask_array
 
    # 保存结果
    Image.fromarray(transparent_img).save(output_path)
    print(f"已保存处理后的图像到: {output_path}")

到此这篇关于Python去除图片背景的两种方式介绍的文章就介绍到这了,更多相关Python去除图片背景内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

相关文章

  • Django配置多个环境的MySQL数据库的完整指南

    Django配置多个环境的MySQL数据库的完整指南

    在 Django 项目中配置多个环境的 MySQL 数据库是一个常见的需求,本文为大家详细介绍了配置的完整方法,感兴趣的小伙伴可以跟随小编一起学习一下
    2025-04-04
  • Python函数返回不定数量的值方法

    Python函数返回不定数量的值方法

    今天小编就为大家分享一篇Python函数返回不定数量的值方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-01-01
  • Python使用Keras库中的LSTM模型生成新文本内容教程

    Python使用Keras库中的LSTM模型生成新文本内容教程

    Python语言使用金庸小说文本库,对文本进行预处理,然后使用Keras库中的LSTM模型创建和训练了模型,根据这个模型,我们可以生成新的文本,并探索小说的不同应用
    2024-01-01
  • python应用之如何使用Python发送通知到微信

    python应用之如何使用Python发送通知到微信

    现在通过发微信信息来做消息通知和告警已经很普遍了,下面这篇文章主要给大家介绍了关于python应用之如何使用Python发送通知到微信的相关资料,文中通过实例代码介绍的非常详细,需要的朋友可以参考下
    2022-03-03
  • Python操作PDF实现制作数据报告

    Python操作PDF实现制作数据报告

    Python操作PDF的库有很多,比如PyPDF2、pdfplumber、PyMuPDF等等。本文将利用FPDF模块操作PDF实现制作数据报告,感兴趣的小伙伴可以尝试一下
    2022-12-12
  • Django模板标签中url使用详解(url跳转到指定页面)

    Django模板标签中url使用详解(url跳转到指定页面)

    这篇文章主要介绍了Django模板标签中url使用详解(url跳转到指定页面),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-03-03
  • np.hstack()和np.dstack()的使用

    np.hstack()和np.dstack()的使用

    本文主要介绍了np.hstack()和np.dstack()的使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-03-03
  • openstack中的rpc远程调用的方法

    openstack中的rpc远程调用的方法

    今天通过本文给大家分享openstack中的rpc远程调用的方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧
    2021-07-07
  • python正则表达式去掉数字中的逗号(python正则匹配逗号)

    python正则表达式去掉数字中的逗号(python正则匹配逗号)

    在处理自然语言时123,000,000如果以标点符号分割,就会出现问题,好好的一个数字就被逗号肢解了,因此可以先下手把数字处理干净(逗号去掉)
    2013-12-12
  • Django定制Admin页面详细实例(展示页面和编辑页面)

    Django定制Admin页面详细实例(展示页面和编辑页面)

    django自带的admin因为功能和样式比较简陋,常常需要再次定制,下面这篇文章主要给大家介绍了关于Django定制Admin页面(展示页面和编辑页面)的相关资料,需要的朋友可以参考下
    2023-06-06

最新评论