mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Add SplitK support into Batched GEMM V3 (#1729)
* add bmm api * add bf16 multi_d * add ckProfiler for bf16 * add ckProfiler files * add more instance; fixed 64bit index issue * fixed naming * enabled batched Ds * use long_index for ds offsets * clean * add bmm fp8 ckProfiler * Update example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * Update example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * Update example/24_batched_gemm/run_batched_gemm_example_rowwise.inc Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instance.cpp Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * Update profiler/src/profile_gemm_universal_batched.cpp Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * Update profiler/include/profiler/profile_gemm_universal_batched_impl.hpp Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * clean * Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp * Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp * Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp * Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp * Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp * Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp * refactor batch offset func * add splitk suppport into bmm_v3 * clean * clean * format * fixed * fix --------- Co-authored-by: Jing Zhang <jizhan@fb.com> Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -41,7 +41,7 @@ __global__ void
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
@@ -76,7 +76,7 @@ __global__ void
|
||||
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
@@ -639,27 +639,27 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
__device__ SplitKBatchOffset(Argument& karg)
|
||||
__device__ SplitKBatchOffset(Argument& karg, index_t k_id)
|
||||
{
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = blockIdx.z * karg.KRead;
|
||||
a_k_split_offset = k_id * karg.KRead;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
|
||||
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
|
||||
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = blockIdx.z * karg.KRead;
|
||||
b_k_split_offset = k_id * karg.KRead;
|
||||
}
|
||||
|
||||
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
|
||||
if(k_id < karg.KBatch - 1)
|
||||
{
|
||||
karg.K = karg.KRead;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user