mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
[CK TILE] Apply get_k_warp_tile_for_preshuffle_b in examples and tests
This commit is contained in:
@@ -148,7 +148,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();
|
||||
|
||||
static constexpr bool kPadK = true;
|
||||
|
||||
@@ -174,7 +174,7 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
@@ -220,7 +220,8 @@ struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();
|
||||
|
||||
static constexpr bool kPadK = true;
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase<Persistent>
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();
|
||||
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
|
||||
@@ -145,7 +145,7 @@ struct GemmConfigPreshuffleB_BQuant_Decode : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();
|
||||
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
@@ -175,7 +175,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();
|
||||
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
|
||||
@@ -77,37 +77,35 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
|
||||
|
||||
if(ck_tile::is_gfx12_supported())
|
||||
{
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
|
||||
constexpr int kKLanePerWarp = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
int kABK0PerLane = gemmConfig.K_Warp_Tile / kKLanePerWarp / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
kABK0PerLane,
|
||||
divisor,
|
||||
kKLanePerWarp,
|
||||
kABK1PerLane});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
|
||||
}
|
||||
else
|
||||
{
|
||||
int divisor = 1;
|
||||
int kKLanePerWarp = 1;
|
||||
if(ck_tile::is_gfx11_supported())
|
||||
{
|
||||
divisor = 1;
|
||||
kKLanePerWarp = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = get_warp_size() / gemmConfig.N_Warp_Tile;
|
||||
kKLanePerWarp = get_warp_size() / gemmConfig.N_Warp_Tile;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
divisor,
|
||||
gemmConfig.K_Warp_Tile / divisor});
|
||||
k_ / (gemmConfig.K_Warp_Tile / kKLanePerWarp),
|
||||
gemmConfig.K_Warp_Tile / kKLanePerWarp});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 1, 3});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,39 +142,39 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmC
|
||||
int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp;
|
||||
if(ck_tile::is_gfx12_supported())
|
||||
{
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
|
||||
constexpr int kKLanePerWarp = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
int kABK0PerLane = gemmConfig.K_Warp_Tile / kKLanePerWarp / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
|
||||
gemmConfig.N_Warp,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
kABK0PerLane,
|
||||
divisor,
|
||||
kKLanePerWarp,
|
||||
kABK1PerLane});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 6, 5, 2, 7});
|
||||
}
|
||||
else
|
||||
{
|
||||
int divisor = 1;
|
||||
int kKLanePerWarp = 1;
|
||||
if(ck_tile::is_gfx11_supported())
|
||||
{
|
||||
divisor = 1;
|
||||
kKLanePerWarp = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = get_warp_size() / gemmConfig.N_Warp_Tile;
|
||||
kKLanePerWarp = get_warp_size() / gemmConfig.N_Warp_Tile;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
|
||||
gemmConfig.N_Warp,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
divisor,
|
||||
gemmConfig.K_Warp_Tile / divisor});
|
||||
kKLanePerWarp,
|
||||
gemmConfig.K_Warp_Tile / kKLanePerWarp});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
|
||||
}
|
||||
|
||||
@@ -69,26 +69,38 @@ constexpr index_t get_k_warp_tile()
|
||||
template <typename PrecType, index_t N_Warp_Tile>
|
||||
constexpr index_t get_k_warp_tile_for_preshuffle_b()
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return 16;
|
||||
#else
|
||||
// When preshuffle B is enabled, the K_Warp_Tile must be sized appropriately
|
||||
// to support both dwordx4 loading instructions and MFMA instruction requirements.
|
||||
// A single dwordx4 load may feed one or more MFMA instructions, or conversely,
|
||||
// multiple loads may be required for a single MFMA instruction with a larger K dimension
|
||||
// (e.g., 16x16x128 on gfx950).
|
||||
|
||||
// To achieve optimal memory bandwidth, each thread loads a minimum of 16 bytes (dwordx4)
|
||||
// from global memory.
|
||||
const int kMaxBytesPerLoad = 16; // buffer load max 16 bytes
|
||||
const int kMaxElementsPerLoad = kMaxBytesPerLoad / sizeof(PrecType);
|
||||
const int kKLanePerWarp = ck_tile::get_warp_size() / N_Warp_Tile;
|
||||
const int kKPerWarp = kMaxElementsPerLoad * kKLanePerWarp;
|
||||
const int kKLanePerWarp = ck_tile::get_warp_size() / N_Warp_Tile;
|
||||
const int kKPerWarp = kMaxElementsPerLoad * kKLanePerWarp;
|
||||
|
||||
// Minimum K_Warp_Tile required by MFMA instructions
|
||||
const index_t kMfmaN16Index = 0;
|
||||
const index_t kMfmaN32Index = 1;
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
const index_t kF8MfmaMaxK[2] = {128, 64};
|
||||
const index_t kF8MfmaMaxK[2] = {128, 64};
|
||||
const index_t kF16MfmaMaxK[2] = {32, 16};
|
||||
#else
|
||||
const index_t kF8MfmaMaxK[2] = {32, 16};
|
||||
const index_t kF8MfmaMaxK[2] = {32, 16};
|
||||
const index_t kF16MfmaMaxK[2] = {16, 8};
|
||||
#endif
|
||||
const bool kIsF8 =
|
||||
std::is_same_v<PrecType, fp8_t> || std::is_same_v<PrecType, bf8_t>;
|
||||
const index_t kMfmaIndex = N_Warp_Tile == 16 ? kMfmaN16Index : kMfmaN32Index;
|
||||
const index_t kMfmaMaxK = kIsF8 ? kF8MfmaMaxK[kMfmaIndex] : kF16MfmaMaxK[kMfmaIndex];
|
||||
const bool kIsF8 = std::is_same_v<PrecType, fp8_t> || std::is_same_v<PrecType, bf8_t>;
|
||||
const index_t kMfmaIndex = N_Warp_Tile == 16 ? kMfmaN16Index : kMfmaN32Index;
|
||||
const index_t kMfmaMaxK = kIsF8 ? kF8MfmaMaxK[kMfmaIndex] : kF16MfmaMaxK[kMfmaIndex];
|
||||
|
||||
return max(kKPerWarp, kMfmaMaxK);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -39,16 +39,11 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
constexpr index_t k_b_per_load =
|
||||
TileShape::WarpTile::at(I1) * TileShape::WarpTile::at(I2) / get_warp_size();
|
||||
|
||||
/* The k_b_per_load should meet the requirement that each thread loads 16 bytes in
|
||||
* Preshuffle B */
|
||||
static_assert(k_b_per_load * sizeof(BDataType) == 16);
|
||||
|
||||
return k_b_per_load;
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
|
||||
// Forward declarations for quant type-specific implementations
|
||||
template <ck_tile::QuantType QT>
|
||||
@@ -74,11 +75,14 @@ class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||
static constexpr bool PreshuffleQuant = GemmConfig::PreshuffleQuant;
|
||||
static constexpr bool PreshuffleB = GemmConfig::PreshuffleB;
|
||||
static constexpr bool TiledMMAPermuteN = GemmConfig::TiledMMAPermuteN;
|
||||
static constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
GemmConfig::PreshuffleB
|
||||
? ck_tile::get_k_warp_tile_for_preshuffle_b<BDataType, N_Warp_Tile>()
|
||||
: ck_tile::get_k_warp_tile<BDataType, N_Warp_Tile>();
|
||||
static constexpr bool PreshuffleQuant = GemmConfig::PreshuffleQuant;
|
||||
static constexpr bool PreshuffleB = GemmConfig::PreshuffleB;
|
||||
static constexpr bool TiledMMAPermuteN = GemmConfig::TiledMMAPermuteN;
|
||||
static constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
|
||||
|
||||
static constexpr bool kPadM = GemmConfig::kPadM;
|
||||
static constexpr bool kPadN = GemmConfig::kPadN;
|
||||
|
||||
@@ -6,16 +6,7 @@
|
||||
#include "test_gemm_quant_base.hpp"
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
#include "ck_tile/host/tensor_shuffle_utils.hpp"
|
||||
|
||||
template <bool is_8bit>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return 16;
|
||||
#else
|
||||
return is_8bit ? 64 : 32;
|
||||
#endif
|
||||
}
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
@@ -50,23 +41,21 @@ struct GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<false>();
|
||||
// K_Warp_Tile is derived from N_Warp_Tile and BDataType
|
||||
};
|
||||
|
||||
struct GemmConfigDecode : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<true>();
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
};
|
||||
|
||||
struct GemmConfigPrefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<true>();
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
};
|
||||
|
||||
struct GemmConfigMxFp4 : public GemmConfigBase
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
|
||||
using AddScale = ck_tile::element_wise::AddScale;
|
||||
using ElementWiseAddAdd = ck_tile::element_wise::MultiDAdd;
|
||||
@@ -23,28 +24,6 @@ static constexpr inline auto is_row_major(Layout layout_)
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return 16;
|
||||
#else
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename A0DataType,
|
||||
typename B0DataType,
|
||||
typename AccDataType,
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "ck_tile/host/tensor_shuffle_utils.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
@@ -86,7 +87,7 @@ struct config
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<Datatype, N_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename Datatype>
|
||||
@@ -102,7 +103,7 @@ struct config_wmma
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<Datatype, M_Warp_Tile>();
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<Datatype, N_Warp_Tile>();
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
|
||||
@@ -48,7 +48,7 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
static const ck_tile::index_t M_Warp_Tile = 16;
|
||||
static const ck_tile::index_t N_Warp_Tile = 16;
|
||||
static const ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<BDataType, M_Warp_Tile, true>();
|
||||
ck_tile::get_k_warp_tile_for_preshuffle_b<PrecType, N_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true; // preshuffle v2 uses ping-pong smem
|
||||
static constexpr bool TransposeC = false; // transpose c is not supported
|
||||
|
||||
Reference in New Issue
Block a user