BiRefNet Portrait TensorRT
Model Description
BiRefNet Portrait TensorRT is a high-performance background removal model optimized for NVIDIA GPUs using TensorRT. It delivers 5.3x speedup over standard PyTorch inference with a median latency of 123ms at 1024x1024 resolution (RTX 3060, FP16).
This model is a TensorRT-optimized version of BiRefNet trained specifically for portrait matting, ideal for real-time selfie and photo background removal applications.
- Model type: Image Segmentation / Portrait Matting
- License: See original BiRefNet repository for license details
- Parent Model: ZhengPeng7/BiRefNet
- Related Model: briaai/RMBG-2.0
Model Details
Model Architecture
| Attribute | Value |
|---|---|
| Architecture | BiRefNet (Bilateral Reference Network) |
| Backbone | Swin Transformer V1 Large |
| Task | Portrait Matting (Dichotomous Image Segmentation) |
| Input Resolution | 1024 ร 1024 pixels |
| Input Channels | 3 (RGB) |
| Output | Single-channel alpha mask |
| Format | TensorRT Engine (.trt) |
| Precision | FP16 (Half Precision) |
Training Data
The original BiRefNet model was trained on:
| Task | Training Sets | Backbone | Test Set |
|---|---|---|---|
| Portrait matting | P3M-10k, humans | Swin-V1-Large | P3M-500-P |
Model Specifications
- File Size: ~615 MB
- Engine Version: TensorRT 10.x compatible
- CUDA Requirements: CUDA 12.0+
- GPU Memory: ~2GB at 1024x1024 FP16
Intended Uses & Limitations
Intended Uses
- Selfie background removal - Real-time processing of portrait photos
- Video conferencing - Virtual background for live video
- Photo editing - Batch background removal for portrait photography
- Portrait matting - High-quality alpha extraction for compositing
Limitations
- Input constraints: Optimized for 1024ร1024 resolution; resizing required for other sizes
- Subject matter: Trained primarily on human portraits; performance may vary on non-human subjects
- Background complexity: Works best with distinct foreground/background separation
- Hardware: Requires NVIDIA GPU with TensorRT support
- Batch size: Currently optimized for single-image inference (batch=1)
How to Get Started with the Model
This TensorRT engine file (birefnet_portrait.trt) can be used in two ways:
Option 1: Using Transformers (Recommended for Python)
from transformers import pipeline
# Load the pipeline with TensorRT support
remover = pipeline("image-segmentation", model="israellaguan/birefnet-portrait-tensorrt")
# Process an image
result = remover("input.jpg")
Option 2: Using Official TensorRT APIs
For production deployment or C++ applications, use the official TensorRT runtime as documented in the NVIDIA TensorRT Quick Start Guide:
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
# Load engine
logger = trt.Logger(trt.Logger.WARNING)
with open("birefnet_portrait.trt", "rb") as f:
runtime = trt.Runtime(logger)
engine = runtime.deserialize_cuda_engine(f.read())
# Create execution context
context = engine.create_execution_context()
# Allocate buffers and run inference
# See NVIDIA TensorRT documentation for complete examples
Option 3: Using the Provided Python Package
This repository includes a complete Python package in the rmbg/ folder with CLI and API for running the TensorRT model. See rmbg/README.md for detailed usage:
# CLI usage
python -m rmbg.cli process input.jpg -o output/
# With options
python -m rmbg.cli process input.jpg -o output/ --verbose --warmup
The rmbg/ package provides:
- CLI tool (
cli.py) - Command-line interface with progress bars - Python API (
tools/) - Pipeline and BackgroundRemover classes - TensorRT backend (
backends/tensorrt.py) - Native TensorRT inference
Requirements
pip install tensorrt pycuda Pillow numpy
- CUDA 12.0+
- TensorRT 10.x
- NVIDIA GPU with Compute Capability 7.0+
Performance Benchmarks
| Runtime | Median Latency | FPS | Speedup |
|---|---|---|---|
| TensorRT FP16 | 123ms | 8.1 | 5.3x |
| PyTorch (baseline) | 653ms | 1.5 | 1.0x |
Tested on RTX 3060, CUDA 12.0, TensorRT 10.8, 1024ร1024 input
Conversion Process
This model was converted from PyTorch to TensorRT using the ONNX intermediate format:
PyTorch โ ONNX โ TensorRT Engine
Step 1: Export to ONNX
import torch
# Load your BiRefNet model
checkpoint_path = "BiRefNet-portrait-epoch_150.pth" # or any model from https://github.com/ZhengPeng7/BiRefNet/releases
device = torch.device("cuda")
# Create model and load weights
from models.birefnet import BiRefNet
from utils import check_state_dict
model = BiRefNet(bb_pretrained=False)
state_dict = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
state_dict = check_state_dict(state_dict)
model.load_state_dict(state_dict)
model = model.eval().to(device)
# Export to ONNX
dummy_input = torch.randn(1, 3, 1024, 1024).to(device)
torch.onnx.export(
model,
dummy_input,
"birefnet_portrait.onnx",
export_params=True,
opset_version=17,
do_constant_folding=True,
input_names=['input_image'],
output_names=['output_mask'],
dynamic_axes={
'input_image': {0: 'batch_size', 2: 'height', 3: 'width'},
'output_mask': {0: 'batch_size', 2: 'height', 3: 'width'}
},
)
Step 2: Build TensorRT Engine
import tensorrt as trt
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, logger)
# Parse ONNX model
with open("birefnet_portrait.onnx", 'rb') as f:
if not parser.parse(f.read()):
for error in range(parser.num_errors):
print(parser.get_error(error))
raise RuntimeError("ONNX parsing failed")
# Configure builder
config = builder.create_builder_config()
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 4 * (1 << 30)) # 4GB workspace
config.set_flag(trt.BuilderFlag.FP16) # Enable FP16 precision
# Build engine
serialized_engine = builder.build_serialized_network(network, config)
if serialized_engine is None:
raise RuntimeError("Engine build failed")
# Save engine
with open("birefnet_portrait.trt", 'wb') as f:
f.write(serialized_engine)
print(f"TensorRT engine saved: {len(serialized_engine) / (1024**2):.1f} MB")
Available Models
Other BiRefNet models can be converted using the same process. Download checkpoints from:
- Official releases: https://github.com/ZhengPeng7/BiRefNet/releases
- Portrait model (used here):
BiRefNet-portrait-epoch_150.pth - General purpose:
BiRefNet-general-epoch_240.pth - High-resolution matting: Various task-specific models available
Note on prior work: Previous BiRefNet TensorRT implementations (lbq779660843/BiRefNet-Tensorrt and yuanyang1991/birefnet_tensorrt) were published 2 years ago and provide only conversion directions without downloadable pre-built models. This repository provides a ready-to-use TensorRT engine.
Citation
If you use this model, please cite the original BiRefNet paper:
@article{biRefNet2024,
title={BiRefNet: Bilateral Reference Network for High-Resolution Dichotomous Image Segmentation},
author={Zheng, Peng and Gao, Dehong and Fan, Guolei and Li, Sheng and Sarkar, Berihun},
journal={arXiv preprint},
year={2024}
}
- Downloads last month
- -
Model tree for israellaguan/birefnet-portrait-tensorrt
Base model
ZhengPeng7/BiRefNet