[CK TILE] Apply get_k_warp_tile_for_preshuffle_b in examples and tests

This commit is contained in:
Cong Ma
2026-01-23 18:41:35 -05:00
parent 109bfa1558
commit bc91bb7dd7
11 changed files with 68 additions and 89 deletions

View File

@@ -148,7 +148,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();
static constexpr bool kPadK = true;
@@ -174,7 +174,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
@@ -220,7 +220,8 @@ struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();
static constexpr bool kPadK = true;

View File

@@ -84,7 +84,7 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase<Persistent>
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;

View File

@@ -145,7 +145,7 @@ struct GemmConfigPreshuffleB_BQuant_Decode : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
@@ -175,7 +175,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;

View File

@@ -77,37 +77,35 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
if(ck_tile::is_gfx12_supported())
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
constexpr int kKLanePerWarp = 2;
constexpr int kABK1PerLane = 8;
int kABK0PerLane = gemmConfig.K_Warp_Tile / kKLanePerWarp / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
gemmConfig.N_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
kABK0PerLane,
divisor,
kKLanePerWarp,
kABK1PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
}
else
{
int divisor = 1;
int kKLanePerWarp = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
kKLanePerWarp = 1;
}
else
{
assert(is_wave32() == false);
divisor = get_warp_size() / gemmConfig.N_Warp_Tile;
kKLanePerWarp = get_warp_size() / gemmConfig.N_Warp_Tile;
}
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
gemmConfig.N_Warp_Tile,
k_ / gemmConfig.K_Warp_Tile,
divisor,
gemmConfig.K_Warp_Tile / divisor});
k_ / (gemmConfig.K_Warp_Tile / kKLanePerWarp),
gemmConfig.K_Warp_Tile / kKLanePerWarp});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
return ck_tile::reference_permute(t_view, {0, 2, 1, 3});
}
}
@@ -144,39 +142,39 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmC
int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp;
if(ck_tile::is_gfx12_supported())
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
constexpr int kKLanePerWarp = 2;
constexpr int kABK1PerLane = 8;
int kABK0PerLane = gemmConfig.K_Warp_Tile / kKLanePerWarp / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / gemmConfig.K_Warp_Tile,
kABK0PerLane,
divisor,
kKLanePerWarp,
kABK1PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 6, 5, 2, 7});
}
else
{
int divisor = 1;
int kKLanePerWarp = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
kKLanePerWarp = 1;
}
else
{
assert(is_wave32() == false);
divisor = get_warp_size() / gemmConfig.N_Warp_Tile;
kKLanePerWarp = get_warp_size() / gemmConfig.N_Warp_Tile;
}
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / gemmConfig.K_Warp_Tile,
divisor,
gemmConfig.K_Warp_Tile / divisor});
kKLanePerWarp,
gemmConfig.K_Warp_Tile / kKLanePerWarp});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
}

View File

@@ -69,26 +69,38 @@ constexpr index_t get_k_warp_tile()
template <typename PrecType, index_t N_Warp_Tile>
constexpr index_t get_k_warp_tile_for_preshuffle_b()
{
#if CK_TILE_USE_WMMA
return 16;
#else
// When preshuffle B is enabled, the K_Warp_Tile must be sized appropriately
// to support both dwordx4 loading instructions and MFMA instruction requirements.
// A single dwordx4 load may feed one or more MFMA instructions, or conversely,
// multiple loads may be required for a single MFMA instruction with a larger K dimension
// (e.g., 16x16x128 on gfx950).
// To achieve optimal memory bandwidth, each thread loads a minimum of 16 bytes (dwordx4)
// from global memory.
const int kMaxBytesPerLoad = 16; // buffer load max 16 bytes
const int kMaxElementsPerLoad = kMaxBytesPerLoad / sizeof(PrecType);
const int kKLanePerWarp = ck_tile::get_warp_size() / N_Warp_Tile;
const int kKPerWarp = kMaxElementsPerLoad * kKLanePerWarp;
const int kKLanePerWarp = ck_tile::get_warp_size() / N_Warp_Tile;
const int kKPerWarp = kMaxElementsPerLoad * kKLanePerWarp;
// Minimum K_Warp_Tile required by MFMA instructions
const index_t kMfmaN16Index = 0;
const index_t kMfmaN32Index = 1;
#if defined(CK_GFX950_SUPPORT)
const index_t kF8MfmaMaxK[2] = {128, 64};
const index_t kF8MfmaMaxK[2] = {128, 64};
const index_t kF16MfmaMaxK[2] = {32, 16};
#else
const index_t kF8MfmaMaxK[2] = {32, 16};
const index_t kF8MfmaMaxK[2] = {32, 16};
const index_t kF16MfmaMaxK[2] = {16, 8};
#endif
const bool kIsF8 =
std::is_same_v<PrecType, fp8_t> || std::is_same_v<PrecType, bf8_t>;
const index_t kMfmaIndex = N_Warp_Tile == 16 ? kMfmaN16Index : kMfmaN32Index;
const index_t kMfmaMaxK = kIsF8 ? kF8MfmaMaxK[kMfmaIndex] : kF16MfmaMaxK[kMfmaIndex];
const bool kIsF8 = std::is_same_v<PrecType, fp8_t> || std::is_same_v<PrecType, bf8_t>;
const index_t kMfmaIndex = N_Warp_Tile == 16 ? kMfmaN16Index : kMfmaN32Index;
const index_t kMfmaMaxK = kIsF8 ? kF8MfmaMaxK[kMfmaIndex] : kF16MfmaMaxK[kMfmaIndex];
return max(kKPerWarp, kMfmaMaxK);
#endif
}
} // namespace ck_tile

View File

@@ -39,16 +39,11 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using TileShape = typename Problem::BlockGemmShape;
constexpr index_t k_b_per_load =
TileShape::WarpTile::at(I1) * TileShape::WarpTile::at(I2) / get_warp_size();
/* The k_b_per_load should meet the requirement that each thread loads 16 bytes in
* Preshuffle B */
static_assert(k_b_per_load * sizeof(BDataType) == 16);
return k_b_per_load;
}

View File

@@ -16,6 +16,7 @@
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_quant.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
// Forward declarations for quant type-specific implementations
template <ck_tile::QuantType QT>
@@ -74,11 +75,14 @@ class TestCkTileGemmQuantBase : public ::testing::Test
static constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
static constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
static constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
static constexpr bool PreshuffleQuant = GemmConfig::PreshuffleQuant;
static constexpr bool PreshuffleB = GemmConfig::PreshuffleB;
static constexpr bool TiledMMAPermuteN = GemmConfig::TiledMMAPermuteN;
static constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
static constexpr ck_tile::index_t K_Warp_Tile =
GemmConfig::PreshuffleB
? ck_tile::get_k_warp_tile_for_preshuffle_b<BDataType, N_Warp_Tile>()
: ck_tile::get_k_warp_tile<BDataType, N_Warp_Tile>();
static constexpr bool PreshuffleQuant = GemmConfig::PreshuffleQuant;
static constexpr bool PreshuffleB = GemmConfig::PreshuffleB;
static constexpr bool TiledMMAPermuteN = GemmConfig::TiledMMAPermuteN;
static constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
static constexpr bool kPadM = GemmConfig::kPadM;
static constexpr bool kPadN = GemmConfig::kPadN;

View File

@@ -6,16 +6,7 @@
#include "test_gemm_quant_base.hpp"
#include "ck_tile/host/permute_pk_int4.hpp"
#include "ck_tile/host/tensor_shuffle_utils.hpp"
template <bool is_8bit>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if CK_TILE_USE_WMMA
return 16;
#else
return is_8bit ? 64 : 32;
#endif
}
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
struct GemmConfigBase
{
@@ -50,23 +41,21 @@ struct GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<false>();
// K_Warp_Tile is derived from N_Warp_Tile and BDataType
};
struct GemmConfigDecode : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<true>();
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
};
struct GemmConfigPrefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<true>();
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
};
struct GemmConfigMxFp4 : public GemmConfigBase

View File

@@ -11,6 +11,7 @@
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
using AddScale = ck_tile::element_wise::AddScale;
using ElementWiseAddAdd = ck_tile::element_wise::MultiDAdd;
@@ -23,28 +24,6 @@ static constexpr inline auto is_row_major(Layout layout_)
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if CK_TILE_USE_WMMA
return 16;
#else
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
#endif
}
template <typename A0DataType,
typename B0DataType,
typename AccDataType,

View File

@@ -12,6 +12,7 @@
#include "ck_tile/host/tensor_shuffle_utils.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
@@ -86,7 +87,7 @@ struct config
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<Datatype, N_Warp_Tile>();
};
template <typename Datatype>
@@ -102,7 +103,7 @@ struct config_wmma
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<Datatype, M_Warp_Tile>();
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<Datatype, N_Warp_Tile>();
};
template <typename Tuple>

View File

@@ -48,7 +48,7 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
static const ck_tile::index_t M_Warp_Tile = 16;
static const ck_tile::index_t N_Warp_Tile = 16;
static const ck_tile::index_t K_Warp_Tile =
ck_tile::get_k_warp_tile<BDataType, M_Warp_Tile, true>();
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true; // preshuffle v2 uses ping-pong smem
static constexpr bool TransposeC = false; // transpose c is not supported