[GEMM] Fix bWarpTile issue and remove redundant pipeline in BlockGemmPipeline

This commit is contained in:
YC Lin
2025-04-21 16:44:23 +00:00
parent 77a96c7a82
commit 8a6cc0e94b
3 changed files with 88 additions and 142 deletions

View File

@@ -43,7 +43,7 @@ struct BlockGemmASmemBSmemCReg
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
#if defined(ENABLE_INSTRUCTION_SCH)
#if defined(ENABLE_PREFETCH)
// A block tile distribution for load from lds
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{
@@ -92,16 +92,16 @@ struct BlockGemmASmemBSmemCReg
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
ALdsTile a_warp_tile_;
ALdsTile b_warp_tile_;
ALdsTile aWarpTile;
BLdsTile bWarpTile;
// Prefetch from LDS to warp register
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
load_tile(a_warp_tile_, a_block_window);
load_tile(b_warp_tile_, b_block_window);
aWarpTile = load_tile(a_block_window);
bWarpTile = load_tile(b_block_window);
}
#endif
@@ -178,23 +178,23 @@ struct BlockGemmASmemBSmemCReg
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
#if defined(ENABLE_INSTRUCTION_SCH)
#if defined(ENABLE_PREFETCH)
#pragma message ("local data share prefetch")
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
#else
load_tile(a_warp_tensor, a_warp_windows(mIter)(kIter));
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
#endif
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
#if defined(ENABLE_INSTRUCTION_SCH)
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
#if defined(ENABLE_PREFETCH)
b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
#else
load_tile(b_warp_tensor, b_warp_windows(nIter)(kIter));
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
#endif
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
@@ -305,22 +305,22 @@ struct BlockGemmASmemBSmemCReg
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
#if defined(ENABLE_INSTRUCTION_SCH)
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
#if defined(ENABLE_PREFETCH)
a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
#else
load_tile(a_warp_tensor, a_warp_windows(mIter)(kIter));
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
#endif
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
#if defined(ENABLE_INSTRUCTION_SCH)
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
#if defined(ENABLE_PREFETCH)
b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
#else
load_tile(b_warp_tensor, b_warp_windows(nIter)(kIter));
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
#endif
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;

View File

@@ -17,19 +17,31 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
#if defined(ADJUST_BLOCK_TILE_SHAPE)
constexpr index_t kMWarp = 2;
constexpr index_t kNWarp = 2;
#else
constexpr index_t kMWarp = 4;
constexpr index_t kNWarp = 1;
#endif
#if defined(NAIVE_IMPLEMENTATION)
#pragma message ("mfma m32 n32 k8")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 2, 2);
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{},
kMWarp,
kNWarp);
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 2, 2);
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{},
kMWarp,
kNWarp);
}
#elif defined(USING_MFMA_32x32x_8x2)
#pragma message ("mfma m32 n32 k16")
@@ -37,13 +49,17 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 2, 2);
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{},
kMWarp,
kNWarp);
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 2, 2);
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{},
kMWarp,
kNWarp);
}
#elif defined(USING_MFMA_16x16x16)
#pragma message ("mfma m16 n16 k16")
@@ -51,13 +67,17 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 2);
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{},
kMWarp,
kNWarp);
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, 2, 2);
return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{},
kMWarp,
kNWarp);
}
#elif defined(USING_MFMA_16x16x_16x2)
#pragma message ("mfma m16 n16 k32")
@@ -65,13 +85,17 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 2);
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{},
kMWarp,
kNWarp);
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}, 2, 2);
return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{},
kMWarp,
kNWarp);
}
#endif
else

View File

@@ -42,15 +42,8 @@ struct BlockGemmPipelineAGmemBGmemCReg
}
#if defined(ENABLE_INSTRUCTION_SCH)
static constexpr index_t APackedSize =
static constexpr index_t kPackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
@@ -72,31 +65,31 @@ struct BlockGemmPipelineAGmemBGmemCReg
constexpr index_t AB_LDS_RW_Width = GetSmemPack();
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
kMPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
kNPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width);
kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width);
kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width);
WaveNumN * kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * NPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width);
WaveNumM * kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) /
constexpr index_t C_MFMA_Inst_Num = kMPerBlock * kNPerBlock * kKPerBlock /
(kBlockSize / WaveSize) /
(MPerXDL * NPerXDL * KPerXDL);
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a =
AB_LDS_RW_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num :
AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16 ? A_LDS_Read_Inst_Num :
A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b =
AB_LDS_RW_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num :
AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? B_LDS_Read_Inst_Num :
B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
@@ -109,9 +102,9 @@ struct BlockGemmPipelineAGmemBGmemCReg
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto ds_read_a_issue_cycle =
AB_LDS_RW_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4;
AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =
AB_LDS_RW_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4;
AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto ds_read_b_mfma_rate =
@@ -266,15 +259,15 @@ struct BlockGemmPipelineAGmemBGmemCReg
{0, 0},
b_copy_dram_window.get_tile_distribution());
#if defined(ENABLE_INSTRUCTION_SCH)
#if defined(ENABLE_PREFETCH)
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0},
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0},
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()));
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0},
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0},
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()));
#else
// A LDS tile for block GEMM
@@ -303,23 +296,23 @@ struct BlockGemmPipelineAGmemBGmemCReg
ABlockTile a_block_tile;
BBlockTile b_block_tile;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kKPerBlock);
// -------------------------------------------------------------------------------------
// Gemm pipeline start
#if defined(ENABLE_INSTRUCTION_SCH)
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, KPerBlock);
#if defined(ENABLE_PREFETCH)
// Initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// Prefetch
// Global read 0
load_tile(a_block_tile, a_copy_dram_window);
load_tile(b_block_tile, b_copy_dram_window);
a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window);
if (num_loop > 1)
{
@@ -331,8 +324,8 @@ struct BlockGemmPipelineAGmemBGmemCReg
store_tile(b_copy_lds_window, b_block_tile);
// Global read 0
load_tile(a_block_tile, a_copy_dram_window);
load_tile(b_block_tile, b_copy_dram_window);
a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
@@ -357,8 +350,8 @@ struct BlockGemmPipelineAGmemBGmemCReg
store_tile(b_copy_lds_window, b_block_tile);
// Global read 0
load_tile(a_block_tile, a_copy_dram_window);
load_tile(b_block_tile, b_copy_dram_window);
a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
@@ -369,12 +362,14 @@ struct BlockGemmPipelineAGmemBGmemCReg
// Prefetch from LDS to warp register in block gemm
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
#if defined(ENABLE_INSTRUCTION_SCH)
HotLoopScheduler();
#endif
__builtin_amdgcn_sched_barrier(0);
i += 1;
} while(i < (num_loop - 2));
iCounter += 1;
} while(iCounter < (num_loop - 2));
}
// Tail
@@ -388,84 +383,12 @@ struct BlockGemmPipelineAGmemBGmemCReg
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
#elif defined(ENABLE_PREFETCH)
// Prefetch
// Global read 0
load_tile(a_block_tile, a_copy_dram_window);
load_tile(b_block_tile, b_copy_dram_window);
{
// Move to 1
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// Initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
store_tile(a_copy_lds_window, a_block_tile);
// Global read 1
load_tile(a_block_tile, a_copy_dram_window);
// LDS write 0
store_tile(b_copy_lds_window, b_block_tile);
// Global read 1
load_tile(b_block_tile, b_copy_dram_window);
}
index_t iCounter = num_loop - 2;
do
{
block_sync_lds();
// GEMM i
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
// Move to i + 2
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
// LDS write i + 1
store_tile(a_copy_lds_window, a_block_tile);
// Global read i + 2
load_tile(a_block_tile, a_copy_dram_window);
// LDS write i + 1
store_tile(b_copy_lds_window, b_block_tile);
// Global read i + 2
load_tile(b_block_tile, b_copy_dram_window);
iCounter--;
} while(iCounter > 0);
// Tail
{
block_sync_lds();
// GEMM num_loop - 2
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
// LDS write num_loop - 1
store_tile(a_copy_lds_window, a_block_tile);
store_tile(b_copy_lds_window, b_block_tile);
block_sync_lds();
// GEMM num_loop - 1
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
#else
// non-prefetch
load_tile(a_block_tile, a_copy_dram_window);
load_tile(b_block_tile, b_copy_dram_window);
a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
store_tile(a_copy_lds_window, a_block_tile);
store_tile(b_copy_lds_window, b_block_tile);
@@ -477,11 +400,10 @@ struct BlockGemmPipelineAGmemBGmemCReg
while (iCounter > 0)
{
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
load_tile(a_block_tile, a_copy_dram_window);
load_tile(b_block_tile, b_copy_dram_window);
a_block_tile = load_tile(a_copy_dram_window);
b_block_tile = load_tile(b_copy_dram_window);
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
store_tile(a_copy_lds_window, a_block_tile);
store_tile(b_copy_lds_window, b_block_tile);