Registering custom C++/CUDA operators using modern PyTorch APIs
Nowadays, most ML inference engines are built on top of PyTorch‘s echosystem. By leveraging custom operators, they deliver state-of-the-art throughput on large workloads—ranging from LLMs to Stable Diffusion. Compared to PyTorch’s native kernels, these operators usually offer lower latency and enable optimized implementations for cutting-edge operations that aren’t yet supported out of the box.
This tutorial shows you how to build and integrate a simple PyTorch custom operator that runs on both CPU and NVIDIA GPUs. For the PyTorch APIs used in this tutorial, check out the official PyTorch documentation and this Google Doc. An official example project can be found on GitHub. Here, we provide a more compact, self-contained version to illustrate the same concepts in a smaller, easier-to-follow project.
This simple custom operator is provided purely for demonstration; in real-world projects,PyTorch’s built-in operators can already handle simple workloads — like element-wise multiplication efficiently.
The repository of this tutorial:
Kernel Implementation
The kernel simply multiply the input by 1.23 — equivalent to this Python snippet:
1 | def short_op(x: torch.Tensor): |
We define our simple kernel in the simple_ops.cu file.
1 | // my_ops/csrc/simple_ops.cu |
We also provide two binding functions that sit between PyTorch and our low-level C++/CUDA code. These wrappers map PyTorch’s tensor dtypes to the correct C++/CUDA scalar types and then launch the appropriate kernel.
PyTorch Binding: CUDA Side
1 | // my_ops/csrc/simple_ops.cu |
PyTorch Binding: CPU side
1 | // ---------------------------------------- |
What is AT_DISPATCH_FLOATING_TYPES_AND_HALF?
Here, AT_DISPATCH_FLOATING_TYPES_AND_HALF
becomes roughly
1 | switch (x.scalar_type()) { |
This macro eliminates boilerplate and guarantees you cover all the major floating-point types seamlessly.
Kernel registration:
Finally, we wire everything into PyTorch’s dispatcher:
1 | // my_ops/csrc/simple_ops.cu |
TORCH_LIBRARY(my_ops, m)
registers operators to the namespace ‘my_ops’ in PyTorch, so we can use torch.ops.my_ops.short_kernel to call our implementations. To register this op, we need to pass a schema "short_kernel(Tensor x) -> Tensor"
to tell PyTorch how this op can be called. Please see The Custom Operators Manual for more details.
TORCH_LIBRARY_IMPL(my_ops, CUDA, m)
says “for themy_ops
operator library, here are the implementations to use when the dispatch key is CUDA.”m.impl("short_kernel", &short_kernel)
binds the C++ functionshort_kernel(at::Tensor)
as the CUDA‐backend kernel for that op. So if you do:
1 | x = torch.randn(..., device="cuda") |
the dispatcher will route the call into your short_kernel
CUDA function.
TORCH_LIBRARY_IMPL(my_ops, CPU, m)
does the same thing for the CPU dispatch key.- If you call
1 | x = torch.randn(..., device="cpu") |
then PyTorch will invoke your short_kernel_cpu(at::Tensor)
function instead.
Setting Up the Build System
Following Python’s convention, we use setuptools to configure the build system.
Our code is as simple as the following:
1 | # setup.py |
By adding the compiler flag -DPy_LIMITED_API=0x03090000
, we can assure that the built extension can be run on any Python environment with version ≥ 3.9. It helps verify that the extension is in fact only using the CPython Stable Limited API which ensures forward compatibility. According to PyTorch documentation,
If this requirement (Defining the
Py_LIMITED_API
flag) is not met, it is possible to build a wheel that looks CPython agnostic but will crash, or worse, be silently incorrect, in another CPython environment.
We need to register an empty _C module as the following code block shows to let python directly import our built .so library. By configuring ext_modules.cpp_extension.CUDAExtension.name = “my_ops._C”
, the output file of our built module would be my_ops/_C.abi3.so. By executing from my_ops import _C
, our custom ops can be registered to PyTorch since our simple_ops.cu file contains calls to TORCH_LIBRARY
andTORCH_LIBRARY_IMPL
.
1 | // my_ops/csrc/simple_ops.cu |
How to export a c++ function to Python?
The extern "C" PyObject *PyInit__C(void)
is a standard way to expose a cpp function to Python.
To register a cpp function, we need to do the following things:
- Write a real C++ func (
add_one
). - Wrap it in a
py_add_one
that speaks the Python/C API: parsing args & building return values. - Fill out a
PyMethodDef[]
with your method(s). - Point your
PyModuleDef
at that table. - Keep your same
PyInit__C()
entrypoint.
A complete example:
1 | // my_ops/csrc/simple_ops.cu |
Here we use extern “C”
to disable the name mangling feature of C++. Without extern "C"
, the compiler might emit a symbol like _Z10PyInit__Cv
(or worse), and Python wouldn’t be able to locate the entry point it expects.
To register our custom ops at import time, __init__.py
should be written as
1 | // my_ops/__init__.py |
Project File Organization
The current file organization:
1 | my_ops/ |
Building Python Module and Importing
From the project’s root directory, run:
pip install -e .
This command tells pip to install the current directory (.
) as an editable package.
During installation, pip will compile the C++/CUDA source code (as specified in setup.py) and place the resulting .so
extension inside the package directory ./my_ops
.
Once that’s done, you can simply:
1 | import my_ops |
to invoke your custom operator.