Fixing numerical error, and interchange preshuffle configs to match with flatmm (#2515)

[ROCm/composable_kernel commit: 579bd73435]
This commit is contained in:
Khushbu Agarwal
2025-07-16 22:33:03 -07:00
committed by GitHub
parent e75bbb6a13
commit f1f9b9635c
4 changed files with 16 additions and 8 deletions

View File

@@ -241,8 +241,8 @@ struct GemmConfigPreshufle_1 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
@@ -263,8 +263,8 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;

View File

@@ -220,7 +220,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
auto [result, arg_parser] = create_args(argc, argv);
bool preshuffle = GemmConfig::Preshuffle;
if(preshuffle && a_layout != "R" && b_layout != "C")
if(preshuffle && (a_layout != "R" || b_layout != "C"))
{
throw std::runtime_error(
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");

View File

@@ -315,8 +315,16 @@ int run_gemm_example_with_layouts(int argc,
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
if constexpr(preshuffle)
{
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
}
}
else if(init_method == 1)
{

View File

@@ -18,7 +18,7 @@ constexpr const char* DataTypeToString()
{
return "bf8";
}
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
return "bf16";
}