Merge branch 'mmflat' of https://github.com/ROCm/composable_kernel into mmflat

This commit is contained in:
AviralGoelAMD
2025-07-23 14:01:56 -05:00
10 changed files with 86 additions and 44 deletions

View File

@@ -2,9 +2,14 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp)
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-unused-local-typedef)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS --save-temps)
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm -enable-noalias-to-md-conversion=0")
target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS})

View File

@@ -249,9 +249,9 @@ struct GemmConfigPreshuffle_1 : public GemmConfigBase
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V3;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V1;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr bool DoubleSmemBuffer = false;
};
template <typename PrecType>
@@ -271,7 +271,7 @@ struct GemmConfigPreshuffle_2 : public GemmConfigBase
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V3;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
};
@@ -291,7 +291,7 @@ struct GemmConfigPreshuffle_3 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;

View File

@@ -36,10 +36,13 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
// using TilePartitioner =
// ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
// GemmConfig::TileParitionerGroupNum,
// GemmConfig::TileParitionerM01>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
ck_tile::GemmTile1DPartitioner<GemmShape>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
@@ -69,6 +72,7 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
std::cout << "k_grain: " << k_grain << " K_split: " << K_split << std::endl;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
@@ -290,5 +294,5 @@ int main(int argc, char* argv[])
// Return a non-zero code to indicate failure
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
//return EXIT_SUCCESS;
}

View File

@@ -41,10 +41,10 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
{
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
r.x = __builtin_amdgcn_readfirstlane(r.x);
r.y = __builtin_amdgcn_readfirstlane(r.y);
r.z = __builtin_amdgcn_readfirstlane(r.z);
r.w = __builtin_amdgcn_readfirstlane(r.w);
// r.x = __builtin_amdgcn_readfirstlane(r.x);
// r.y = __builtin_amdgcn_readfirstlane(r.y);
// r.z = __builtin_amdgcn_readfirstlane(r.z);
// r.w = __builtin_amdgcn_readfirstlane(r.w);
return r;
}

View File

@@ -113,6 +113,7 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
__builtin_amdgcn_sched_barrier(0x7F6);
});
});
});

View File

@@ -246,6 +246,11 @@ struct GemmKernel
{
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize2ndBuffer()
{
return GemmPipeline::GetSmemSize();
}
struct SplitKBatchOffset
{
@@ -950,7 +955,7 @@ struct GemmKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
__shared__ char smem_ptr_1[(GemmPipeline::Preshuffle) ? GetSmemSize2ndBuffer() : GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))

View File

@@ -112,7 +112,7 @@ struct GemmTile1DPartitioner
* @param N GEMM's N dimension.
* @return dim3 Structure holding grid's X,Y and Z dimensions.
*/
CK_TILE_HOST static auto
CK_TILE_HOST_DEVICE static auto
GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
{
const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;

View File

@@ -118,6 +118,10 @@ struct GemmPipelineProblemBase
}
static constexpr index_t VectorSizeA = []() {
// std::cout << "FixedVectorSize: " << FixedVectorSize << std::endl;
// std::cout << "kPadK: " << kPadK << std::endl;
// std::cout << "kPadM: " << kPadM << std::endl;
if constexpr(FixedVectorSize)
{
return VectorSizeA_;
@@ -133,6 +137,9 @@ struct GemmPipelineProblemBase
}();
static constexpr index_t VectorSizeB = []() {
// std::cout << "FixedVectorSize: " << FixedVectorSize << std::endl;
// std::cout << "kPadK: " << kPadK << std::endl;
// std::cout << "kPadN: " << kPadN << std::endl;
if constexpr(FixedVectorSize)
{
return VectorSizeB_;

View File

@@ -14,36 +14,37 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
// using ALayout = remove_cvref_t<typename Problem::ALayout>;
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto MakeADramDistribution() //looks like this function is not getting used
// {
// using ADataType = remove_cvref_t<typename Problem::ADataType>;
// // using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
// constexpr index_t BlockSize = Problem::kBlockSize;
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
// // constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
// constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
// constexpr index_t M0 = MPerBlock / (M2 * M1);
// static_assert(M0 * M1 * M2 == MPerBlock,
// "Incorrect M0, M2, M1 configuration! "
// "M0, M1, M2 must cover whole MPerBlock!");
// constexpr index_t K1 = 16 / sizeof(ADataType);
// constexpr index_t K0 = KPerBlock / K1;
// constexpr index_t M2 = get_warp_size() / K0;
// constexpr index_t M1 = BlockSize / get_warp_size();
// static_assert(K1 == 1, "M2 is zero, which will lead to a division by zero error.");
// static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
// static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
// // constexpr index_t M0 = MPerBlock / (M2 * M1);
// // static_assert(M0 * M1 * M2 == MPerBlock,
// // "Incorrect M0, M2, M1 configuration! "
// // "M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2>,
sequence<1>>{});
}
// return make_static_tile_distribution(
// tile_distribution_encoding<sequence<1>,
// tuple<sequence<M1, M2>, sequence<K0, K1>>,
// tuple<sequence<1>, sequence<1, 2>>,
// tuple<sequence<0>, sequence<1, 0>>,
// sequence<2>,
// sequence<1>>{});
// }
// 3d + padding
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()

View File

@@ -17,8 +17,12 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
static constexpr index_t GlobalBufferNum = 1;
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
{
std::cout << "BlockHasHotloop: " << num_loop << std::endl;
return num_loop > PrefetchStages;
}
@@ -33,10 +37,12 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
{
if(tail_number == TailNumber::Odd)
{
std::cout << "TailHandler: Odd" << std::endl;
run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Odd>{});
}
else if(tail_number == TailNumber::Even)
{
std::cout << "TailHandler: Even" << std::endl;
run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Even>{});
}
}
@@ -74,8 +80,14 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; }
static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; }
static constexpr index_t GetVectorSizeA() {
return PipelinePolicy::template GetVectorSizeA<Problem>();
//return Problem::VectorSizeA;
}
static constexpr index_t GetVectorSizeB() {
return PipelinePolicy::template GetVectorSizeB<Problem>();
//return Problem::VectorSizeB;
}
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
@@ -127,7 +139,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
// clang-format on
}
static constexpr bool DoubleSmemBuffer = true;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t Preshuffle = Problem::Preshuffle;
using Base::UsePersistentKernel;
@@ -226,6 +238,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
#if defined(__gfx950__)
if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256)
{
//printf("Inside gfx950, with 16x16 128x256x256 \n");
static_for<0, 2, 1>{}([&](auto j) {
ignore = j;
static_for<0, 3, 1>{}([&](auto i) {
@@ -273,6 +286,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
}
else
{
//printf("Inside gfx950, with 16x16 otherwise \n");
static_for<0, 2, 1>{}([&](auto j) {
ignore = j;
static_for<0, 3, 1>{}([&](auto i) {
@@ -311,8 +325,9 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
// MFMA → MFMA → MFMA → MFMA → DS Read
// For other device engine we need more agressive MFMA with DS writes interleaved
#else
if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256)
if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256) //TODO :: 128x256x128
{
//printf("Inside gfx942, with 16x16 128x256x256 \n");
static_for<0, 2, 1>{}([&](auto j) {
ignore = j;
// Uses loops to amortize scheduling overhead
@@ -388,6 +403,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
}
else if constexpr(kMPerBlock == 16 && kNPerBlock == 64 && kKPerBlock == 256)
{
//printf("Inside gfx942, with 16x16 16x64x256 \n");
static_for<0, 1, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
@@ -416,6 +432,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
}
else if constexpr(kMPerBlock == 128 && kNPerBlock == 128 && kKPerBlock == 128)
{
//printf("Inside gfx942, with 16x16 128x128x128 \n");
// prioritize MFMA to avoid LDS write conflicts
static_for<0, 2, 1>{}([&](auto j) {
ignore = j;
@@ -478,6 +495,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
}
else
{
//printf("Inside gfx942, with 16x16 otherwise \n");
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
@@ -505,6 +523,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
}
else
{
//printf("Inside gfx950 or gfx942, with other then 16x16 any block sizes \n");
if constexpr((A_LDS_Read_Inst_Num / 2 >
A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num))
{