mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
Add optimized blockwise gemm using ck wrapper (#1157)
* Add optimized blockwise gemm using ck wrapper * Add basic gemm example * Update docs * Add tutorial for gemm using ck wrapper * Add perf note * edits * Fix cmake * Fixes --------- Co-authored-by: Lisa Delaney <lisa.delaney@amd.com>
This commit is contained in:
@@ -2,3 +2,11 @@ add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrap
|
||||
target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations)
|
||||
add_executable(client_wrapper_img2col wrapper_img2col.cpp)
|
||||
target_link_libraries(client_wrapper_img2col PRIVATE composable_kernel::device_other_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR
|
||||
GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR
|
||||
GPU_TARGETS MATCHES "gfx942")
|
||||
add_executable(client_wrapper_basic_gemm wrapper_basic_gemm.cpp)
|
||||
target_link_libraries(client_wrapper_basic_gemm PRIVATE composable_kernel::device_other_operations)
|
||||
add_executable(client_wrapper_optimized_gemm wrapper_optimized_gemm.cpp)
|
||||
target_link_libraries(client_wrapper_optimized_gemm PRIVATE composable_kernel::device_other_operations)
|
||||
endif()
|
||||
|
||||
177
client_example/25_wrapper/README.md
Normal file
177
client_example/25_wrapper/README.md
Normal file
@@ -0,0 +1,177 @@
|
||||
# Composable Kernel wrapper GEMM tutorial
|
||||
|
||||
This tutorial demonstrates how to implement matrix multiplication using Composable Kernel (CK)
|
||||
wrapper. We present the base version of GEMM without most of the available optimizations; however,
|
||||
it's worth noting that CK has kernels with different optimizations.
|
||||
|
||||
To implement these optimizations, you can use the CK wrapper or directly use available instances in
|
||||
CK. You can also refer to the
|
||||
[optimized GEMM example](https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_optimized_gemm.cpp),
|
||||
that uses CK wrapper based on the
|
||||
[`gridwise_gemm_xdlops_v2r3`](https://github.com/ROCm/composable_kernel/blob/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp) implementation.
|
||||
|
||||
The kernel definition should look similar to:
|
||||
|
||||
```cpp
|
||||
template <typename DataType,
|
||||
typename GemmTraits,
|
||||
ck::index_t scalar_per_vector,
|
||||
typename BlockShape,
|
||||
typename ThreadLayout>
|
||||
__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const BlockShape tile_shape,
|
||||
const ThreadLayout thread_layout)
|
||||
```
|
||||
|
||||
We pass pointers to global memory and matrix dimensions via arguments. Additionally, we pass
|
||||
selected lengths of processed data through each block (`tile_shape`) and thread layout
|
||||
(`thread_layout`). For compilation time parameters, we define the data type,
|
||||
[traits for the GEMM operation](https://github.com/ROCm/composable_kernel/blob/develop/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp)
|
||||
and scalar per vector value during copy.
|
||||
|
||||
Step 1: Create layouts for global and LDS memory.
|
||||
|
||||
```cpp
|
||||
// Specify layouts for global memory.
|
||||
const auto a_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1));
|
||||
const auto b_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1));
|
||||
const auto c_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1));
|
||||
|
||||
// Specify layouts for tiles.
|
||||
constexpr auto a_tile_layout = ck::wrapper::make_layout(
|
||||
ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
|
||||
constexpr auto b_tile_layout = ck::wrapper::make_layout(
|
||||
ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
|
||||
constexpr auto c_tile_layout = ck::wrapper::make_layout(
|
||||
ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{}));
|
||||
|
||||
// Apply padding for global memory.
|
||||
auto a_global_layout_padded = ck::wrapper::pad(a_global_layout, shape(a_tile_layout));
|
||||
auto b_global_layout_padded = ck::wrapper::pad(b_global_layout, shape(b_tile_layout));
|
||||
auto c_global_layout_padded = ck::wrapper::pad(c_global_layout, shape(c_tile_layout));
|
||||
```
|
||||
|
||||
We pad layouts for global tensors in case M, N, and K are not divisible by `MPerBlock`, `NPerBlock`, or
|
||||
`KPerBlock`.
|
||||
|
||||
Step 2: Create tensors for global and LDS memory.
|
||||
|
||||
```cpp
|
||||
// Make tensors for global memory.
|
||||
auto a_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<const DataType*>(p_a), a_global_layout_padded);
|
||||
auto b_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<const DataType*>(p_b), b_global_layout_padded);
|
||||
auto c_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<DataType*>(p_c), c_global_layout_padded);
|
||||
|
||||
// Allocate LDS memory.
|
||||
__shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)];
|
||||
__shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)];
|
||||
|
||||
// Make tensors for lds memory.
|
||||
auto a_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
|
||||
static_cast<DataType*>(lds_a), a_tile_layout);
|
||||
auto b_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
|
||||
static_cast<DataType*>(lds_b), b_tile_layout);
|
||||
```
|
||||
|
||||
We must specify parameters for copy and convert block indexes to tuple:
|
||||
|
||||
```cpp
|
||||
// Specify block index as tuple.
|
||||
const auto block_idxs = ck::make_tuple(static_cast<ck::index_t>(blockIdx.x),
|
||||
static_cast<ck::index_t>(blockIdx.y),
|
||||
ck::wrapper::slice());
|
||||
// Specify access parameters for copy.
|
||||
using DimAccessOrder = ck::Tuple<ck::Number<0>, ck::Number<1>>;
|
||||
constexpr ck::index_t vector_dim = 1;
|
||||
```
|
||||
|
||||
We create a local tile (per block) and local partitions (per thread) for the global memory (`C`). We also
|
||||
define and clear an output register (`c_vgpr_reg`) for the accumulation.
|
||||
|
||||
```cpp
|
||||
auto c_global_local_tile = ck::wrapper::make_local_tile(
|
||||
c_global_tensor,
|
||||
tile_shape,
|
||||
block_idxs,
|
||||
make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock)));
|
||||
auto c_global_local_partition =
|
||||
ck::wrapper::make_blockwise_gemm_xdl_c_local_partition<DataType,
|
||||
decltype(a_tile_layout),
|
||||
decltype(b_tile_layout),
|
||||
ck::wrapper::size(thread_layout),
|
||||
GemmTraits>(c_global_local_tile);
|
||||
// Create C vgpr to accumulate results.
|
||||
auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr<DataType,
|
||||
decltype(a_tile_layout),
|
||||
decltype(b_tile_layout),
|
||||
ck::wrapper::size(thread_layout),
|
||||
GemmTraits>();
|
||||
// Clear C vgpr.
|
||||
ck::wrapper::clear(c_vgpr_reg);
|
||||
```
|
||||
|
||||
We use two specific functions for `blockwise_gemm`: `make_blockwise_gemm_xdl_c_local_partition` and
|
||||
`make_blockwise_gemm_xdl_c_vgpr`. This helps to choose the appropriate partition for the `C` output
|
||||
and define tensors with specific layouts for `blockwise_gemm`. In the following step, we use only
|
||||
generic functions for the CK wrapper.
|
||||
|
||||
Step 3: Create the compute loop.
|
||||
|
||||
```cpp
|
||||
const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock);
|
||||
ck::index_t i = 0;
|
||||
do
|
||||
{
|
||||
// Get KPerBlock slice.
|
||||
const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock);
|
||||
auto a_global_tensor_k_slice = a_global_tensor(ck::wrapper::slice(), k_slice);
|
||||
auto b_global_tensor_k_slice = b_global_tensor(ck::wrapper::slice(), k_slice);
|
||||
// Create local tiles for A and B.
|
||||
auto a_global_local_tile = ck::wrapper::make_local_tile(
|
||||
a_global_tensor_k_slice,
|
||||
tile_shape,
|
||||
block_idxs,
|
||||
make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}));
|
||||
auto b_global_local_tile = ck::wrapper::make_local_tile(
|
||||
b_global_tensor_k_slice,
|
||||
tile_shape,
|
||||
block_idxs,
|
||||
make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}));
|
||||
// Copy from global to LDS.
|
||||
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
a_global_local_tile, a_lds_tensor, thread_layout);
|
||||
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
b_global_local_tile, b_lds_tensor, thread_layout);
|
||||
// Synchronize lds.
|
||||
ck::block_sync_lds();
|
||||
// Execute blockwise GEMM.
|
||||
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
|
||||
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
|
||||
|
||||
++i;
|
||||
} while(i < num_loop);
|
||||
```
|
||||
|
||||
Loop iterate over `K / KPerBlock`. Each time a local tile is created for A and B tensors (tensor per block),
|
||||
data is copied from global memory to LDS. The `blockwise_gemm` function performs the GEMM
|
||||
operation on `a_lds_tensor` and `b_lds_tensor`, and stores results in `c_vgpr_reg`.
|
||||
|
||||
The end result from `c_vgpr_reg` is stored in the `C` local partition (tensor per thread):
|
||||
|
||||
```cpp
|
||||
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
|
||||
```
|
||||
|
||||
If you want to dive deep into the details, you can find the entire example
|
||||
[here](https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_basic_gemm.cpp).
|
||||
216
client_example/25_wrapper/wrapper_basic_gemm.cpp
Normal file
216
client_example/25_wrapper/wrapper_basic_gemm.cpp
Normal file
@@ -0,0 +1,216 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <numeric>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/wrapper/layout.hpp"
|
||||
#include "ck/wrapper/tensor.hpp"
|
||||
#include "ck/wrapper/operations/copy.hpp"
|
||||
#include "ck/wrapper/operations/gemm.hpp"
|
||||
#include "ck/wrapper/utils/kernel_utils.hpp"
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
template <typename DataType,
|
||||
typename GemmTraits,
|
||||
ck::index_t scalar_per_vector,
|
||||
typename BlockShape,
|
||||
typename ThreadLayout>
|
||||
__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const BlockShape tile_shape,
|
||||
const ThreadLayout thread_layout)
|
||||
{
|
||||
constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape);
|
||||
constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape);
|
||||
constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape);
|
||||
|
||||
// Specify layouts for global memory.
|
||||
const auto a_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1));
|
||||
const auto b_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1));
|
||||
const auto c_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1));
|
||||
// Specify layouts for tiles.
|
||||
constexpr auto a_tile_layout = ck::wrapper::make_layout(
|
||||
ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
|
||||
constexpr auto b_tile_layout = ck::wrapper::make_layout(
|
||||
ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
|
||||
constexpr auto c_tile_layout = ck::wrapper::make_layout(
|
||||
ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{}));
|
||||
// Apply padding for global memory.
|
||||
auto a_global_layout_padded = ck::wrapper::pad(a_global_layout, shape(a_tile_layout));
|
||||
auto b_global_layout_padded = ck::wrapper::pad(b_global_layout, shape(b_tile_layout));
|
||||
auto c_global_layout_padded = ck::wrapper::pad(c_global_layout, shape(c_tile_layout));
|
||||
// Make tensors for global memory.
|
||||
auto a_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<const DataType*>(p_a), a_global_layout_padded);
|
||||
auto b_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<const DataType*>(p_b), b_global_layout_padded);
|
||||
auto c_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<DataType*>(p_c), c_global_layout_padded);
|
||||
// Allocate lds memory.
|
||||
__shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)];
|
||||
__shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)];
|
||||
// Make tensors for lds memory.
|
||||
auto a_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
|
||||
static_cast<DataType*>(lds_a), a_tile_layout);
|
||||
auto b_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
|
||||
static_cast<DataType*>(lds_b), b_tile_layout);
|
||||
// Specify block index as tuple.
|
||||
const auto block_idxs = ck::make_tuple(static_cast<ck::index_t>(blockIdx.x),
|
||||
static_cast<ck::index_t>(blockIdx.y),
|
||||
ck::wrapper::slice());
|
||||
// Specify access parameters for copy.
|
||||
using DimAccessOrder = ck::Tuple<ck::Number<0>, ck::Number<1>>;
|
||||
constexpr ck::index_t vector_dim = 1;
|
||||
// Create tile and partition for C. Use specific function for blockwise_gemm to assign the
|
||||
// appropriate partitions.
|
||||
auto c_global_local_tile = ck::wrapper::make_local_tile(
|
||||
c_global_tensor,
|
||||
tile_shape,
|
||||
block_idxs,
|
||||
make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock)));
|
||||
auto c_global_local_partition =
|
||||
ck::wrapper::make_blockwise_gemm_xdl_c_local_partition<DataType,
|
||||
decltype(a_tile_layout),
|
||||
decltype(b_tile_layout),
|
||||
ck::wrapper::size(thread_layout),
|
||||
GemmTraits>(c_global_local_tile);
|
||||
// Create C vgpr to accumulate results.
|
||||
auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr<DataType,
|
||||
decltype(a_tile_layout),
|
||||
decltype(b_tile_layout),
|
||||
ck::wrapper::size(thread_layout),
|
||||
GemmTraits>();
|
||||
// Clear C vgpr.
|
||||
ck::wrapper::clear(c_vgpr_reg);
|
||||
|
||||
// Iterate over K with KPerBlock step.
|
||||
const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock);
|
||||
ck::index_t i = 0;
|
||||
do
|
||||
{
|
||||
// Get KPerBlock slice.
|
||||
const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock);
|
||||
auto a_global_tensor_k_slice = a_global_tensor(ck::wrapper::slice(), k_slice);
|
||||
auto b_global_tensor_k_slice = b_global_tensor(ck::wrapper::slice(), k_slice);
|
||||
// Create local tiles for A and B.
|
||||
auto a_global_local_tile = ck::wrapper::make_local_tile(
|
||||
a_global_tensor_k_slice,
|
||||
tile_shape,
|
||||
block_idxs,
|
||||
make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}));
|
||||
auto b_global_local_tile = ck::wrapper::make_local_tile(
|
||||
b_global_tensor_k_slice,
|
||||
tile_shape,
|
||||
block_idxs,
|
||||
make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}));
|
||||
// Copy from global to lds.
|
||||
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
a_global_local_tile, a_lds_tensor, thread_layout);
|
||||
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
b_global_local_tile, b_lds_tensor, thread_layout);
|
||||
// Synchronize lds.
|
||||
ck::block_sync_lds();
|
||||
// Execute blockwise gemm.
|
||||
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
|
||||
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
|
||||
|
||||
++i;
|
||||
} while(i < num_loop);
|
||||
// Copy vgpr results to C global memory.
|
||||
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
|
||||
}
|
||||
|
||||
template <typename DataType,
|
||||
typename GemmTraits,
|
||||
ck::index_t scalar_per_vector,
|
||||
typename BlockShape,
|
||||
typename ThreadLayout>
|
||||
void PerformGemm(const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const BlockShape& tile_shape,
|
||||
const ThreadLayout& thread_layout)
|
||||
{
|
||||
// Global memory buffers
|
||||
SimpleDeviceMem a_mem(M * K * sizeof(DataType));
|
||||
SimpleDeviceMem b_mem(K * N * sizeof(DataType));
|
||||
SimpleDeviceMem c_mem(M * N * sizeof(DataType));
|
||||
|
||||
const ck::index_t grid_size_x =
|
||||
ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape));
|
||||
const ck::index_t grid_size_y =
|
||||
ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape));
|
||||
|
||||
const auto kernel =
|
||||
DeviceGemm<DataType, GemmTraits, scalar_per_vector, BlockShape, ThreadLayout>;
|
||||
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
|
||||
kernel,
|
||||
dim3(grid_size_x, grid_size_y, 1),
|
||||
dim3(ck::wrapper::size(thread_layout)),
|
||||
0,
|
||||
a_mem.GetDeviceBuffer(),
|
||||
b_mem.GetDeviceBuffer(),
|
||||
c_mem.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
tile_shape,
|
||||
thread_layout);
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using DataType = ck::half_t;
|
||||
const auto thread_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}),
|
||||
ck::make_tuple(ck::Number<4>{}, ck::Number<1>{}));
|
||||
const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<32>{});
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1, 8>(
|
||||
3840, 4096, 4096, tile_shape, thread_layout);
|
||||
return 0;
|
||||
}
|
||||
// MI300X Perf: 0.471337 ms, 273.369 TFlops, 204.671 GB/s,
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "ck/wrapper/layout.hpp"
|
||||
#include "ck/wrapper/tensor.hpp"
|
||||
#include "ck/wrapper/operations/copy.hpp"
|
||||
#include "ck/wrapper/utils/kernel_utils.hpp"
|
||||
|
||||
static constexpr ck::index_t NumDimSpatial = 3;
|
||||
using DataType = float;
|
||||
@@ -36,21 +37,20 @@ struct SimpleDeviceMem
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
// Test copy from Global to Global through LDS and VGPR
|
||||
template <typename InputTensor,
|
||||
typename OutputTensor,
|
||||
typename BlockShape,
|
||||
typename ThreadLayoutShape>
|
||||
__global__ void DeviceImageToColumnPad0(InputTensor input_tensor,
|
||||
OutputTensor output_tensor,
|
||||
const BlockShape tile_shape,
|
||||
const ThreadLayoutShape thread_layout)
|
||||
template <typename InputTensor, typename OutputTensor, typename BlockShape, typename ThreadLayout>
|
||||
__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__
|
||||
DeviceImageToColumnPad0(InputTensor input_tensor,
|
||||
OutputTensor output_tensor,
|
||||
const BlockShape tile_shape,
|
||||
const ThreadLayout thread_layout)
|
||||
{
|
||||
const ck::index_t block_idx = static_cast<ck::index_t>(blockIdx.x);
|
||||
// grid layout (dim1, dim0)
|
||||
const auto block_idxs =
|
||||
ck::make_tuple(static_cast<ck::index_t>(blockIdx.y), static_cast<ck::index_t>(blockIdx.x));
|
||||
|
||||
// Get local tiles for global memory
|
||||
auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idx);
|
||||
auto output_local_tile = ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idx);
|
||||
auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idxs);
|
||||
auto output_local_tile = ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idxs);
|
||||
|
||||
// Get partition per thread
|
||||
const auto input_local_partition =
|
||||
@@ -112,9 +112,11 @@ void PerformImageToColumnPad0(const ck::index_t G,
|
||||
SimpleDeviceMem out_buf(ck::wrapper::size(out_layout) * sizeof(DataType));
|
||||
|
||||
// User can choose appropriate number of threads and sizes per block
|
||||
const auto thread_layout = ck::make_tuple(ck::Number<8>{}, ck::Number<16>{});
|
||||
const auto thread_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(ck::Number<8>{}, ck::Number<16>{}),
|
||||
ck::make_tuple(ck::Number<16>{}, ck::Number<1>{}));
|
||||
// This example doesn't support padding, user should select tile sizes
|
||||
// which divides the shape completely
|
||||
// which are divisible by the shape.
|
||||
const auto tile_shape = ck::make_tuple(ck::Number<32>{}, ck::Number<64>{});
|
||||
|
||||
// Create buffers for global memory
|
||||
@@ -123,10 +125,11 @@ void PerformImageToColumnPad0(const ck::index_t G,
|
||||
auto output_tensor_global = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<DataType*>(out_buf.GetDeviceBuffer()), out_layout);
|
||||
|
||||
const ck::index_t grid_size = ck::math::integer_divide_ceil(ck::wrapper::size<0>(in_layout),
|
||||
ck::wrapper::size<0>(tile_shape)) *
|
||||
ck::math::integer_divide_ceil(ck::wrapper::size<1>(in_layout),
|
||||
ck::wrapper::size<1>(tile_shape));
|
||||
// grid layout (dim1, dim0)
|
||||
const ck::index_t grid_size_x = ck::math::integer_divide_ceil(ck::wrapper::size<1>(in_layout),
|
||||
ck::wrapper::size<1>(tile_shape));
|
||||
const ck::index_t grid_size_y = ck::math::integer_divide_ceil(ck::wrapper::size<0>(in_layout),
|
||||
ck::wrapper::size<0>(tile_shape));
|
||||
|
||||
const auto kernel = DeviceImageToColumnPad0<decltype(input_tensor_global),
|
||||
decltype(output_tensor_global),
|
||||
@@ -134,7 +137,7 @@ void PerformImageToColumnPad0(const ck::index_t G,
|
||||
decltype(thread_layout)>;
|
||||
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(grid_size_x, grid_size_y, 1),
|
||||
dim3(ck::wrapper::size(thread_layout)),
|
||||
0,
|
||||
input_tensor_global,
|
||||
@@ -178,3 +181,4 @@ int main(int argc, char* argv[])
|
||||
{1, 1, 1} /*filter_dilations*/);
|
||||
return 0;
|
||||
}
|
||||
// MI100 Perf: 0.255178 ms, 1698.9 GB/s,
|
||||
|
||||
308
client_example/25_wrapper/wrapper_optimized_gemm.cpp
Normal file
308
client_example/25_wrapper/wrapper_optimized_gemm.cpp
Normal file
@@ -0,0 +1,308 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <numeric>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/wrapper/layout.hpp"
|
||||
#include "ck/wrapper/tensor.hpp"
|
||||
#include "ck/wrapper/operations/copy.hpp"
|
||||
#include "ck/wrapper/operations/gemm.hpp"
|
||||
#include "ck/wrapper/utils/kernel_utils.hpp"
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
template <bool DoPad, typename Layout, typename PaddingDims>
|
||||
__device__ auto ApplyPadding(const Layout& layout, const PaddingDims& padding_dims)
|
||||
{
|
||||
if constexpr(DoPad)
|
||||
{
|
||||
return ck::wrapper::pad(layout, padding_dims);
|
||||
}
|
||||
else
|
||||
{
|
||||
return layout;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType,
|
||||
typename GemmTraits,
|
||||
ck::index_t scalar_per_vector,
|
||||
typename BlockShape,
|
||||
typename ThreadLayout,
|
||||
bool DoPadding>
|
||||
__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const BlockShape tile_shape,
|
||||
const ThreadLayout thread_layout)
|
||||
{
|
||||
constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape);
|
||||
constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape);
|
||||
constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape);
|
||||
constexpr auto K1 = GemmTraits::K1;
|
||||
constexpr auto K0PerBlock = KPerBlock / K1;
|
||||
const auto K0 = ck::math::integer_divide_ceil(K, K1);
|
||||
|
||||
const auto tile_shape_k0_m_n_k1 = ck::make_tuple(K0PerBlock, MPerBlock, NPerBlock, K1);
|
||||
// Create layouts for global memory
|
||||
const auto a_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1));
|
||||
const auto b_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1));
|
||||
const auto c_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1));
|
||||
// Apply padding
|
||||
auto a_padded_global_layout =
|
||||
ApplyPadding<DoPadding>(a_global_layout, ck::make_tuple(MPerBlock, KPerBlock));
|
||||
auto b_padded_global_layout =
|
||||
ApplyPadding<DoPadding>(b_global_layout, ck::make_tuple(NPerBlock, KPerBlock));
|
||||
auto c_padded_global_layout =
|
||||
ApplyPadding<DoPadding>(c_global_layout, ck::make_tuple(MPerBlock, NPerBlock));
|
||||
// Reshape from M,K to K0,M,K1
|
||||
const auto reshaped_dims_idxs =
|
||||
ck::make_tuple(ck::Number<1>{}, ck::make_tuple(ck::Number<0>{}, ck::Number<2>{}));
|
||||
auto a_padded_unmerged_global_layout =
|
||||
ck::wrapper::unmerge<1>(a_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs);
|
||||
auto b_padded_unmerged_global_layout =
|
||||
ck::wrapper::unmerge<1>(b_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs);
|
||||
// Create tensors for global memory
|
||||
auto a_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<const DataType*>(p_a), a_padded_unmerged_global_layout);
|
||||
auto b_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<const DataType*>(p_b), b_padded_unmerged_global_layout);
|
||||
auto c_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<DataType*>(p_c), c_padded_global_layout);
|
||||
// Create layouts and tensors for lds memory.
|
||||
constexpr auto a_tile_layout = ck::wrapper::make_layout(
|
||||
ck::make_tuple(K0PerBlock, MPerBlock, K1),
|
||||
ck::make_tuple((MPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{}));
|
||||
constexpr auto b_tile_layout = ck::wrapper::make_layout(
|
||||
ck::make_tuple(K0PerBlock, NPerBlock, K1),
|
||||
ck::make_tuple((NPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{}));
|
||||
|
||||
__shared__ DataType lds_a[ck::wrapper::size(a_tile_layout) + K0PerBlock];
|
||||
__shared__ DataType lds_b[ck::wrapper::size(b_tile_layout) + K0PerBlock];
|
||||
|
||||
auto a_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
|
||||
static_cast<DataType*>(lds_a), a_tile_layout);
|
||||
auto b_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
|
||||
static_cast<DataType*>(lds_b), b_tile_layout);
|
||||
|
||||
const auto block_idxs = ck::make_tuple(ck::wrapper::slice(),
|
||||
static_cast<ck::index_t>(blockIdx.x),
|
||||
static_cast<ck::index_t>(blockIdx.y),
|
||||
ck::wrapper::slice());
|
||||
using DimAccessOrder = ck::Tuple<ck::Number<1>, ck::Number<0>, ck::Number<2>>;
|
||||
constexpr ck::index_t vector_dim = 2;
|
||||
|
||||
// Create tile and partition for C global memory. Use specific gemm
|
||||
// functions to get appropriate layouts.
|
||||
auto c_global_local_tile =
|
||||
ck::wrapper::make_local_tile(c_global_tensor,
|
||||
tile_shape_k0_m_n_k1,
|
||||
block_idxs,
|
||||
make_tuple(ck::wrapper::slice(K0PerBlock),
|
||||
ck::Number<1>{},
|
||||
ck::Number<1>{},
|
||||
ck::wrapper::slice(K1)));
|
||||
auto c_global_local_partition =
|
||||
ck::wrapper::make_blockwise_gemm_xdl_c_local_partition<DataType,
|
||||
decltype(a_tile_layout),
|
||||
decltype(b_tile_layout),
|
||||
ck::wrapper::size(thread_layout),
|
||||
GemmTraits>(c_global_local_tile);
|
||||
// Define and clear c vgpr register
|
||||
auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr<DataType,
|
||||
decltype(a_tile_layout),
|
||||
decltype(b_tile_layout),
|
||||
ck::wrapper::size(thread_layout),
|
||||
GemmTraits>();
|
||||
ck::wrapper::clear(c_vgpr_reg);
|
||||
// Local partitions for lds memory
|
||||
auto a_lds_tensor_local_partition =
|
||||
ck::wrapper::make_local_partition(a_lds_tensor, thread_layout, threadIdx.x);
|
||||
auto b_lds_tensor_local_partition =
|
||||
ck::wrapper::make_local_partition(b_lds_tensor, thread_layout, threadIdx.x);
|
||||
// Lamda to slice tensor, then create local tile and partition
|
||||
auto make_global_partition = [&](auto tensor, auto projection, ck::index_t i) {
|
||||
const auto k_slice =
|
||||
ck::make_tuple(ck::wrapper::slice(i * K0PerBlock, (i + 1) * K0PerBlock),
|
||||
ck::wrapper::slice(),
|
||||
ck::wrapper::slice());
|
||||
auto local_tile = ck::wrapper::make_local_tile(
|
||||
tensor(k_slice), tile_shape_k0_m_n_k1, block_idxs, projection);
|
||||
return ck::wrapper::make_local_partition(local_tile, thread_layout, threadIdx.x);
|
||||
};
|
||||
|
||||
auto a_global_local_partition = make_global_partition(
|
||||
a_global_tensor,
|
||||
make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}),
|
||||
0);
|
||||
auto b_global_local_partition = make_global_partition(
|
||||
b_global_tensor,
|
||||
make_tuple(ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}),
|
||||
0);
|
||||
|
||||
// (row-major vgpr layout)
|
||||
auto a_vgpr_tensor =
|
||||
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, DataType>(
|
||||
ck::wrapper::make_layout(
|
||||
shape(a_global_local_partition),
|
||||
ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) *
|
||||
ck::wrapper::size<2>(a_global_local_partition),
|
||||
ck::wrapper::size<2>(a_global_local_partition),
|
||||
ck::Number<1>{})));
|
||||
auto b_vgpr_tensor =
|
||||
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, DataType>(
|
||||
ck::wrapper::make_layout(
|
||||
shape(b_global_local_partition),
|
||||
ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) *
|
||||
ck::wrapper::size<2>(a_global_local_partition),
|
||||
ck::wrapper::size<2>(a_global_local_partition),
|
||||
ck::Number<1>{})));
|
||||
// Copy first values to lds
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(a_global_local_partition,
|
||||
a_vgpr_tensor);
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(b_global_local_partition,
|
||||
b_vgpr_tensor);
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(a_vgpr_tensor,
|
||||
a_lds_tensor_local_partition);
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(b_vgpr_tensor,
|
||||
b_lds_tensor_local_partition);
|
||||
// Pipeline loop
|
||||
const ck::index_t num_loop =
|
||||
__builtin_amdgcn_readfirstlane(ck::math::integer_divide_ceil(K, KPerBlock));
|
||||
// Skip if only tile should be processed
|
||||
if(num_loop > 1)
|
||||
{
|
||||
ck::index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto a_global_local_partition_i = make_global_partition(
|
||||
a_global_tensor,
|
||||
make_tuple(
|
||||
ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}),
|
||||
i + 1);
|
||||
auto b_global_local_partition_i = make_global_partition(
|
||||
b_global_tensor,
|
||||
make_tuple(
|
||||
ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}),
|
||||
i + 1);
|
||||
// Copy data to A vgpr.
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
a_global_local_partition_i, a_vgpr_tensor);
|
||||
// Synchronize.
|
||||
ck::block_sync_lds();
|
||||
// Copy data to B vgpr.
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
b_global_local_partition_i, b_vgpr_tensor);
|
||||
// Perform gemm.
|
||||
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
|
||||
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
|
||||
// Synchronize
|
||||
ck::block_sync_lds();
|
||||
// Copy data to A and B lds tiles.
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
a_vgpr_tensor, a_lds_tensor_local_partition);
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
b_vgpr_tensor, b_lds_tensor_local_partition);
|
||||
|
||||
++i;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
// Handle tail.
|
||||
ck::block_sync_lds();
|
||||
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
|
||||
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
|
||||
// Store data from C vgpr to C global memory.
|
||||
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
|
||||
}
|
||||
|
||||
template <typename DataType,
|
||||
typename GemmTraits,
|
||||
ck::index_t scalar_per_vector,
|
||||
bool DoPadding,
|
||||
typename BlockShape,
|
||||
typename ThreadLayout>
|
||||
void PerformGemm(const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const BlockShape& tile_shape,
|
||||
const ThreadLayout& thread_layout)
|
||||
{
|
||||
// Global memory buffers
|
||||
SimpleDeviceMem a_mem(M * K * sizeof(DataType));
|
||||
SimpleDeviceMem b_mem(K * N * sizeof(DataType));
|
||||
SimpleDeviceMem c_mem(M * N * sizeof(DataType));
|
||||
|
||||
const ck::index_t grid_size_x =
|
||||
ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape));
|
||||
const ck::index_t grid_size_y =
|
||||
ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape));
|
||||
|
||||
const auto kernel =
|
||||
DeviceGemm<DataType, GemmTraits, scalar_per_vector, BlockShape, ThreadLayout, DoPadding>;
|
||||
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
|
||||
kernel,
|
||||
dim3(grid_size_x, grid_size_y, 1),
|
||||
dim3(ck::wrapper::size(thread_layout)),
|
||||
0,
|
||||
a_mem.GetDeviceBuffer(),
|
||||
b_mem.GetDeviceBuffer(),
|
||||
c_mem.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
tile_shape,
|
||||
thread_layout);
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using DataType = ck::half_t;
|
||||
const auto thread_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
|
||||
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
|
||||
const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<32>{});
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1, 8, false>(
|
||||
3840, 4096, 4096, tile_shape, thread_layout);
|
||||
return 0;
|
||||
}
|
||||
// MI300X Perf: 0.411552 ms, 313.081 TFlops, 234.403 GB/s,
|
||||
Reference in New Issue
Block a user