mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Added Int4 mixed batch gemm support (#1839)
* remove redundant kernels. * added batched_gemm_xdl_fp16int4_b_scale_v3 * Enabled the split K. * added the batched_gemm_b_scale ckProfiler, meet function issue * fix some typo * fix ckProfiler build issue * fix some bugs * updated some debug info * comment some code * Fix * fixed some bugs and refactor the code * fixed a function bug. * formatted files. * formatted * uncommented the ckProfiler CMakeLists * fixed. * fix ckProfiler for batched_gemm_b_scale --------- Co-authored-by: mtgu0705 <mtgu@amd.com> Co-authored-by: aska-0096 <haocwang@amd.com> Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>
This commit is contained in:
@@ -37,7 +37,7 @@ __global__ void
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
@@ -70,7 +70,7 @@ __global__ void
|
||||
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
@@ -638,45 +638,45 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
|
||||
__device__ SplitKBatchOffset(Argument& karg)
|
||||
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
|
||||
{
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
|
||||
a_k_split_offset = k_id * karg.KRead / APackedSize;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
|
||||
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
|
||||
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
if constexpr(!PermuteB)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
|
||||
b_k_split_offset = k_id * karg.KRead / BPackedSize;
|
||||
}
|
||||
else
|
||||
{
|
||||
const int k0_offset = karg.KRead * karg.N;
|
||||
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
|
||||
b_k_split_offset = k_id * k0_offset / BPackedSize;
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate B scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideB;
|
||||
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK);
|
||||
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
|
||||
}
|
||||
|
||||
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
|
||||
if(k_id < (karg.KBatch - 1))
|
||||
{
|
||||
karg.K = karg.KRead;
|
||||
}
|
||||
@@ -687,7 +687,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
if(karg.IsReduceAdd())
|
||||
{
|
||||
c_reduce_offset = blockIdx.z * karg.M * karg.N;
|
||||
c_reduce_offset = k_id * karg.M * karg.N;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user