mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add optimized blockwise gemm using ck wrapper (#1157)
* Add optimized blockwise gemm using ck wrapper
* Add basic gemm example
* Update docs
* Add tutorial for gemm using ck wrapper
* Add perf note
* edits
* Fix cmake
* Fixes
---------
Co-authored-by: Lisa Delaney <lisa.delaney@amd.com>
[ROCm/composable_kernel commit: 1e73adbc28]
This commit is contained in:
@@ -2,3 +2,11 @@ add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrap
|
||||
target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations)
|
||||
add_executable(client_wrapper_img2col wrapper_img2col.cpp)
|
||||
target_link_libraries(client_wrapper_img2col PRIVATE composable_kernel::device_other_operations)
|
||||
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR
|
||||
GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR
|
||||
GPU_TARGETS MATCHES "gfx942")
|
||||
add_executable(client_wrapper_basic_gemm wrapper_basic_gemm.cpp)
|
||||
target_link_libraries(client_wrapper_basic_gemm PRIVATE composable_kernel::device_other_operations)
|
||||
add_executable(client_wrapper_optimized_gemm wrapper_optimized_gemm.cpp)
|
||||
target_link_libraries(client_wrapper_optimized_gemm PRIVATE composable_kernel::device_other_operations)
|
||||
endif()
|
||||
|
||||
177
client_example/25_wrapper/README.md
Normal file
177
client_example/25_wrapper/README.md
Normal file
@@ -0,0 +1,177 @@
|
||||
# Composable Kernel wrapper GEMM tutorial
|
||||
|
||||
This tutorial demonstrates how to implement matrix multiplication using Composable Kernel (CK)
|
||||
wrapper. We present the base version of GEMM without most of the available optimizations; however,
|
||||
it's worth noting that CK has kernels with different optimizations.
|
||||
|
||||
To implement these optimizations, you can use the CK wrapper or directly use available instances in
|
||||
CK. You can also refer to the
|
||||
[optimized GEMM example](https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_optimized_gemm.cpp),
|
||||
that uses CK wrapper based on the
|
||||
[`gridwise_gemm_xdlops_v2r3`](https://github.com/ROCm/composable_kernel/blob/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp) implementation.
|
||||
|
||||
The kernel definition should look similar to:
|
||||
|
||||
```cpp
|
||||
template <typename DataType,
|
||||
typename GemmTraits,
|
||||
ck::index_t scalar_per_vector,
|
||||
typename BlockShape,
|
||||
typename ThreadLayout>
|
||||
__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const BlockShape tile_shape,
|
||||
const ThreadLayout thread_layout)
|
||||
```
|
||||
|
||||
We pass pointers to global memory and matrix dimensions via arguments. Additionally, we pass
|
||||
selected lengths of processed data through each block (`tile_shape`) and thread layout
|
||||
(`thread_layout`). For compilation time parameters, we define the data type,
|
||||
[traits for the GEMM operation](https://github.com/ROCm/composable_kernel/blob/develop/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp)
|
||||
and scalar per vector value during copy.
|
||||
|
||||
Step 1: Create layouts for global and LDS memory.
|
||||
|
||||
```cpp
|
||||
// Specify layouts for global memory.
|
||||
const auto a_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1));
|
||||
const auto b_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1));
|
||||
const auto c_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1));
|
||||
|
||||
// Specify layouts for tiles.
|
||||
constexpr auto a_tile_layout = ck::wrapper::make_layout(
|
||||
ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
|
||||
constexpr auto b_tile_layout = ck::wrapper::make_layout(
|
||||
ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
|
||||
constexpr auto c_tile_layout = ck::wrapper::make_layout(
|
||||
ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{}));
|
||||
|
||||
// Apply padding for global memory.
|
||||
auto a_global_layout_padded = ck::wrapper::pad(a_global_layout, shape(a_tile_layout));
|
||||
auto b_global_layout_padded = ck::wrapper::pad(b_global_layout, shape(b_tile_layout));
|
||||
auto c_global_layout_padded = ck::wrapper::pad(c_global_layout, shape(c_tile_layout));
|
||||
```
|
||||
|
||||
We pad layouts for global tensors in case M, N, and K are not divisible by `MPerBlock`, `NPerBlock`, or
|
||||
`KPerBlock`.
|
||||
|
||||
Step 2: Create tensors for global and LDS memory.
|
||||
|
||||
```cpp
|
||||
// Make tensors for global memory.
|
||||
auto a_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<const DataType*>(p_a), a_global_layout_padded);
|
||||
auto b_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<const DataType*>(p_b), b_global_layout_padded);
|
||||
auto c_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<DataType*>(p_c), c_global_layout_padded);
|
||||
|
||||
// Allocate LDS memory.
|
||||
__shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)];
|
||||
__shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)];
|
||||
|
||||
// Make tensors for lds memory.
|
||||
auto a_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
|
||||
static_cast<DataType*>(lds_a), a_tile_layout);
|
||||
auto b_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
|
||||
static_cast<DataType*>(lds_b), b_tile_layout);
|
||||
```
|
||||
|
||||
We must specify parameters for copy and convert block indexes to tuple:
|
||||
|
||||
```cpp
|
||||
// Specify block index as tuple.
|
||||
const auto block_idxs = ck::make_tuple(static_cast<ck::index_t>(blockIdx.x),
|
||||
static_cast<ck::index_t>(blockIdx.y),
|
||||
ck::wrapper::slice());
|
||||
// Specify access parameters for copy.
|
||||
using DimAccessOrder = ck::Tuple<ck::Number<0>, ck::Number<1>>;
|
||||
constexpr ck::index_t vector_dim = 1;
|
||||
```
|
||||
|
||||
We create a local tile (per block) and local partitions (per thread) for the global memory (`C`). We also
|
||||
define and clear an output register (`c_vgpr_reg`) for the accumulation.
|
||||
|
||||
```cpp
|
||||
auto c_global_local_tile = ck::wrapper::make_local_tile(
|
||||
c_global_tensor,
|
||||
tile_shape,
|
||||
block_idxs,
|
||||
make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock)));
|
||||
auto c_global_local_partition =
|
||||
ck::wrapper::make_blockwise_gemm_xdl_c_local_partition<DataType,
|
||||
decltype(a_tile_layout),
|
||||
decltype(b_tile_layout),
|
||||
ck::wrapper::size(thread_layout),
|
||||
GemmTraits>(c_global_local_tile);
|
||||
// Create C vgpr to accumulate results.
|
||||
auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr<DataType,
|
||||
decltype(a_tile_layout),
|
||||
decltype(b_tile_layout),
|
||||
ck::wrapper::size(thread_layout),
|
||||
GemmTraits>();
|
||||
// Clear C vgpr.
|
||||
ck::wrapper::clear(c_vgpr_reg);
|
||||
```
|
||||
|
||||
We use two specific functions for `blockwise_gemm`: `make_blockwise_gemm_xdl_c_local_partition` and
|
||||
`make_blockwise_gemm_xdl_c_vgpr`. This helps to choose the appropriate partition for the `C` output
|
||||
and define tensors with specific layouts for `blockwise_gemm`. In the following step, we use only
|
||||
generic functions for the CK wrapper.
|
||||
|
||||
Step 3: Create the compute loop.
|
||||
|
||||
```cpp
|
||||
const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock);
|
||||
ck::index_t i = 0;
|
||||
do
|
||||
{
|
||||
// Get KPerBlock slice.
|
||||
const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock);
|
||||
auto a_global_tensor_k_slice = a_global_tensor(ck::wrapper::slice(), k_slice);
|
||||
auto b_global_tensor_k_slice = b_global_tensor(ck::wrapper::slice(), k_slice);
|
||||
// Create local tiles for A and B.
|
||||
auto a_global_local_tile = ck::wrapper::make_local_tile(
|
||||
a_global_tensor_k_slice,
|
||||
tile_shape,
|
||||
block_idxs,
|
||||
make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}));
|
||||
auto b_global_local_tile = ck::wrapper::make_local_tile(
|
||||
b_global_tensor_k_slice,
|
||||
tile_shape,
|
||||
block_idxs,
|
||||
make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}));
|
||||
// Copy from global to LDS.
|
||||
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
a_global_local_tile, a_lds_tensor, thread_layout);
|
||||
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
b_global_local_tile, b_lds_tensor, thread_layout);
|
||||
// Synchronize lds.
|
||||
ck::block_sync_lds();
|
||||
// Execute blockwise GEMM.
|
||||
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
|
||||
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
|
||||
|
||||
++i;
|
||||
} while(i < num_loop);
|
||||
```
|
||||
|
||||
Loop iterate over `K / KPerBlock`. Each time a local tile is created for A and B tensors (tensor per block),
|
||||
data is copied from global memory to LDS. The `blockwise_gemm` function performs the GEMM
|
||||
operation on `a_lds_tensor` and `b_lds_tensor`, and stores results in `c_vgpr_reg`.
|
||||
|
||||
The end result from `c_vgpr_reg` is stored in the `C` local partition (tensor per thread):
|
||||
|
||||
```cpp
|
||||
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
|
||||
```
|
||||
|
||||
If you want to dive deep into the details, you can find the entire example
|
||||
[here](https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_basic_gemm.cpp).
|
||||
216
client_example/25_wrapper/wrapper_basic_gemm.cpp
Normal file
216
client_example/25_wrapper/wrapper_basic_gemm.cpp
Normal file
@@ -0,0 +1,216 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <numeric>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/wrapper/layout.hpp"
|
||||
#include "ck/wrapper/tensor.hpp"
|
||||
#include "ck/wrapper/operations/copy.hpp"
|
||||
#include "ck/wrapper/operations/gemm.hpp"
|
||||
#include "ck/wrapper/utils/kernel_utils.hpp"
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
template <typename DataType,
|
||||
typename GemmTraits,
|
||||
ck::index_t scalar_per_vector,
|
||||
typename BlockShape,
|
||||
typename ThreadLayout>
|
||||
__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const BlockShape tile_shape,
|
||||
const ThreadLayout thread_layout)
|
||||
{
|
||||
constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape);
|
||||
constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape);
|
||||
constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape);
|
||||
|
||||
// Specify layouts for global memory.
|
||||
const auto a_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1));
|
||||
const auto b_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1));
|
||||
const auto c_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1));
|
||||
// Specify layouts for tiles.
|
||||
constexpr auto a_tile_layout = ck::wrapper::make_layout(
|
||||
ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
|
||||
constexpr auto b_tile_layout = ck::wrapper::make_layout(
|
||||
ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
|
||||
constexpr auto c_tile_layout = ck::wrapper::make_layout(
|
||||
ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{}));
|
||||
// Apply padding for global memory.
|
||||
auto a_global_layout_padded = ck::wrapper::pad(a_global_layout, shape(a_tile_layout));
|
||||
auto b_global_layout_padded = ck::wrapper::pad(b_global_layout, shape(b_tile_layout));
|
||||
auto c_global_layout_padded = ck::wrapper::pad(c_global_layout, shape(c_tile_layout));
|
||||
// Make tensors for global memory.
|
||||
auto a_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<const DataType*>(p_a), a_global_layout_padded);
|
||||
auto b_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<const DataType*>(p_b), b_global_layout_padded);
|
||||
auto c_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<DataType*>(p_c), c_global_layout_padded);
|
||||
// Allocate lds memory.
|
||||
__shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)];
|
||||
__shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)];
|
||||
// Make tensors for lds memory.
|
||||
auto a_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
|
||||
static_cast<DataType*>(lds_a), a_tile_layout);
|
||||
auto b_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
|
||||
static_cast<DataType*>(lds_b), b_tile_layout);
|
||||
// Specify block index as tuple.
|
||||
const auto block_idxs = ck::make_tuple(static_cast<ck::index_t>(blockIdx.x),
|
||||
static_cast<ck::index_t>(blockIdx.y),
|
||||
ck::wrapper::slice());
|
||||
// Specify access parameters for copy.
|
||||
using DimAccessOrder = ck::Tuple<ck::Number<0>, ck::Number<1>>;
|
||||
constexpr ck::index_t vector_dim = 1;
|
||||
// Create tile and partition for C. Use specific function for blockwise_gemm to assign the
|
||||
// appropriate partitions.
|
||||
auto c_global_local_tile = ck::wrapper::make_local_tile(
|
||||
c_global_tensor,
|
||||
tile_shape,
|
||||
block_idxs,
|
||||
make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock)));
|
||||
auto c_global_local_partition =
|
||||
ck::wrapper::make_blockwise_gemm_xdl_c_local_partition<DataType,
|
||||
decltype(a_tile_layout),
|
||||
decltype(b_tile_layout),
|
||||
ck::wrapper::size(thread_layout),
|
||||
GemmTraits>(c_global_local_tile);
|
||||
// Create C vgpr to accumulate results.
|
||||
auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr<DataType,
|
||||
decltype(a_tile_layout),
|
||||
decltype(b_tile_layout),
|
||||
ck::wrapper::size(thread_layout),
|
||||
GemmTraits>();
|
||||
// Clear C vgpr.
|
||||
ck::wrapper::clear(c_vgpr_reg);
|
||||
|
||||
// Iterate over K with KPerBlock step.
|
||||
const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock);
|
||||
ck::index_t i = 0;
|
||||
do
|
||||
{
|
||||
// Get KPerBlock slice.
|
||||
const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock);
|
||||
auto a_global_tensor_k_slice = a_global_tensor(ck::wrapper::slice(), k_slice);
|
||||
auto b_global_tensor_k_slice = b_global_tensor(ck::wrapper::slice(), k_slice);
|
||||
// Create local tiles for A and B.
|
||||
auto a_global_local_tile = ck::wrapper::make_local_tile(
|
||||
a_global_tensor_k_slice,
|
||||
tile_shape,
|
||||
block_idxs,
|
||||
make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}));
|
||||
auto b_global_local_tile = ck::wrapper::make_local_tile(
|
||||
b_global_tensor_k_slice,
|
||||
tile_shape,
|
||||
block_idxs,
|
||||
make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}));
|
||||
// Copy from global to lds.
|
||||
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
a_global_local_tile, a_lds_tensor, thread_layout);
|
||||
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
b_global_local_tile, b_lds_tensor, thread_layout);
|
||||
// Synchronize lds.
|
||||
ck::block_sync_lds();
|
||||
// Execute blockwise gemm.
|
||||
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
|
||||
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
|
||||
|
||||
++i;
|
||||
} while(i < num_loop);
|
||||
// Copy vgpr results to C global memory.
|
||||
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
|
||||
}
|
||||
|
||||
template <typename DataType,
|
||||
typename GemmTraits,
|
||||
ck::index_t scalar_per_vector,
|
||||
typename BlockShape,
|
||||
typename ThreadLayout>
|
||||
void PerformGemm(const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const BlockShape& tile_shape,
|
||||
const ThreadLayout& thread_layout)
|
||||
{
|
||||
// Global memory buffers
|
||||
SimpleDeviceMem a_mem(M * K * sizeof(DataType));
|
||||
SimpleDeviceMem b_mem(K * N * sizeof(DataType));
|
||||
SimpleDeviceMem c_mem(M * N * sizeof(DataType));
|
||||
|
||||
const ck::index_t grid_size_x =
|
||||
ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape));
|
||||
const ck::index_t grid_size_y =
|
||||
ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape));
|
||||
|
||||
const auto kernel =
|
||||
DeviceGemm<DataType, GemmTraits, scalar_per_vector, BlockShape, ThreadLayout>;
|
||||
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
|
||||
kernel,
|
||||
dim3(grid_size_x, grid_size_y, 1),
|
||||
dim3(ck::wrapper::size(thread_layout)),
|
||||
0,
|
||||
a_mem.GetDeviceBuffer(),
|
||||
b_mem.GetDeviceBuffer(),
|
||||
c_mem.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
tile_shape,
|
||||
thread_layout);
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using DataType = ck::half_t;
|
||||
const auto thread_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}),
|
||||
ck::make_tuple(ck::Number<4>{}, ck::Number<1>{}));
|
||||
const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<32>{});
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1, 8>(
|
||||
3840, 4096, 4096, tile_shape, thread_layout);
|
||||
return 0;
|
||||
}
|
||||
// MI300X Perf: 0.471337 ms, 273.369 TFlops, 204.671 GB/s,
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "ck/wrapper/layout.hpp"
|
||||
#include "ck/wrapper/tensor.hpp"
|
||||
#include "ck/wrapper/operations/copy.hpp"
|
||||
#include "ck/wrapper/utils/kernel_utils.hpp"
|
||||
|
||||
static constexpr ck::index_t NumDimSpatial = 3;
|
||||
using DataType = float;
|
||||
@@ -36,21 +37,20 @@ struct SimpleDeviceMem
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
// Test copy from Global to Global through LDS and VGPR
|
||||
template <typename InputTensor,
|
||||
typename OutputTensor,
|
||||
typename BlockShape,
|
||||
typename ThreadLayoutShape>
|
||||
__global__ void DeviceImageToColumnPad0(InputTensor input_tensor,
|
||||
OutputTensor output_tensor,
|
||||
const BlockShape tile_shape,
|
||||
const ThreadLayoutShape thread_layout)
|
||||
template <typename InputTensor, typename OutputTensor, typename BlockShape, typename ThreadLayout>
|
||||
__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__
|
||||
DeviceImageToColumnPad0(InputTensor input_tensor,
|
||||
OutputTensor output_tensor,
|
||||
const BlockShape tile_shape,
|
||||
const ThreadLayout thread_layout)
|
||||
{
|
||||
const ck::index_t block_idx = static_cast<ck::index_t>(blockIdx.x);
|
||||
// grid layout (dim1, dim0)
|
||||
const auto block_idxs =
|
||||
ck::make_tuple(static_cast<ck::index_t>(blockIdx.y), static_cast<ck::index_t>(blockIdx.x));
|
||||
|
||||
// Get local tiles for global memory
|
||||
auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idx);
|
||||
auto output_local_tile = ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idx);
|
||||
auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idxs);
|
||||
auto output_local_tile = ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idxs);
|
||||
|
||||
// Get partition per thread
|
||||
const auto input_local_partition =
|
||||
@@ -112,9 +112,11 @@ void PerformImageToColumnPad0(const ck::index_t G,
|
||||
SimpleDeviceMem out_buf(ck::wrapper::size(out_layout) * sizeof(DataType));
|
||||
|
||||
// User can choose appropriate number of threads and sizes per block
|
||||
const auto thread_layout = ck::make_tuple(ck::Number<8>{}, ck::Number<16>{});
|
||||
const auto thread_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(ck::Number<8>{}, ck::Number<16>{}),
|
||||
ck::make_tuple(ck::Number<16>{}, ck::Number<1>{}));
|
||||
// This example doesn't support padding, user should select tile sizes
|
||||
// which divides the shape completely
|
||||
// which are divisible by the shape.
|
||||
const auto tile_shape = ck::make_tuple(ck::Number<32>{}, ck::Number<64>{});
|
||||
|
||||
// Create buffers for global memory
|
||||
@@ -123,10 +125,11 @@ void PerformImageToColumnPad0(const ck::index_t G,
|
||||
auto output_tensor_global = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<DataType*>(out_buf.GetDeviceBuffer()), out_layout);
|
||||
|
||||
const ck::index_t grid_size = ck::math::integer_divide_ceil(ck::wrapper::size<0>(in_layout),
|
||||
ck::wrapper::size<0>(tile_shape)) *
|
||||
ck::math::integer_divide_ceil(ck::wrapper::size<1>(in_layout),
|
||||
ck::wrapper::size<1>(tile_shape));
|
||||
// grid layout (dim1, dim0)
|
||||
const ck::index_t grid_size_x = ck::math::integer_divide_ceil(ck::wrapper::size<1>(in_layout),
|
||||
ck::wrapper::size<1>(tile_shape));
|
||||
const ck::index_t grid_size_y = ck::math::integer_divide_ceil(ck::wrapper::size<0>(in_layout),
|
||||
ck::wrapper::size<0>(tile_shape));
|
||||
|
||||
const auto kernel = DeviceImageToColumnPad0<decltype(input_tensor_global),
|
||||
decltype(output_tensor_global),
|
||||
@@ -134,7 +137,7 @@ void PerformImageToColumnPad0(const ck::index_t G,
|
||||
decltype(thread_layout)>;
|
||||
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(grid_size_x, grid_size_y, 1),
|
||||
dim3(ck::wrapper::size(thread_layout)),
|
||||
0,
|
||||
input_tensor_global,
|
||||
@@ -178,3 +181,4 @@ int main(int argc, char* argv[])
|
||||
{1, 1, 1} /*filter_dilations*/);
|
||||
return 0;
|
||||
}
|
||||
// MI100 Perf: 0.255178 ms, 1698.9 GB/s,
|
||||
|
||||
308
client_example/25_wrapper/wrapper_optimized_gemm.cpp
Normal file
308
client_example/25_wrapper/wrapper_optimized_gemm.cpp
Normal file
@@ -0,0 +1,308 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <numeric>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/wrapper/layout.hpp"
|
||||
#include "ck/wrapper/tensor.hpp"
|
||||
#include "ck/wrapper/operations/copy.hpp"
|
||||
#include "ck/wrapper/operations/gemm.hpp"
|
||||
#include "ck/wrapper/utils/kernel_utils.hpp"
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
template <bool DoPad, typename Layout, typename PaddingDims>
|
||||
__device__ auto ApplyPadding(const Layout& layout, const PaddingDims& padding_dims)
|
||||
{
|
||||
if constexpr(DoPad)
|
||||
{
|
||||
return ck::wrapper::pad(layout, padding_dims);
|
||||
}
|
||||
else
|
||||
{
|
||||
return layout;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType,
|
||||
typename GemmTraits,
|
||||
ck::index_t scalar_per_vector,
|
||||
typename BlockShape,
|
||||
typename ThreadLayout,
|
||||
bool DoPadding>
|
||||
__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const BlockShape tile_shape,
|
||||
const ThreadLayout thread_layout)
|
||||
{
|
||||
constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape);
|
||||
constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape);
|
||||
constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape);
|
||||
constexpr auto K1 = GemmTraits::K1;
|
||||
constexpr auto K0PerBlock = KPerBlock / K1;
|
||||
const auto K0 = ck::math::integer_divide_ceil(K, K1);
|
||||
|
||||
const auto tile_shape_k0_m_n_k1 = ck::make_tuple(K0PerBlock, MPerBlock, NPerBlock, K1);
|
||||
// Create layouts for global memory
|
||||
const auto a_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1));
|
||||
const auto b_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1));
|
||||
const auto c_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1));
|
||||
// Apply padding
|
||||
auto a_padded_global_layout =
|
||||
ApplyPadding<DoPadding>(a_global_layout, ck::make_tuple(MPerBlock, KPerBlock));
|
||||
auto b_padded_global_layout =
|
||||
ApplyPadding<DoPadding>(b_global_layout, ck::make_tuple(NPerBlock, KPerBlock));
|
||||
auto c_padded_global_layout =
|
||||
ApplyPadding<DoPadding>(c_global_layout, ck::make_tuple(MPerBlock, NPerBlock));
|
||||
// Reshape from M,K to K0,M,K1
|
||||
const auto reshaped_dims_idxs =
|
||||
ck::make_tuple(ck::Number<1>{}, ck::make_tuple(ck::Number<0>{}, ck::Number<2>{}));
|
||||
auto a_padded_unmerged_global_layout =
|
||||
ck::wrapper::unmerge<1>(a_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs);
|
||||
auto b_padded_unmerged_global_layout =
|
||||
ck::wrapper::unmerge<1>(b_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs);
|
||||
// Create tensors for global memory
|
||||
auto a_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<const DataType*>(p_a), a_padded_unmerged_global_layout);
|
||||
auto b_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<const DataType*>(p_b), b_padded_unmerged_global_layout);
|
||||
auto c_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<DataType*>(p_c), c_padded_global_layout);
|
||||
// Create layouts and tensors for lds memory.
|
||||
constexpr auto a_tile_layout = ck::wrapper::make_layout(
|
||||
ck::make_tuple(K0PerBlock, MPerBlock, K1),
|
||||
ck::make_tuple((MPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{}));
|
||||
constexpr auto b_tile_layout = ck::wrapper::make_layout(
|
||||
ck::make_tuple(K0PerBlock, NPerBlock, K1),
|
||||
ck::make_tuple((NPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{}));
|
||||
|
||||
__shared__ DataType lds_a[ck::wrapper::size(a_tile_layout) + K0PerBlock];
|
||||
__shared__ DataType lds_b[ck::wrapper::size(b_tile_layout) + K0PerBlock];
|
||||
|
||||
auto a_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
|
||||
static_cast<DataType*>(lds_a), a_tile_layout);
|
||||
auto b_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
|
||||
static_cast<DataType*>(lds_b), b_tile_layout);
|
||||
|
||||
const auto block_idxs = ck::make_tuple(ck::wrapper::slice(),
|
||||
static_cast<ck::index_t>(blockIdx.x),
|
||||
static_cast<ck::index_t>(blockIdx.y),
|
||||
ck::wrapper::slice());
|
||||
using DimAccessOrder = ck::Tuple<ck::Number<1>, ck::Number<0>, ck::Number<2>>;
|
||||
constexpr ck::index_t vector_dim = 2;
|
||||
|
||||
// Create tile and partition for C global memory. Use specific gemm
|
||||
// functions to get appropriate layouts.
|
||||
auto c_global_local_tile =
|
||||
ck::wrapper::make_local_tile(c_global_tensor,
|
||||
tile_shape_k0_m_n_k1,
|
||||
block_idxs,
|
||||
make_tuple(ck::wrapper::slice(K0PerBlock),
|
||||
ck::Number<1>{},
|
||||
ck::Number<1>{},
|
||||
ck::wrapper::slice(K1)));
|
||||
auto c_global_local_partition =
|
||||
ck::wrapper::make_blockwise_gemm_xdl_c_local_partition<DataType,
|
||||
decltype(a_tile_layout),
|
||||
decltype(b_tile_layout),
|
||||
ck::wrapper::size(thread_layout),
|
||||
GemmTraits>(c_global_local_tile);
|
||||
// Define and clear c vgpr register
|
||||
auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr<DataType,
|
||||
decltype(a_tile_layout),
|
||||
decltype(b_tile_layout),
|
||||
ck::wrapper::size(thread_layout),
|
||||
GemmTraits>();
|
||||
ck::wrapper::clear(c_vgpr_reg);
|
||||
// Local partitions for lds memory
|
||||
auto a_lds_tensor_local_partition =
|
||||
ck::wrapper::make_local_partition(a_lds_tensor, thread_layout, threadIdx.x);
|
||||
auto b_lds_tensor_local_partition =
|
||||
ck::wrapper::make_local_partition(b_lds_tensor, thread_layout, threadIdx.x);
|
||||
// Lamda to slice tensor, then create local tile and partition
|
||||
auto make_global_partition = [&](auto tensor, auto projection, ck::index_t i) {
|
||||
const auto k_slice =
|
||||
ck::make_tuple(ck::wrapper::slice(i * K0PerBlock, (i + 1) * K0PerBlock),
|
||||
ck::wrapper::slice(),
|
||||
ck::wrapper::slice());
|
||||
auto local_tile = ck::wrapper::make_local_tile(
|
||||
tensor(k_slice), tile_shape_k0_m_n_k1, block_idxs, projection);
|
||||
return ck::wrapper::make_local_partition(local_tile, thread_layout, threadIdx.x);
|
||||
};
|
||||
|
||||
auto a_global_local_partition = make_global_partition(
|
||||
a_global_tensor,
|
||||
make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}),
|
||||
0);
|
||||
auto b_global_local_partition = make_global_partition(
|
||||
b_global_tensor,
|
||||
make_tuple(ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}),
|
||||
0);
|
||||
|
||||
// (row-major vgpr layout)
|
||||
auto a_vgpr_tensor =
|
||||
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, DataType>(
|
||||
ck::wrapper::make_layout(
|
||||
shape(a_global_local_partition),
|
||||
ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) *
|
||||
ck::wrapper::size<2>(a_global_local_partition),
|
||||
ck::wrapper::size<2>(a_global_local_partition),
|
||||
ck::Number<1>{})));
|
||||
auto b_vgpr_tensor =
|
||||
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, DataType>(
|
||||
ck::wrapper::make_layout(
|
||||
shape(b_global_local_partition),
|
||||
ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) *
|
||||
ck::wrapper::size<2>(a_global_local_partition),
|
||||
ck::wrapper::size<2>(a_global_local_partition),
|
||||
ck::Number<1>{})));
|
||||
// Copy first values to lds
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(a_global_local_partition,
|
||||
a_vgpr_tensor);
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(b_global_local_partition,
|
||||
b_vgpr_tensor);
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(a_vgpr_tensor,
|
||||
a_lds_tensor_local_partition);
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(b_vgpr_tensor,
|
||||
b_lds_tensor_local_partition);
|
||||
// Pipeline loop
|
||||
const ck::index_t num_loop =
|
||||
__builtin_amdgcn_readfirstlane(ck::math::integer_divide_ceil(K, KPerBlock));
|
||||
// Skip if only tile should be processed
|
||||
if(num_loop > 1)
|
||||
{
|
||||
ck::index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto a_global_local_partition_i = make_global_partition(
|
||||
a_global_tensor,
|
||||
make_tuple(
|
||||
ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}),
|
||||
i + 1);
|
||||
auto b_global_local_partition_i = make_global_partition(
|
||||
b_global_tensor,
|
||||
make_tuple(
|
||||
ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}),
|
||||
i + 1);
|
||||
// Copy data to A vgpr.
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
a_global_local_partition_i, a_vgpr_tensor);
|
||||
// Synchronize.
|
||||
ck::block_sync_lds();
|
||||
// Copy data to B vgpr.
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
b_global_local_partition_i, b_vgpr_tensor);
|
||||
// Perform gemm.
|
||||
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
|
||||
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
|
||||
// Synchronize
|
||||
ck::block_sync_lds();
|
||||
// Copy data to A and B lds tiles.
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
a_vgpr_tensor, a_lds_tensor_local_partition);
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
b_vgpr_tensor, b_lds_tensor_local_partition);
|
||||
|
||||
++i;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
// Handle tail.
|
||||
ck::block_sync_lds();
|
||||
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
|
||||
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
|
||||
// Store data from C vgpr to C global memory.
|
||||
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
|
||||
}
|
||||
|
||||
template <typename DataType,
|
||||
typename GemmTraits,
|
||||
ck::index_t scalar_per_vector,
|
||||
bool DoPadding,
|
||||
typename BlockShape,
|
||||
typename ThreadLayout>
|
||||
void PerformGemm(const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const BlockShape& tile_shape,
|
||||
const ThreadLayout& thread_layout)
|
||||
{
|
||||
// Global memory buffers
|
||||
SimpleDeviceMem a_mem(M * K * sizeof(DataType));
|
||||
SimpleDeviceMem b_mem(K * N * sizeof(DataType));
|
||||
SimpleDeviceMem c_mem(M * N * sizeof(DataType));
|
||||
|
||||
const ck::index_t grid_size_x =
|
||||
ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape));
|
||||
const ck::index_t grid_size_y =
|
||||
ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape));
|
||||
|
||||
const auto kernel =
|
||||
DeviceGemm<DataType, GemmTraits, scalar_per_vector, BlockShape, ThreadLayout, DoPadding>;
|
||||
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
|
||||
kernel,
|
||||
dim3(grid_size_x, grid_size_y, 1),
|
||||
dim3(ck::wrapper::size(thread_layout)),
|
||||
0,
|
||||
a_mem.GetDeviceBuffer(),
|
||||
b_mem.GetDeviceBuffer(),
|
||||
c_mem.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
tile_shape,
|
||||
thread_layout);
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using DataType = ck::half_t;
|
||||
const auto thread_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
|
||||
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
|
||||
const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<32>{});
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1, 8, false>(
|
||||
3840, 4096, 4096, tile_shape, thread_layout);
|
||||
return 0;
|
||||
}
|
||||
// MI300X Perf: 0.411552 ms, 313.081 TFlops, 234.403 GB/s,
|
||||
@@ -12,10 +12,6 @@ Wrapper
|
||||
Description
|
||||
-------------------------------------
|
||||
|
||||
.. note::
|
||||
|
||||
The wrapper is under development and its functionality is limited.
|
||||
|
||||
|
||||
The CK library provides a lightweight wrapper for more complex operations implemented in
|
||||
the library.
|
||||
@@ -54,9 +50,15 @@ Output::
|
||||
2 6 10 14 18 22 26 30
|
||||
|
||||
|
||||
Tutorials:
|
||||
|
||||
* `GEMM tutorial <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/README.md>`_
|
||||
|
||||
Advanced examples:
|
||||
|
||||
* `Image to column <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_img2col.cpp>`_
|
||||
* `Basic gemm <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_basic_gemm.cpp>`_
|
||||
* `Optimized gemm <https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_optimized_gemm.cpp>`_
|
||||
|
||||
-------------------------------------
|
||||
Layout
|
||||
|
||||
@@ -61,12 +61,12 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
|
||||
decltype(dim_access_order),
|
||||
VectorDim,
|
||||
ScalarPerVector,
|
||||
Sequence<false>,
|
||||
Sequence<false>>{in_grid_desc,
|
||||
make_tuple(src_tensor.GetMultiIdxOffsets()),
|
||||
out_grid_desc,
|
||||
make_tuple(dst_tensor.GetMultiIdxOffsets()),
|
||||
tensor_operation::element_wise::PassThrough{}};
|
||||
Sequence<true>,
|
||||
Sequence<true>>{in_grid_desc,
|
||||
make_tuple(src_tensor.GetMultiIdxOffsets()),
|
||||
out_grid_desc,
|
||||
make_tuple(dst_tensor.GetMultiIdxOffsets()),
|
||||
tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
transfer.Run(tie(in_grid_desc),
|
||||
tie(src_tensor.GetBuffer()),
|
||||
@@ -104,37 +104,25 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor)
|
||||
else if constexpr(SrcTensorType::IsDynamicBuffer && !DstTensorType::IsDynamicBuffer)
|
||||
{
|
||||
// Perform copy from DynamicBuffer to StaticBuffer
|
||||
const auto src_dst_slice_origin =
|
||||
const auto dst_slice_origin_idxs =
|
||||
generate_tuple([&](auto) { return I0; }, Number<num_dims>{});
|
||||
constexpr auto src_vector_tensor_lengths = generate_sequence_v2(
|
||||
[&](auto I) {
|
||||
if constexpr(I == VectorDim)
|
||||
{
|
||||
return Number<ScalarPerVector>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return I1;
|
||||
}
|
||||
},
|
||||
Number<num_dims>{});
|
||||
|
||||
auto transfer =
|
||||
ThreadwiseTensorSliceTransfer_v4r1<typename SrcTensorType::TensorElementType,
|
||||
typename DstTensorType::TensorElementType,
|
||||
remove_cvref_t<decltype(in_grid_desc)>,
|
||||
remove_cvref_t<decltype(out_grid_desc)>,
|
||||
decltype(thread_slice_lengths),
|
||||
decltype(dim_access_order),
|
||||
decltype(src_vector_tensor_lengths),
|
||||
decltype(dim_access_order)>{
|
||||
src_tensor.GetMultiIdxOffsets()};
|
||||
auto transfer = ThreadwiseTensorSliceTransfer_v2<
|
||||
std::remove_const_t<typename SrcTensorType::TensorElementType>,
|
||||
std::remove_const_t<typename DstTensorType::TensorElementType>,
|
||||
remove_cvref_t<decltype(in_grid_desc)>,
|
||||
remove_cvref_t<decltype(out_grid_desc)>,
|
||||
decltype(thread_slice_lengths),
|
||||
decltype(dim_access_order),
|
||||
VectorDim,
|
||||
ScalarPerVector,
|
||||
I1,
|
||||
false,
|
||||
false>{in_grid_desc, src_tensor.GetMultiIdxOffsets()};
|
||||
|
||||
transfer.Run(in_grid_desc,
|
||||
src_dst_slice_origin,
|
||||
src_tensor.GetBuffer(),
|
||||
out_grid_desc,
|
||||
src_dst_slice_origin,
|
||||
dst_slice_origin_idxs,
|
||||
dst_tensor.GetBuffer());
|
||||
}
|
||||
else
|
||||
@@ -183,10 +171,12 @@ template <typename DimAccessOrderTuple,
|
||||
index_t ScalarPerVector,
|
||||
typename SrcTensorType,
|
||||
typename DstTensorType,
|
||||
typename ThreadLayoutTuple>
|
||||
__device__ void blockwise_copy(const SrcTensorType& src_tensor,
|
||||
DstTensorType& dst_tensor,
|
||||
[[maybe_unused]] ThreadLayoutTuple& thread_layout)
|
||||
typename ThreadShape,
|
||||
typename ThreadUnrolledDesc>
|
||||
__device__ void
|
||||
blockwise_copy(const SrcTensorType& src_tensor,
|
||||
DstTensorType& dst_tensor,
|
||||
[[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout)
|
||||
{
|
||||
static_assert(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer);
|
||||
static_assert(is_detected<is_tuple, DimAccessOrderTuple>::value);
|
||||
@@ -199,12 +189,12 @@ __device__ void blockwise_copy(const SrcTensorType& src_tensor,
|
||||
|
||||
constexpr auto tile_lengths_seq =
|
||||
generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number<num_dims>{});
|
||||
constexpr auto thread_layout_seq = generate_sequence_v2(
|
||||
[](auto I) { return size(ThreadLayoutTuple{}.At(I)); }, Number<num_dims>{});
|
||||
constexpr auto thread_layout_seq =
|
||||
generate_sequence_v2([](auto I) { return size<I>(ThreadShape{}); }, Number<num_dims>{});
|
||||
constexpr auto dim_access_order = generate_sequence_v2(
|
||||
[](auto I) { return DimAccessOrderTuple{}.At(I); }, Number<num_dims>{});
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<size(ThreadLayoutTuple{})>;
|
||||
using ThisThreadBlock = ThisThreadBlock<size(ThreadShape{})>;
|
||||
|
||||
// Perform copy between DynamicBuffers
|
||||
auto transfer = ThreadGroupTensorSliceTransfer_v7<
|
||||
|
||||
@@ -48,8 +48,9 @@ __device__ constexpr auto GetBlockDescriptor()
|
||||
|
||||
/**
|
||||
* \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be
|
||||
* stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) and B
|
||||
* data layout must be (NPerBlock, KPerBlock).
|
||||
* stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) or
|
||||
* (K0PerBlock, MPerBlock, K1) and B data layout must be (NPerBlock, KPerBlock)
|
||||
* or (K0PerBlock, NPerBlock, K1).
|
||||
*
|
||||
* \note C output Vgpr register layout (8D):
|
||||
* - MXdlPerWave - The number of MFMA instructions run by single wave in M
|
||||
@@ -71,9 +72,9 @@ __device__ constexpr auto GetBlockDescriptor()
|
||||
* \tparam BlockSize Tensor to pad.
|
||||
* \tparam GemmTraits Traits of gemm xdl operation.
|
||||
* \param a_local_tile_tensor A tensor in LDS memory for blockwise gemm
|
||||
* (MPerBlock, KPerBlock) layout.
|
||||
* (MPerBlock, KPerBlock) or (K0PerBlock, MPerBlock, K1) layout.
|
||||
* \param b_local_tile_tensor B tensor in LDS memory for blockwise gemm
|
||||
* (NPerBlock, KPerBlock) layout.
|
||||
* (NPerBlock, KPerBlock) or (K0PerBlock, NPerBlock, K1) layout.
|
||||
* \param c_reg_tensor C tensor VGPR memory for blockwise gemm.
|
||||
*/
|
||||
template <typename DataType,
|
||||
@@ -86,6 +87,8 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor,
|
||||
const BTensorType& b_local_tile_tensor,
|
||||
CTensorType& c_reg_tensor)
|
||||
{
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
static_assert(ATensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
|
||||
static_assert(BTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds);
|
||||
static_assert(CTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Vgpr);
|
||||
@@ -99,10 +102,18 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor,
|
||||
using ATileLayout = remove_cvref_t<decltype(layout(a_local_tile_tensor))>;
|
||||
using BTileLayout = remove_cvref_t<decltype(layout(b_local_tile_tensor))>;
|
||||
|
||||
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
|
||||
typename BTileLayout::LayoutShape{}.Size());
|
||||
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
|
||||
|
||||
using ABlockDesc_K0_M_K1_Type =
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
|
||||
conditional_t<is_3d_desc,
|
||||
typename ATileLayout::LayoutUnrolledDescriptorType,
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
|
||||
using BBlockDesc_K0_N_K1_Type =
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
|
||||
conditional_t<is_3d_desc,
|
||||
typename BTileLayout::LayoutUnrolledDescriptorType,
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
|
||||
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
DataType,
|
||||
@@ -168,14 +179,22 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
|
||||
typename BTileLayout::LayoutShape{}.Size());
|
||||
|
||||
constexpr bool is_integer =
|
||||
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
|
||||
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
|
||||
|
||||
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
|
||||
using ABlockDesc_K0_M_K1_Type =
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
|
||||
conditional_t<is_3d_desc,
|
||||
typename ATileLayout::LayoutUnrolledDescriptorType,
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
|
||||
using BBlockDesc_K0_N_K1_Type =
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
|
||||
conditional_t<is_3d_desc,
|
||||
typename BTileLayout::LayoutUnrolledDescriptorType,
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
|
||||
|
||||
using BlockwiseGemmXdlops =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
@@ -233,19 +252,45 @@ make_blockwise_gemm_xdl_c_local_partition(CTensorType& c_local_tile_tensor)
|
||||
|
||||
const auto partition_desc = BlockwiseGemmXdlops::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
|
||||
layout(c_local_tile_tensor).GetUnrolledDescriptor());
|
||||
|
||||
const auto lower_upper_dims =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<8>{});
|
||||
|
||||
auto sliced_desc = transform_tensor_descriptor(
|
||||
partition_desc,
|
||||
make_tuple(
|
||||
make_slice_transform(partition_shape.At(Number<0>{}),
|
||||
m_thread_data_on_grid_idx[I0],
|
||||
partition_shape.At(Number<0>{}) + m_thread_data_on_grid_idx[I0]),
|
||||
make_slice_transform(partition_shape.At(Number<1>{}),
|
||||
n_thread_data_on_grid_idx[I0],
|
||||
partition_shape.At(Number<1>{}) + n_thread_data_on_grid_idx[I0]),
|
||||
make_slice_transform(partition_shape.At(Number<2>{}),
|
||||
m_thread_data_on_grid_idx[I1],
|
||||
partition_shape.At(Number<2>{}) + m_thread_data_on_grid_idx[I1]),
|
||||
make_slice_transform(partition_shape.At(Number<3>{}),
|
||||
n_thread_data_on_grid_idx[I1],
|
||||
partition_shape.At(Number<3>{}) + n_thread_data_on_grid_idx[I1]),
|
||||
make_slice_transform(partition_shape.At(Number<4>{}),
|
||||
m_thread_data_on_grid_idx[I2],
|
||||
partition_shape.At(Number<4>{}) + m_thread_data_on_grid_idx[I2]),
|
||||
make_slice_transform(partition_shape.At(Number<5>{}),
|
||||
m_thread_data_on_grid_idx[I3],
|
||||
partition_shape.At(Number<5>{}) + m_thread_data_on_grid_idx[I3]),
|
||||
make_slice_transform(partition_shape.At(Number<6>{}),
|
||||
m_thread_data_on_grid_idx[I4],
|
||||
partition_shape.At(Number<6>{}) + m_thread_data_on_grid_idx[I4]),
|
||||
make_slice_transform(partition_shape.At(Number<7>{}),
|
||||
n_thread_data_on_grid_idx[I2],
|
||||
partition_shape.At(Number<7>{}) + n_thread_data_on_grid_idx[I2])),
|
||||
lower_upper_dims,
|
||||
lower_upper_dims);
|
||||
|
||||
const auto partition_layout =
|
||||
Layout<remove_reference_t<decltype(partition_shape)>, decltype(partition_desc)>(
|
||||
partition_shape, partition_desc);
|
||||
Layout<remove_reference_t<decltype(partition_shape)>, decltype(sliced_desc)>(
|
||||
partition_shape, sliced_desc);
|
||||
auto partition_tensor = make_tensor<CTensorType::TensorBufferAddressSpace>(
|
||||
c_local_tile_tensor.GetPointer(), partition_layout);
|
||||
partition_tensor.SetMultiIdxOffset(make_multi_index(m_thread_data_on_grid_idx[I0],
|
||||
n_thread_data_on_grid_idx[I0],
|
||||
m_thread_data_on_grid_idx[I1],
|
||||
n_thread_data_on_grid_idx[I1],
|
||||
m_thread_data_on_grid_idx[I2],
|
||||
m_thread_data_on_grid_idx[I3],
|
||||
m_thread_data_on_grid_idx[I4],
|
||||
n_thread_data_on_grid_idx[I2]));
|
||||
return partition_tensor;
|
||||
}
|
||||
|
||||
@@ -292,14 +337,22 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr()
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
static_assert(typename ATileLayout::LayoutShape{}.Size() ==
|
||||
typename BTileLayout::LayoutShape{}.Size());
|
||||
|
||||
constexpr bool is_integer =
|
||||
is_same_v<DataType, int8_t> || is_same_v<DataType, int16_t> || is_same_v<DataType, int32_t>;
|
||||
using GemmAccDataType = std::conditional_t<is_integer, int32_t, float>;
|
||||
|
||||
constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3;
|
||||
using ABlockDesc_K0_M_K1_Type =
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>());
|
||||
conditional_t<is_3d_desc,
|
||||
typename ATileLayout::LayoutUnrolledDescriptorType,
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, ATileLayout>())>;
|
||||
using BBlockDesc_K0_N_K1_Type =
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>());
|
||||
conditional_t<is_3d_desc,
|
||||
typename BTileLayout::LayoutUnrolledDescriptorType,
|
||||
decltype(detail::GetBlockDescriptor<GemmTraits::K1, BTileLayout>())>;
|
||||
|
||||
using BlockwiseGemmXdlops =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
@@ -326,9 +379,8 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr()
|
||||
const auto vgpr_layout = Layout<remove_reference_t<decltype(vgpr_shape)>, decltype(vgpr_desc)>(
|
||||
vgpr_shape, vgpr_desc);
|
||||
// Get vector type for Vgpr
|
||||
using BlockwiseGemmCThreadBufferType =
|
||||
remove_reference_t<decltype(BlockwiseGemmXdlops{}.GetCThreadBuffer())>;
|
||||
using VgprVectorType = typename BlockwiseGemmCThreadBufferType::V;
|
||||
constexpr index_t ScalarPerVector = BlockwiseGemmXdlops::xdlops_gemm.GetRegSizePerXdlops();
|
||||
using VgprVectorType = typename vector_type<GemmAccDataType, ScalarPerVector>::type;
|
||||
return ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, VgprVectorType>(
|
||||
vgpr_layout);
|
||||
}
|
||||
|
||||
@@ -172,10 +172,10 @@ __host__ __device__ constexpr auto GenerateUpperDims(const Tuple<Transforms...>&
|
||||
}
|
||||
}
|
||||
|
||||
template <typename... Ts, typename Shape, typename FlattenDescriptor>
|
||||
template <typename... Ts, typename Shape, typename UnrolledDescriptor>
|
||||
__host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple<Ts...>& idx,
|
||||
const Shape& shape,
|
||||
const FlattenDescriptor& flatten_desc)
|
||||
const UnrolledDescriptor& flatten_desc)
|
||||
{
|
||||
constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size();
|
||||
|
||||
|
||||
@@ -20,48 +20,57 @@ namespace wrapper {
|
||||
* \tparam K1Value The number of K-dim elements that are packed together as
|
||||
* a separate logical dimension. Usually aligns with vector load size.
|
||||
*/
|
||||
template <index_t MPerXDLValue,
|
||||
index_t NPerXDLValue,
|
||||
index_t MXdlPerWaveValue,
|
||||
index_t NXdlPerWaveValue,
|
||||
index_t K1Value>
|
||||
template <typename MPerXDLValue,
|
||||
typename NPerXDLValue,
|
||||
typename MXdlPerWaveValue,
|
||||
typename NXdlPerWaveValue,
|
||||
typename K1Value>
|
||||
struct BlockwisGemmXdlTraits
|
||||
{
|
||||
static constexpr index_t MPerXDL = MPerXDLValue;
|
||||
static constexpr index_t NPerXDL = NPerXDLValue;
|
||||
static constexpr index_t MXdlPerWave = MXdlPerWaveValue;
|
||||
static constexpr index_t NXdlPerWave = NXdlPerWaveValue;
|
||||
static constexpr index_t K1 = K1Value;
|
||||
static constexpr auto MPerXDL = MPerXDLValue{};
|
||||
static constexpr auto NPerXDL = NPerXDLValue{};
|
||||
static constexpr auto MXdlPerWave = MXdlPerWaveValue{};
|
||||
static constexpr auto NXdlPerWave = NXdlPerWaveValue{};
|
||||
static constexpr auto K1 = K1Value{};
|
||||
};
|
||||
|
||||
// K1 = 4
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 4>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<4>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 4>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<4>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 4>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<4>>
|
||||
{
|
||||
};
|
||||
// K1 = 8
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 8>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<8>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 8>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<8>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 8>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<8>>
|
||||
{
|
||||
};
|
||||
// K1 = 16
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 16>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<4>, Number<2>, Number<16>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 16>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<4>, Number<16>>
|
||||
{
|
||||
};
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 16>
|
||||
struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1
|
||||
: BlockwisGemmXdlTraits<Number<32>, Number<32>, Number<2>, Number<2>, Number<16>>
|
||||
{
|
||||
};
|
||||
|
||||
|
||||
14
include/ck/wrapper/utils/kernel_utils.hpp
Normal file
14
include/ck/wrapper/utils/kernel_utils.hpp
Normal file
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
|
||||
#define __CK_WRAPPER_LAUNCH_BOUNDS__ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace wrapper {
|
||||
@@ -29,6 +30,7 @@ template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
namespace {
|
||||
namespace detail {
|
||||
/**
|
||||
* \brief Generate packed (column-major) strides if not passed
|
||||
*
|
||||
@@ -83,6 +85,7 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha
|
||||
return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides);
|
||||
}
|
||||
}
|
||||
} // namespace detail
|
||||
} // namespace
|
||||
|
||||
/// @endcond
|
||||
@@ -98,8 +101,9 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha
|
||||
template <typename Shape, typename Strides>
|
||||
__host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides)
|
||||
{
|
||||
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Strides{}));
|
||||
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, strides));
|
||||
using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Strides{}));
|
||||
return Layout<Shape, UnrolledDescriptorType>(shape,
|
||||
detail::MakeUnrolledDescriptor(shape, strides));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -112,13 +116,12 @@ __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides
|
||||
template <typename Shape>
|
||||
__host__ __device__ constexpr auto make_layout(const Shape& shape)
|
||||
{
|
||||
using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
|
||||
return Layout<Shape, UnrolledDescriptorType>(shape, MakeUnrolledDescriptor(shape, Tuple<>{}));
|
||||
using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
|
||||
return Layout<Shape, UnrolledDescriptorType>(shape,
|
||||
detail::MakeUnrolledDescriptor(shape, Tuple<>{}));
|
||||
}
|
||||
|
||||
// Layout helpers
|
||||
// get
|
||||
|
||||
/**
|
||||
* \private
|
||||
* \brief Get dim.
|
||||
@@ -152,8 +155,8 @@ __host__ __device__ constexpr auto get(const Tuple<Dims...>& tuple)
|
||||
* \param layout Layout to create sub layout.
|
||||
* \return Requsted sub layout.
|
||||
*/
|
||||
template <index_t idx, typename Shape, typename FlattenDesc>
|
||||
__host__ __device__ constexpr auto get(const Layout<Shape, FlattenDesc>& layout)
|
||||
template <index_t idx, typename Shape, typename UnrolledDesc>
|
||||
__host__ __device__ constexpr auto get(const Layout<Shape, UnrolledDesc>& layout)
|
||||
{
|
||||
const auto& shape = layout.GetShape();
|
||||
const auto new_shape = get<idx>(shape);
|
||||
@@ -427,5 +430,91 @@ __host__ __device__ constexpr const auto& shape(const LayoutType& layout)
|
||||
return layout.GetShape();
|
||||
}
|
||||
|
||||
// pad
|
||||
/**
|
||||
* \brief Pad layout shapes to be adjusted to tile lengths.
|
||||
*
|
||||
*
|
||||
* \param layout Layout to pad.
|
||||
* \param tile_lengths Tile lengths to align layout shape.
|
||||
* \return Padded layout.
|
||||
*/
|
||||
template <typename Shape, typename UnrolledDesc, typename TileLengths>
|
||||
__host__ __device__ constexpr auto pad(const Layout<Shape, UnrolledDesc>& layout,
|
||||
const TileLengths& tile_lengths)
|
||||
{
|
||||
auto& unrolled_desc = layout.GetUnrolledDescriptor();
|
||||
// Generate sequence with ones to mark that all dims will be padded
|
||||
constexpr auto do_pads_seq =
|
||||
generate_sequence_v2([](auto) { return Number<1>{}; }, Number<Shape::Size()>{});
|
||||
// Create descriptor with padding
|
||||
auto padded_desc =
|
||||
tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq);
|
||||
// Generate padded shape
|
||||
const auto padded_shape = generate_tuple(
|
||||
[&](auto i) { return padded_desc.GetLength(Number<i>{}); }, Number<TileLengths::Size()>{});
|
||||
// Create layout
|
||||
return Layout<decltype(padded_shape), decltype(padded_desc)>(padded_shape, padded_desc);
|
||||
}
|
||||
|
||||
// unmerge
|
||||
/**
|
||||
* \brief Unmerge selected dim in layout.
|
||||
*
|
||||
* \tparam Idx Index to dimension being unmerged.
|
||||
* \param layout Layout to pad.
|
||||
* \param new_lengths Dimensions into which the indicated dimension will be divided.
|
||||
* \param new_indexes Indexes to shuffle dims. Dims for unmerged dim should be nested.
|
||||
* \return Unmerged layout.
|
||||
*/
|
||||
template <index_t Idx, typename Shape, typename UnrolledDesc, typename NewLengths, typename NewIdxs>
|
||||
__host__ __device__ constexpr auto unmerge(const Layout<Shape, UnrolledDesc>& layout,
|
||||
const NewLengths& new_lengths,
|
||||
[[maybe_unused]] const NewIdxs& new_indexes)
|
||||
{
|
||||
const auto& layout_shape = shape(layout);
|
||||
auto& unrolled_desc = layout.GetUnrolledDescriptor();
|
||||
constexpr auto dims = Shape::Size();
|
||||
// Generate transforms
|
||||
const auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == Idx)
|
||||
{
|
||||
return make_unmerge_transform(new_lengths);
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_pass_through_transform(layout_shape.At(i));
|
||||
}
|
||||
},
|
||||
Number<dims>{});
|
||||
|
||||
constexpr auto lower_dims =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<dims>{});
|
||||
constexpr auto upper_dims = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(is_detected<is_tuple, tuple_element_t<i.value, NewIdxs>>::value)
|
||||
{
|
||||
constexpr auto idxs_tuple = tuple_element_t<i.value, NewIdxs>{};
|
||||
return to_sequence(idxs_tuple);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t index = tuple_element_t<i.value, NewIdxs>{};
|
||||
return Sequence<index>{};
|
||||
}
|
||||
},
|
||||
Number<dims>{});
|
||||
|
||||
const auto unmerged_desc =
|
||||
transform_tensor_descriptor(unrolled_desc, transforms, lower_dims, upper_dims);
|
||||
const auto unmerged_shape =
|
||||
generate_tuple([&](auto i) { return unmerged_desc.GetLength(Number<i>{}); },
|
||||
Number<decltype(unmerged_desc)::GetNumOfVisibleDimension()>{});
|
||||
|
||||
// Create layout
|
||||
return Layout<decltype(unmerged_shape), decltype(unmerged_desc)>(unmerged_shape, unmerged_desc);
|
||||
}
|
||||
|
||||
} // namespace wrapper
|
||||
} // namespace ck
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
#include "tensor_utils.hpp"
|
||||
#include "layout_utils.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
|
||||
@@ -44,8 +43,9 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple<Ts..
|
||||
* \brief Apply projection.
|
||||
*
|
||||
* \param base_tuple Tuple to apply projection.
|
||||
* \param projection Projection to remove selected dim from partitioning.
|
||||
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
|
||||
* \param projection Projection is used to remove selected dim from
|
||||
* partitioning. Use `slice(X)` to remove dimension, where X is dim
|
||||
* size. Use `Number<1>{}` to keep it.
|
||||
* \return Multi index after projection.
|
||||
*/
|
||||
template <typename MultiIndex, typename ProjectionTuple>
|
||||
@@ -73,7 +73,7 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
|
||||
}
|
||||
else
|
||||
{
|
||||
return base_tuple.At(i_num);
|
||||
return make_tuple(base_tuple.At(i_num));
|
||||
}
|
||||
},
|
||||
Number<MultiIndex::Size()>{});
|
||||
@@ -86,8 +86,9 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple,
|
||||
* \brief Calculate shape with dims from projection.
|
||||
*
|
||||
* \param shape Base tensor shape.
|
||||
* \param projection Projection to remove selected dim from partitioning.
|
||||
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
|
||||
* \param projection Projection is used to remove selected dim from
|
||||
* partitioning. Use `slice(X)` to remove dimension, where X is dim
|
||||
* size. Use `Number<1>{}` to keep it.
|
||||
* \return Shape with dims from projection
|
||||
*/
|
||||
template <typename... Ts, typename... Ps>
|
||||
@@ -119,22 +120,14 @@ __host__ __device__ constexpr auto CalculateShapeWithProjection(const Tuple<Ts..
|
||||
*
|
||||
* \param shape Base tensor shape.
|
||||
* \param tile_shape Tile shape.
|
||||
* \param projection Projection is used to remove selected dim from
|
||||
* partitioning. Use `slice(X)` to remove dimension, where X is dim
|
||||
* size. Use `Number<1>{}` to keep it.
|
||||
* \return Tuple with blocks number.
|
||||
*/
|
||||
template <typename... Ts, typename... Ls, typename... Ps>
|
||||
__host__ __device__ constexpr auto CalculateGridSize(const Tuple<Ts...>& shape,
|
||||
const Tuple<Ls...>& tile_shape,
|
||||
const Tuple<Ps...>& projection)
|
||||
const Tuple<Ls...>& tile_shape)
|
||||
{
|
||||
auto shape_with_projection = CalculateShapeWithProjection(shape, projection);
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return ck::math::integer_divide_ceil(size<i>(shape_with_projection),
|
||||
size<i>(tile_shape));
|
||||
},
|
||||
[&](auto i) { return ck::math::integer_divide_ceil(size<i>(shape), size<i>(tile_shape)); },
|
||||
Number<Tuple<Ls...>::Size()>{});
|
||||
}
|
||||
|
||||
@@ -155,6 +148,54 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs,
|
||||
return thread_idxs * partition_lengths_seq + old_offset_idxs;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Select dims to partition (skip if slice).
|
||||
*
|
||||
* \param block_idxs Input block indexes.
|
||||
* \return Partitioned dims.
|
||||
*/
|
||||
template <typename BlockIdxs>
|
||||
__host__ __device__ constexpr auto GetDimsToPartition([[maybe_unused]] const BlockIdxs& block_idxs)
|
||||
{
|
||||
const auto dims_to_partition = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(!is_detected<is_slice, tuple_element_t<i, BlockIdxs>>::value)
|
||||
{
|
||||
return Number<i>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return Tuple<>{};
|
||||
}
|
||||
},
|
||||
Number<BlockIdxs::Size()>{});
|
||||
// Remove empty tuples
|
||||
return UnrollNestedTuple<0, 1>(dims_to_partition);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Replace slices with zeros (Slice dims are not partitioned).
|
||||
*
|
||||
* \param block_idxs Input block indexes.
|
||||
* \return Parsed dims.
|
||||
*/
|
||||
template <typename BlockIdxs>
|
||||
__host__ __device__ constexpr auto ReplaceSlicesWithZeros(const BlockIdxs& block_idxs)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(!is_detected<is_slice, tuple_element_t<i, BlockIdxs>>::value)
|
||||
{
|
||||
return block_idxs.At(i);
|
||||
}
|
||||
else
|
||||
{
|
||||
return Number<0>{};
|
||||
}
|
||||
},
|
||||
Number<BlockIdxs::Size()>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Calculate default projection.
|
||||
*
|
||||
@@ -168,6 +209,31 @@ GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape)
|
||||
return generate_tuple([&](auto) { return Number<1>{}; }, Number<TileShape::Size()>{});
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Calculate thread multi index from 1d thread index.
|
||||
*
|
||||
* \param thread_layout Layout of threads (could not be nested).
|
||||
* \param thread_id Thread index represented as integer.
|
||||
* \return Multi index.
|
||||
*/
|
||||
template <typename ThreadShape, typename ThreadUnrolledDesc>
|
||||
__host__ __device__ constexpr auto CalculateThreadMultiIdx(
|
||||
[[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout,
|
||||
const index_t thread_id)
|
||||
{
|
||||
static_assert(ThreadUnrolledDesc::GetNumOfTransform() == 1,
|
||||
"Thread layout should not be transformed.");
|
||||
constexpr auto embed_transform = ThreadUnrolledDesc{}.GetTransforms().At(Number<0>{});
|
||||
constexpr auto shape = ThreadShape{};
|
||||
constexpr auto strides = embed_transform.coefficients_;
|
||||
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto num_i = Number<i>{};
|
||||
return (thread_id / strides.At(num_i)) % shape.At(num_i);
|
||||
},
|
||||
Number<ThreadShape::Size()>{});
|
||||
}
|
||||
} // namespace detail
|
||||
} // namespace
|
||||
|
||||
@@ -176,51 +242,62 @@ GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape)
|
||||
* is supported).
|
||||
*
|
||||
* \param tensor Tensor for partition.
|
||||
* \param thread_lengths Layout of threads (could not be nested).
|
||||
* \param thread_layout Layout of threads (could not be transformed).
|
||||
* \param thread_id Thread index represented as integer.
|
||||
* \param projection Projection is used to remove selected dim from
|
||||
* partitioning. Use `slice(X)` to remove dimension, where X is dim
|
||||
* size. Use `Number<1>{}` to keep it.
|
||||
* \return Partition tensor.
|
||||
*/
|
||||
template <typename TensorType, typename ThreadLengthsTuple, typename ProjectionTuple>
|
||||
template <typename TensorType,
|
||||
typename ThreadShape,
|
||||
typename ThreadUnrolledDesc,
|
||||
typename ProjectionTuple>
|
||||
__host__ __device__ constexpr auto
|
||||
make_local_partition(TensorType& tensor,
|
||||
[[maybe_unused]] const ThreadLengthsTuple& thread_lengths,
|
||||
[[maybe_unused]] const Layout<ThreadShape, ThreadUnrolledDesc>& thread_layout,
|
||||
const index_t thread_id,
|
||||
const ProjectionTuple& projection)
|
||||
{
|
||||
static_assert(!IsNestedTuple(ThreadLengthsTuple{}));
|
||||
static_assert(!IsNestedTuple(ThreadShape{}));
|
||||
// Calculate new partition shape
|
||||
const auto& tensor_shape = shape(tensor);
|
||||
// Calculate projected thread lengths
|
||||
constexpr auto projected_thread_lengths =
|
||||
detail::ApplyProjection(ThreadLengthsTuple{}, ProjectionTuple{});
|
||||
detail::ApplyProjection(ThreadShape{}, ProjectionTuple{});
|
||||
constexpr auto partition_shape =
|
||||
detail::CalculateLocalPartitionShape(decltype(tensor_shape){}, projected_thread_lengths);
|
||||
// Create Thread Cluster Descriptor
|
||||
constexpr auto partition_shape_seq =
|
||||
generate_sequence_v2([&](auto I) { return size<I>(partition_shape); },
|
||||
Number<decltype(partition_shape)::Size()>{});
|
||||
constexpr auto thread_lengths_seq =
|
||||
generate_sequence_v2([&](auto I) { return size<I>(ThreadLengthsTuple{}); },
|
||||
Number<ThreadLengthsTuple::Size()>{});
|
||||
constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq);
|
||||
// Calculate thread idxs and offsets
|
||||
const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
const auto thread_idxs = detail::CalculateThreadMultiIdx(thread_layout, thread_id);
|
||||
// Apply projection on thread idxs to remove not needed idxs
|
||||
const auto projected_thread_idxs = detail::ApplyProjection(thread_idxs, projection);
|
||||
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
|
||||
projected_thread_idxs, partition_shape_seq, tensor.GetMultiIdxOffsets());
|
||||
// Create new layout and tensor
|
||||
auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
|
||||
// Slice descriptor
|
||||
const auto transforms = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_slice_transform(partition_shape.At(i),
|
||||
offset_multi_idxs.At(i),
|
||||
partition_shape.At(i) + offset_multi_idxs.At(i));
|
||||
},
|
||||
Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
|
||||
const auto lower_upper_dims =
|
||||
generate_tuple([&](auto i) { return Sequence<i.value>{}; },
|
||||
Number<remove_reference_t<decltype(tensor_shape)>::Size()>{});
|
||||
auto sliced_desc =
|
||||
transform_tensor_descriptor(unrolled_desc, transforms, lower_upper_dims, lower_upper_dims);
|
||||
// Create layout
|
||||
const auto partition_layout =
|
||||
Layout<remove_reference_t<decltype(partition_shape)>, decltype(unrolled_desc)>(
|
||||
partition_shape, unrolled_desc);
|
||||
Layout<remove_reference_t<decltype(partition_shape)>, decltype(sliced_desc)>(
|
||||
partition_shape, sliced_desc);
|
||||
auto partition_tensor =
|
||||
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout);
|
||||
// Apply offsets
|
||||
partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
|
||||
return partition_tensor;
|
||||
}
|
||||
|
||||
@@ -233,12 +310,13 @@ make_local_partition(TensorType& tensor,
|
||||
* \param thread_id Thread index represented as integer.
|
||||
* \return Partition tensor.
|
||||
*/
|
||||
template <typename TensorType, typename ThreadLengthsTuple>
|
||||
__host__ __device__ constexpr auto make_local_partition(TensorType& tensor,
|
||||
const ThreadLengthsTuple& thread_lengths,
|
||||
const index_t thread_id)
|
||||
template <typename TensorType, typename ThreadShape, typename ThreadUnrolledDesc>
|
||||
__host__ __device__ constexpr auto
|
||||
make_local_partition(TensorType& tensor,
|
||||
const Layout<ThreadShape, ThreadUnrolledDesc>& thread_lengths,
|
||||
const index_t thread_id)
|
||||
{
|
||||
const auto projection = detail::GenerateDefaultProjection(ThreadLengthsTuple{});
|
||||
const auto projection = detail::GenerateDefaultProjection(ThreadShape{});
|
||||
return make_local_partition(tensor, thread_lengths, thread_id, projection);
|
||||
}
|
||||
|
||||
@@ -252,21 +330,24 @@ __host__ __device__ constexpr auto make_local_partition(TensorType& tensor,
|
||||
*
|
||||
* \param tensor Tensor for partition.
|
||||
* \param tile_shape Shapes of requested tile.
|
||||
* \param block_id Block index represented as integer.
|
||||
* \param projection Projection to remove selected dim from partitioning.
|
||||
* slice(X) to remove, where X is dim size, Number<1>{} to keep.
|
||||
* \param block_idxs Tuple of block indexes represented as integer. If slice,
|
||||
* then get whole dim.
|
||||
* \param projection Projection is used to remove selected dim from
|
||||
* partitioning. Use `slice(X)` to remove dimension, where X is dim
|
||||
* size. Use `Number<1>{}` to keep it.
|
||||
* \return Tile tensor.
|
||||
*/
|
||||
template <typename TensorType, typename BlockShapeTuple, typename ProjectionTuple>
|
||||
template <typename TensorType,
|
||||
typename BlockShapeTuple,
|
||||
typename BlockIdxs,
|
||||
typename ProjectionTuple>
|
||||
__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
|
||||
const BlockShapeTuple& tile_shape,
|
||||
const index_t block_id,
|
||||
const BlockIdxs& block_idxs,
|
||||
const ProjectionTuple& projection)
|
||||
{
|
||||
static_assert(!IsNestedTuple(BlockShapeTuple{}));
|
||||
|
||||
constexpr bool is_default_projection =
|
||||
is_same_v<ProjectionTuple, decltype(detail::GenerateDefaultProjection(BlockShapeTuple{}))>;
|
||||
static_assert(!IsNestedTuple(BlockIdxs{}));
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -274,49 +355,77 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
|
||||
|
||||
auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor();
|
||||
|
||||
// TODO: Enable block_2_tile_map partitioning for non-default projection.
|
||||
if constexpr(BlockShapeTuple::Size() == I2 && is_default_projection)
|
||||
constexpr auto projected_tile_shape =
|
||||
detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
|
||||
// Number of dims which are partitioned
|
||||
constexpr auto dims_to_partition = detail::GetDimsToPartition(BlockIdxs{});
|
||||
const auto parsed_block_idxs = detail::ReplaceSlicesWithZeros(block_idxs);
|
||||
if constexpr(decltype(dims_to_partition)::Size() == I2)
|
||||
{
|
||||
// Optimized version for 2d tile shape [MxK]
|
||||
const auto shape_with_projection_dims =
|
||||
detail::CalculateShapeWithProjection(shape(tensor), projection);
|
||||
// Set Value for M, N partition
|
||||
const auto M = shape_with_projection_dims.At(dims_to_partition.At(I0));
|
||||
const auto N = shape_with_projection_dims.At(dims_to_partition.At(I1));
|
||||
constexpr auto MPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I0));
|
||||
constexpr auto NPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I1));
|
||||
auto m_n_desc = make_naive_tensor_descriptor_packed(make_tuple(M, N));
|
||||
// Get 1D block id
|
||||
const auto grid_size = detail::CalculateGridSize(shape_with_projection_dims, tile_shape);
|
||||
const auto block_lengths_desc = make_naive_tensor_descriptor_packed(grid_size);
|
||||
const index_t block_id_1d = block_lengths_desc.CalculateOffset(parsed_block_idxs);
|
||||
// Optimized version for 2d tile shape [MxN]
|
||||
const auto block_2_tile_map =
|
||||
BlockToCTileMap_M00_N0_M01Adapt<BlockShapeTuple{}.At(I0),
|
||||
BlockShapeTuple{}.At(I1),
|
||||
remove_cvref_t<decltype(aligned_desc)>>(aligned_desc);
|
||||
BlockToCTileMap_M00_N0_M01Adapt<MPerBlock,
|
||||
NPerBlock,
|
||||
remove_cvref_t<decltype(m_n_desc)>>(m_n_desc);
|
||||
const auto block_work_idx =
|
||||
block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id));
|
||||
block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id_1d));
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * size<0>(tile_shape));
|
||||
const index_t k_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * size<1>(tile_shape));
|
||||
const auto offset_multi_idxs =
|
||||
make_tuple(m_block_data_idx_on_grid, k_block_data_idx_on_grid);
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
|
||||
// Apply 0 for non partitioned dims
|
||||
const auto offset_multi_idxs = generate_tuple(
|
||||
[&](auto i) {
|
||||
if constexpr(i == dims_to_partition.At(I0))
|
||||
{
|
||||
return m_block_data_idx_on_grid;
|
||||
}
|
||||
else if constexpr(i == dims_to_partition.At(I1))
|
||||
{
|
||||
return n_block_data_idx_on_grid;
|
||||
}
|
||||
else
|
||||
{
|
||||
return Number<0>{};
|
||||
}
|
||||
},
|
||||
Number<BlockShapeTuple::Size()>{});
|
||||
const auto projected_offset_multi_idxs =
|
||||
detail::ApplyProjection(offset_multi_idxs, projection);
|
||||
// Create new layout and tensor
|
||||
const auto tile_layout =
|
||||
Layout<remove_reference_t<decltype(tile_shape)>, decltype(aligned_desc)>(tile_shape,
|
||||
aligned_desc);
|
||||
Layout<remove_reference_t<decltype(projected_tile_shape)>, decltype(aligned_desc)>(
|
||||
projected_tile_shape, aligned_desc);
|
||||
auto tile_tensor =
|
||||
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
|
||||
// Apply offsets
|
||||
tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs));
|
||||
tile_tensor.SetMultiIdxOffset(to_multi_index(projected_offset_multi_idxs));
|
||||
return tile_tensor;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Calculate offsets
|
||||
// Sequence with data to process per block
|
||||
constexpr auto projected_tile_shape =
|
||||
detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
|
||||
using ProjectedTileShapeTuple = decltype(projected_tile_shape);
|
||||
constexpr auto projected_tile_shape_seq =
|
||||
generate_sequence_v2([](auto I) { return ProjectedTileShapeTuple{}.At(I); },
|
||||
Number<ProjectedTileShapeTuple::Size()>{});
|
||||
// Tuple with number of blocks
|
||||
const auto block_lengths = detail::CalculateGridSize(shape(tensor), tile_shape, projection);
|
||||
const auto block_cluster_desc_ = make_cluster_descriptor(block_lengths);
|
||||
const auto block_idxs =
|
||||
block_cluster_desc_.CalculateBottomIndex(make_multi_index(block_id));
|
||||
const auto projected_block_idxs = detail::ApplyProjection(block_idxs, projection);
|
||||
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
|
||||
const auto projected_block_idxs =
|
||||
to_multi_index(detail::ApplyProjection(parsed_block_idxs, projection));
|
||||
const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
|
||||
projected_block_idxs, projected_tile_shape_seq, tensor.GetMultiIdxOffsets());
|
||||
// Create new layout and tensor
|
||||
const auto tile_layout =
|
||||
@@ -338,52 +447,17 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
|
||||
*
|
||||
* \param tensor Tensor for partition.
|
||||
* \param tile_shape Shapes of requested tile.
|
||||
* \param block_id Block index represented as integer.
|
||||
* \param block_idxs Tuple of block indexes represented as integer. If slice,
|
||||
* then get whole dim.
|
||||
* \return Tile tensor.
|
||||
*/
|
||||
template <typename TensorType, typename BlockShapeTuple>
|
||||
__host__ __device__ constexpr auto
|
||||
make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id)
|
||||
template <typename TensorType, typename BlockShapeTuple, typename BlockIdxs>
|
||||
__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor,
|
||||
const BlockShapeTuple& tile_shape,
|
||||
const BlockIdxs& block_idxs)
|
||||
{
|
||||
const auto projection = detail::GenerateDefaultProjection(BlockShapeTuple{});
|
||||
return make_local_tile(tensor, tile_shape, block_id, projection);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Pad tensor shapes to be adjusted to tile lengths.
|
||||
*
|
||||
*
|
||||
* \param tensor Tensor to pad.
|
||||
* \param tile_lengths Tile lengths to align tensor shape.
|
||||
* \return Padded tensor.
|
||||
*/
|
||||
template <typename TensorType, typename TileLengths>
|
||||
__host__ __device__ constexpr auto pad(const TensorType& tensor, const TileLengths& tile_lengths)
|
||||
{
|
||||
const auto& tensor_shape = shape(tensor);
|
||||
using TensorShapeType = remove_reference_t<decltype(tensor_shape)>;
|
||||
auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor();
|
||||
// Generate sequence with ones to mark that all dims will be padded
|
||||
constexpr auto do_pads_seq =
|
||||
generate_sequence_v2([](auto) { return Number<1>{}; }, Number<TensorShapeType::Size()>{});
|
||||
// Create descriptor with padding
|
||||
auto padded_desc =
|
||||
tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq);
|
||||
// Generate padded shape
|
||||
const auto padded_shape = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& dim = size<i>(tensor_shape);
|
||||
const auto& tile_length = size<i>(tile_lengths);
|
||||
return ck::math::integer_divide_ceil(dim, tile_length) * tile_length;
|
||||
},
|
||||
Number<TileLengths::Size()>{});
|
||||
// Create layout and tensor
|
||||
const auto padded_layout =
|
||||
Layout<decltype(padded_shape), decltype(padded_desc)>(padded_shape, padded_desc);
|
||||
auto partition_tensor =
|
||||
make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), padded_layout);
|
||||
partition_tensor.SetMultiIdxOffset(tensor.GetMultiIdxOffsets());
|
||||
return partition_tensor;
|
||||
return make_local_tile(tensor, tile_shape, block_idxs, projection);
|
||||
}
|
||||
|
||||
} // namespace wrapper
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
add_gtest_executable(test_layout test_layout.cpp)
|
||||
target_link_libraries(test_layout PRIVATE utility)
|
||||
add_gtest_executable(test_tensor test_tensor.cpp)
|
||||
target_link_libraries(test_tensor PRIVATE utility)
|
||||
add_gtest_executable(test_copy test_copy.cpp)
|
||||
target_link_libraries(test_copy PRIVATE utility)
|
||||
add_gtest_executable(test_partition test_partition.cpp)
|
||||
target_link_libraries(test_partition PRIVATE utility)
|
||||
add_custom_target(test_wrapper)
|
||||
|
||||
add_gtest_executable(test_wrapper_layout test_wrapper_layout.cpp)
|
||||
target_link_libraries(test_wrapper_layout PRIVATE utility)
|
||||
add_dependencies(test_wrapper test_wrapper_layout)
|
||||
add_gtest_executable(test_wrapper_tensor test_wrapper_tensor.cpp)
|
||||
target_link_libraries(test_wrapper_tensor PRIVATE utility)
|
||||
add_dependencies(test_wrapper test_wrapper_tensor)
|
||||
add_gtest_executable(test_wrapper_copy test_wrapper_copy.cpp)
|
||||
target_link_libraries(test_wrapper_copy PRIVATE utility)
|
||||
add_dependencies(test_wrapper test_wrapper_copy)
|
||||
add_gtest_executable(test_wrapper_partition test_wrapper_partition.cpp)
|
||||
target_link_libraries(test_wrapper_partition PRIVATE utility)
|
||||
add_dependencies(test_wrapper test_wrapper_partition)
|
||||
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR
|
||||
GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR
|
||||
GPU_TARGETS MATCHES "gfx942")
|
||||
add_gtest_executable(test_gemm test_gemm.cpp)
|
||||
target_link_libraries(test_gemm PRIVATE utility)
|
||||
add_gtest_executable(test_wrapper_gemm test_wrapper_gemm.cpp)
|
||||
target_link_libraries(test_wrapper_gemm PRIVATE utility)
|
||||
add_dependencies(test_wrapper test_wrapper_gemm)
|
||||
endif()
|
||||
|
||||
@@ -1,257 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <numeric>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/wrapper/layout.hpp"
|
||||
#include "ck/wrapper/tensor.hpp"
|
||||
#include "ck/wrapper/operations/copy.hpp"
|
||||
#include "ck/wrapper/operations/gemm.hpp"
|
||||
|
||||
template <typename DataType>
|
||||
void CheckResult(const std::vector<DataType>& a_data,
|
||||
const std::vector<DataType>& b_data,
|
||||
std::vector<DataType>& c_m_n_device_result,
|
||||
const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K)
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<DataType, DataType, DataType, float, PassThrough, PassThrough, PassThrough>;
|
||||
|
||||
Tensor<DataType> a_m_k(HostTensorDescriptor({M, K}));
|
||||
Tensor<DataType> b_k_n(HostTensorDescriptor({K, N}, {1, K}));
|
||||
Tensor<DataType> c_m_n_host_result(HostTensorDescriptor({M, N}));
|
||||
|
||||
a_m_k.mData = a_data;
|
||||
b_k_n.mData = b_data;
|
||||
|
||||
auto ref_op = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_op.MakeInvoker();
|
||||
auto ref_argument = ref_op.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
EXPECT_TRUE(ck::utils::check_err(c_m_n_device_result, c_m_n_host_result.mData));
|
||||
}
|
||||
|
||||
template <typename DataType,
|
||||
typename GemmTraits,
|
||||
ck::index_t scalar_per_vector,
|
||||
typename BlockShape,
|
||||
typename ThreadLayoutShape>
|
||||
__global__ void DeviceGemm(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const BlockShape tile_shape,
|
||||
const ThreadLayoutShape thread_layout)
|
||||
{
|
||||
constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape);
|
||||
constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape);
|
||||
constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape);
|
||||
|
||||
const auto a_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1));
|
||||
const auto b_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1));
|
||||
const auto c_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1));
|
||||
|
||||
constexpr auto a_tile_layout = ck::wrapper::make_layout(
|
||||
ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
|
||||
constexpr auto b_tile_layout = ck::wrapper::make_layout(
|
||||
ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{}));
|
||||
constexpr auto c_tile_layout = ck::wrapper::make_layout(
|
||||
ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{}));
|
||||
|
||||
auto a_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<const DataType*>(p_a), a_global_layout);
|
||||
auto b_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<const DataType*>(p_b), b_global_layout);
|
||||
auto c_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<DataType*>(p_c), c_global_layout);
|
||||
|
||||
auto a_padded_global_tensor = ck::wrapper::pad(a_global_tensor, shape(a_tile_layout));
|
||||
auto b_padded_global_tensor = ck::wrapper::pad(b_global_tensor, shape(b_tile_layout));
|
||||
auto c_padded_global_tensor = ck::wrapper::pad(c_global_tensor, shape(c_tile_layout));
|
||||
|
||||
__shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)];
|
||||
__shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)];
|
||||
|
||||
auto a_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
|
||||
static_cast<DataType*>(lds_a), a_tile_layout);
|
||||
auto b_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
|
||||
static_cast<DataType*>(lds_b), b_tile_layout);
|
||||
|
||||
const ck::index_t block_idx = static_cast<ck::index_t>(blockIdx.x);
|
||||
using DimAccessOrder = ck::Tuple<ck::Number<0>, ck::Number<1>>;
|
||||
constexpr ck::index_t vector_dim = 1;
|
||||
|
||||
auto c_global_local_tile = ck::wrapper::make_local_tile(
|
||||
c_padded_global_tensor,
|
||||
tile_shape,
|
||||
block_idx,
|
||||
make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock)));
|
||||
auto c_global_local_partition =
|
||||
ck::wrapper::make_blockwise_gemm_xdl_c_local_partition<DataType,
|
||||
decltype(a_tile_layout),
|
||||
decltype(b_tile_layout),
|
||||
ck::wrapper::size(thread_layout),
|
||||
GemmTraits>(c_global_local_tile);
|
||||
auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr<DataType,
|
||||
decltype(a_tile_layout),
|
||||
decltype(b_tile_layout),
|
||||
ck::wrapper::size(thread_layout),
|
||||
GemmTraits>();
|
||||
ck::wrapper::clear(c_vgpr_reg);
|
||||
|
||||
const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock);
|
||||
ck::index_t i = 0;
|
||||
do
|
||||
{
|
||||
const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock);
|
||||
auto a_padded_global_tensor_k_slice = a_padded_global_tensor(ck::wrapper::slice(), k_slice);
|
||||
auto b_padded_global_tensor_k_slice = b_padded_global_tensor(ck::wrapper::slice(), k_slice);
|
||||
auto a_global_local_tile = ck::wrapper::make_local_tile(
|
||||
a_padded_global_tensor_k_slice,
|
||||
tile_shape,
|
||||
block_idx,
|
||||
make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}));
|
||||
auto b_global_local_tile = ck::wrapper::make_local_tile(
|
||||
b_padded_global_tensor_k_slice,
|
||||
tile_shape,
|
||||
block_idx,
|
||||
make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}));
|
||||
|
||||
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
a_global_local_tile, a_lds_tensor, thread_layout);
|
||||
ck::wrapper::blockwise_copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
b_global_local_tile, b_lds_tensor, thread_layout);
|
||||
ck::block_sync_lds();
|
||||
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
|
||||
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
|
||||
|
||||
++i;
|
||||
} while(i < num_loop);
|
||||
|
||||
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
|
||||
}
|
||||
|
||||
template <typename DataType,
|
||||
typename GemmTraits,
|
||||
ck::index_t scalar_per_vector,
|
||||
typename BlockShape,
|
||||
typename ThreadLayoutShape>
|
||||
void PerformGemm(const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const BlockShape& tile_shape,
|
||||
const ThreadLayoutShape& thread_layout)
|
||||
{
|
||||
// Global memory buffers
|
||||
DeviceMem a_mem(M * K * sizeof(DataType));
|
||||
DeviceMem b_mem(K * N * sizeof(DataType));
|
||||
DeviceMem c_mem(M * N * sizeof(DataType));
|
||||
|
||||
std::vector<DataType> a_data(M * K);
|
||||
std::vector<DataType> b_data(K * N);
|
||||
ck::utils::FillUniformDistributionIntegerValue<DataType>{-5.f, 5.f}(a_data);
|
||||
ck::utils::FillUniformDistributionIntegerValue<DataType>{-5.f, 5.f}(b_data);
|
||||
|
||||
a_mem.ToDevice(a_data.data());
|
||||
b_mem.ToDevice(b_data.data());
|
||||
c_mem.SetZero();
|
||||
|
||||
const ck::index_t grid_size =
|
||||
ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape)) *
|
||||
ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape));
|
||||
|
||||
const auto kernel =
|
||||
DeviceGemm<DataType, GemmTraits, scalar_per_vector, BlockShape, ThreadLayoutShape>;
|
||||
launch_and_time_kernel(StreamConfig{nullptr},
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(ck::wrapper::size(thread_layout)),
|
||||
0,
|
||||
a_mem.GetDeviceBuffer(),
|
||||
b_mem.GetDeviceBuffer(),
|
||||
c_mem.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
tile_shape,
|
||||
thread_layout);
|
||||
|
||||
std::vector<DataType> c_data(M * N);
|
||||
c_mem.FromDevice(c_data.data());
|
||||
|
||||
CheckResult<DataType>(a_data, b_data, c_data, M, N, K);
|
||||
}
|
||||
|
||||
TEST(TestGemm, Float)
|
||||
{
|
||||
using DataType = float;
|
||||
const auto thread_layout = ck::make_tuple(ck::Number<16>{}, ck::Number<16>{});
|
||||
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 4>(
|
||||
512, 512, 128, tile_shape, thread_layout);
|
||||
// Irregular case
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 1>(
|
||||
129, 129, 67, tile_shape, thread_layout);
|
||||
}
|
||||
|
||||
TEST(TestGemm, Int8)
|
||||
{
|
||||
using DataType = int8_t;
|
||||
const auto thread_layout = ck::make_tuple(ck::Number<64>{}, ck::Number<4>{});
|
||||
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1, 16>(
|
||||
512, 512, 128, tile_shape, thread_layout);
|
||||
// Irregular case
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1, 1>(
|
||||
129, 129, 67, tile_shape, thread_layout);
|
||||
}
|
||||
|
||||
TEST(TestGemm, Half)
|
||||
{
|
||||
using DataType = ck::half_t;
|
||||
const auto thread_layout = ck::make_tuple(ck::Number<32>{}, ck::Number<8>{});
|
||||
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 8>(
|
||||
512, 512, 128, tile_shape, thread_layout);
|
||||
// Irregular case
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 1>(
|
||||
129, 129, 67, tile_shape, thread_layout);
|
||||
}
|
||||
|
||||
TEST(TestGemm, Float_2x4_4x2_XdlPerWave)
|
||||
{
|
||||
using DataType = float;
|
||||
const auto thread_layout_4x2_xdl_per_wave = ck::make_tuple(ck::Number<16>{}, ck::Number<8>{});
|
||||
const auto thread_layout_2x4_xdl_per_wave = ck::make_tuple(ck::Number<8>{}, ck::Number<16>{});
|
||||
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1, 4>(
|
||||
512, 512, 128, tile_shape, thread_layout_4x2_xdl_per_wave);
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1, 4>(
|
||||
512, 512, 128, tile_shape, thread_layout_2x4_xdl_per_wave);
|
||||
}
|
||||
@@ -20,23 +20,25 @@
|
||||
template <typename InputTensor,
|
||||
typename OutputTensor,
|
||||
typename BlockShape,
|
||||
typename ThreadLayoutShape,
|
||||
typename ThreadLayout,
|
||||
bool UseOptimizedCopy>
|
||||
__global__ void TestCopyDevice(const InputTensor input_tensor,
|
||||
OutputTensor output_tensor,
|
||||
const BlockShape tile_shape,
|
||||
const ThreadLayoutShape thread_layout)
|
||||
const ThreadLayout thread_layout)
|
||||
{
|
||||
__shared__ ck::index_t p_shared[ck::wrapper::size(tile_shape)];
|
||||
const auto tensor_lds = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
|
||||
p_shared, ck::wrapper::make_layout(tile_shape));
|
||||
|
||||
const auto block_idx = static_cast<ck::index_t>(blockIdx.x);
|
||||
const auto block_idxs =
|
||||
ck::make_tuple(static_cast<ck::index_t>(blockIdx.x), static_cast<ck::index_t>(blockIdx.y));
|
||||
|
||||
// Get local tiles for global memory
|
||||
const auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idx);
|
||||
const auto input_local_tile =
|
||||
ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idxs);
|
||||
const auto output_local_tile =
|
||||
ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idx);
|
||||
ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idxs);
|
||||
|
||||
// Get partition per thread
|
||||
const auto input_local_partition =
|
||||
@@ -49,7 +51,7 @@ __global__ void TestCopyDevice(const InputTensor input_tensor,
|
||||
// Allocate VGPR
|
||||
auto tensor_vgpr =
|
||||
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, ck::index_t>(
|
||||
layout(lds_local_partition));
|
||||
ck::wrapper::make_layout(shape(lds_local_partition)));
|
||||
|
||||
// Perform copy
|
||||
if constexpr(UseOptimizedCopy)
|
||||
@@ -99,11 +101,14 @@ void PerformCopyGlobalToGlobalViaLDS()
|
||||
auto output_tensor_global = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<ck::index_t*>(out_buf.GetDeviceBuffer()), layout);
|
||||
|
||||
const auto thread_layout = ck::make_tuple(ck::Number<1>{}, ck::Number<32>{});
|
||||
const auto tile_shape = ck::make_tuple(ck::Number<4>{}, ck::Number<64>{});
|
||||
const auto thread_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(ck::Number<1>{}, ck::Number<32>{}));
|
||||
const auto tile_shape = ck::make_tuple(ck::Number<4>{}, ck::Number<64>{});
|
||||
|
||||
const ck::index_t grid_size = ck::math::integer_divide_ceil(
|
||||
ck::wrapper::size(input_tensor_global), ck::wrapper::size(tile_shape));
|
||||
const ck::index_t grid_size_x = ck::math::integer_divide_ceil(
|
||||
ck::wrapper::size<0>(input_tensor_global), ck::wrapper::size<0>(tile_shape));
|
||||
const ck::index_t grid_size_y = ck::math::integer_divide_ceil(
|
||||
ck::wrapper::size<1>(input_tensor_global), ck::wrapper::size<1>(tile_shape));
|
||||
|
||||
const auto kernel = TestCopyDevice<decltype(input_tensor_global),
|
||||
decltype(output_tensor_global),
|
||||
@@ -112,7 +117,7 @@ void PerformCopyGlobalToGlobalViaLDS()
|
||||
UseOptimizedCopy>;
|
||||
launch_and_time_kernel(StreamConfig{},
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(grid_size_x, grid_size_y, 1),
|
||||
dim3(ck::wrapper::size(thread_layout)),
|
||||
0,
|
||||
input_tensor_global,
|
||||
376
test/wrapper/test_wrapper_gemm.cpp
Normal file
376
test/wrapper/test_wrapper_gemm.cpp
Normal file
@@ -0,0 +1,376 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <numeric>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/wrapper/layout.hpp"
|
||||
#include "ck/wrapper/tensor.hpp"
|
||||
#include "ck/wrapper/operations/copy.hpp"
|
||||
#include "ck/wrapper/operations/gemm.hpp"
|
||||
#include "ck/wrapper/utils/kernel_utils.hpp"
|
||||
|
||||
template <typename DataType>
|
||||
void CheckResult(const std::vector<DataType>& a_data,
|
||||
const std::vector<DataType>& b_data,
|
||||
std::vector<DataType>& c_m_n_device_result,
|
||||
const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K)
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<DataType, DataType, DataType, float, PassThrough, PassThrough, PassThrough>;
|
||||
|
||||
Tensor<DataType> a_m_k(HostTensorDescriptor({M, K}));
|
||||
Tensor<DataType> b_k_n(HostTensorDescriptor({K, N}, {1, K}));
|
||||
Tensor<DataType> c_m_n_host_result(HostTensorDescriptor({M, N}));
|
||||
|
||||
a_m_k.mData = a_data;
|
||||
b_k_n.mData = b_data;
|
||||
|
||||
auto ref_op = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_op.MakeInvoker();
|
||||
auto ref_argument = ref_op.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
EXPECT_TRUE(ck::utils::check_err(c_m_n_device_result, c_m_n_host_result.mData));
|
||||
}
|
||||
|
||||
template <bool DoPad, typename Layout, typename PaddingDims>
|
||||
__device__ auto ApplyPadding(const Layout& layout, const PaddingDims& padding_dims)
|
||||
{
|
||||
if constexpr(DoPad)
|
||||
{
|
||||
return ck::wrapper::pad(layout, padding_dims);
|
||||
}
|
||||
else
|
||||
{
|
||||
return layout;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType,
|
||||
typename GemmTraits,
|
||||
ck::index_t scalar_per_vector,
|
||||
typename BlockShape,
|
||||
typename ThreadLayout,
|
||||
bool DoPadding>
|
||||
__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const BlockShape tile_shape,
|
||||
const ThreadLayout thread_layout)
|
||||
{
|
||||
constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape);
|
||||
constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape);
|
||||
constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape);
|
||||
constexpr auto K1 = GemmTraits::K1;
|
||||
constexpr auto K0PerBlock = KPerBlock / K1;
|
||||
const auto K0 = ck::math::integer_divide_ceil(K, K1);
|
||||
|
||||
const auto tile_shape_k0_m_n_k1 = ck::make_tuple(K0PerBlock, MPerBlock, NPerBlock, K1);
|
||||
|
||||
const auto a_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1));
|
||||
const auto b_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1));
|
||||
const auto c_global_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1));
|
||||
|
||||
auto a_padded_global_layout =
|
||||
ApplyPadding<DoPadding>(a_global_layout, ck::make_tuple(MPerBlock, KPerBlock));
|
||||
auto b_padded_global_layout =
|
||||
ApplyPadding<DoPadding>(b_global_layout, ck::make_tuple(NPerBlock, KPerBlock));
|
||||
auto c_padded_global_layout =
|
||||
ApplyPadding<DoPadding>(c_global_layout, ck::make_tuple(MPerBlock, NPerBlock));
|
||||
|
||||
// Reshape from M,K to K0,M,K1
|
||||
const auto reshaped_dims_idxs =
|
||||
ck::make_tuple(ck::Number<1>{}, ck::make_tuple(ck::Number<0>{}, ck::Number<2>{}));
|
||||
auto a_padded_unmerged_global_layout =
|
||||
ck::wrapper::unmerge<1>(a_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs);
|
||||
auto b_padded_unmerged_global_layout =
|
||||
ck::wrapper::unmerge<1>(b_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs);
|
||||
|
||||
auto a_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<const DataType*>(p_a), a_padded_unmerged_global_layout);
|
||||
auto b_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<const DataType*>(p_b), b_padded_unmerged_global_layout);
|
||||
auto c_global_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Global>(
|
||||
static_cast<DataType*>(p_c), c_padded_global_layout);
|
||||
|
||||
// Add extra M and N
|
||||
constexpr auto a_tile_layout = ck::wrapper::make_layout(
|
||||
ck::make_tuple(K0PerBlock, MPerBlock, K1),
|
||||
ck::make_tuple((MPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{}));
|
||||
constexpr auto b_tile_layout = ck::wrapper::make_layout(
|
||||
ck::make_tuple(K0PerBlock, NPerBlock, K1),
|
||||
ck::make_tuple((NPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{}));
|
||||
|
||||
__shared__ DataType lds_a[ck::wrapper::size(a_tile_layout) + NPerBlock];
|
||||
__shared__ DataType lds_b[ck::wrapper::size(b_tile_layout) + NPerBlock];
|
||||
|
||||
auto a_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
|
||||
static_cast<DataType*>(lds_a), a_tile_layout);
|
||||
auto b_lds_tensor = ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Lds>(
|
||||
static_cast<DataType*>(lds_b), b_tile_layout);
|
||||
|
||||
const auto block_idxs = ck::make_tuple(ck::wrapper::slice(),
|
||||
static_cast<ck::index_t>(blockIdx.x),
|
||||
static_cast<ck::index_t>(blockIdx.y),
|
||||
ck::wrapper::slice());
|
||||
using DimAccessOrder = ck::Tuple<ck::Number<1>, ck::Number<0>, ck::Number<2>>;
|
||||
constexpr ck::index_t vector_dim = 2;
|
||||
|
||||
auto c_global_local_tile =
|
||||
ck::wrapper::make_local_tile(c_global_tensor,
|
||||
tile_shape_k0_m_n_k1,
|
||||
block_idxs,
|
||||
make_tuple(ck::wrapper::slice(K0PerBlock),
|
||||
ck::Number<1>{},
|
||||
ck::Number<1>{},
|
||||
ck::wrapper::slice(K1)));
|
||||
auto c_global_local_partition =
|
||||
ck::wrapper::make_blockwise_gemm_xdl_c_local_partition<DataType,
|
||||
decltype(a_tile_layout),
|
||||
decltype(b_tile_layout),
|
||||
ck::wrapper::size(thread_layout),
|
||||
GemmTraits>(c_global_local_tile);
|
||||
auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr<DataType,
|
||||
decltype(a_tile_layout),
|
||||
decltype(b_tile_layout),
|
||||
ck::wrapper::size(thread_layout),
|
||||
GemmTraits>();
|
||||
ck::wrapper::clear(c_vgpr_reg);
|
||||
|
||||
auto a_lds_tensor_local_partition =
|
||||
ck::wrapper::make_local_partition(a_lds_tensor, thread_layout, threadIdx.x);
|
||||
auto b_lds_tensor_local_partition =
|
||||
ck::wrapper::make_local_partition(b_lds_tensor, thread_layout, threadIdx.x);
|
||||
|
||||
auto make_global_partition = [&](auto tensor, auto projection, ck::index_t i) {
|
||||
const auto k_slice =
|
||||
ck::make_tuple(ck::wrapper::slice(i * K0PerBlock, (i + 1) * K0PerBlock),
|
||||
ck::wrapper::slice(),
|
||||
ck::wrapper::slice());
|
||||
auto local_tile = ck::wrapper::make_local_tile(
|
||||
tensor(k_slice), tile_shape_k0_m_n_k1, block_idxs, projection);
|
||||
return ck::wrapper::make_local_partition(local_tile, thread_layout, threadIdx.x);
|
||||
};
|
||||
|
||||
auto a_global_local_partition = make_global_partition(
|
||||
a_global_tensor,
|
||||
make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}),
|
||||
0);
|
||||
auto b_global_local_partition = make_global_partition(
|
||||
b_global_tensor,
|
||||
make_tuple(ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}),
|
||||
0);
|
||||
|
||||
// (row-major vgpr layout)
|
||||
auto a_vgpr_tensor =
|
||||
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, DataType>(
|
||||
ck::wrapper::make_layout(
|
||||
shape(a_global_local_partition),
|
||||
ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) *
|
||||
ck::wrapper::size<2>(a_global_local_partition),
|
||||
ck::wrapper::size<2>(a_global_local_partition),
|
||||
ck::Number<1>{})));
|
||||
auto b_vgpr_tensor =
|
||||
ck::wrapper::make_register_tensor<ck::wrapper::MemoryTypeEnum::Vgpr, DataType>(
|
||||
ck::wrapper::make_layout(
|
||||
shape(b_global_local_partition),
|
||||
ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) *
|
||||
ck::wrapper::size<2>(a_global_local_partition),
|
||||
ck::wrapper::size<2>(a_global_local_partition),
|
||||
ck::Number<1>{})));
|
||||
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(a_global_local_partition,
|
||||
a_vgpr_tensor);
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(b_global_local_partition,
|
||||
b_vgpr_tensor);
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(a_vgpr_tensor,
|
||||
a_lds_tensor_local_partition);
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(b_vgpr_tensor,
|
||||
b_lds_tensor_local_partition);
|
||||
|
||||
const ck::index_t num_loop =
|
||||
__builtin_amdgcn_readfirstlane(ck::math::integer_divide_ceil(K, KPerBlock));
|
||||
if(num_loop > 1)
|
||||
{
|
||||
ck::index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto a_global_local_partition_i = make_global_partition(
|
||||
a_global_tensor,
|
||||
make_tuple(
|
||||
ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}),
|
||||
i + 1);
|
||||
auto b_global_local_partition_i = make_global_partition(
|
||||
b_global_tensor,
|
||||
make_tuple(
|
||||
ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}),
|
||||
i + 1);
|
||||
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
a_global_local_partition_i, a_vgpr_tensor);
|
||||
|
||||
ck::block_sync_lds();
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
b_global_local_partition_i, b_vgpr_tensor);
|
||||
|
||||
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
|
||||
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
|
||||
|
||||
ck::block_sync_lds();
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
a_vgpr_tensor, a_lds_tensor_local_partition);
|
||||
ck::wrapper::copy<DimAccessOrder, vector_dim, scalar_per_vector>(
|
||||
b_vgpr_tensor, b_lds_tensor_local_partition);
|
||||
|
||||
++i;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
ck::block_sync_lds();
|
||||
ck::wrapper::blockwise_gemm_xdl<DataType, ck::wrapper::size(thread_layout), GemmTraits>(
|
||||
a_lds_tensor, b_lds_tensor, c_vgpr_reg);
|
||||
|
||||
ck::wrapper::copy(c_vgpr_reg, c_global_local_partition);
|
||||
}
|
||||
|
||||
template <typename DataType,
|
||||
typename GemmTraits,
|
||||
ck::index_t scalar_per_vector,
|
||||
bool DoPadding,
|
||||
typename BlockShape,
|
||||
typename ThreadLayout>
|
||||
void PerformGemm(const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const BlockShape& tile_shape,
|
||||
const ThreadLayout& thread_layout)
|
||||
{
|
||||
// Global memory buffers
|
||||
DeviceMem a_mem(M * K * sizeof(DataType));
|
||||
DeviceMem b_mem(K * N * sizeof(DataType));
|
||||
DeviceMem c_mem(M * N * sizeof(DataType));
|
||||
|
||||
std::vector<DataType> a_data(M * K);
|
||||
std::vector<DataType> b_data(K * N);
|
||||
ck::utils::FillUniformDistributionIntegerValue<DataType>{-5.f, 5.f}(a_data);
|
||||
ck::utils::FillUniformDistributionIntegerValue<DataType>{-5.f, 5.f}(b_data);
|
||||
|
||||
a_mem.ToDevice(a_data.data());
|
||||
b_mem.ToDevice(b_data.data());
|
||||
c_mem.SetZero();
|
||||
|
||||
const ck::index_t grid_size_x =
|
||||
ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape));
|
||||
const ck::index_t grid_size_y =
|
||||
ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape));
|
||||
|
||||
const auto kernel =
|
||||
DeviceGemm<DataType, GemmTraits, scalar_per_vector, BlockShape, ThreadLayout, DoPadding>;
|
||||
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
|
||||
kernel,
|
||||
dim3(grid_size_x, grid_size_y, 1),
|
||||
dim3(ck::wrapper::size(thread_layout)),
|
||||
0,
|
||||
a_mem.GetDeviceBuffer(),
|
||||
b_mem.GetDeviceBuffer(),
|
||||
c_mem.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
tile_shape,
|
||||
thread_layout);
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
std::vector<DataType> c_data(M * N);
|
||||
c_mem.FromDevice(c_data.data());
|
||||
CheckResult<DataType>(a_data, b_data, c_data, M, N, K);
|
||||
}
|
||||
|
||||
TEST(TestGemm, Float)
|
||||
{
|
||||
using DataType = float;
|
||||
// (dim1, dim2, dim0 thread layout)
|
||||
const auto thread_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
|
||||
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
|
||||
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<16>{});
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 4, false>(
|
||||
512, 512, 128, tile_shape, thread_layout);
|
||||
// Irregular case
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1, 1, true>(
|
||||
129, 129, 67, tile_shape, thread_layout);
|
||||
}
|
||||
|
||||
TEST(TestGemm, Int8)
|
||||
{
|
||||
using DataType = int8_t;
|
||||
const auto thread_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
|
||||
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
|
||||
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{});
|
||||
PerformGemm<DataType,
|
||||
ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1,
|
||||
16,
|
||||
false>(512, 512, 128, tile_shape, thread_layout);
|
||||
// Irregular case
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1, 1, true>(
|
||||
129, 129, 67, tile_shape, thread_layout);
|
||||
}
|
||||
|
||||
TEST(TestGemm, Half)
|
||||
{
|
||||
using DataType = ck::half_t;
|
||||
const auto thread_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
|
||||
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
|
||||
const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<32>{});
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 8, false>(
|
||||
512, 512, 128, tile_shape, thread_layout);
|
||||
// Irregular case
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1, 1, true>(
|
||||
129, 129, 67, tile_shape, thread_layout);
|
||||
}
|
||||
|
||||
TEST(TestGemm, Float_2x4_4x2_XdlPerWave)
|
||||
{
|
||||
using DataType = float;
|
||||
const auto thread_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}),
|
||||
ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}));
|
||||
const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<16>{});
|
||||
PerformGemm<DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1, 4, false>(
|
||||
512, 512, 128, tile_shape, thread_layout);
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
@@ -29,8 +29,11 @@ TEST(TestPartition, LocalPartition)
|
||||
const auto tensor =
|
||||
ck::wrapper::make_tensor<ck::wrapper::MemoryTypeEnum::Generic>(data.data(), layout);
|
||||
|
||||
const auto thread_steps = ck::make_tuple(ck::Number<1>{}, ck::Number<8>{}, ck::Number<1>{});
|
||||
const auto thread_layout = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}, ck::Number<1>{});
|
||||
const auto thread_steps = ck::make_tuple(ck::Number<1>{}, ck::Number<8>{}, ck::Number<1>{});
|
||||
// row-major thread layout
|
||||
const auto thread_layout =
|
||||
ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}, ck::Number<1>{}),
|
||||
ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}, ck::Number<1>{}));
|
||||
// 3d partition on 2d shape (calculate partition on 3d thread layout, and then skip first dim)
|
||||
const auto thread_projection =
|
||||
ck::make_tuple(ck::wrapper::slice(4), ck::Number<1>{}, ck::Number<1>{});
|
||||
@@ -70,29 +73,37 @@ TEST(TestPartition, LocalTile)
|
||||
ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}, ck::Number<2>{}, ck::Number<2>{});
|
||||
const auto block_projection =
|
||||
ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(2));
|
||||
constexpr ck::index_t projection_block_dim = ck::Number<2>{};
|
||||
const auto num_blocks =
|
||||
|
||||
const auto grid_shape =
|
||||
ck::make_tuple(ck::wrapper::size<0>(shape) / ck::wrapper::size<0>(block_shape),
|
||||
ck::wrapper::size<1>(shape) / ck::wrapper::size<1>(block_shape),
|
||||
ck::wrapper::size<2>(shape) / ck::wrapper::size<2>(block_shape));
|
||||
std::vector<ck::index_t> block_idxs(ck::wrapper::size(num_blocks));
|
||||
std::iota(block_idxs.begin(), block_idxs.end(), 0);
|
||||
std::vector<ck::Tuple<ck::index_t, ck::index_t, ck::index_t, ck::index_t>> block_idxs;
|
||||
for(int i = 0; i < ck::wrapper::size<0>(grid_shape); i++)
|
||||
{
|
||||
for(int j = 0; j < ck::wrapper::size<1>(grid_shape); j++)
|
||||
{
|
||||
for(int k = 0; k < ck::wrapper::size<2>(grid_shape); k++)
|
||||
{
|
||||
block_idxs.emplace_back(i, j, k, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(auto block_idx : block_idxs)
|
||||
{
|
||||
constexpr ck::index_t projection_block_dim = ck::Number<2>{};
|
||||
const auto packed_tile =
|
||||
ck::wrapper::make_local_tile(tensor, block_shape, block_idx, block_projection);
|
||||
|
||||
const auto expected_tile_size = ck::wrapper::size(block_shape) / projection_block_dim;
|
||||
auto expected_tile_first_val = (block_idx % ck::wrapper::size<2>(num_blocks)) *
|
||||
auto expected_tile_first_val = ck::wrapper::size<2>(block_idx) *
|
||||
ck::wrapper::size<2>(block_shape) *
|
||||
ck::wrapper::size<2>(strides);
|
||||
block_idx /= ck::wrapper::size<2>(num_blocks);
|
||||
expected_tile_first_val += (block_idx % ck::wrapper::size<1>(num_blocks)) *
|
||||
expected_tile_first_val += ck::wrapper::size<1>(block_idx) *
|
||||
ck::wrapper::size<1>(block_shape) *
|
||||
ck::wrapper::size<1>(strides);
|
||||
block_idx /= ck::wrapper::size<1>(num_blocks);
|
||||
expected_tile_first_val += (block_idx % ck::wrapper::size<0>(num_blocks)) *
|
||||
expected_tile_first_val += ck::wrapper::size<0>(block_idx) *
|
||||
ck::wrapper::size<0>(block_shape) *
|
||||
ck::wrapper::size<0>(strides);
|
||||
|
||||
Reference in New Issue
Block a user