mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Add SplitK support into Batched GEMM V3 (#1729)
* add bmm api * add bf16 multi_d * add ckProfiler for bf16 * add ckProfiler files * add more instance; fixed 64bit index issue * fixed naming * enabled batched Ds * use long_index for ds offsets * clean * add bmm fp8 ckProfiler * Update example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * Update example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * Update example/24_batched_gemm/run_batched_gemm_example_rowwise.inc Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instance.cpp Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * Update profiler/src/profile_gemm_universal_batched.cpp Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * Update profiler/include/profiler/profile_gemm_universal_batched_impl.hpp Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * clean * Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp * Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp * Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp * Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp * Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp * Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp * refactor batch offset func * add splitk suppport into bmm_v3 * clean * clean * format * fixed * fix --------- Co-authored-by: Jing Zhang <jizhan@fb.com> Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -41,12 +41,15 @@ __global__ void
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t g_idx = blockIdx.z % karg.Batch;
|
||||
const index_t k_idx = blockIdx.z / karg.Batch;
|
||||
|
||||
const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
|
||||
const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
|
||||
const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx);
|
||||
|
||||
// populate pointer, desc for Ds
|
||||
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
|
||||
// D pointer
|
||||
@@ -54,8 +57,8 @@ __global__ void
|
||||
});
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid + c_batch_offset,
|
||||
p_shared,
|
||||
@@ -87,12 +90,15 @@ __global__ void
|
||||
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
const index_t g_idx = blockIdx.z % karg.Batch;
|
||||
const index_t k_idx = blockIdx.z / karg.Batch;
|
||||
|
||||
const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
|
||||
const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
|
||||
const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx);
|
||||
|
||||
// populate pointer, desc for Ds
|
||||
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
|
||||
// D pointer
|
||||
@@ -100,8 +106,8 @@ __global__ void
|
||||
});
|
||||
|
||||
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid + a_batch_offset,
|
||||
karg.p_b_grid + b_batch_offset,
|
||||
karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid + c_batch_offset,
|
||||
p_shared_0,
|
||||
@@ -303,7 +309,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
|
||||
index_t Batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CElementwiseOperation c_element_op_)
|
||||
CElementwiseOperation c_element_op_,
|
||||
index_t KBatch_)
|
||||
: GridwiseGemm::Argument{p_a_grid_,
|
||||
p_b_grid_,
|
||||
p_ds_grid_,
|
||||
@@ -315,7 +322,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
|
||||
StrideB_,
|
||||
StrideDs_,
|
||||
StrideE_,
|
||||
1,
|
||||
KBatch_,
|
||||
a_element_op_,
|
||||
b_element_op_,
|
||||
c_element_op_},
|
||||
@@ -336,13 +343,14 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg) || arg.KBatch > 1)
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch);
|
||||
std::tie(gdx, gdy, gdz) =
|
||||
GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch * arg.KBatch);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
@@ -387,10 +395,11 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg.Batch * arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
@@ -889,7 +898,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
|
||||
index_t BatchStrideE,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
CElementwiseOperation c_element_op,
|
||||
index_t KBatch = 1)
|
||||
{
|
||||
return Argument{static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
@@ -909,7 +919,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
|
||||
Batch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
c_element_op,
|
||||
KBatch};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -934,7 +945,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
|
||||
index_t BatchStrideE,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
CElementwiseOperation c_element_op,
|
||||
index_t KBatch = 1) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
@@ -954,7 +966,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
|
||||
Batch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
c_element_op,
|
||||
KBatch);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
|
||||
Reference in New Issue
Block a user