mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
fix ck_tile/basic_gemm build error (#1988)
This commit is contained in:
0
example/ck_tile/03_gemm/gemm_basic.cpp
Normal file → Executable file
0
example/ck_tile/03_gemm/gemm_basic.cpp
Normal file → Executable file
@@ -0,0 +1,14 @@
|
||||
#!/bin/sh
|
||||
EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)"
|
||||
VALID=1
|
||||
|
||||
|
||||
for b_matrix_layout in "C"; do
|
||||
for m in "64" "512" "1024" "2048"; do
|
||||
for n in "512" "1024" "2048"; do
|
||||
for k in "64" "512" "1024" "2048"; do
|
||||
$EXE -prec=bf16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
@@ -0,0 +1,14 @@
|
||||
#!/bin/sh
|
||||
EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)"
|
||||
VALID=1
|
||||
|
||||
|
||||
for b_matrix_layout in "C"; do
|
||||
for m in "64" "512" "1024" "2048"; do
|
||||
for n in "512" "1024" "2048"; do
|
||||
for k in "64" "512" "1024" "2048"; do
|
||||
$EXE -prec=bf8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
@@ -32,6 +32,9 @@ function print_log_header(){
|
||||
}
|
||||
|
||||
# run verification tests
|
||||
for dtype in fp16 bf16 fp8 bf8; do
|
||||
example/ck_tile/03_gemm/script/benchmark_basic_$dtype.sh
|
||||
done
|
||||
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
|
||||
|
||||
# run performance benchmarks
|
||||
|
||||
1
include/ck_tile/ops/common/utils.hpp
Normal file → Executable file
1
include/ck_tile/ops/common/utils.hpp
Normal file → Executable file
@@ -18,6 +18,7 @@ template <> struct typeToStr<bf16_t> { static constexpr const char * name = "bf1
|
||||
template <> struct typeToStr<fp8_t> { static constexpr const char * name = "fp8"; };
|
||||
template <> struct typeToStr<bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
template <> struct typeToStr<int8_t> { static constexpr const char * name = "int8"; };
|
||||
template <> struct typeToStr<pk_int4_t> { static constexpr const char * name = "pk_int4"; };
|
||||
// clang-format on
|
||||
|
||||
template <typename ADataType_, typename BDataType_>
|
||||
|
||||
@@ -12,7 +12,7 @@ namespace ck_tile {
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
|
||||
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
|
||||
struct GemmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
@@ -182,11 +182,11 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegBlockDistribution<Problem>());
|
||||
shuffle_tile(a_shuffle_tmp, a_block_tile);
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile);
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
}
|
||||
@@ -196,11 +196,11 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
}
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegBlockDistribution<Problem>());
|
||||
shuffle_tile(b_shuffle_tmp, b_block_tile);
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile);
|
||||
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp);
|
||||
store_tile(b_copy_lds_window, b_block_tile_tmp);
|
||||
}
|
||||
@@ -229,15 +229,26 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp_loop = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp_loop, a_block_tile);
|
||||
store_tile(a_copy_lds_window,
|
||||
tile_elementwise_in(a_element_func, a_shuffle_tmp_loop));
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
}
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegBlockDistribution<Problem>());
|
||||
shuffle_tile(b_shuffle_tmp_loop, b_block_tile);
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp_loop, b_block_tile);
|
||||
store_tile(b_copy_lds_window,
|
||||
tile_elementwise_in(b_element_func, b_shuffle_tmp_loop));
|
||||
}
|
||||
|
||||
4
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
Normal file → Executable file
4
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
Normal file → Executable file
@@ -129,7 +129,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
static_assert(KPack % K3 == 0);
|
||||
constexpr index_t K2 = KPack / K3;
|
||||
if constexpr(get_warp_size() % (K2 * M0))
|
||||
if constexpr(get_warp_size() >= (K2 * M0))
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * M0);
|
||||
constexpr index_t K0 = BlockSize / get_warp_size();
|
||||
@@ -219,7 +219,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
constexpr index_t KPack = GetSmemPackB<Problem>();
|
||||
static_assert(KPack % K3 == 0);
|
||||
constexpr index_t K2 = KPack / K3;
|
||||
if constexpr(get_warp_size() % (K2 * N0) == 0)
|
||||
if constexpr(get_warp_size() >= (K2 * N0))
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
constexpr index_t K0 = BlockSize / get_warp_size();
|
||||
|
||||
Reference in New Issue
Block a user