mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[CK_TILE] Switch to universal gemm for batched and grouped gemms (#1919)
* switch to universal gemm for batched and grouped gemms * added reviewer comments * fixed grouped gemm tests
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
TYPED_TEST(TestCkTileGroupedGemm, Basic)
|
||||
{
|
||||
const int group_count = 16;
|
||||
const int group_count = 8;
|
||||
std::vector<int> Ms;
|
||||
std::vector<int> Ns;
|
||||
std::vector<int> Ks;
|
||||
@@ -13,8 +13,8 @@ TYPED_TEST(TestCkTileGroupedGemm, Basic)
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(128 + 128 * i);
|
||||
Ks.push_back(128 + 64 * i);
|
||||
Ns.push_back(256 + 512 * i);
|
||||
Ks.push_back(256 + 64 * i);
|
||||
|
||||
stride_As.push_back(Ks[i]);
|
||||
stride_Bs.push_back(Ks[i]);
|
||||
|
||||
@@ -44,65 +44,10 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
static const ck_tile::index_t K_Warp_Tile = 8;
|
||||
};
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<GroupedGemKernelParam::M_Tile,
|
||||
GroupedGemKernelParam::N_Tile,
|
||||
GroupedGemKernelParam::K_Tile>,
|
||||
ck_tile::sequence<GroupedGemKernelParam::M_Warp,
|
||||
GroupedGemKernelParam::N_Warp,
|
||||
GroupedGemKernelParam::K_Warp>,
|
||||
ck_tile::sequence<GroupedGemKernelParam::M_Warp_Tile,
|
||||
GroupedGemKernelParam::N_Warp_Tile,
|
||||
GroupedGemKernelParam::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
using CodegenGemmTraits = ck_tile::TileGemmTraits<GroupedGemKernelParam::kPadM,
|
||||
GroupedGemKernelParam::kPadN,
|
||||
GroupedGemKernelParam::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
using CodegenPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits<ALayout, BLayout, CLayout>>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
using CodegenGemmPipeline =
|
||||
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
CodegenGemmPipeline<ALayout, BLayout, CLayout>::BlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GroupedGemKernelParam::M_Warp,
|
||||
GroupedGemKernelParam::N_Warp,
|
||||
GroupedGemKernelParam::M_Warp_Tile,
|
||||
GroupedGemKernelParam::N_Warp_Tile,
|
||||
GroupedGemKernelParam::K_Warp_Tile,
|
||||
CodegenPipelineProblem<ALayout, BLayout, CLayout>::TransposeC>>;
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
|
||||
CodegenGemmPipeline<ALayout, BLayout, CLayout>,
|
||||
GemmEpilogue<ALayout, BLayout, CLayout>>;
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
|
||||
std::size_t GetWorkspaceSize(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
using grouped_gemm_kargs = ck_tile::GemmHostArgs;
|
||||
std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return Kernel<std::nullptr_t, std::nullptr_t, std::nullptr_t>::GetWorkSpaceSize(gemm_descs);
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
@@ -110,35 +55,140 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
const ck_tile::stream_config& s,
|
||||
void* p_workspace_)
|
||||
{
|
||||
using GroupedGemmKernel = Kernel<ALayout, BLayout, CLayout>;
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
constexpr bool TransposeC = false;
|
||||
|
||||
auto arguments = GroupedGemmKernel::MakeKargs(gemm_descs);
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
const dim3 grids = GroupedGemmKernel::GridSize(gemm_descs);
|
||||
constexpr dim3 blocks = GroupedGemmKernel::BlockSize();
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<GroupedGemKernelParam::M_Tile,
|
||||
GroupedGemKernelParam::N_Tile,
|
||||
GroupedGemKernelParam::K_Tile>,
|
||||
ck_tile::sequence<GroupedGemKernelParam::M_Warp,
|
||||
GroupedGemKernelParam::N_Warp,
|
||||
GroupedGemKernelParam::K_Warp>,
|
||||
ck_tile::sequence<GroupedGemKernelParam::M_Warp_Tile,
|
||||
GroupedGemKernelParam::N_Warp_Tile,
|
||||
GroupedGemKernelParam::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpyWithStream(
|
||||
p_workspace_,
|
||||
arguments.data(),
|
||||
arguments.size() * sizeof(typename GroupedGemmKernel::GemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
using Traits = ck_tile::TileGemmTraits<GroupedGemKernelParam::kPadM,
|
||||
GroupedGemKernelParam::kPadN,
|
||||
GroupedGemKernelParam::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GroupedGemKernelParam::kPadM,
|
||||
GroupedGemKernelParam::kPadN,
|
||||
GroupedGemKernelParam::kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
TransposeC>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GroupedGemKernelParam::K_Tile;
|
||||
const ck_tile::index_t K_split =
|
||||
(gemm_descs[0].K + k_grain - 1) / k_grain * GroupedGemKernelParam::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GroupedGemKernelParam::M_Warp,
|
||||
GroupedGemKernelParam::N_Warp,
|
||||
GroupedGemKernelParam::M_Warp_Tile,
|
||||
GroupedGemKernelParam::N_Warp_Tile,
|
||||
GroupedGemKernelParam::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpyWithStream(p_workspace_,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<blocks.x, GroupedGemKernelParam::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(p_workspace_),
|
||||
gemm_descs.size()));
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(has_hot_loop)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "For compute pipeline tail number should always be Full, but have \""
|
||||
<< tail_num << "\" which is not supported! PrefetchStages: "
|
||||
<< BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Num K loop must be larger than number of prefetech stages."
|
||||
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GroupedGemKernelParam::kBlockPerCu>(
|
||||
GroupedGemmKernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(p_workspace_),
|
||||
gemm_descs.size()));
|
||||
}
|
||||
|
||||
public:
|
||||
@@ -243,12 +293,14 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
|
||||
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
|
||||
|
||||
// TODO add support for kbatch > 1
|
||||
static constexpr ck_tile::index_t k_batch = 1;
|
||||
gemm_descs.push_back(
|
||||
{p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
|
||||
{p_a, p_b, p_c, k_batch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(GetWorkspaceSize(gemm_descs));
|
||||
gemm_workspace.Realloc(get_workspace_size(gemm_descs));
|
||||
|
||||
invoke_grouped_gemm<ALayout, BLayout, CLayout>(
|
||||
gemm_descs, ck_tile::stream_config{nullptr, false}, gemm_workspace.GetDeviceBuffer());
|
||||
|
||||
Reference in New Issue
Block a user