mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
udpdate
This commit is contained in:
@@ -25,7 +25,7 @@ template <typename ADataType,
|
||||
typename BScaleCLayout,
|
||||
typename CLayout,
|
||||
uint32_t BlockScaleSize>
|
||||
float gemm_mx_calc(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
float gemm_mx_calc(const ck_tile::GemmMXHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
@@ -37,7 +37,7 @@ float gemm_mx_calc(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::strea
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = 64;
|
||||
constexpr ck_tile::index_t N_Tile = 64;
|
||||
constexpr ck_tile::index_t K_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 128;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
@@ -45,7 +45,7 @@ float gemm_mx_calc(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::strea
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 256;
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
|
||||
@@ -169,10 +169,10 @@ struct UniversalGemmBasePolicy
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern>;
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern>;
|
||||
|
||||
constexpr auto BK0 = number<TileEncodingPattern::X1>{};
|
||||
constexpr auto BK1 = number<TileEncodingPattern::Y0>{};
|
||||
@@ -324,7 +324,14 @@ struct UniversalGemmBasePolicy
|
||||
else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 16 / sizeof(DataType));
|
||||
if constexpr(std::is_same_v(remove_cvref_t<DataType>, ck_tile::pk_fp4_t >))
|
||||
{
|
||||
return 16; // special procssing for packed fp4 to avoid re-packing
|
||||
}
|
||||
else
|
||||
{
|
||||
return (PackedSize * 16 / sizeof(DataType));
|
||||
}
|
||||
}
|
||||
else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0)
|
||||
@@ -636,15 +643,15 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
: WGAttrNumAccessEnum::Invalid;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
Problem::UseStructuredSparsity,
|
||||
wg_attr_num_access>;
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
Problem::UseStructuredSparsity,
|
||||
wg_attr_num_access>;
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
|
||||
@@ -117,10 +117,19 @@ struct GemmMXKernel
|
||||
using BScaleDataType = remove_cvref_t<typename GemmPipeline::BScaleDataType>;
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using APackedSize = remove_cvref_t<typename GemmPipeline::PackedSize>;
|
||||
using BPackedSize = remove_cvref_t<typename GemmPipeline::PackedSize>;
|
||||
using BlockScaleSize = remove_cvref_t<typename GemmPipeline::BlockScaleSize>;
|
||||
|
||||
static constexpr auto MXdlPack = 2;
|
||||
static constexpr auto NXdlPack = 2;
|
||||
static constexpr auto KXdlPack = 2;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>();
|
||||
static constexpr auto I4 = number<4>();
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
@@ -136,22 +145,23 @@ struct GemmMXKernel
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
|
||||
CK_TILE_HOST static constexpr AQuantGemmKernelArgs
|
||||
MakeKernelArgs(const AQuantGemmHostArgs& hostArgs)
|
||||
CK_TILE_HOST static constexpr GemmMXKernelArgs MakeKernelArgs(const GemmMXHostArgs& hostArgs)
|
||||
{
|
||||
return AQuantGemmKernelArgs{hostArgs.a_ptr,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.aq_ptr,
|
||||
hostArgs.c_ptr,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K,
|
||||
hostArgs.QK,
|
||||
hostArgs.stride_A,
|
||||
hostArgs.stride_B,
|
||||
hostArgs.stride_C,
|
||||
hostArgs.stride_AQ,
|
||||
hostArgs.k_batch};
|
||||
return GemmMXKernelArgs{hostArgs.a_ptr,
|
||||
hostArs.a_scale_ptr_,
|
||||
hostArgs.b_ptr,
|
||||
hostArgs.b_scale_ptr_,
|
||||
hostArgs.c_ptr,
|
||||
hostArgs.M,
|
||||
hostArgs.N,
|
||||
hostArgs.K / APackedSize,
|
||||
hostArgs.stride_A / APackedSize,
|
||||
hostArgs.stride_scale_A,
|
||||
hostArgs.stride_B / BPackedSize,
|
||||
hostArgs.stride_scale_B,
|
||||
hostArgs.stride_C,
|
||||
hostArgs.stride_AQ,
|
||||
hostArgs.k_batch};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
@@ -194,10 +204,20 @@ struct GemmMXKernel
|
||||
{
|
||||
splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1));
|
||||
}
|
||||
|
||||
// Calculate A scale offset
|
||||
a_scale_k_split_offset = __builtin_amdgcn_readfirstlane(
|
||||
k_id * kargs.KRead / (BlockScaleSize / APackedSize) * MXdlPack * NPerXdl)
|
||||
|
||||
// Caluculate B scale offset
|
||||
b_scale_k_split_offset = __builtin_amdgcn_readfirstlane(
|
||||
k_id * kargs.KRead / (BlockScaleSize / BPackedSize) * NXdlPack * NPerXdl);
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
index_t a_scale_k_split_offset;
|
||||
index_t b_scale_k_split_offset;
|
||||
index_t splitted_k;
|
||||
};
|
||||
|
||||
@@ -351,8 +371,9 @@ struct GemmMXKernel
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const AScaleDataType* a_scale_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const AQDataType* aq_ptr,
|
||||
const BScaleDataType* b_scale_ptr,
|
||||
CDataType* c_ptr,
|
||||
const AQuantGemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
@@ -379,16 +400,27 @@ struct GemmMXKernel
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& aq_tensor_view = [&]() {
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
// A scale tensor view
|
||||
const auto& a_scale_tensor_view = [&]() {
|
||||
static_asssert(std::is_same_v<AScaleLayout, tensor_layout::gemm::RowMajor>);
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
aq_ptr,
|
||||
make_tuple(kargs.M, kargs.QK),
|
||||
make_tuple(kargs.stride_AQ, 1),
|
||||
a_scale_ptr,
|
||||
make_tuple(kargs.M, kargs.K / BlockScaleSize),
|
||||
make_tuple(kargs.stride_scale_A, 1),
|
||||
number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
number<1>{});
|
||||
}();
|
||||
|
||||
// const auto& aq_tensor_view = [&]() {
|
||||
// static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
// return make_naive_tensor_view<address_space_enum::global>(
|
||||
// aq_ptr,
|
||||
// make_tuple(kargs.M, kargs.QK),
|
||||
// make_tuple(kargs.stride_AQ, 1),
|
||||
// number<GemmPipeline::GetVectorSizeAQ()>{},
|
||||
// number<1>{});
|
||||
// }();
|
||||
|
||||
const auto& b_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -452,6 +484,17 @@ struct GemmMXKernel
|
||||
}
|
||||
}();
|
||||
|
||||
// B scale tensor view
|
||||
const auto& b_scale_tensor_view = [&]() {
|
||||
static_assert(std::is_same_v<BScaleLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_scale_ptr,
|
||||
make_tuple(kargs.N, kargs.K / BlockScaleSize),
|
||||
make_tuple(kargs.stride_scale_B, 1),
|
||||
number<GemmPipeline::GetVectorSizeBQ()>{},
|
||||
number<1>{});
|
||||
}();
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& c_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -474,7 +517,8 @@ struct GemmMXKernel
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_tensor_view, aq_tensor_view, b_tensor_view, c_tensor_view);
|
||||
return make_tuple(
|
||||
a_tensor_view, a_scale_tensor_view, b_tensor_view, b_scale_tensor_view, c_tensor_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
@@ -498,13 +542,13 @@ struct GemmMXKernel
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& aq_pad_view = [&]() {
|
||||
const auto& aq_tensor_view = views.at(I1);
|
||||
const auto& a_scale_pad_view = [&]() {
|
||||
const auto& a_scale_tensor_view = views.at(I1);
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
|
||||
return pad_tensor_view(
|
||||
aq_tensor_view,
|
||||
a_scale_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
|
||||
number<TilePartitioner::KPerBlock / GemmPipeline::BlockScaleSize>{}),
|
||||
// TODO: Add support for padding.
|
||||
sequence<false, false>{});
|
||||
}();
|
||||
@@ -527,9 +571,19 @@ struct GemmMXKernel
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_scale_pad_view = [&]() {
|
||||
const auto& b_scale_tensor_view = views.at(I3);
|
||||
static_assert(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return pad_tensor_view(
|
||||
b_scale_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / GemmPipeline::BlockScaleSize>{}),
|
||||
sequence<false, false>{});
|
||||
}();
|
||||
|
||||
// TODO vector write in for C in ColMajor
|
||||
const auto& c_pad_view = [&]() {
|
||||
const auto& c_tensor_view = views.at(I3);
|
||||
const auto& c_tensor_view = views.at(I4);
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
@@ -546,17 +600,18 @@ struct GemmMXKernel
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, aq_pad_view, b_pad_view, c_pad_view);
|
||||
return make_tuple(a_pad_view, a_scale_pad_view, b_pad_view, b_scale_pad_view, c_pad_view);
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& aq_pad_view = views.at(I1);
|
||||
const auto& b_pad_view = views.at(I2);
|
||||
const auto& c_pad_view = views.at(I3);
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& a_scale_pad_view = views.at(I1);
|
||||
const auto& b_pad_view = views.at(I2);
|
||||
const auto& b_scale_pad_view = views.at(I3);
|
||||
const auto& c_pad_view = views.at(I4);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -575,12 +630,12 @@ struct GemmMXKernel
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& aq_block_window = [&]() {
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
const auto& a_scale_block_window = [&]() {
|
||||
static_assert(std::is_same_v<AScaleLayout, tensor_layout::gemm::RowMajor>);
|
||||
return make_tile_window(
|
||||
aq_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
|
||||
number<TilePartitioner::KPerBlock / GemmPipeline::BlockScaleSize>{}),
|
||||
{i_m, 0});
|
||||
}();
|
||||
|
||||
@@ -601,6 +656,15 @@ struct GemmMXKernel
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_scale_block_window = [&]() {
|
||||
static_assert(std::is_same_v<BScaleLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
return make_tile_window(
|
||||
b_scale_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock / GemmPipeline::BlockScaleSize>{}),
|
||||
{0, i_n});
|
||||
}();
|
||||
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
@@ -626,8 +690,9 @@ struct GemmMXKernel
|
||||
*/
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
|
||||
const AScaleDataType* a_scale_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const AQDataType* aq_ptr,
|
||||
const BScaleDataType* b_scale_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const AQuantGemmKernelArgs& kargs,
|
||||
@@ -670,16 +735,26 @@ struct GemmMXKernel
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// options
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
|
||||
const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
|
||||
const AScaleDataType* a_scale_ptr = static_cast<const AScaleDataType*>(kargs.a_scale_ptr);
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
|
||||
const BScaleDataType* b_scale_ptr = static_cast<const BScaleDataType*>(kargs.b_scale_ptr);
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
assert(kargs.k_batch == 1);
|
||||
RunGemm(a_ptr, b_ptr, aq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
RunGemm(a_ptr,
|
||||
a_scale_ptr,
|
||||
b_ptr,
|
||||
b_scale_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -28,7 +28,20 @@ struct GemmMXPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgCrPol
|
||||
constexpr index_t KPerBlockScale = KPerBlock / Problem::kBlockScaleSize;
|
||||
|
||||
static_assert(std::is_same_v<AScaleLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
return GetAScaleGlobalVectorLoadSize<Problem, AScaleDataType, MPerBlock, KPerBlockScale>();
|
||||
return GetScaleGlobalVectorLoadSize<Problem, AScaleDataType, MPerBlock, KPerBlockScale>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBScale()
|
||||
{
|
||||
using BScaleLayout = remove_cvref_t<typename Problem::BScaleLayout>;
|
||||
using BScaleDataType = remove_cvref_t<typename Problem::BScaleDataType>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KPerBlockScale = KPerBlock / Problem::kBlockScaleSize;
|
||||
|
||||
static_assert(std::is_same_v<BScaleLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
return GetScaleGlobalVectorLoadSize<Problem, BScaleDataType, NPerBlock, KPerBlockScale>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -125,9 +125,14 @@ struct GemmMXPipelineAgBgCrCompV3 : public BaseGemmMXPipelineAgBgCrCompV3<Proble
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeScale()
|
||||
|
||||
static constexpr index_t GetVectorSizeAScale()
|
||||
{
|
||||
return Policy::template GetVectorSizeAQ<Problem>();
|
||||
return Policy::template GetVectorSizeAScale<Problem>();
|
||||
}
|
||||
static constexpr index_t GetVectorSizeBScale()
|
||||
{
|
||||
return Policy::template GetVectorSizeBScale<Problem>();
|
||||
}
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename DataType, index_t YPerTile, index_t XPerTile>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAScaleGlobalVectorLoadSize()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetScaleGlobalVectorLoadSize()
|
||||
{
|
||||
using I1 = number<1>;
|
||||
constexpr index_t NWarps = Problem::BlockGemmShape::BlockWarps::at(I1{});
|
||||
|
||||
Reference in New Issue
Block a user