This commit is contained in:
mtgu0705
2025-08-06 04:20:20 -05:00
parent 1afa2b61e3
commit f96d5b74b2
6 changed files with 162 additions and 62 deletions

View File

@@ -25,7 +25,7 @@ template <typename ADataType,
typename BScaleCLayout,
typename CLayout,
uint32_t BlockScaleSize>
float gemm_mx_calc(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
float gemm_mx_calc(const ck_tile::GemmMXHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
@@ -37,7 +37,7 @@ float gemm_mx_calc(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::strea
constexpr ck_tile::index_t M_Tile = 64;
constexpr ck_tile::index_t N_Tile = 64;
constexpr ck_tile::index_t K_Tile = 256;
constexpr ck_tile::index_t K_Tile = 128;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
@@ -45,7 +45,7 @@ float gemm_mx_calc(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::strea
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 128;
constexpr ck_tile::index_t K_Warp_Tile = 256;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,

View File

@@ -169,10 +169,10 @@ struct UniversalGemmBasePolicy
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern>;
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern>;
constexpr auto BK0 = number<TileEncodingPattern::X1>{};
constexpr auto BK1 = number<TileEncodingPattern::Y0>{};
@@ -324,7 +324,14 @@ struct UniversalGemmBasePolicy
else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
{
return (PackedSize * 16 / sizeof(DataType));
if constexpr(std::is_same_v(remove_cvref_t<DataType>, ck_tile::pk_fp4_t >))
{
return 16; // special procssing for packed fp4 to avoid re-packing
}
else
{
return (PackedSize * 16 / sizeof(DataType));
}
}
else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0)
@@ -636,15 +643,15 @@ struct UniversalGemmPipelineAgBgCrPolicy
: WGAttrNumAccessEnum::Invalid;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
Problem::UseStructuredSparsity,
wg_attr_num_access>;
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
Problem::UseStructuredSparsity,
wg_attr_num_access>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,

View File

@@ -117,10 +117,19 @@ struct GemmMXKernel
using BScaleDataType = remove_cvref_t<typename GemmPipeline::BScaleDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using APackedSize = remove_cvref_t<typename GemmPipeline::PackedSize>;
using BPackedSize = remove_cvref_t<typename GemmPipeline::PackedSize>;
using BlockScaleSize = remove_cvref_t<typename GemmPipeline::BlockScaleSize>;
static constexpr auto MXdlPack = 2;
static constexpr auto NXdlPack = 2;
static constexpr auto KXdlPack = 2;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto I3 = number<3>();
static constexpr auto I4 = number<4>();
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
@@ -136,22 +145,23 @@ struct GemmMXKernel
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
CK_TILE_HOST static constexpr AQuantGemmKernelArgs
MakeKernelArgs(const AQuantGemmHostArgs& hostArgs)
CK_TILE_HOST static constexpr GemmMXKernelArgs MakeKernelArgs(const GemmMXHostArgs& hostArgs)
{
return AQuantGemmKernelArgs{hostArgs.a_ptr,
hostArgs.b_ptr,
hostArgs.aq_ptr,
hostArgs.c_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.QK,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_C,
hostArgs.stride_AQ,
hostArgs.k_batch};
return GemmMXKernelArgs{hostArgs.a_ptr,
hostArs.a_scale_ptr_,
hostArgs.b_ptr,
hostArgs.b_scale_ptr_,
hostArgs.c_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K / APackedSize,
hostArgs.stride_A / APackedSize,
hostArgs.stride_scale_A,
hostArgs.stride_B / BPackedSize,
hostArgs.stride_scale_B,
hostArgs.stride_C,
hostArgs.stride_AQ,
hostArgs.k_batch};
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
@@ -194,10 +204,20 @@ struct GemmMXKernel
{
splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1));
}
// Calculate A scale offset
a_scale_k_split_offset = __builtin_amdgcn_readfirstlane(
k_id * kargs.KRead / (BlockScaleSize / APackedSize) * MXdlPack * NPerXdl)
// Caluculate B scale offset
b_scale_k_split_offset = __builtin_amdgcn_readfirstlane(
k_id * kargs.KRead / (BlockScaleSize / BPackedSize) * NXdlPack * NPerXdl);
}
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t a_scale_k_split_offset;
index_t b_scale_k_split_offset;
index_t splitted_k;
};
@@ -351,8 +371,9 @@ struct GemmMXKernel
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
const AScaleDataType* a_scale_ptr,
const BDataType* b_ptr,
const AQDataType* aq_ptr,
const BScaleDataType* b_scale_ptr,
CDataType* c_ptr,
const AQuantGemmKernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
@@ -379,16 +400,27 @@ struct GemmMXKernel
}
}();
const auto& aq_tensor_view = [&]() {
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
// A scale tensor view
const auto& a_scale_tensor_view = [&]() {
static_asssert(std::is_same_v<AScaleLayout, tensor_layout::gemm::RowMajor>);
return make_naive_tensor_view<address_space_enum::global>(
aq_ptr,
make_tuple(kargs.M, kargs.QK),
make_tuple(kargs.stride_AQ, 1),
a_scale_ptr,
make_tuple(kargs.M, kargs.K / BlockScaleSize),
make_tuple(kargs.stride_scale_A, 1),
number<GemmPipeline::GetVectorSizeAQ()>{},
number<1>{});
}();
// const auto& aq_tensor_view = [&]() {
// static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
// return make_naive_tensor_view<address_space_enum::global>(
// aq_ptr,
// make_tuple(kargs.M, kargs.QK),
// make_tuple(kargs.stride_AQ, 1),
// number<GemmPipeline::GetVectorSizeAQ()>{},
// number<1>{});
// }();
const auto& b_tensor_view = [&]() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
@@ -452,6 +484,17 @@ struct GemmMXKernel
}
}();
// B scale tensor view
const auto& b_scale_tensor_view = [&]() {
static_assert(std::is_same_v<BScaleLayout, tensor_layout::gemm::ColumnMajor>);
return make_naive_tensor_view<address_space_enum::global>(
b_scale_ptr,
make_tuple(kargs.N, kargs.K / BlockScaleSize),
make_tuple(kargs.stride_scale_B, 1),
number<GemmPipeline::GetVectorSizeBQ()>{},
number<1>{});
}();
// TODO: enable vector write for C in ColMajor
const auto& c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
@@ -474,7 +517,8 @@ struct GemmMXKernel
}
}();
return make_tuple(a_tensor_view, aq_tensor_view, b_tensor_view, c_tensor_view);
return make_tuple(
a_tensor_view, a_scale_tensor_view, b_tensor_view, b_scale_tensor_view, c_tensor_view);
}
template <typename TensorView>
@@ -498,13 +542,13 @@ struct GemmMXKernel
}
}();
const auto& aq_pad_view = [&]() {
const auto& aq_tensor_view = views.at(I1);
const auto& a_scale_pad_view = [&]() {
const auto& a_scale_tensor_view = views.at(I1);
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
return pad_tensor_view(
aq_tensor_view,
a_scale_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
number<TilePartitioner::KPerBlock / GemmPipeline::BlockScaleSize>{}),
// TODO: Add support for padding.
sequence<false, false>{});
}();
@@ -527,9 +571,19 @@ struct GemmMXKernel
}
}();
const auto& b_scale_pad_view = [&]() {
const auto& b_scale_tensor_view = views.at(I3);
static_assert(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>);
return pad_tensor_view(
b_scale_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::BlockScaleSize>{}),
sequence<false, false>{});
}();
// TODO vector write in for C in ColMajor
const auto& c_pad_view = [&]() {
const auto& c_tensor_view = views.at(I3);
const auto& c_tensor_view = views.at(I4);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(c_tensor_view,
@@ -546,17 +600,18 @@ struct GemmMXKernel
}
}();
return make_tuple(a_pad_view, aq_pad_view, b_pad_view, c_pad_view);
return make_tuple(a_pad_view, a_scale_pad_view, b_pad_view, b_scale_pad_view, c_pad_view);
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& aq_pad_view = views.at(I1);
const auto& b_pad_view = views.at(I2);
const auto& c_pad_view = views.at(I3);
const auto& a_pad_view = views.at(I0);
const auto& a_scale_pad_view = views.at(I1);
const auto& b_pad_view = views.at(I2);
const auto& b_scale_pad_view = views.at(I3);
const auto& c_pad_view = views.at(I4);
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
@@ -575,12 +630,12 @@ struct GemmMXKernel
}
}();
const auto& aq_block_window = [&]() {
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
const auto& a_scale_block_window = [&]() {
static_assert(std::is_same_v<AScaleLayout, tensor_layout::gemm::RowMajor>);
return make_tile_window(
aq_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
number<TilePartitioner::KPerBlock / GemmPipeline::BlockScaleSize>{}),
{i_m, 0});
}();
@@ -601,6 +656,15 @@ struct GemmMXKernel
}
}();
const auto& b_scale_block_window = [&]() {
static_assert(std::is_same_v<BScaleLayout, tensor_layout::gemm::ColumnMajor>);
return make_tile_window(
b_scale_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock / GemmPipeline::BlockScaleSize>{}),
{0, i_n});
}();
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
@@ -626,8 +690,9 @@ struct GemmMXKernel
*/
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
const AScaleDataType* a_scale_ptr,
const BDataType* b_ptr,
const AQDataType* aq_ptr,
const BScaleDataType* b_scale_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
const AQuantGemmKernelArgs& kargs,
@@ -670,16 +735,26 @@ struct GemmMXKernel
const SplitKBatchOffset splitk_batch_offset(kargs);
// options
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
const AScaleDataType* a_scale_ptr = static_cast<const AScaleDataType*>(kargs.a_scale_ptr);
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
const BScaleDataType* b_scale_ptr = static_cast<const BScaleDataType*>(kargs.b_scale_ptr);
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
assert(kargs.k_batch == 1);
RunGemm(a_ptr, b_ptr, aq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
RunGemm(a_ptr,
a_scale_ptr,
b_ptr,
b_scale_ptr,
c_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
};

View File

@@ -28,7 +28,20 @@ struct GemmMXPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPol
constexpr index_t KPerBlockScale = KPerBlock / Problem::kBlockScaleSize;
static_assert(std::is_same_v<AScaleLayout, ck_tile::tensor_layout::gemm::RowMajor>);
return GetAScaleGlobalVectorLoadSize<Problem, AScaleDataType, MPerBlock, KPerBlockScale>();
return GetScaleGlobalVectorLoadSize<Problem, AScaleDataType, MPerBlock, KPerBlockScale>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBScale()
{
using BScaleLayout = remove_cvref_t<typename Problem::BScaleLayout>;
using BScaleDataType = remove_cvref_t<typename Problem::BScaleDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockScale = KPerBlock / Problem::kBlockScaleSize;
static_assert(std::is_same_v<BScaleLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
return GetScaleGlobalVectorLoadSize<Problem, BScaleDataType, NPerBlock, KPerBlockScale>();
}
template <typename Problem>

View File

@@ -125,9 +125,14 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3<Proble
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr index_t GetVectorSizeScale()
static constexpr index_t GetVectorSizeAScale()
{
return Policy::template GetVectorSizeAQ<Problem>();
return Policy::template GetVectorSizeAScale<Problem>();
}
static constexpr index_t GetVectorSizeBScale()
{
return Policy::template GetVectorSizeBScale<Problem>();
}
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }

View File

@@ -8,7 +8,7 @@
namespace ck_tile {
template <typename Problem, typename DataType, index_t YPerTile, index_t XPerTile>
CK_TILE_HOST_DEVICE static constexpr auto GetAScaleGlobalVectorLoadSize()
CK_TILE_HOST_DEVICE static constexpr auto GetScaleGlobalVectorLoadSize()
{
using I1 = number<1>;
constexpr index_t NWarps = Problem::BlockGemmShape::BlockWarps::at(I1{});