From 22bfdc11686112fce03b6d56453a410d700c6b2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 13 Feb 2024 17:04:36 +0100 Subject: [PATCH] Add optimized blockwise gemm using ck wrapper (#1157) * Add optimized blockwise gemm using ck wrapper * Add basic gemm example * Update docs * Add tutorial for gemm using ck wrapper * Add perf note * edits * Fix cmake * Fixes --------- Co-authored-by: Lisa Delaney [ROCm/composable_kernel commit: 1e73adbc2809fb582c40f91daa8ecd7cd6737aff] --- client_example/25_wrapper/CMakeLists.txt | 8 + client_example/25_wrapper/README.md | 177 +++++++++ .../25_wrapper/wrapper_basic_gemm.cpp | 216 ++++++++++ client_example/25_wrapper/wrapper_img2col.cpp | 42 +- .../25_wrapper/wrapper_optimized_gemm.cpp | 308 ++++++++++++++ docs/wrapper.rst | 10 +- include/ck/wrapper/operations/copy.hpp | 68 ++-- include/ck/wrapper/operations/gemm.hpp | 98 +++-- include/ck/wrapper/tensor.hpp | 4 +- .../traits/blockwise_gemm_xdl_traits.hpp | 47 ++- include/ck/wrapper/utils/kernel_utils.hpp | 14 + include/ck/wrapper/utils/layout_utils.hpp | 105 ++++- include/ck/wrapper/utils/tensor_partition.hpp | 290 +++++++++----- test/wrapper/CMakeLists.txt | 27 +- test/wrapper/test_gemm.cpp | 257 ------------ .../{test_copy.cpp => test_wrapper_copy.cpp} | 27 +- test/wrapper/test_wrapper_gemm.cpp | 376 ++++++++++++++++++ ...est_layout.cpp => test_wrapper_layout.cpp} | 2 +- ...rtition.cpp => test_wrapper_partition.cpp} | 33 +- ...est_tensor.cpp => test_wrapper_tensor.cpp} | 0 20 files changed, 1597 insertions(+), 512 deletions(-) create mode 100644 client_example/25_wrapper/README.md create mode 100644 client_example/25_wrapper/wrapper_basic_gemm.cpp create mode 100644 client_example/25_wrapper/wrapper_optimized_gemm.cpp create mode 100644 include/ck/wrapper/utils/kernel_utils.hpp delete mode 100644 test/wrapper/test_gemm.cpp rename test/wrapper/{test_copy.cpp => test_wrapper_copy.cpp} (83%) create mode 100644 test/wrapper/test_wrapper_gemm.cpp rename test/wrapper/{test_layout.cpp => test_wrapper_layout.cpp} (99%) rename test/wrapper/{test_partition.cpp => test_wrapper_partition.cpp} (79%) rename test/wrapper/{test_tensor.cpp => test_wrapper_tensor.cpp} (100%) diff --git a/client_example/25_wrapper/CMakeLists.txt b/client_example/25_wrapper/CMakeLists.txt index eb3be0e6c8..fdfc1d8d2e 100644 --- a/client_example/25_wrapper/CMakeLists.txt +++ b/client_example/25_wrapper/CMakeLists.txt @@ -2,3 +2,11 @@ add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrap target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations) add_executable(client_wrapper_img2col wrapper_img2col.cpp) target_link_libraries(client_wrapper_img2col PRIVATE composable_kernel::device_other_operations) +if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR + GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR + GPU_TARGETS MATCHES "gfx942") + add_executable(client_wrapper_basic_gemm wrapper_basic_gemm.cpp) + target_link_libraries(client_wrapper_basic_gemm PRIVATE composable_kernel::device_other_operations) + add_executable(client_wrapper_optimized_gemm wrapper_optimized_gemm.cpp) + target_link_libraries(client_wrapper_optimized_gemm PRIVATE composable_kernel::device_other_operations) +endif() diff --git a/client_example/25_wrapper/README.md b/client_example/25_wrapper/README.md new file mode 100644 index 0000000000..eba3de017f --- /dev/null +++ b/client_example/25_wrapper/README.md @@ -0,0 +1,177 @@ +# Composable Kernel wrapper GEMM tutorial + +This tutorial demonstrates how to implement matrix multiplication using Composable Kernel (CK) +wrapper. We present the base version of GEMM without most of the available optimizations; however, +it's worth noting that CK has kernels with different optimizations. + +To implement these optimizations, you can use the CK wrapper or directly use available instances in +CK. You can also refer to the +[optimized GEMM example](https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_optimized_gemm.cpp), +that uses CK wrapper based on the +[`gridwise_gemm_xdlops_v2r3`](https://github.com/ROCm/composable_kernel/blob/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp) implementation. + +The kernel definition should look similar to: + +```cpp +template +__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a, + const void* p_b, + void* p_c, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape tile_shape, + const ThreadLayout thread_layout) +``` + +We pass pointers to global memory and matrix dimensions via arguments. Additionally, we pass +selected lengths of processed data through each block (`tile_shape`) and thread layout +(`thread_layout`). For compilation time parameters, we define the data type, +[traits for the GEMM operation](https://github.com/ROCm/composable_kernel/blob/develop/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp) +and scalar per vector value during copy. + +Step 1: Create layouts for global and LDS memory. + +```cpp + // Specify layouts for global memory. + const auto a_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1)); + const auto b_global_layout = + ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1)); + const auto c_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1)); + + // Specify layouts for tiles. + constexpr auto a_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); + constexpr auto b_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); + constexpr auto c_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{})); + + // Apply padding for global memory. + auto a_global_layout_padded = ck::wrapper::pad(a_global_layout, shape(a_tile_layout)); + auto b_global_layout_padded = ck::wrapper::pad(b_global_layout, shape(b_tile_layout)); + auto c_global_layout_padded = ck::wrapper::pad(c_global_layout, shape(c_tile_layout)); +``` + +We pad layouts for global tensors in case M, N, and K are not divisible by `MPerBlock`, `NPerBlock`, or +`KPerBlock`. + +Step 2: Create tensors for global and LDS memory. + +```cpp + // Make tensors for global memory. + auto a_global_tensor = ck::wrapper::make_tensor( + static_cast(p_a), a_global_layout_padded); + auto b_global_tensor = ck::wrapper::make_tensor( + static_cast(p_b), b_global_layout_padded); + auto c_global_tensor = ck::wrapper::make_tensor( + static_cast(p_c), c_global_layout_padded); + + // Allocate LDS memory. + __shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)]; + __shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)]; + + // Make tensors for lds memory. + auto a_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_a), a_tile_layout); + auto b_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_b), b_tile_layout); +``` + +We must specify parameters for copy and convert block indexes to tuple: + +```cpp + // Specify block index as tuple. + const auto block_idxs = ck::make_tuple(static_cast(blockIdx.x), + static_cast(blockIdx.y), + ck::wrapper::slice()); + // Specify access parameters for copy. + using DimAccessOrder = ck::Tuple, ck::Number<1>>; + constexpr ck::index_t vector_dim = 1; +``` + +We create a local tile (per block) and local partitions (per thread) for the global memory (`C`). We also +define and clear an output register (`c_vgpr_reg`) for the accumulation. + +```cpp + auto c_global_local_tile = ck::wrapper::make_local_tile( + c_global_tensor, + tile_shape, + block_idxs, + make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock))); + auto c_global_local_partition = + ck::wrapper::make_blockwise_gemm_xdl_c_local_partition(c_global_local_tile); + // Create C vgpr to accumulate results. + auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr(); + // Clear C vgpr. + ck::wrapper::clear(c_vgpr_reg); +``` + +We use two specific functions for `blockwise_gemm`: `make_blockwise_gemm_xdl_c_local_partition` and +`make_blockwise_gemm_xdl_c_vgpr`. This helps to choose the appropriate partition for the `C` output +and define tensors with specific layouts for `blockwise_gemm`. In the following step, we use only +generic functions for the CK wrapper. + +Step 3: Create the compute loop. + +```cpp + const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock); + ck::index_t i = 0; + do + { + // Get KPerBlock slice. + const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock); + auto a_global_tensor_k_slice = a_global_tensor(ck::wrapper::slice(), k_slice); + auto b_global_tensor_k_slice = b_global_tensor(ck::wrapper::slice(), k_slice); + // Create local tiles for A and B. + auto a_global_local_tile = ck::wrapper::make_local_tile( + a_global_tensor_k_slice, + tile_shape, + block_idxs, + make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{})); + auto b_global_local_tile = ck::wrapper::make_local_tile( + b_global_tensor_k_slice, + tile_shape, + block_idxs, + make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{})); + // Copy from global to LDS. + ck::wrapper::blockwise_copy( + a_global_local_tile, a_lds_tensor, thread_layout); + ck::wrapper::blockwise_copy( + b_global_local_tile, b_lds_tensor, thread_layout); + // Synchronize lds. + ck::block_sync_lds(); + // Execute blockwise GEMM. + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + + ++i; + } while(i < num_loop); +``` + +Loop iterate over `K / KPerBlock`. Each time a local tile is created for A and B tensors (tensor per block), +data is copied from global memory to LDS. The `blockwise_gemm` function performs the GEMM +operation on `a_lds_tensor` and `b_lds_tensor`, and stores results in `c_vgpr_reg`. + +The end result from `c_vgpr_reg` is stored in the `C` local partition (tensor per thread): + +```cpp + ck::wrapper::copy(c_vgpr_reg, c_global_local_partition); +``` + +If you want to dive deep into the details, you can find the entire example +[here](https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_basic_gemm.cpp). diff --git a/client_example/25_wrapper/wrapper_basic_gemm.cpp b/client_example/25_wrapper/wrapper_basic_gemm.cpp new file mode 100644 index 0000000000..1f1a4de751 --- /dev/null +++ b/client_example/25_wrapper/wrapper_basic_gemm.cpp @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/utility/host_tensor.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/wrapper/layout.hpp" +#include "ck/wrapper/tensor.hpp" +#include "ck/wrapper/operations/copy.hpp" +#include "ck/wrapper/operations/gemm.hpp" +#include "ck/wrapper/utils/kernel_utils.hpp" + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a, + const void* p_b, + void* p_c, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape tile_shape, + const ThreadLayout thread_layout) +{ + constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape); + constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape); + constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape); + + // Specify layouts for global memory. + const auto a_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1)); + const auto b_global_layout = + ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1)); + const auto c_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1)); + // Specify layouts for tiles. + constexpr auto a_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); + constexpr auto b_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); + constexpr auto c_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{})); + // Apply padding for global memory. + auto a_global_layout_padded = ck::wrapper::pad(a_global_layout, shape(a_tile_layout)); + auto b_global_layout_padded = ck::wrapper::pad(b_global_layout, shape(b_tile_layout)); + auto c_global_layout_padded = ck::wrapper::pad(c_global_layout, shape(c_tile_layout)); + // Make tensors for global memory. + auto a_global_tensor = ck::wrapper::make_tensor( + static_cast(p_a), a_global_layout_padded); + auto b_global_tensor = ck::wrapper::make_tensor( + static_cast(p_b), b_global_layout_padded); + auto c_global_tensor = ck::wrapper::make_tensor( + static_cast(p_c), c_global_layout_padded); + // Allocate lds memory. + __shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)]; + __shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)]; + // Make tensors for lds memory. + auto a_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_a), a_tile_layout); + auto b_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_b), b_tile_layout); + // Specify block index as tuple. + const auto block_idxs = ck::make_tuple(static_cast(blockIdx.x), + static_cast(blockIdx.y), + ck::wrapper::slice()); + // Specify access parameters for copy. + using DimAccessOrder = ck::Tuple, ck::Number<1>>; + constexpr ck::index_t vector_dim = 1; + // Create tile and partition for C. Use specific function for blockwise_gemm to assign the + // appropriate partitions. + auto c_global_local_tile = ck::wrapper::make_local_tile( + c_global_tensor, + tile_shape, + block_idxs, + make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock))); + auto c_global_local_partition = + ck::wrapper::make_blockwise_gemm_xdl_c_local_partition(c_global_local_tile); + // Create C vgpr to accumulate results. + auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr(); + // Clear C vgpr. + ck::wrapper::clear(c_vgpr_reg); + + // Iterate over K with KPerBlock step. + const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock); + ck::index_t i = 0; + do + { + // Get KPerBlock slice. + const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock); + auto a_global_tensor_k_slice = a_global_tensor(ck::wrapper::slice(), k_slice); + auto b_global_tensor_k_slice = b_global_tensor(ck::wrapper::slice(), k_slice); + // Create local tiles for A and B. + auto a_global_local_tile = ck::wrapper::make_local_tile( + a_global_tensor_k_slice, + tile_shape, + block_idxs, + make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{})); + auto b_global_local_tile = ck::wrapper::make_local_tile( + b_global_tensor_k_slice, + tile_shape, + block_idxs, + make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{})); + // Copy from global to lds. + ck::wrapper::blockwise_copy( + a_global_local_tile, a_lds_tensor, thread_layout); + ck::wrapper::blockwise_copy( + b_global_local_tile, b_lds_tensor, thread_layout); + // Synchronize lds. + ck::block_sync_lds(); + // Execute blockwise gemm. + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + + ++i; + } while(i < num_loop); + // Copy vgpr results to C global memory. + ck::wrapper::copy(c_vgpr_reg, c_global_local_partition); +} + +template +void PerformGemm(const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape& tile_shape, + const ThreadLayout& thread_layout) +{ + // Global memory buffers + SimpleDeviceMem a_mem(M * K * sizeof(DataType)); + SimpleDeviceMem b_mem(K * N * sizeof(DataType)); + SimpleDeviceMem c_mem(M * N * sizeof(DataType)); + + const ck::index_t grid_size_x = + ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape)); + const ck::index_t grid_size_y = + ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape)); + + const auto kernel = + DeviceGemm; + const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, + kernel, + dim3(grid_size_x, grid_size_y, 1), + dim3(ck::wrapper::size(thread_layout)), + 0, + a_mem.GetDeviceBuffer(), + b_mem.GetDeviceBuffer(), + c_mem.GetDeviceBuffer(), + M, + N, + K, + tile_shape, + thread_layout); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << std::endl; +} + +int main(int argc, char* argv[]) +{ + using DataType = ck::half_t; + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}), + ck::make_tuple(ck::Number<4>{}, ck::Number<1>{})); + const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<32>{}); + PerformGemm( + 3840, 4096, 4096, tile_shape, thread_layout); + return 0; +} +// MI300X Perf: 0.471337 ms, 273.369 TFlops, 204.671 GB/s, diff --git a/client_example/25_wrapper/wrapper_img2col.cpp b/client_example/25_wrapper/wrapper_img2col.cpp index 35074be4c1..2a4034d62f 100644 --- a/client_example/25_wrapper/wrapper_img2col.cpp +++ b/client_example/25_wrapper/wrapper_img2col.cpp @@ -15,6 +15,7 @@ #include "ck/wrapper/layout.hpp" #include "ck/wrapper/tensor.hpp" #include "ck/wrapper/operations/copy.hpp" +#include "ck/wrapper/utils/kernel_utils.hpp" static constexpr ck::index_t NumDimSpatial = 3; using DataType = float; @@ -36,21 +37,20 @@ struct SimpleDeviceMem void* p_mem_; }; -// Test copy from Global to Global through LDS and VGPR -template -__global__ void DeviceImageToColumnPad0(InputTensor input_tensor, - OutputTensor output_tensor, - const BlockShape tile_shape, - const ThreadLayoutShape thread_layout) +template +__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ +DeviceImageToColumnPad0(InputTensor input_tensor, + OutputTensor output_tensor, + const BlockShape tile_shape, + const ThreadLayout thread_layout) { - const ck::index_t block_idx = static_cast(blockIdx.x); + // grid layout (dim1, dim0) + const auto block_idxs = + ck::make_tuple(static_cast(blockIdx.y), static_cast(blockIdx.x)); // Get local tiles for global memory - auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idx); - auto output_local_tile = ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idx); + auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idxs); + auto output_local_tile = ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idxs); // Get partition per thread const auto input_local_partition = @@ -112,9 +112,11 @@ void PerformImageToColumnPad0(const ck::index_t G, SimpleDeviceMem out_buf(ck::wrapper::size(out_layout) * sizeof(DataType)); // User can choose appropriate number of threads and sizes per block - const auto thread_layout = ck::make_tuple(ck::Number<8>{}, ck::Number<16>{}); + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<8>{}, ck::Number<16>{}), + ck::make_tuple(ck::Number<16>{}, ck::Number<1>{})); // This example doesn't support padding, user should select tile sizes - // which divides the shape completely + // which are divisible by the shape. const auto tile_shape = ck::make_tuple(ck::Number<32>{}, ck::Number<64>{}); // Create buffers for global memory @@ -123,10 +125,11 @@ void PerformImageToColumnPad0(const ck::index_t G, auto output_tensor_global = ck::wrapper::make_tensor( static_cast(out_buf.GetDeviceBuffer()), out_layout); - const ck::index_t grid_size = ck::math::integer_divide_ceil(ck::wrapper::size<0>(in_layout), - ck::wrapper::size<0>(tile_shape)) * - ck::math::integer_divide_ceil(ck::wrapper::size<1>(in_layout), - ck::wrapper::size<1>(tile_shape)); + // grid layout (dim1, dim0) + const ck::index_t grid_size_x = ck::math::integer_divide_ceil(ck::wrapper::size<1>(in_layout), + ck::wrapper::size<1>(tile_shape)); + const ck::index_t grid_size_y = ck::math::integer_divide_ceil(ck::wrapper::size<0>(in_layout), + ck::wrapper::size<0>(tile_shape)); const auto kernel = DeviceImageToColumnPad0; const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, kernel, - dim3(grid_size), + dim3(grid_size_x, grid_size_y, 1), dim3(ck::wrapper::size(thread_layout)), 0, input_tensor_global, @@ -178,3 +181,4 @@ int main(int argc, char* argv[]) {1, 1, 1} /*filter_dilations*/); return 0; } +// MI100 Perf: 0.255178 ms, 1698.9 GB/s, diff --git a/client_example/25_wrapper/wrapper_optimized_gemm.cpp b/client_example/25_wrapper/wrapper_optimized_gemm.cpp new file mode 100644 index 0000000000..ddf01de612 --- /dev/null +++ b/client_example/25_wrapper/wrapper_optimized_gemm.cpp @@ -0,0 +1,308 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/library/utility/host_tensor.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/wrapper/layout.hpp" +#include "ck/wrapper/tensor.hpp" +#include "ck/wrapper/operations/copy.hpp" +#include "ck/wrapper/operations/gemm.hpp" +#include "ck/wrapper/utils/kernel_utils.hpp" + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +__device__ auto ApplyPadding(const Layout& layout, const PaddingDims& padding_dims) +{ + if constexpr(DoPad) + { + return ck::wrapper::pad(layout, padding_dims); + } + else + { + return layout; + } +} + +template +__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a, + const void* p_b, + void* p_c, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape tile_shape, + const ThreadLayout thread_layout) +{ + constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape); + constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape); + constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape); + constexpr auto K1 = GemmTraits::K1; + constexpr auto K0PerBlock = KPerBlock / K1; + const auto K0 = ck::math::integer_divide_ceil(K, K1); + + const auto tile_shape_k0_m_n_k1 = ck::make_tuple(K0PerBlock, MPerBlock, NPerBlock, K1); + // Create layouts for global memory + const auto a_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1)); + const auto b_global_layout = + ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1)); + const auto c_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1)); + // Apply padding + auto a_padded_global_layout = + ApplyPadding(a_global_layout, ck::make_tuple(MPerBlock, KPerBlock)); + auto b_padded_global_layout = + ApplyPadding(b_global_layout, ck::make_tuple(NPerBlock, KPerBlock)); + auto c_padded_global_layout = + ApplyPadding(c_global_layout, ck::make_tuple(MPerBlock, NPerBlock)); + // Reshape from M,K to K0,M,K1 + const auto reshaped_dims_idxs = + ck::make_tuple(ck::Number<1>{}, ck::make_tuple(ck::Number<0>{}, ck::Number<2>{})); + auto a_padded_unmerged_global_layout = + ck::wrapper::unmerge<1>(a_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs); + auto b_padded_unmerged_global_layout = + ck::wrapper::unmerge<1>(b_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs); + // Create tensors for global memory + auto a_global_tensor = ck::wrapper::make_tensor( + static_cast(p_a), a_padded_unmerged_global_layout); + auto b_global_tensor = ck::wrapper::make_tensor( + static_cast(p_b), b_padded_unmerged_global_layout); + auto c_global_tensor = ck::wrapper::make_tensor( + static_cast(p_c), c_padded_global_layout); + // Create layouts and tensors for lds memory. + constexpr auto a_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(K0PerBlock, MPerBlock, K1), + ck::make_tuple((MPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{})); + constexpr auto b_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(K0PerBlock, NPerBlock, K1), + ck::make_tuple((NPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{})); + + __shared__ DataType lds_a[ck::wrapper::size(a_tile_layout) + K0PerBlock]; + __shared__ DataType lds_b[ck::wrapper::size(b_tile_layout) + K0PerBlock]; + + auto a_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_a), a_tile_layout); + auto b_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_b), b_tile_layout); + + const auto block_idxs = ck::make_tuple(ck::wrapper::slice(), + static_cast(blockIdx.x), + static_cast(blockIdx.y), + ck::wrapper::slice()); + using DimAccessOrder = ck::Tuple, ck::Number<0>, ck::Number<2>>; + constexpr ck::index_t vector_dim = 2; + + // Create tile and partition for C global memory. Use specific gemm + // functions to get appropriate layouts. + auto c_global_local_tile = + ck::wrapper::make_local_tile(c_global_tensor, + tile_shape_k0_m_n_k1, + block_idxs, + make_tuple(ck::wrapper::slice(K0PerBlock), + ck::Number<1>{}, + ck::Number<1>{}, + ck::wrapper::slice(K1))); + auto c_global_local_partition = + ck::wrapper::make_blockwise_gemm_xdl_c_local_partition(c_global_local_tile); + // Define and clear c vgpr register + auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr(); + ck::wrapper::clear(c_vgpr_reg); + // Local partitions for lds memory + auto a_lds_tensor_local_partition = + ck::wrapper::make_local_partition(a_lds_tensor, thread_layout, threadIdx.x); + auto b_lds_tensor_local_partition = + ck::wrapper::make_local_partition(b_lds_tensor, thread_layout, threadIdx.x); + // Lamda to slice tensor, then create local tile and partition + auto make_global_partition = [&](auto tensor, auto projection, ck::index_t i) { + const auto k_slice = + ck::make_tuple(ck::wrapper::slice(i * K0PerBlock, (i + 1) * K0PerBlock), + ck::wrapper::slice(), + ck::wrapper::slice()); + auto local_tile = ck::wrapper::make_local_tile( + tensor(k_slice), tile_shape_k0_m_n_k1, block_idxs, projection); + return ck::wrapper::make_local_partition(local_tile, thread_layout, threadIdx.x); + }; + + auto a_global_local_partition = make_global_partition( + a_global_tensor, + make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}), + 0); + auto b_global_local_partition = make_global_partition( + b_global_tensor, + make_tuple(ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}), + 0); + + // (row-major vgpr layout) + auto a_vgpr_tensor = + ck::wrapper::make_register_tensor( + ck::wrapper::make_layout( + shape(a_global_local_partition), + ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) * + ck::wrapper::size<2>(a_global_local_partition), + ck::wrapper::size<2>(a_global_local_partition), + ck::Number<1>{}))); + auto b_vgpr_tensor = + ck::wrapper::make_register_tensor( + ck::wrapper::make_layout( + shape(b_global_local_partition), + ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) * + ck::wrapper::size<2>(a_global_local_partition), + ck::wrapper::size<2>(a_global_local_partition), + ck::Number<1>{}))); + // Copy first values to lds + ck::wrapper::copy(a_global_local_partition, + a_vgpr_tensor); + ck::wrapper::copy(b_global_local_partition, + b_vgpr_tensor); + ck::wrapper::copy(a_vgpr_tensor, + a_lds_tensor_local_partition); + ck::wrapper::copy(b_vgpr_tensor, + b_lds_tensor_local_partition); + // Pipeline loop + const ck::index_t num_loop = + __builtin_amdgcn_readfirstlane(ck::math::integer_divide_ceil(K, KPerBlock)); + // Skip if only tile should be processed + if(num_loop > 1) + { + ck::index_t i = 0; + do + { + auto a_global_local_partition_i = make_global_partition( + a_global_tensor, + make_tuple( + ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}), + i + 1); + auto b_global_local_partition_i = make_global_partition( + b_global_tensor, + make_tuple( + ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}), + i + 1); + // Copy data to A vgpr. + ck::wrapper::copy( + a_global_local_partition_i, a_vgpr_tensor); + // Synchronize. + ck::block_sync_lds(); + // Copy data to B vgpr. + ck::wrapper::copy( + b_global_local_partition_i, b_vgpr_tensor); + // Perform gemm. + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + // Synchronize + ck::block_sync_lds(); + // Copy data to A and B lds tiles. + ck::wrapper::copy( + a_vgpr_tensor, a_lds_tensor_local_partition); + ck::wrapper::copy( + b_vgpr_tensor, b_lds_tensor_local_partition); + + ++i; + } while(i < (num_loop - 1)); + } + // Handle tail. + ck::block_sync_lds(); + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + // Store data from C vgpr to C global memory. + ck::wrapper::copy(c_vgpr_reg, c_global_local_partition); +} + +template +void PerformGemm(const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape& tile_shape, + const ThreadLayout& thread_layout) +{ + // Global memory buffers + SimpleDeviceMem a_mem(M * K * sizeof(DataType)); + SimpleDeviceMem b_mem(K * N * sizeof(DataType)); + SimpleDeviceMem c_mem(M * N * sizeof(DataType)); + + const ck::index_t grid_size_x = + ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape)); + const ck::index_t grid_size_y = + ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape)); + + const auto kernel = + DeviceGemm; + const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, + kernel, + dim3(grid_size_x, grid_size_y, 1), + dim3(ck::wrapper::size(thread_layout)), + 0, + a_mem.GetDeviceBuffer(), + b_mem.GetDeviceBuffer(), + c_mem.GetDeviceBuffer(), + M, + N, + K, + tile_shape, + thread_layout); + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << std::endl; +} + +int main(int argc, char* argv[]) +{ + using DataType = ck::half_t; + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}), + ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{})); + const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<32>{}); + PerformGemm( + 3840, 4096, 4096, tile_shape, thread_layout); + return 0; +} +// MI300X Perf: 0.411552 ms, 313.081 TFlops, 234.403 GB/s, diff --git a/docs/wrapper.rst b/docs/wrapper.rst index c64c0bf17f..39e2fd0bbd 100644 --- a/docs/wrapper.rst +++ b/docs/wrapper.rst @@ -12,10 +12,6 @@ Wrapper Description ------------------------------------- -.. note:: - - The wrapper is under development and its functionality is limited. - The CK library provides a lightweight wrapper for more complex operations implemented in the library. @@ -54,9 +50,15 @@ Output:: 2 6 10 14 18 22 26 30 +Tutorials: + +* `GEMM tutorial `_ + Advanced examples: * `Image to column `_ +* `Basic gemm `_ +* `Optimized gemm `_ ------------------------------------- Layout diff --git a/include/ck/wrapper/operations/copy.hpp b/include/ck/wrapper/operations/copy.hpp index 614dfd758e..5f64031ebe 100644 --- a/include/ck/wrapper/operations/copy.hpp +++ b/include/ck/wrapper/operations/copy.hpp @@ -61,12 +61,12 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor) decltype(dim_access_order), VectorDim, ScalarPerVector, - Sequence, - Sequence>{in_grid_desc, - make_tuple(src_tensor.GetMultiIdxOffsets()), - out_grid_desc, - make_tuple(dst_tensor.GetMultiIdxOffsets()), - tensor_operation::element_wise::PassThrough{}}; + Sequence, + Sequence>{in_grid_desc, + make_tuple(src_tensor.GetMultiIdxOffsets()), + out_grid_desc, + make_tuple(dst_tensor.GetMultiIdxOffsets()), + tensor_operation::element_wise::PassThrough{}}; transfer.Run(tie(in_grid_desc), tie(src_tensor.GetBuffer()), @@ -104,37 +104,25 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor) else if constexpr(SrcTensorType::IsDynamicBuffer && !DstTensorType::IsDynamicBuffer) { // Perform copy from DynamicBuffer to StaticBuffer - const auto src_dst_slice_origin = + const auto dst_slice_origin_idxs = generate_tuple([&](auto) { return I0; }, Number{}); - constexpr auto src_vector_tensor_lengths = generate_sequence_v2( - [&](auto I) { - if constexpr(I == VectorDim) - { - return Number{}; - } - else - { - return I1; - } - }, - Number{}); - - auto transfer = - ThreadwiseTensorSliceTransfer_v4r1, - remove_cvref_t, - decltype(thread_slice_lengths), - decltype(dim_access_order), - decltype(src_vector_tensor_lengths), - decltype(dim_access_order)>{ - src_tensor.GetMultiIdxOffsets()}; + auto transfer = ThreadwiseTensorSliceTransfer_v2< + std::remove_const_t, + std::remove_const_t, + remove_cvref_t, + remove_cvref_t, + decltype(thread_slice_lengths), + decltype(dim_access_order), + VectorDim, + ScalarPerVector, + I1, + false, + false>{in_grid_desc, src_tensor.GetMultiIdxOffsets()}; transfer.Run(in_grid_desc, - src_dst_slice_origin, src_tensor.GetBuffer(), out_grid_desc, - src_dst_slice_origin, + dst_slice_origin_idxs, dst_tensor.GetBuffer()); } else @@ -183,10 +171,12 @@ template -__device__ void blockwise_copy(const SrcTensorType& src_tensor, - DstTensorType& dst_tensor, - [[maybe_unused]] ThreadLayoutTuple& thread_layout) + typename ThreadShape, + typename ThreadUnrolledDesc> +__device__ void +blockwise_copy(const SrcTensorType& src_tensor, + DstTensorType& dst_tensor, + [[maybe_unused]] const Layout& thread_layout) { static_assert(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer); static_assert(is_detected::value); @@ -199,12 +189,12 @@ __device__ void blockwise_copy(const SrcTensorType& src_tensor, constexpr auto tile_lengths_seq = generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number{}); - constexpr auto thread_layout_seq = generate_sequence_v2( - [](auto I) { return size(ThreadLayoutTuple{}.At(I)); }, Number{}); + constexpr auto thread_layout_seq = + generate_sequence_v2([](auto I) { return size(ThreadShape{}); }, Number{}); constexpr auto dim_access_order = generate_sequence_v2( [](auto I) { return DimAccessOrderTuple{}.At(I); }, Number{}); - using ThisThreadBlock = ThisThreadBlock; + using ThisThreadBlock = ThisThreadBlock; // Perform copy between DynamicBuffers auto transfer = ThreadGroupTensorSliceTransfer_v7< diff --git a/include/ck/wrapper/operations/gemm.hpp b/include/ck/wrapper/operations/gemm.hpp index 9b8c0543fd..e41cd5bd8a 100644 --- a/include/ck/wrapper/operations/gemm.hpp +++ b/include/ck/wrapper/operations/gemm.hpp @@ -48,8 +48,9 @@ __device__ constexpr auto GetBlockDescriptor() /** * \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be - * stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) and B - * data layout must be (NPerBlock, KPerBlock). + * stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) or + * (K0PerBlock, MPerBlock, K1) and B data layout must be (NPerBlock, KPerBlock) + * or (K0PerBlock, NPerBlock, K1). * * \note C output Vgpr register layout (8D): * - MXdlPerWave - The number of MFMA instructions run by single wave in M @@ -71,9 +72,9 @@ __device__ constexpr auto GetBlockDescriptor() * \tparam BlockSize Tensor to pad. * \tparam GemmTraits Traits of gemm xdl operation. * \param a_local_tile_tensor A tensor in LDS memory for blockwise gemm - * (MPerBlock, KPerBlock) layout. + * (MPerBlock, KPerBlock) or (K0PerBlock, MPerBlock, K1) layout. * \param b_local_tile_tensor B tensor in LDS memory for blockwise gemm - * (NPerBlock, KPerBlock) layout. + * (NPerBlock, KPerBlock) or (K0PerBlock, NPerBlock, K1) layout. * \param c_reg_tensor C tensor VGPR memory for blockwise gemm. */ template {}; + static_assert(ATensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds); static_assert(BTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds); static_assert(CTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Vgpr); @@ -99,10 +102,18 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor, using ATileLayout = remove_cvref_t; using BTileLayout = remove_cvref_t; + static_assert(typename ATileLayout::LayoutShape{}.Size() == + typename BTileLayout::LayoutShape{}.Size()); + constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3; + using ABlockDesc_K0_M_K1_Type = - decltype(detail::GetBlockDescriptor()); + conditional_t())>; using BBlockDesc_K0_N_K1_Type = - decltype(detail::GetBlockDescriptor()); + conditional_t())>; BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; constexpr auto I7 = Number<7>{}; + static_assert(typename ATileLayout::LayoutShape{}.Size() == + typename BTileLayout::LayoutShape{}.Size()); + constexpr bool is_integer = is_same_v || is_same_v || is_same_v; using GemmAccDataType = std::conditional_t; + constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3; using ABlockDesc_K0_M_K1_Type = - decltype(detail::GetBlockDescriptor()); + conditional_t())>; using BBlockDesc_K0_N_K1_Type = - decltype(detail::GetBlockDescriptor()); + conditional_t())>; using BlockwiseGemmXdlops = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; }, Number<8>{}); + + auto sliced_desc = transform_tensor_descriptor( + partition_desc, + make_tuple( + make_slice_transform(partition_shape.At(Number<0>{}), + m_thread_data_on_grid_idx[I0], + partition_shape.At(Number<0>{}) + m_thread_data_on_grid_idx[I0]), + make_slice_transform(partition_shape.At(Number<1>{}), + n_thread_data_on_grid_idx[I0], + partition_shape.At(Number<1>{}) + n_thread_data_on_grid_idx[I0]), + make_slice_transform(partition_shape.At(Number<2>{}), + m_thread_data_on_grid_idx[I1], + partition_shape.At(Number<2>{}) + m_thread_data_on_grid_idx[I1]), + make_slice_transform(partition_shape.At(Number<3>{}), + n_thread_data_on_grid_idx[I1], + partition_shape.At(Number<3>{}) + n_thread_data_on_grid_idx[I1]), + make_slice_transform(partition_shape.At(Number<4>{}), + m_thread_data_on_grid_idx[I2], + partition_shape.At(Number<4>{}) + m_thread_data_on_grid_idx[I2]), + make_slice_transform(partition_shape.At(Number<5>{}), + m_thread_data_on_grid_idx[I3], + partition_shape.At(Number<5>{}) + m_thread_data_on_grid_idx[I3]), + make_slice_transform(partition_shape.At(Number<6>{}), + m_thread_data_on_grid_idx[I4], + partition_shape.At(Number<6>{}) + m_thread_data_on_grid_idx[I4]), + make_slice_transform(partition_shape.At(Number<7>{}), + n_thread_data_on_grid_idx[I2], + partition_shape.At(Number<7>{}) + n_thread_data_on_grid_idx[I2])), + lower_upper_dims, + lower_upper_dims); + const auto partition_layout = - Layout, decltype(partition_desc)>( - partition_shape, partition_desc); + Layout, decltype(sliced_desc)>( + partition_shape, sliced_desc); auto partition_tensor = make_tensor( c_local_tile_tensor.GetPointer(), partition_layout); - partition_tensor.SetMultiIdxOffset(make_multi_index(m_thread_data_on_grid_idx[I0], - n_thread_data_on_grid_idx[I0], - m_thread_data_on_grid_idx[I1], - n_thread_data_on_grid_idx[I1], - m_thread_data_on_grid_idx[I2], - m_thread_data_on_grid_idx[I3], - m_thread_data_on_grid_idx[I4], - n_thread_data_on_grid_idx[I2])); return partition_tensor; } @@ -292,14 +337,22 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr() constexpr auto I6 = Number<6>{}; constexpr auto I7 = Number<7>{}; + static_assert(typename ATileLayout::LayoutShape{}.Size() == + typename BTileLayout::LayoutShape{}.Size()); + constexpr bool is_integer = is_same_v || is_same_v || is_same_v; using GemmAccDataType = std::conditional_t; + constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3; using ABlockDesc_K0_M_K1_Type = - decltype(detail::GetBlockDescriptor()); + conditional_t())>; using BBlockDesc_K0_N_K1_Type = - decltype(detail::GetBlockDescriptor()); + conditional_t())>; using BlockwiseGemmXdlops = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1, decltype(vgpr_desc)>( vgpr_shape, vgpr_desc); // Get vector type for Vgpr - using BlockwiseGemmCThreadBufferType = - remove_reference_t; - using VgprVectorType = typename BlockwiseGemmCThreadBufferType::V; + constexpr index_t ScalarPerVector = BlockwiseGemmXdlops::xdlops_gemm.GetRegSizePerXdlops(); + using VgprVectorType = typename vector_type::type; return ck::wrapper::make_register_tensor( vgpr_layout); } diff --git a/include/ck/wrapper/tensor.hpp b/include/ck/wrapper/tensor.hpp index e344399dbf..6946e79ea4 100644 --- a/include/ck/wrapper/tensor.hpp +++ b/include/ck/wrapper/tensor.hpp @@ -172,10 +172,10 @@ __host__ __device__ constexpr auto GenerateUpperDims(const Tuple& } } -template +template __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple& idx, const Shape& shape, - const FlattenDescriptor& flatten_desc) + const UnrolledDescriptor& flatten_desc) { constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size(); diff --git a/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp b/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp index 8301636a9f..54804dea3c 100644 --- a/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp +++ b/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp @@ -20,48 +20,57 @@ namespace wrapper { * \tparam K1Value The number of K-dim elements that are packed together as * a separate logical dimension. Usually aligns with vector load size. */ -template +template struct BlockwisGemmXdlTraits { - static constexpr index_t MPerXDL = MPerXDLValue; - static constexpr index_t NPerXDL = NPerXDLValue; - static constexpr index_t MXdlPerWave = MXdlPerWaveValue; - static constexpr index_t NXdlPerWave = NXdlPerWaveValue; - static constexpr index_t K1 = K1Value; + static constexpr auto MPerXDL = MPerXDLValue{}; + static constexpr auto NPerXDL = NPerXDLValue{}; + static constexpr auto MXdlPerWave = MXdlPerWaveValue{}; + static constexpr auto NXdlPerWave = NXdlPerWaveValue{}; + static constexpr auto K1 = K1Value{}; }; // K1 = 4 -struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 4> +struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1 + : BlockwisGemmXdlTraits, Number<32>, Number<4>, Number<2>, Number<4>> { }; -struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 4> +struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1 + : BlockwisGemmXdlTraits, Number<32>, Number<2>, Number<4>, Number<4>> { }; -struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 4> +struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1 + : BlockwisGemmXdlTraits, Number<32>, Number<2>, Number<2>, Number<4>> { }; // K1 = 8 -struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 8> +struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1 + : BlockwisGemmXdlTraits, Number<32>, Number<4>, Number<2>, Number<8>> { }; -struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 8> +struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1 + : BlockwisGemmXdlTraits, Number<32>, Number<2>, Number<4>, Number<8>> { }; -struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 8> +struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1 + : BlockwisGemmXdlTraits, Number<32>, Number<2>, Number<2>, Number<8>> { }; // K1 = 16 -struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 16> +struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1 + : BlockwisGemmXdlTraits, Number<32>, Number<4>, Number<2>, Number<16>> { }; -struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 16> +struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1 + : BlockwisGemmXdlTraits, Number<32>, Number<2>, Number<4>, Number<16>> { }; -struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 16> +struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1 + : BlockwisGemmXdlTraits, Number<32>, Number<2>, Number<2>, Number<16>> { }; diff --git a/include/ck/wrapper/utils/kernel_utils.hpp b/include/ck/wrapper/utils/kernel_utils.hpp new file mode 100644 index 0000000000..add94ec6ae --- /dev/null +++ b/include/ck/wrapper/utils/kernel_utils.hpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" + +namespace ck { +namespace wrapper { + +#define __CK_WRAPPER_LAUNCH_BOUNDS__ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) + +} // namespace wrapper +} // namespace ck diff --git a/include/ck/wrapper/utils/layout_utils.hpp b/include/ck/wrapper/utils/layout_utils.hpp index d04bd5078b..e077fade2c 100644 --- a/include/ck/wrapper/utils/layout_utils.hpp +++ b/include/ck/wrapper/utils/layout_utils.hpp @@ -15,6 +15,7 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" namespace ck { namespace wrapper { @@ -29,6 +30,7 @@ template using is_tuple = decltype(std::declval().IsTuple()); namespace { +namespace detail { /** * \brief Generate packed (column-major) strides if not passed * @@ -83,6 +85,7 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides); } } +} // namespace detail } // namespace /// @endcond @@ -98,8 +101,9 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha template __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides) { - using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Strides{})); - return Layout(shape, MakeUnrolledDescriptor(shape, strides)); + using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Strides{})); + return Layout(shape, + detail::MakeUnrolledDescriptor(shape, strides)); } /** @@ -112,13 +116,12 @@ __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides template __host__ __device__ constexpr auto make_layout(const Shape& shape) { - using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Tuple<>{})); - return Layout(shape, MakeUnrolledDescriptor(shape, Tuple<>{})); + using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Tuple<>{})); + return Layout(shape, + detail::MakeUnrolledDescriptor(shape, Tuple<>{})); } - // Layout helpers // get - /** * \private * \brief Get dim. @@ -152,8 +155,8 @@ __host__ __device__ constexpr auto get(const Tuple& tuple) * \param layout Layout to create sub layout. * \return Requsted sub layout. */ -template -__host__ __device__ constexpr auto get(const Layout& layout) +template +__host__ __device__ constexpr auto get(const Layout& layout) { const auto& shape = layout.GetShape(); const auto new_shape = get(shape); @@ -427,5 +430,91 @@ __host__ __device__ constexpr const auto& shape(const LayoutType& layout) return layout.GetShape(); } +// pad +/** + * \brief Pad layout shapes to be adjusted to tile lengths. + * + * + * \param layout Layout to pad. + * \param tile_lengths Tile lengths to align layout shape. + * \return Padded layout. + */ +template +__host__ __device__ constexpr auto pad(const Layout& layout, + const TileLengths& tile_lengths) +{ + auto& unrolled_desc = layout.GetUnrolledDescriptor(); + // Generate sequence with ones to mark that all dims will be padded + constexpr auto do_pads_seq = + generate_sequence_v2([](auto) { return Number<1>{}; }, Number{}); + // Create descriptor with padding + auto padded_desc = + tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq); + // Generate padded shape + const auto padded_shape = generate_tuple( + [&](auto i) { return padded_desc.GetLength(Number{}); }, Number{}); + // Create layout + return Layout(padded_shape, padded_desc); +} + +// unmerge +/** + * \brief Unmerge selected dim in layout. + * + * \tparam Idx Index to dimension being unmerged. + * \param layout Layout to pad. + * \param new_lengths Dimensions into which the indicated dimension will be divided. + * \param new_indexes Indexes to shuffle dims. Dims for unmerged dim should be nested. + * \return Unmerged layout. + */ +template +__host__ __device__ constexpr auto unmerge(const Layout& layout, + const NewLengths& new_lengths, + [[maybe_unused]] const NewIdxs& new_indexes) +{ + const auto& layout_shape = shape(layout); + auto& unrolled_desc = layout.GetUnrolledDescriptor(); + constexpr auto dims = Shape::Size(); + // Generate transforms + const auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == Idx) + { + return make_unmerge_transform(new_lengths); + } + else + { + return make_pass_through_transform(layout_shape.At(i)); + } + }, + Number{}); + + constexpr auto lower_dims = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + constexpr auto upper_dims = generate_tuple( + [&](auto i) { + if constexpr(is_detected>::value) + { + constexpr auto idxs_tuple = tuple_element_t{}; + return to_sequence(idxs_tuple); + } + else + { + constexpr index_t index = tuple_element_t{}; + return Sequence{}; + } + }, + Number{}); + + const auto unmerged_desc = + transform_tensor_descriptor(unrolled_desc, transforms, lower_dims, upper_dims); + const auto unmerged_shape = + generate_tuple([&](auto i) { return unmerged_desc.GetLength(Number{}); }, + Number{}); + + // Create layout + return Layout(unmerged_shape, unmerged_desc); +} + } // namespace wrapper } // namespace ck diff --git a/include/ck/wrapper/utils/tensor_partition.hpp b/include/ck/wrapper/utils/tensor_partition.hpp index 5638382dba..141e0a58e5 100644 --- a/include/ck/wrapper/utils/tensor_partition.hpp +++ b/include/ck/wrapper/utils/tensor_partition.hpp @@ -6,7 +6,6 @@ #include "tensor_utils.hpp" #include "layout_utils.hpp" -#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_description/cluster_descriptor.hpp" @@ -44,8 +43,9 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple{} to keep. + * \param projection Projection is used to remove selected dim from + * partitioning. Use `slice(X)` to remove dimension, where X is dim + * size. Use `Number<1>{}` to keep it. * \return Multi index after projection. */ template @@ -73,7 +73,7 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple, } else { - return base_tuple.At(i_num); + return make_tuple(base_tuple.At(i_num)); } }, Number{}); @@ -86,8 +86,9 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple, * \brief Calculate shape with dims from projection. * * \param shape Base tensor shape. - * \param projection Projection to remove selected dim from partitioning. - * slice(X) to remove, where X is dim size, Number<1>{} to keep. + * \param projection Projection is used to remove selected dim from + * partitioning. Use `slice(X)` to remove dimension, where X is dim + * size. Use `Number<1>{}` to keep it. * \return Shape with dims from projection */ template @@ -119,22 +120,14 @@ __host__ __device__ constexpr auto CalculateShapeWithProjection(const Tuple{}` to keep it. * \return Tuple with blocks number. */ template __host__ __device__ constexpr auto CalculateGridSize(const Tuple& shape, - const Tuple& tile_shape, - const Tuple& projection) + const Tuple& tile_shape) { - auto shape_with_projection = CalculateShapeWithProjection(shape, projection); return generate_tuple( - [&](auto i) { - return ck::math::integer_divide_ceil(size(shape_with_projection), - size(tile_shape)); - }, + [&](auto i) { return ck::math::integer_divide_ceil(size(shape), size(tile_shape)); }, Number::Size()>{}); } @@ -155,6 +148,54 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs, return thread_idxs * partition_lengths_seq + old_offset_idxs; } +/** + * \brief Select dims to partition (skip if slice). + * + * \param block_idxs Input block indexes. + * \return Partitioned dims. + */ +template +__host__ __device__ constexpr auto GetDimsToPartition([[maybe_unused]] const BlockIdxs& block_idxs) +{ + const auto dims_to_partition = generate_tuple( + [&](auto i) { + if constexpr(!is_detected>::value) + { + return Number{}; + } + else + { + return Tuple<>{}; + } + }, + Number{}); + // Remove empty tuples + return UnrollNestedTuple<0, 1>(dims_to_partition); +} + +/** + * \brief Replace slices with zeros (Slice dims are not partitioned). + * + * \param block_idxs Input block indexes. + * \return Parsed dims. + */ +template +__host__ __device__ constexpr auto ReplaceSlicesWithZeros(const BlockIdxs& block_idxs) +{ + return generate_tuple( + [&](auto i) { + if constexpr(!is_detected>::value) + { + return block_idxs.At(i); + } + else + { + return Number<0>{}; + } + }, + Number{}); +} + /** * \brief Calculate default projection. * @@ -168,6 +209,31 @@ GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape) return generate_tuple([&](auto) { return Number<1>{}; }, Number{}); } +/** + * \brief Calculate thread multi index from 1d thread index. + * + * \param thread_layout Layout of threads (could not be nested). + * \param thread_id Thread index represented as integer. + * \return Multi index. + */ +template +__host__ __device__ constexpr auto CalculateThreadMultiIdx( + [[maybe_unused]] const Layout& thread_layout, + const index_t thread_id) +{ + static_assert(ThreadUnrolledDesc::GetNumOfTransform() == 1, + "Thread layout should not be transformed."); + constexpr auto embed_transform = ThreadUnrolledDesc{}.GetTransforms().At(Number<0>{}); + constexpr auto shape = ThreadShape{}; + constexpr auto strides = embed_transform.coefficients_; + + return generate_tuple( + [&](auto i) { + constexpr auto num_i = Number{}; + return (thread_id / strides.At(num_i)) % shape.At(num_i); + }, + Number{}); +} } // namespace detail } // namespace @@ -176,51 +242,62 @@ GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape) * is supported). * * \param tensor Tensor for partition. - * \param thread_lengths Layout of threads (could not be nested). + * \param thread_layout Layout of threads (could not be transformed). * \param thread_id Thread index represented as integer. * \param projection Projection is used to remove selected dim from * partitioning. Use `slice(X)` to remove dimension, where X is dim * size. Use `Number<1>{}` to keep it. * \return Partition tensor. */ -template +template __host__ __device__ constexpr auto make_local_partition(TensorType& tensor, - [[maybe_unused]] const ThreadLengthsTuple& thread_lengths, + [[maybe_unused]] const Layout& thread_layout, const index_t thread_id, const ProjectionTuple& projection) { - static_assert(!IsNestedTuple(ThreadLengthsTuple{})); + static_assert(!IsNestedTuple(ThreadShape{})); // Calculate new partition shape const auto& tensor_shape = shape(tensor); // Calculate projected thread lengths constexpr auto projected_thread_lengths = - detail::ApplyProjection(ThreadLengthsTuple{}, ProjectionTuple{}); + detail::ApplyProjection(ThreadShape{}, ProjectionTuple{}); constexpr auto partition_shape = detail::CalculateLocalPartitionShape(decltype(tensor_shape){}, projected_thread_lengths); - // Create Thread Cluster Descriptor constexpr auto partition_shape_seq = generate_sequence_v2([&](auto I) { return size(partition_shape); }, Number{}); - constexpr auto thread_lengths_seq = - generate_sequence_v2([&](auto I) { return size(ThreadLengthsTuple{}); }, - Number{}); - constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq); // Calculate thread idxs and offsets - const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id)); + const auto thread_idxs = detail::CalculateThreadMultiIdx(thread_layout, thread_id); // Apply projection on thread idxs to remove not needed idxs const auto projected_thread_idxs = detail::ApplyProjection(thread_idxs, projection); const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs( projected_thread_idxs, partition_shape_seq, tensor.GetMultiIdxOffsets()); // Create new layout and tensor auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor(); + // Slice descriptor + const auto transforms = generate_tuple( + [&](auto i) { + return make_slice_transform(partition_shape.At(i), + offset_multi_idxs.At(i), + partition_shape.At(i) + offset_multi_idxs.At(i)); + }, + Number::Size()>{}); + const auto lower_upper_dims = + generate_tuple([&](auto i) { return Sequence{}; }, + Number::Size()>{}); + auto sliced_desc = + transform_tensor_descriptor(unrolled_desc, transforms, lower_upper_dims, lower_upper_dims); + // Create layout const auto partition_layout = - Layout, decltype(unrolled_desc)>( - partition_shape, unrolled_desc); + Layout, decltype(sliced_desc)>( + partition_shape, sliced_desc); auto partition_tensor = make_tensor(tensor.GetPointer(), partition_layout); // Apply offsets - partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs)); return partition_tensor; } @@ -233,12 +310,13 @@ make_local_partition(TensorType& tensor, * \param thread_id Thread index represented as integer. * \return Partition tensor. */ -template -__host__ __device__ constexpr auto make_local_partition(TensorType& tensor, - const ThreadLengthsTuple& thread_lengths, - const index_t thread_id) +template +__host__ __device__ constexpr auto +make_local_partition(TensorType& tensor, + const Layout& thread_lengths, + const index_t thread_id) { - const auto projection = detail::GenerateDefaultProjection(ThreadLengthsTuple{}); + const auto projection = detail::GenerateDefaultProjection(ThreadShape{}); return make_local_partition(tensor, thread_lengths, thread_id, projection); } @@ -252,21 +330,24 @@ __host__ __device__ constexpr auto make_local_partition(TensorType& tensor, * * \param tensor Tensor for partition. * \param tile_shape Shapes of requested tile. - * \param block_id Block index represented as integer. - * \param projection Projection to remove selected dim from partitioning. - * slice(X) to remove, where X is dim size, Number<1>{} to keep. + * \param block_idxs Tuple of block indexes represented as integer. If slice, + * then get whole dim. + * \param projection Projection is used to remove selected dim from + * partitioning. Use `slice(X)` to remove dimension, where X is dim + * size. Use `Number<1>{}` to keep it. * \return Tile tensor. */ -template +template __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, - const index_t block_id, + const BlockIdxs& block_idxs, const ProjectionTuple& projection) { static_assert(!IsNestedTuple(BlockShapeTuple{})); - - constexpr bool is_default_projection = - is_same_v; + static_assert(!IsNestedTuple(BlockIdxs{})); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -274,49 +355,77 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor, auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor(); - // TODO: Enable block_2_tile_map partitioning for non-default projection. - if constexpr(BlockShapeTuple::Size() == I2 && is_default_projection) + constexpr auto projected_tile_shape = + detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{}); + // Number of dims which are partitioned + constexpr auto dims_to_partition = detail::GetDimsToPartition(BlockIdxs{}); + const auto parsed_block_idxs = detail::ReplaceSlicesWithZeros(block_idxs); + if constexpr(decltype(dims_to_partition)::Size() == I2) { - // Optimized version for 2d tile shape [MxK] + const auto shape_with_projection_dims = + detail::CalculateShapeWithProjection(shape(tensor), projection); + // Set Value for M, N partition + const auto M = shape_with_projection_dims.At(dims_to_partition.At(I0)); + const auto N = shape_with_projection_dims.At(dims_to_partition.At(I1)); + constexpr auto MPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I0)); + constexpr auto NPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I1)); + auto m_n_desc = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + // Get 1D block id + const auto grid_size = detail::CalculateGridSize(shape_with_projection_dims, tile_shape); + const auto block_lengths_desc = make_naive_tensor_descriptor_packed(grid_size); + const index_t block_id_1d = block_lengths_desc.CalculateOffset(parsed_block_idxs); + // Optimized version for 2d tile shape [MxN] const auto block_2_tile_map = - BlockToCTileMap_M00_N0_M01Adapt>(aligned_desc); + BlockToCTileMap_M00_N0_M01Adapt>(m_n_desc); const auto block_work_idx = - block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id)); + block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id_1d)); const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I0] * size<0>(tile_shape)); - const index_t k_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I1] * size<1>(tile_shape)); - const auto offset_multi_idxs = - make_tuple(m_block_data_idx_on_grid, k_block_data_idx_on_grid); + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + // Apply 0 for non partitioned dims + const auto offset_multi_idxs = generate_tuple( + [&](auto i) { + if constexpr(i == dims_to_partition.At(I0)) + { + return m_block_data_idx_on_grid; + } + else if constexpr(i == dims_to_partition.At(I1)) + { + return n_block_data_idx_on_grid; + } + else + { + return Number<0>{}; + } + }, + Number{}); + const auto projected_offset_multi_idxs = + detail::ApplyProjection(offset_multi_idxs, projection); // Create new layout and tensor const auto tile_layout = - Layout, decltype(aligned_desc)>(tile_shape, - aligned_desc); + Layout, decltype(aligned_desc)>( + projected_tile_shape, aligned_desc); auto tile_tensor = make_tensor(tensor.GetPointer(), tile_layout); // Apply offsets - tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs)); + tile_tensor.SetMultiIdxOffset(to_multi_index(projected_offset_multi_idxs)); return tile_tensor; } else { // Calculate offsets // Sequence with data to process per block - constexpr auto projected_tile_shape = - detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{}); using ProjectedTileShapeTuple = decltype(projected_tile_shape); constexpr auto projected_tile_shape_seq = generate_sequence_v2([](auto I) { return ProjectedTileShapeTuple{}.At(I); }, Number{}); // Tuple with number of blocks - const auto block_lengths = detail::CalculateGridSize(shape(tensor), tile_shape, projection); - const auto block_cluster_desc_ = make_cluster_descriptor(block_lengths); - const auto block_idxs = - block_cluster_desc_.CalculateBottomIndex(make_multi_index(block_id)); - const auto projected_block_idxs = detail::ApplyProjection(block_idxs, projection); - const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs( + const auto projected_block_idxs = + to_multi_index(detail::ApplyProjection(parsed_block_idxs, projection)); + const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs( projected_block_idxs, projected_tile_shape_seq, tensor.GetMultiIdxOffsets()); // Create new layout and tensor const auto tile_layout = @@ -338,52 +447,17 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor, * * \param tensor Tensor for partition. * \param tile_shape Shapes of requested tile. - * \param block_id Block index represented as integer. + * \param block_idxs Tuple of block indexes represented as integer. If slice, + * then get whole dim. * \return Tile tensor. */ -template -__host__ __device__ constexpr auto -make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id) +template +__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor, + const BlockShapeTuple& tile_shape, + const BlockIdxs& block_idxs) { const auto projection = detail::GenerateDefaultProjection(BlockShapeTuple{}); - return make_local_tile(tensor, tile_shape, block_id, projection); -} - -/** - * \brief Pad tensor shapes to be adjusted to tile lengths. - * - * - * \param tensor Tensor to pad. - * \param tile_lengths Tile lengths to align tensor shape. - * \return Padded tensor. - */ -template -__host__ __device__ constexpr auto pad(const TensorType& tensor, const TileLengths& tile_lengths) -{ - const auto& tensor_shape = shape(tensor); - using TensorShapeType = remove_reference_t; - auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor(); - // Generate sequence with ones to mark that all dims will be padded - constexpr auto do_pads_seq = - generate_sequence_v2([](auto) { return Number<1>{}; }, Number{}); - // Create descriptor with padding - auto padded_desc = - tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq); - // Generate padded shape - const auto padded_shape = generate_tuple( - [&](auto i) { - const auto& dim = size(tensor_shape); - const auto& tile_length = size(tile_lengths); - return ck::math::integer_divide_ceil(dim, tile_length) * tile_length; - }, - Number{}); - // Create layout and tensor - const auto padded_layout = - Layout(padded_shape, padded_desc); - auto partition_tensor = - make_tensor(tensor.GetPointer(), padded_layout); - partition_tensor.SetMultiIdxOffset(tensor.GetMultiIdxOffsets()); - return partition_tensor; + return make_local_tile(tensor, tile_shape, block_idxs, projection); } } // namespace wrapper diff --git a/test/wrapper/CMakeLists.txt b/test/wrapper/CMakeLists.txt index cadc146795..383707828c 100644 --- a/test/wrapper/CMakeLists.txt +++ b/test/wrapper/CMakeLists.txt @@ -1,14 +1,21 @@ -add_gtest_executable(test_layout test_layout.cpp) -target_link_libraries(test_layout PRIVATE utility) -add_gtest_executable(test_tensor test_tensor.cpp) -target_link_libraries(test_tensor PRIVATE utility) -add_gtest_executable(test_copy test_copy.cpp) -target_link_libraries(test_copy PRIVATE utility) -add_gtest_executable(test_partition test_partition.cpp) -target_link_libraries(test_partition PRIVATE utility) +add_custom_target(test_wrapper) + +add_gtest_executable(test_wrapper_layout test_wrapper_layout.cpp) +target_link_libraries(test_wrapper_layout PRIVATE utility) +add_dependencies(test_wrapper test_wrapper_layout) +add_gtest_executable(test_wrapper_tensor test_wrapper_tensor.cpp) +target_link_libraries(test_wrapper_tensor PRIVATE utility) +add_dependencies(test_wrapper test_wrapper_tensor) +add_gtest_executable(test_wrapper_copy test_wrapper_copy.cpp) +target_link_libraries(test_wrapper_copy PRIVATE utility) +add_dependencies(test_wrapper test_wrapper_copy) +add_gtest_executable(test_wrapper_partition test_wrapper_partition.cpp) +target_link_libraries(test_wrapper_partition PRIVATE utility) +add_dependencies(test_wrapper test_wrapper_partition) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") - add_gtest_executable(test_gemm test_gemm.cpp) - target_link_libraries(test_gemm PRIVATE utility) + add_gtest_executable(test_wrapper_gemm test_wrapper_gemm.cpp) + target_link_libraries(test_wrapper_gemm PRIVATE utility) + add_dependencies(test_wrapper test_wrapper_gemm) endif() diff --git a/test/wrapper/test_gemm.cpp b/test/wrapper/test_gemm.cpp deleted file mode 100644 index 12245490d1..0000000000 --- a/test/wrapper/test_gemm.cpp +++ /dev/null @@ -1,257 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include -#include -#include - -#include "ck/library/utility/host_tensor.hpp" - -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" - -#include "ck/host_utility/kernel_launch.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/utility/common_header.hpp" -#include "ck/library/utility/fill.hpp" -#include "ck/wrapper/layout.hpp" -#include "ck/wrapper/tensor.hpp" -#include "ck/wrapper/operations/copy.hpp" -#include "ck/wrapper/operations/gemm.hpp" - -template -void CheckResult(const std::vector& a_data, - const std::vector& b_data, - std::vector& c_m_n_device_result, - const ck::index_t M, - const ck::index_t N, - const ck::index_t K) -{ - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; - - Tensor a_m_k(HostTensorDescriptor({M, K})); - Tensor b_k_n(HostTensorDescriptor({K, N}, {1, K})); - Tensor c_m_n_host_result(HostTensorDescriptor({M, N})); - - a_m_k.mData = a_data; - b_k_n.mData = b_data; - - auto ref_op = ReferenceGemmInstance{}; - auto ref_invoker = ref_op.MakeInvoker(); - auto ref_argument = ref_op.MakeArgument( - a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); - - ref_invoker.Run(ref_argument); - EXPECT_TRUE(ck::utils::check_err(c_m_n_device_result, c_m_n_host_result.mData)); -} - -template -__global__ void DeviceGemm(const void* p_a, - const void* p_b, - void* p_c, - const ck::index_t M, - const ck::index_t N, - const ck::index_t K, - const BlockShape tile_shape, - const ThreadLayoutShape thread_layout) -{ - constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape); - constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape); - constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape); - - const auto a_global_layout = - ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1)); - const auto b_global_layout = - ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1)); - const auto c_global_layout = - ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1)); - - constexpr auto a_tile_layout = ck::wrapper::make_layout( - ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); - constexpr auto b_tile_layout = ck::wrapper::make_layout( - ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); - constexpr auto c_tile_layout = ck::wrapper::make_layout( - ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{})); - - auto a_global_tensor = ck::wrapper::make_tensor( - static_cast(p_a), a_global_layout); - auto b_global_tensor = ck::wrapper::make_tensor( - static_cast(p_b), b_global_layout); - auto c_global_tensor = ck::wrapper::make_tensor( - static_cast(p_c), c_global_layout); - - auto a_padded_global_tensor = ck::wrapper::pad(a_global_tensor, shape(a_tile_layout)); - auto b_padded_global_tensor = ck::wrapper::pad(b_global_tensor, shape(b_tile_layout)); - auto c_padded_global_tensor = ck::wrapper::pad(c_global_tensor, shape(c_tile_layout)); - - __shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)]; - __shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)]; - - auto a_lds_tensor = ck::wrapper::make_tensor( - static_cast(lds_a), a_tile_layout); - auto b_lds_tensor = ck::wrapper::make_tensor( - static_cast(lds_b), b_tile_layout); - - const ck::index_t block_idx = static_cast(blockIdx.x); - using DimAccessOrder = ck::Tuple, ck::Number<1>>; - constexpr ck::index_t vector_dim = 1; - - auto c_global_local_tile = ck::wrapper::make_local_tile( - c_padded_global_tensor, - tile_shape, - block_idx, - make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock))); - auto c_global_local_partition = - ck::wrapper::make_blockwise_gemm_xdl_c_local_partition(c_global_local_tile); - auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr(); - ck::wrapper::clear(c_vgpr_reg); - - const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock); - ck::index_t i = 0; - do - { - const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock); - auto a_padded_global_tensor_k_slice = a_padded_global_tensor(ck::wrapper::slice(), k_slice); - auto b_padded_global_tensor_k_slice = b_padded_global_tensor(ck::wrapper::slice(), k_slice); - auto a_global_local_tile = ck::wrapper::make_local_tile( - a_padded_global_tensor_k_slice, - tile_shape, - block_idx, - make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{})); - auto b_global_local_tile = ck::wrapper::make_local_tile( - b_padded_global_tensor_k_slice, - tile_shape, - block_idx, - make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{})); - - ck::wrapper::blockwise_copy( - a_global_local_tile, a_lds_tensor, thread_layout); - ck::wrapper::blockwise_copy( - b_global_local_tile, b_lds_tensor, thread_layout); - ck::block_sync_lds(); - ck::wrapper::blockwise_gemm_xdl( - a_lds_tensor, b_lds_tensor, c_vgpr_reg); - - ++i; - } while(i < num_loop); - - ck::wrapper::copy(c_vgpr_reg, c_global_local_partition); -} - -template -void PerformGemm(const ck::index_t M, - const ck::index_t N, - const ck::index_t K, - const BlockShape& tile_shape, - const ThreadLayoutShape& thread_layout) -{ - // Global memory buffers - DeviceMem a_mem(M * K * sizeof(DataType)); - DeviceMem b_mem(K * N * sizeof(DataType)); - DeviceMem c_mem(M * N * sizeof(DataType)); - - std::vector a_data(M * K); - std::vector b_data(K * N); - ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_data); - ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_data); - - a_mem.ToDevice(a_data.data()); - b_mem.ToDevice(b_data.data()); - c_mem.SetZero(); - - const ck::index_t grid_size = - ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape)) * - ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape)); - - const auto kernel = - DeviceGemm; - launch_and_time_kernel(StreamConfig{nullptr}, - kernel, - dim3(grid_size), - dim3(ck::wrapper::size(thread_layout)), - 0, - a_mem.GetDeviceBuffer(), - b_mem.GetDeviceBuffer(), - c_mem.GetDeviceBuffer(), - M, - N, - K, - tile_shape, - thread_layout); - - std::vector c_data(M * N); - c_mem.FromDevice(c_data.data()); - - CheckResult(a_data, b_data, c_data, M, N, K); -} - -TEST(TestGemm, Float) -{ - using DataType = float; - const auto thread_layout = ck::make_tuple(ck::Number<16>{}, ck::Number<16>{}); - const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); - PerformGemm( - 512, 512, 128, tile_shape, thread_layout); - // Irregular case - PerformGemm( - 129, 129, 67, tile_shape, thread_layout); -} - -TEST(TestGemm, Int8) -{ - using DataType = int8_t; - const auto thread_layout = ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}); - const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); - PerformGemm( - 512, 512, 128, tile_shape, thread_layout); - // Irregular case - PerformGemm( - 129, 129, 67, tile_shape, thread_layout); -} - -TEST(TestGemm, Half) -{ - using DataType = ck::half_t; - const auto thread_layout = ck::make_tuple(ck::Number<32>{}, ck::Number<8>{}); - const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); - PerformGemm( - 512, 512, 128, tile_shape, thread_layout); - // Irregular case - PerformGemm( - 129, 129, 67, tile_shape, thread_layout); -} - -TEST(TestGemm, Float_2x4_4x2_XdlPerWave) -{ - using DataType = float; - const auto thread_layout_4x2_xdl_per_wave = ck::make_tuple(ck::Number<16>{}, ck::Number<8>{}); - const auto thread_layout_2x4_xdl_per_wave = ck::make_tuple(ck::Number<8>{}, ck::Number<16>{}); - const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); - PerformGemm( - 512, 512, 128, tile_shape, thread_layout_4x2_xdl_per_wave); - PerformGemm( - 512, 512, 128, tile_shape, thread_layout_2x4_xdl_per_wave); -} diff --git a/test/wrapper/test_copy.cpp b/test/wrapper/test_wrapper_copy.cpp similarity index 83% rename from test/wrapper/test_copy.cpp rename to test/wrapper/test_wrapper_copy.cpp index e7fa3c539b..4721006435 100644 --- a/test/wrapper/test_copy.cpp +++ b/test/wrapper/test_wrapper_copy.cpp @@ -20,23 +20,25 @@ template __global__ void TestCopyDevice(const InputTensor input_tensor, OutputTensor output_tensor, const BlockShape tile_shape, - const ThreadLayoutShape thread_layout) + const ThreadLayout thread_layout) { __shared__ ck::index_t p_shared[ck::wrapper::size(tile_shape)]; const auto tensor_lds = ck::wrapper::make_tensor( p_shared, ck::wrapper::make_layout(tile_shape)); - const auto block_idx = static_cast(blockIdx.x); + const auto block_idxs = + ck::make_tuple(static_cast(blockIdx.x), static_cast(blockIdx.y)); // Get local tiles for global memory - const auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idx); + const auto input_local_tile = + ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idxs); const auto output_local_tile = - ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idx); + ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idxs); // Get partition per thread const auto input_local_partition = @@ -49,7 +51,7 @@ __global__ void TestCopyDevice(const InputTensor input_tensor, // Allocate VGPR auto tensor_vgpr = ck::wrapper::make_register_tensor( - layout(lds_local_partition)); + ck::wrapper::make_layout(shape(lds_local_partition))); // Perform copy if constexpr(UseOptimizedCopy) @@ -99,11 +101,14 @@ void PerformCopyGlobalToGlobalViaLDS() auto output_tensor_global = ck::wrapper::make_tensor( static_cast(out_buf.GetDeviceBuffer()), layout); - const auto thread_layout = ck::make_tuple(ck::Number<1>{}, ck::Number<32>{}); - const auto tile_shape = ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}); + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<1>{}, ck::Number<32>{})); + const auto tile_shape = ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}); - const ck::index_t grid_size = ck::math::integer_divide_ceil( - ck::wrapper::size(input_tensor_global), ck::wrapper::size(tile_shape)); + const ck::index_t grid_size_x = ck::math::integer_divide_ceil( + ck::wrapper::size<0>(input_tensor_global), ck::wrapper::size<0>(tile_shape)); + const ck::index_t grid_size_y = ck::math::integer_divide_ceil( + ck::wrapper::size<1>(input_tensor_global), ck::wrapper::size<1>(tile_shape)); const auto kernel = TestCopyDevice; launch_and_time_kernel(StreamConfig{}, kernel, - dim3(grid_size), + dim3(grid_size_x, grid_size_y, 1), dim3(ck::wrapper::size(thread_layout)), 0, input_tensor_global, diff --git a/test/wrapper/test_wrapper_gemm.cpp b/test/wrapper/test_wrapper_gemm.cpp new file mode 100644 index 0000000000..fd2cb7d4f3 --- /dev/null +++ b/test/wrapper/test_wrapper_gemm.cpp @@ -0,0 +1,376 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/library/utility/host_tensor.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/wrapper/layout.hpp" +#include "ck/wrapper/tensor.hpp" +#include "ck/wrapper/operations/copy.hpp" +#include "ck/wrapper/operations/gemm.hpp" +#include "ck/wrapper/utils/kernel_utils.hpp" + +template +void CheckResult(const std::vector& a_data, + const std::vector& b_data, + std::vector& c_m_n_device_result, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + + Tensor a_m_k(HostTensorDescriptor({M, K})); + Tensor b_k_n(HostTensorDescriptor({K, N}, {1, K})); + Tensor c_m_n_host_result(HostTensorDescriptor({M, N})); + + a_m_k.mData = a_data; + b_k_n.mData = b_data; + + auto ref_op = ReferenceGemmInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + auto ref_argument = ref_op.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + EXPECT_TRUE(ck::utils::check_err(c_m_n_device_result, c_m_n_host_result.mData)); +} + +template +__device__ auto ApplyPadding(const Layout& layout, const PaddingDims& padding_dims) +{ + if constexpr(DoPad) + { + return ck::wrapper::pad(layout, padding_dims); + } + else + { + return layout; + } +} + +template +__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a, + const void* p_b, + void* p_c, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape tile_shape, + const ThreadLayout thread_layout) +{ + constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape); + constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape); + constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape); + constexpr auto K1 = GemmTraits::K1; + constexpr auto K0PerBlock = KPerBlock / K1; + const auto K0 = ck::math::integer_divide_ceil(K, K1); + + const auto tile_shape_k0_m_n_k1 = ck::make_tuple(K0PerBlock, MPerBlock, NPerBlock, K1); + + const auto a_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1)); + const auto b_global_layout = + ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1)); + const auto c_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1)); + + auto a_padded_global_layout = + ApplyPadding(a_global_layout, ck::make_tuple(MPerBlock, KPerBlock)); + auto b_padded_global_layout = + ApplyPadding(b_global_layout, ck::make_tuple(NPerBlock, KPerBlock)); + auto c_padded_global_layout = + ApplyPadding(c_global_layout, ck::make_tuple(MPerBlock, NPerBlock)); + + // Reshape from M,K to K0,M,K1 + const auto reshaped_dims_idxs = + ck::make_tuple(ck::Number<1>{}, ck::make_tuple(ck::Number<0>{}, ck::Number<2>{})); + auto a_padded_unmerged_global_layout = + ck::wrapper::unmerge<1>(a_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs); + auto b_padded_unmerged_global_layout = + ck::wrapper::unmerge<1>(b_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs); + + auto a_global_tensor = ck::wrapper::make_tensor( + static_cast(p_a), a_padded_unmerged_global_layout); + auto b_global_tensor = ck::wrapper::make_tensor( + static_cast(p_b), b_padded_unmerged_global_layout); + auto c_global_tensor = ck::wrapper::make_tensor( + static_cast(p_c), c_padded_global_layout); + + // Add extra M and N + constexpr auto a_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(K0PerBlock, MPerBlock, K1), + ck::make_tuple((MPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{})); + constexpr auto b_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(K0PerBlock, NPerBlock, K1), + ck::make_tuple((NPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{})); + + __shared__ DataType lds_a[ck::wrapper::size(a_tile_layout) + NPerBlock]; + __shared__ DataType lds_b[ck::wrapper::size(b_tile_layout) + NPerBlock]; + + auto a_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_a), a_tile_layout); + auto b_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_b), b_tile_layout); + + const auto block_idxs = ck::make_tuple(ck::wrapper::slice(), + static_cast(blockIdx.x), + static_cast(blockIdx.y), + ck::wrapper::slice()); + using DimAccessOrder = ck::Tuple, ck::Number<0>, ck::Number<2>>; + constexpr ck::index_t vector_dim = 2; + + auto c_global_local_tile = + ck::wrapper::make_local_tile(c_global_tensor, + tile_shape_k0_m_n_k1, + block_idxs, + make_tuple(ck::wrapper::slice(K0PerBlock), + ck::Number<1>{}, + ck::Number<1>{}, + ck::wrapper::slice(K1))); + auto c_global_local_partition = + ck::wrapper::make_blockwise_gemm_xdl_c_local_partition(c_global_local_tile); + auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr(); + ck::wrapper::clear(c_vgpr_reg); + + auto a_lds_tensor_local_partition = + ck::wrapper::make_local_partition(a_lds_tensor, thread_layout, threadIdx.x); + auto b_lds_tensor_local_partition = + ck::wrapper::make_local_partition(b_lds_tensor, thread_layout, threadIdx.x); + + auto make_global_partition = [&](auto tensor, auto projection, ck::index_t i) { + const auto k_slice = + ck::make_tuple(ck::wrapper::slice(i * K0PerBlock, (i + 1) * K0PerBlock), + ck::wrapper::slice(), + ck::wrapper::slice()); + auto local_tile = ck::wrapper::make_local_tile( + tensor(k_slice), tile_shape_k0_m_n_k1, block_idxs, projection); + return ck::wrapper::make_local_partition(local_tile, thread_layout, threadIdx.x); + }; + + auto a_global_local_partition = make_global_partition( + a_global_tensor, + make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}), + 0); + auto b_global_local_partition = make_global_partition( + b_global_tensor, + make_tuple(ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}), + 0); + + // (row-major vgpr layout) + auto a_vgpr_tensor = + ck::wrapper::make_register_tensor( + ck::wrapper::make_layout( + shape(a_global_local_partition), + ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) * + ck::wrapper::size<2>(a_global_local_partition), + ck::wrapper::size<2>(a_global_local_partition), + ck::Number<1>{}))); + auto b_vgpr_tensor = + ck::wrapper::make_register_tensor( + ck::wrapper::make_layout( + shape(b_global_local_partition), + ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) * + ck::wrapper::size<2>(a_global_local_partition), + ck::wrapper::size<2>(a_global_local_partition), + ck::Number<1>{}))); + + ck::wrapper::copy(a_global_local_partition, + a_vgpr_tensor); + ck::wrapper::copy(b_global_local_partition, + b_vgpr_tensor); + ck::wrapper::copy(a_vgpr_tensor, + a_lds_tensor_local_partition); + ck::wrapper::copy(b_vgpr_tensor, + b_lds_tensor_local_partition); + + const ck::index_t num_loop = + __builtin_amdgcn_readfirstlane(ck::math::integer_divide_ceil(K, KPerBlock)); + if(num_loop > 1) + { + ck::index_t i = 0; + do + { + auto a_global_local_partition_i = make_global_partition( + a_global_tensor, + make_tuple( + ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}), + i + 1); + auto b_global_local_partition_i = make_global_partition( + b_global_tensor, + make_tuple( + ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}), + i + 1); + + ck::wrapper::copy( + a_global_local_partition_i, a_vgpr_tensor); + + ck::block_sync_lds(); + ck::wrapper::copy( + b_global_local_partition_i, b_vgpr_tensor); + + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + + ck::block_sync_lds(); + ck::wrapper::copy( + a_vgpr_tensor, a_lds_tensor_local_partition); + ck::wrapper::copy( + b_vgpr_tensor, b_lds_tensor_local_partition); + + ++i; + } while(i < (num_loop - 1)); + } + ck::block_sync_lds(); + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + + ck::wrapper::copy(c_vgpr_reg, c_global_local_partition); +} + +template +void PerformGemm(const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape& tile_shape, + const ThreadLayout& thread_layout) +{ + // Global memory buffers + DeviceMem a_mem(M * K * sizeof(DataType)); + DeviceMem b_mem(K * N * sizeof(DataType)); + DeviceMem c_mem(M * N * sizeof(DataType)); + + std::vector a_data(M * K); + std::vector b_data(K * N); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_data); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_data); + + a_mem.ToDevice(a_data.data()); + b_mem.ToDevice(b_data.data()); + c_mem.SetZero(); + + const ck::index_t grid_size_x = + ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape)); + const ck::index_t grid_size_y = + ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape)); + + const auto kernel = + DeviceGemm; + const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, + kernel, + dim3(grid_size_x, grid_size_y, 1), + dim3(ck::wrapper::size(thread_layout)), + 0, + a_mem.GetDeviceBuffer(), + b_mem.GetDeviceBuffer(), + c_mem.GetDeviceBuffer(), + M, + N, + K, + tile_shape, + thread_layout); + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << std::endl; + + std::vector c_data(M * N); + c_mem.FromDevice(c_data.data()); + CheckResult(a_data, b_data, c_data, M, N, K); +} + +TEST(TestGemm, Float) +{ + using DataType = float; + // (dim1, dim2, dim0 thread layout) + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}), + ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{})); + const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<16>{}); + PerformGemm( + 512, 512, 128, tile_shape, thread_layout); + // Irregular case + PerformGemm( + 129, 129, 67, tile_shape, thread_layout); +} + +TEST(TestGemm, Int8) +{ + using DataType = int8_t; + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}), + ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{})); + const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); + PerformGemm(512, 512, 128, tile_shape, thread_layout); + // Irregular case + PerformGemm( + 129, 129, 67, tile_shape, thread_layout); +} + +TEST(TestGemm, Half) +{ + using DataType = ck::half_t; + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}), + ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{})); + const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<32>{}); + PerformGemm( + 512, 512, 128, tile_shape, thread_layout); + // Irregular case + PerformGemm( + 129, 129, 67, tile_shape, thread_layout); +} + +TEST(TestGemm, Float_2x4_4x2_XdlPerWave) +{ + using DataType = float; + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}), + ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{})); + const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<16>{}); + PerformGemm( + 512, 512, 128, tile_shape, thread_layout); +} diff --git a/test/wrapper/test_layout.cpp b/test/wrapper/test_wrapper_layout.cpp similarity index 99% rename from test/wrapper/test_layout.cpp rename to test/wrapper/test_wrapper_layout.cpp index a128a6d84f..0b07303299 100644 --- a/test/wrapper/test_layout.cpp +++ b/test/wrapper/test_wrapper_layout.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/wrapper/test_partition.cpp b/test/wrapper/test_wrapper_partition.cpp similarity index 79% rename from test/wrapper/test_partition.cpp rename to test/wrapper/test_wrapper_partition.cpp index 8b6d220cd7..08d196c4ca 100644 --- a/test/wrapper/test_partition.cpp +++ b/test/wrapper/test_wrapper_partition.cpp @@ -29,8 +29,11 @@ TEST(TestPartition, LocalPartition) const auto tensor = ck::wrapper::make_tensor(data.data(), layout); - const auto thread_steps = ck::make_tuple(ck::Number<1>{}, ck::Number<8>{}, ck::Number<1>{}); - const auto thread_layout = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}, ck::Number<1>{}); + const auto thread_steps = ck::make_tuple(ck::Number<1>{}, ck::Number<8>{}, ck::Number<1>{}); + // row-major thread layout + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}, ck::Number<1>{}), + ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}, ck::Number<1>{})); // 3d partition on 2d shape (calculate partition on 3d thread layout, and then skip first dim) const auto thread_projection = ck::make_tuple(ck::wrapper::slice(4), ck::Number<1>{}, ck::Number<1>{}); @@ -70,29 +73,37 @@ TEST(TestPartition, LocalTile) ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}, ck::Number<2>{}, ck::Number<2>{}); const auto block_projection = ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(2)); - constexpr ck::index_t projection_block_dim = ck::Number<2>{}; - const auto num_blocks = + + const auto grid_shape = ck::make_tuple(ck::wrapper::size<0>(shape) / ck::wrapper::size<0>(block_shape), ck::wrapper::size<1>(shape) / ck::wrapper::size<1>(block_shape), ck::wrapper::size<2>(shape) / ck::wrapper::size<2>(block_shape)); - std::vector block_idxs(ck::wrapper::size(num_blocks)); - std::iota(block_idxs.begin(), block_idxs.end(), 0); + std::vector> block_idxs; + for(int i = 0; i < ck::wrapper::size<0>(grid_shape); i++) + { + for(int j = 0; j < ck::wrapper::size<1>(grid_shape); j++) + { + for(int k = 0; k < ck::wrapper::size<2>(grid_shape); k++) + { + block_idxs.emplace_back(i, j, k, 0); + } + } + } for(auto block_idx : block_idxs) { + constexpr ck::index_t projection_block_dim = ck::Number<2>{}; const auto packed_tile = ck::wrapper::make_local_tile(tensor, block_shape, block_idx, block_projection); const auto expected_tile_size = ck::wrapper::size(block_shape) / projection_block_dim; - auto expected_tile_first_val = (block_idx % ck::wrapper::size<2>(num_blocks)) * + auto expected_tile_first_val = ck::wrapper::size<2>(block_idx) * ck::wrapper::size<2>(block_shape) * ck::wrapper::size<2>(strides); - block_idx /= ck::wrapper::size<2>(num_blocks); - expected_tile_first_val += (block_idx % ck::wrapper::size<1>(num_blocks)) * + expected_tile_first_val += ck::wrapper::size<1>(block_idx) * ck::wrapper::size<1>(block_shape) * ck::wrapper::size<1>(strides); - block_idx /= ck::wrapper::size<1>(num_blocks); - expected_tile_first_val += (block_idx % ck::wrapper::size<0>(num_blocks)) * + expected_tile_first_val += ck::wrapper::size<0>(block_idx) * ck::wrapper::size<0>(block_shape) * ck::wrapper::size<0>(strides); diff --git a/test/wrapper/test_tensor.cpp b/test/wrapper/test_wrapper_tensor.cpp similarity index 100% rename from test/wrapper/test_tensor.cpp rename to test/wrapper/test_wrapper_tensor.cpp