mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
[CK_TILE] Port hw independent changes from internal repo to develop branch (#3301)
* [CK_TILE] Port hw independent changes from internal repo to develop branch It includes PR#96, #114, #120, #121. * correct rebase error
This commit is contained in:
@@ -9,7 +9,7 @@ endif()
|
||||
# Use standard asm for rtn bf16 conversion instead of turncate
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95")
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12")
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_multi_d test_grouped_gemm_multi_d.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
@@ -29,9 +29,6 @@ template <typename ALayout_,
|
||||
int M_Warp_val_,
|
||||
int N_Warp_val_,
|
||||
int K_Warp_val_,
|
||||
int M_Warp_Tile_val_,
|
||||
int N_Warp_Tile_val_,
|
||||
int K_Warp_Tile_val_,
|
||||
bool DoubleSmemBuffer_val_,
|
||||
ck_tile::GemmPipelineScheduler Scheduler_val_,
|
||||
PipelineType Pipeline_val_,
|
||||
@@ -50,15 +47,21 @@ struct KernelConfig
|
||||
using EDataType = EDataType_;
|
||||
using DsDataType = ck_tile::tuple<D0DataType_, D1DataType_>;
|
||||
|
||||
static constexpr int M_Tile_ = M_Tile_val_;
|
||||
static constexpr int N_Tile_ = N_Tile_val_;
|
||||
static constexpr int K_Tile_ = K_Tile_val_;
|
||||
static constexpr int M_Warp_ = M_Warp_val_;
|
||||
static constexpr int N_Warp_ = N_Warp_val_;
|
||||
static constexpr int K_Warp_ = K_Warp_val_;
|
||||
static constexpr int M_Warp_Tile_ = M_Warp_Tile_val_;
|
||||
static constexpr int N_Warp_Tile_ = N_Warp_Tile_val_;
|
||||
static constexpr int K_Warp_Tile_ = K_Warp_Tile_val_;
|
||||
static constexpr int M_Tile_ = M_Tile_val_;
|
||||
static constexpr int N_Tile_ = N_Tile_val_;
|
||||
static constexpr int K_Tile_ = K_Tile_val_;
|
||||
static constexpr int M_Warp_ = M_Warp_val_;
|
||||
static constexpr int N_Warp_ = N_Warp_val_;
|
||||
static constexpr int K_Warp_ = K_Warp_val_;
|
||||
#if CK_TILE_USE_WMMA
|
||||
static constexpr int M_Warp_Tile_ = 16;
|
||||
static constexpr int N_Warp_Tile_ = 16;
|
||||
static constexpr int K_Warp_Tile_ = 16;
|
||||
#else
|
||||
static constexpr int M_Warp_Tile_ = 32;
|
||||
static constexpr int N_Warp_Tile_ = 32;
|
||||
static constexpr int K_Warp_Tile_ = (M_Warp_val_ == 2) ? 16 : 8;
|
||||
#endif
|
||||
static constexpr bool DoubleSmemBuffer_ = DoubleSmemBuffer_val_;
|
||||
static constexpr auto Scheduler_ = Scheduler_val_;
|
||||
static constexpr PipelineType Pipeline_ = Pipeline_val_;
|
||||
@@ -68,21 +71,21 @@ struct KernelConfig
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ALayout, BLayout, ELayout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, M_N_K_Warp_Tile, DoubleSmemBuffer, Scheduler, Pipeline, Persistent
|
||||
// ALayout, BLayout, ELayout, ADataType, BDataType, D0DataType, D1DataType, AccDataType, EDataType, M_N_KTiles, M_N_K_Warps, DoubleSmemBuffer, Scheduler, Pipeline, Persistent
|
||||
// FP16 A/B/D/E
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true>, // v4
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4
|
||||
KernelConfig< Row, Col, Row, F16, F16, F16, F16, F32, F16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true>, // v4
|
||||
// BF16 A/B/D/E
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 256, 256, 32, 2, 2, 1, 32, 32, 16, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true> // v4
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, false>, // memory
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 32, 64, 4, 1, 1, false, ck_tile::GemmPipelineScheduler::Interwave, PipelineType::Memory, true>, // memory
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, false>, // v3
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 64, 2, 2, 1, false, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV3, true>, // v3
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, false>, // v4
|
||||
KernelConfig< Row, Col, Row, BF16, BF16, BF16, BF16, F32, BF16, 128, 128, 32, 2, 2, 1, true, ck_tile::GemmPipelineScheduler::Intrawave, PipelineType::CompV4, true> // v4
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
Reference in New Issue
Block a user