mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
Merge commit '1b1c46e508c1fd40a03f54114b6b78629032fb4f' into develop
This commit is contained in:
@@ -13,7 +13,7 @@ using CDataType = ck::bhalf_t;
|
||||
using ComputeTypeA = ck::f8_t;
|
||||
using ComputeTypeB = ck::f8_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using ALayout = Col;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
@@ -30,13 +30,13 @@ using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuf
|
||||
PassThrough, PassThrough, PassThrough, GemmDefault,
|
||||
128,
|
||||
128, 64, 64,
|
||||
8, 8,
|
||||
16, 16, // AK1, BK1
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 4, 16, 0,
|
||||
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
2, 16, 16, 0,
|
||||
1, 1, S<1, 32, 1, 4>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1,
|
||||
ComputeTypeA, ComputeTypeB>;
|
||||
|
||||
@@ -5,7 +5,7 @@ endif()
|
||||
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
|
||||
add_executable(tile_example_gemm_quant_basic EXCLUDE_FROM_ALL gemm_quant_basic.cpp)
|
||||
target_compile_options(tile_example_gemm_quant_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
|
||||
@@ -419,6 +419,10 @@ int dispatch_group_size_ct(int m, int n, int k, F&& f)
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_gemm_example<GemmConfigBQuantPrefill_Wmma>(argc, argv);
|
||||
#else
|
||||
// Use non-preshuffled GemmConfig for 2D block scale support
|
||||
return !run_gemm_example<GemmConfigBQuantPrefill>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -216,6 +216,14 @@ struct GemmConfigBQuantPrefill : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigBQuantPrefill_Wmma : public GemmConfigBQuantPrefill<PrecType>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_ = ADataType_,
|
||||
typename CDataType_ = ADataType_,
|
||||
|
||||
Reference in New Issue
Block a user