Added Int4 mixed batch gemm support (#1839)

* remove redundant kernels.

* added batched_gemm_xdl_fp16int4_b_scale_v3

* Enabled the split K.

* added the batched_gemm_b_scale ckProfiler, meet function issue

* fix some typo

* fix ckProfiler build issue

* fix some bugs

* updated some debug info

* comment some code

* Fix

* fixed some bugs and refactor the code

* fixed a function bug.

* formatted files.

* formatted

* uncommented the ckProfiler CMakeLists

* fixed.

* fix ckProfiler for batched_gemm_b_scale

---------

Co-authored-by: mtgu0705 <mtgu@amd.com>
Co-authored-by: aska-0096 <haocwang@amd.com>
Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>
This commit is contained in:
Mingtao Gu
2025-02-10 11:17:02 +08:00
committed by GitHub
parent a8c5bd9b9a
commit d9f1ead347
14 changed files with 2678 additions and 50 deletions

View File

@@ -37,7 +37,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
@@ -70,7 +70,7 @@ __global__ void
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
@@ -638,45 +638,45 @@ struct GridwiseGemm_xdl_cshuffle_v3
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(Argument& karg)
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
{
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
a_k_split_offset = k_id * karg.KRead / APackedSize;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
if constexpr(!PermuteB)
{
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
b_k_split_offset = k_id * karg.KRead / BPackedSize;
}
else
{
const int k0_offset = karg.KRead * karg.N;
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
b_k_split_offset = k_id * k0_offset / BPackedSize;
}
}
// Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideB;
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK);
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
}
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
if(k_id < (karg.KBatch - 1))
{
karg.K = karg.KRead;
}
@@ -687,7 +687,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
if(karg.IsReduceAdd())
{
c_reduce_offset = blockIdx.z * karg.M * karg.N;
c_reduce_offset = k_id * karg.M * karg.N;
}
else
{