mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Added Int4 mixed batch gemm support (#1839)
* remove redundant kernels. * added batched_gemm_xdl_fp16int4_b_scale_v3 * Enabled the split K. * added the batched_gemm_b_scale ckProfiler, meet function issue * fix some typo * fix ckProfiler build issue * fix some bugs * updated some debug info * comment some code * Fix * fixed some bugs and refactor the code * fixed a function bug. * formatted files. * formatted * uncommented the ckProfiler CMakeLists * fixed. * fix ckProfiler for batched_gemm_b_scale --------- Co-authored-by: mtgu0705 <mtgu@amd.com> Co-authored-by: aska-0096 <haocwang@amd.com> Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>
This commit is contained in:
@@ -44,6 +44,48 @@ struct DeviceBatchedGemm : public BaseOperator
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename BScaleType,
|
||||
typename CDataType,
|
||||
index_t ScaleBlockN,
|
||||
index_t ScaleBlockK,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceBatchedGemmV2BScale : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t StrideScaleB,
|
||||
ck::index_t BatchStrideA,
|
||||
ck::index_t BatchStrideB,
|
||||
ck::index_t BatchStrideC,
|
||||
ck::index_t BatchStrideScaleB,
|
||||
const void* p_b_scale,
|
||||
ck::index_t Batch,
|
||||
ck::index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual bool GetPermuteB() = 0;
|
||||
virtual ck::index_t GetKPerBlock() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user