diff --git a/advanced_source/cpp_custom_ops_sycl.rst b/advanced_source/cpp_custom_ops_sycl.rst index 3b3ad069b58..f40d7787353 100644 --- a/advanced_source/cpp_custom_ops_sycl.rst +++ b/advanced_source/cpp_custom_ops_sycl.rst @@ -13,13 +13,14 @@ Custom SYCL Operators .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites :class-card: card-prerequisites - * PyTorch 2.8 or later + * PyTorch 2.8 or later for Linux + * PyTorch 2.10 or later for Windows * Basic understanding of SYCL programming .. note:: ``SYCL`` serves as the backend programming language for Intel GPUs (device label ``xpu``). For configuration details, see: - `Getting Started on Intel GPUs `_. The Intel Compiler, which comes bundled with Intel Deep Learning Essentials, handles ``SYCL`` compilation. Ensure you install and activate the compiler environment prior to executing the code examples in this tutorial. + `Getting Started on Intel GPUs `_. The Intel Compiler, which comes bundled with `Intel Deep Learning Essentials `_, handles ``SYCL`` compilation. Ensure you install and activate the compiler environment prior to executing the code examples in this tutorial. PyTorch offers a large library of operators that work on Tensors (e.g. torch.add, torch.sum, etc). However, you may wish to bring a new custom operator to PyTorch. This tutorial demonstrates the @@ -47,45 +48,65 @@ Using ``sycl_extension`` is as straightforward as writing the following ``setup. .. code-block:: python - import os - import torch - import glob - from setuptools import find_packages, setup - from torch.utils.cpp_extension import SyclExtension, BuildExtension - - library_name = "sycl_extension" - py_limited_api = True - extra_compile_args = { - "cxx": ["-O3", - "-fdiagnostics-color=always", - "-DPy_LIMITED_API=0x03090000"], - "sycl": ["-O3" ] - } - - assert(torch.xpu.is_available()), "XPU is not available, please check your environment" - # Source files collection - this_dir = os.path.dirname(os.path.curdir) - extensions_dir = os.path.join(this_dir, library_name) - sources = list(glob.glob(os.path.join(extensions_dir, "*.sycl"))) - # Construct extension - ext_modules = [ - SyclExtension( - f"{library_name}._C", - sources, - extra_compile_args=extra_compile_args, - py_limited_api=py_limited_api, - ) - ] - setup( - name=library_name, - packages=find_packages(), - ext_modules=ext_modules, - install_requires=["torch"], - description="Simple Example of PyTorch Sycl extensions", - cmdclass={"build_ext": BuildExtension}, - options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {}, - ) - + import os + import torch + import glob + import platform + from setuptools import find_packages, setup + from torch.utils.cpp_extension import SyclExtension, BuildExtension + + library_name = "sycl_extension" + py_limited_api = True + + IS_WINDOWS = (platform.system() == 'Windows') + + if IS_WINDOWS: + cxx_args = [ + "/O2", + "/std:c++17", + "/DPy_LIMITED_API=0x03090000", + "-fheader-search=gcc", + ] + sycl_args = ["/O2", "/std:c++17", "-fheader-search=gcc"] + else: + cxx_args = [ + "-O3", + "-fdiagnostics-color=always", + "-DPy_LIMITED_API=0x03090000" + ] + sycl_args = ["-O3"] + + extra_compile_args = { + "cxx": cxx_args, + "sycl": sycl_args + } + + assert(torch.xpu.is_available()), "XPU is not available, please check your environment" + + # Source files collection + this_dir = os.path.dirname(os.path.curdir) + extensions_dir = os.path.join(this_dir, library_name) + sources = list(glob.glob(os.path.join(extensions_dir, "*.sycl"))) + + # Construct extension + ext_modules = [ + SyclExtension( + f"{library_name}._C", + sources, + extra_compile_args=extra_compile_args, + py_limited_api=py_limited_api, + ) + ] + + setup( + name=library_name, + packages=find_packages(), + ext_modules=ext_modules, + install_requires=["torch"], + description="Simple Example of PyTorch Sycl extensions", + cmdclass={"build_ext": BuildExtension}, + options={"bdist_wheel": {"py_limited_api": "cp39"}} if py_limited_api else {}, + ) Defining the custom op and adding backend implementations --------------------------------------------------------- @@ -101,82 +122,109 @@ in a separate ``TORCH_LIBRARY_IMPL`` block: .. code-block:: cpp - #include - #include - #include - #include - #include - - namespace sycl_extension { - // MulAdd Kernel: result = a * b + c - static void muladd_kernel( - int numel, const float* a, const float* b, float c, float* result, - const sycl::nd_item<1>& item) { - int idx = item.get_global_id(0); - if (idx < numel) { - result[idx] = a[idx] * b[idx] + c; - } - } - - class MulAddKernelFunctor { - public: - MulAddKernelFunctor(int _numel, const float* _a, const float* _b, float _c, float* _result) - : numel(_numel), a(_a), b(_b), c(_c), result(_result) {} - void operator()(const sycl::nd_item<1>& item) const { - muladd_kernel(numel, a, b, c, result, item); - } - - private: - int numel; - const float* a; - const float* b; - float c; - float* result; - }; - - at::Tensor mymuladd_xpu(const at::Tensor& a, const at::Tensor& b, double c) { - TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have the same shape"); - TORCH_CHECK(a.dtype() == at::kFloat, "a must be a float tensor"); - TORCH_CHECK(b.dtype() == at::kFloat, "b must be a float tensor"); - TORCH_CHECK(a.device().is_xpu(), "a must be an XPU tensor"); - TORCH_CHECK(b.device().is_xpu(), "b must be an XPU tensor"); - - at::Tensor a_contig = a.contiguous(); - at::Tensor b_contig = b.contiguous(); - at::Tensor result = at::empty_like(a_contig); - - const float* a_ptr = a_contig.data_ptr(); - const float* b_ptr = b_contig.data_ptr(); - float* res_ptr = result.data_ptr(); - int numel = a_contig.numel(); - - sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue(); - constexpr int threads = 256; - int blocks = (numel + threads - 1) / threads; - - queue.submit([&](sycl::handler& cgh) { - cgh.parallel_for( - sycl::nd_range<1>(blocks * threads, threads), - MulAddKernelFunctor(numel, a_ptr, b_ptr, static_cast(c), res_ptr) - ); - }); - - return result; - } - // Defines the operators - TORCH_LIBRARY(sycl_extension, m) { + #include + #include + #include + #include + #include + + + #include + + namespace sycl_extension { + + // ========================================================== + // 1. Kernel + // ========================================================== + static void muladd_kernel( + int numel, const float* a, const float* b, float c, float* result, + const sycl::nd_item<1>& item) { + int idx = item.get_global_id(0); + if (idx < numel) { + result[idx] = a[idx] * b[idx] + c; + } + } + + class MulAddKernelFunctor { + public: + MulAddKernelFunctor(int _numel, const float* _a, const float* _b, float _c, float* _result) + : numel(_numel), a(_a), b(_b), c(_c), result(_result) {} + void operator()(const sycl::nd_item<1>& item) const { + muladd_kernel(numel, a, b, c, result, item); + } + + private: + int numel; + const float* a; + const float* b; + float c; + float* result; + }; + + // ========================================================== + // 2. Wrapper + // ========================================================== + at::Tensor mymuladd_xpu(const at::Tensor& a, const at::Tensor& b, double c) { + TORCH_CHECK(a.sizes() == b.sizes(), "a and b must have the same shape"); + TORCH_CHECK(a.dtype() == at::kFloat, "a must be a float tensor"); + TORCH_CHECK(b.dtype() == at::kFloat, "b must be a float tensor"); + TORCH_CHECK(a.device().is_xpu(), "a must be an XPU tensor"); + TORCH_CHECK(b.device().is_xpu(), "b must be an XPU tensor"); + + at::Tensor a_contig = a.contiguous(); + at::Tensor b_contig = b.contiguous(); + at::Tensor result = at::empty_like(a_contig); + + const float* a_ptr = a_contig.data_ptr(); + const float* b_ptr = b_contig.data_ptr(); + float* res_ptr = result.data_ptr(); + int numel = a_contig.numel(); + + sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue(); + constexpr int threads = 256; + int blocks = (numel + threads - 1) / threads; + + queue.submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<1>(blocks * threads, threads), + MulAddKernelFunctor(numel, a_ptr, b_ptr, static_cast(c), res_ptr) + ); + }); + + return result; + } + + // ========================================================== + // 3. Registration + // ========================================================== + TORCH_LIBRARY(sycl_extension, m) { m.def("mymuladd(Tensor a, Tensor b, float c) -> Tensor"); - } - - // ================================================== - // Register SYCL Implementations to Torch Library - // ================================================== - TORCH_LIBRARY_IMPL(sycl_extension, XPU, m) { - m.impl("mymuladd", &mymuladd_xpu); - } - - } // namespace sycl_extension - + } + + TORCH_LIBRARY_IMPL(sycl_extension, XPU, m) { + m.impl("mymuladd", &mymuladd_xpu); + } + + } // namespace sycl_extension + + // ========================================================== + // 4. Windows Linker + // ========================================================== + extern "C" { + #ifdef _WIN32 + __declspec(dllexport) + #endif + PyObject* PyInit__C(void) { + static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + "_C", + "XPU Extension Shim", + -1, + NULL + }; + return PyModule_Create(&moduledef); + } + } Create a Python Interface @@ -201,26 +249,39 @@ Create ``sycl_extension/__init__.py`` file to make the package importable: .. code-block:: python - import ctypes - from pathlib import Path + import ctypes + import platform + from pathlib import Path - import torch + import torch + + current_dir = Path(__file__).parent.parent + build_dir = current_dir / "build" + + if platform.system() == 'Windows': + file_pattern = "**/*.pyd" + else: + file_pattern = "**/*.so" + + lib_files = list(build_dir.glob(file_pattern)) + + if not lib_files: + current_package_dir = Path(__file__).parent + lib_files = list(current_package_dir.glob(file_pattern)) - current_dir = Path(__file__).parent.parent - build_dir = current_dir / "build" - so_files = list(build_dir.glob("**/*.so")) + assert len(lib_files) > 0, f"Could not find any {file_pattern} file in {build_dir} or {current_dir}" + lib_file = lib_files[0] - assert len(so_files) == 1, f"Expected one _C*.so file, found {len(so_files)}" - with torch._ops.dl_open_guard(): - loaded_lib = ctypes.CDLL(so_files[0]) + with torch._ops.dl_open_guard(): + loaded_lib = ctypes.CDLL(str(lib_file)) - from . import ops + from . import ops - __all__ = [ - "loaded_lib", - "ops", - ] + __all__ = [ + "loaded_lib", + "ops", + ] Testing SYCL extension operator -------------------