mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 12:30:16 +00:00
Grouped Gemm + SplitK + simplified Kernel Args (#669)
* simplify karg in device/grid split-k op
* fix mk_kn_mn instances
* add more instances
* B2C with 3D grid for KSplit
* Remove unused code.
* Use default B2C (3D grid) in grid gemm v2r4r2.
* Device gemm splitk use B2C map.
* Device GroupedGemmXdlSplitKCShuffle
* Example for GroupedGemm Xdl SplitK
* Introduce Device GroupedGemmSplitK
* Fix updating kbatch size.
* Add instance mk-nk-mn
* Enable set kbatch in profiler.
* Add GGemmSplitK mk-kn-mn instances
* Add more instances & split into multiple files.
* minor fix
* tuning
* clean
* disabled failed instances
* use pipe v2
* Ignore arg on not supported arch.
* fix warning
---------
Co-authored-by: carlushuang <carlus.huang@amd.com>
Co-authored-by: Adam Osewski <aosewski@amd.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
Co-authored-by: Jing Zhang <jizhan@amd.com>
Co-authored-by: root <root@ctr-ubbsmc15.amd.com>
[ROCm/composable_kernel commit: 8bb2bb4a05]
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp"
|
||||
@@ -39,7 +40,8 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideCs)
|
||||
const std::vector<int>& StrideCs,
|
||||
int kbatch = 1)
|
||||
{
|
||||
|
||||
bool pass = true;
|
||||
@@ -96,8 +98,6 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
a_m_k[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
|
||||
b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
|
||||
}
|
||||
|
||||
c_m_n_device_results[i].GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
|
||||
}
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
@@ -132,13 +132,12 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize()));
|
||||
b_device_buf.emplace_back(
|
||||
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize()));
|
||||
|
||||
c_device_buf.emplace_back(std::make_unique<DeviceMem>(
|
||||
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize()));
|
||||
|
||||
a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
|
||||
b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
|
||||
c_device_buf[i]->ToDevice(c_m_n_device_results[i].mData.data());
|
||||
c_device_buf[i]->SetZero();
|
||||
|
||||
gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
|
||||
|
||||
@@ -197,6 +196,28 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
{
|
||||
std::string gemm_name = gemm_ptr->GetTypeString();
|
||||
|
||||
if(kbatch > 1)
|
||||
{
|
||||
using DeviceOpSplitK =
|
||||
ck::tensor_operation::device::DeviceGroupedGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<>,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<>,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
if(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get()) != nullptr)
|
||||
{
|
||||
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
|
||||
->SetKBatchSize(argument_ptr.get(), kbatch);
|
||||
}
|
||||
}
|
||||
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user