[CK TILE] Block universal gemm lds<->vgpr optimizations (#1906)

* [CK TILE] Block universal gemm lds<->vgpr optimizations

* Rebase

* Fixes
This commit is contained in:
Bartłomiej Kocot
2025-02-27 10:36:28 +01:00
committed by GitHub
parent e9ee568681
commit bf1e17007e
7 changed files with 305 additions and 406 deletions

View File

@@ -129,34 +129,34 @@ struct GemmKernel
const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = kargs.k_batch * K1;
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1);
const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1);
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = k_id * KRead;
a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = k_id * KRead * kargs.stride_A;
a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A);
}
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = k_id * KRead * kargs.stride_B;
b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B);
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = k_id * KRead;
b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead);
}
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
{
splitted_k = KRead;
splitted_k = __builtin_amdgcn_readfirstlane(KRead);
}
else
{
splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1));
}
}
@@ -523,7 +523,8 @@ struct GemmKernel
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
@@ -574,7 +575,8 @@ struct GemmKernel
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
@@ -593,7 +595,8 @@ struct GemmKernel
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
{
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
@@ -607,12 +610,12 @@ struct GemmKernel
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
__shared__ char smem_ptr_1[GetSmemSize()];
if(kargs.k_batch == 1)
{
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
RunGemm2LDS(a_ptr,
b_ptr,
c_ptr,
@@ -637,6 +640,7 @@ struct GemmKernel
{
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
RunGemm2LDS<memory_operation_enum::atomic_add>(a_ptr,
b_ptr,
c_ptr,