mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#5050 (commit 033dad7)
[CK TILE] Skip work if any of Grouped GEMM groups M/N/K are zero. (#5050) ## Motivation It's common in MoE workloads that some experts receive zero tokens, which would result in some of the dimensions equal to zero. Currently we handle such case only for non-persistent kernels where we have all GEMMs information beforehand on host - we validate this during creation of kernel arguments. However for the "dynamic" input path (persistent kernel) this information is not available before kernel launch. Thus we have to validate this during kernel execution. The goal is to add this validation. ## Technical Details Skip work if any of Grouped GEMM groups M/N/K are zero for persistent kernel path. ## Test Plan Add unit-tests which cover "dynamic" inputs with zero dims for persistent kernel execution path. ## Test Result All tests pass. ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
2c3f9bfa52
commit
b09ce811d5
@@ -507,6 +507,12 @@ struct GroupedGemmKernel
|
||||
const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
|
||||
const auto& kargs = gemm_desc_ptr[group_id];
|
||||
|
||||
// Early exit if no work to do.
|
||||
if(kargs.group_karg.M == 0 || kargs.group_karg.N == 0 || kargs.group_karg.K == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N);
|
||||
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
|
||||
0,
|
||||
@@ -534,6 +540,22 @@ struct GroupedGemmKernel
|
||||
const auto& k_batch = kargs.k_batch;
|
||||
const auto block_start = cum_grid_size;
|
||||
cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch;
|
||||
|
||||
// Early exit if no work to do.
|
||||
// If M or N is zero, TilePartitioner::GridSize(kargs.M, kargs.N) returns zero,
|
||||
// so this group contributes no blocks and cum_grid_size is unchanged. The group
|
||||
// is naturally skipped by the block_id < cum_grid_size check below.
|
||||
if(kargs.K == 0)
|
||||
{
|
||||
// Advance only if this workgroup was assigned to this group's range,
|
||||
// matching the pattern of the normal while loop below.
|
||||
while(block_id < cum_grid_size)
|
||||
{
|
||||
block_id += grid_size;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
while(block_id < cum_grid_size)
|
||||
{
|
||||
const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm test_grouped_gemm.cpp)
|
||||
|
||||
add_custom_target(test_ck_tile_grouped_gemm)
|
||||
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_f16 test_grouped_gemm_f16.cpp)
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_bf16 test_grouped_gemm_bf16.cpp)
|
||||
|
||||
add_dependencies(test_ck_tile_grouped_gemm
|
||||
test_ck_tile_grouped_gemm_f16
|
||||
test_ck_tile_grouped_gemm_bf16)
|
||||
endif()
|
||||
|
||||
41
test/ck_tile/grouped_gemm/test_grouped_gemm_bf16.cpp
Normal file
41
test/ck_tile/grouped_gemm/test_grouped_gemm_bf16.cpp
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_grouped_gemm_util.hpp"
|
||||
|
||||
using F32 = float;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using True = ck_tile::bool_constant<true>;
|
||||
using False = ck_tile::bool_constant<false>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, Persistent
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, True>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, False>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, True>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, False>,
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, True>,
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, False>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, True>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, False>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileGroupedGemmBF16 : public TestCkTileGroupedGemm<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGroupedGemmBF16, KernelTypes);
|
||||
|
||||
#define TEST_CKTILE_GGEMM_SUITE_NAME TestCkTileGroupedGemmBF16
|
||||
|
||||
#include "test_grouped_gemm_ut_cases.inc"
|
||||
@@ -10,7 +10,6 @@
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using True = ck_tile::bool_constant<true>;
|
||||
@@ -21,25 +20,22 @@ using KernelTypes = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, Persistent
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, True>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, False>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, True>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, False>,
|
||||
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, True>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, False>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, True>,
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, False>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, True>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, False>,
|
||||
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, True>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, False>,
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, True>,
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, False>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, True>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, False>
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, False>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGroupedGemm, KernelTypes);
|
||||
template <typename Tuple>
|
||||
class TestCkTileGroupedGemmF16 : public TestCkTileGroupedGemm<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGroupedGemmF16, KernelTypes);
|
||||
|
||||
#define TEST_CKTILE_GGEMM_SUITE_NAME TestCkTileGroupedGemmF16
|
||||
|
||||
#include "test_grouped_gemm_ut_cases.inc"
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestCkTileGroupedGemm, Basic)
|
||||
TYPED_TEST(TEST_CKTILE_GGEMM_SUITE_NAME, Basic)
|
||||
{
|
||||
const int group_count = 8;
|
||||
const int kbatch = 1;
|
||||
@@ -16,19 +16,19 @@ TYPED_TEST(TestCkTileGroupedGemm, Basic)
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(256 + 512 * i);
|
||||
Ks.push_back(512 + 128 * i);
|
||||
Ms.push_back(64 + 64 * i);
|
||||
Ns.push_back(128 + 64 * i);
|
||||
Ks.push_back(64 + 32 * i);
|
||||
|
||||
stride_As.push_back(Ks[i]);
|
||||
stride_Bs.push_back(Ks[i]);
|
||||
stride_Cs.push_back(Ns[i]);
|
||||
stride_As.push_back(0);
|
||||
stride_Bs.push_back(0);
|
||||
stride_Cs.push_back(0);
|
||||
}
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileGroupedGemm, SplitK)
|
||||
TYPED_TEST(TEST_CKTILE_GGEMM_SUITE_NAME, SplitK)
|
||||
{
|
||||
const int group_count = 8;
|
||||
const int kbatch = 2;
|
||||
@@ -41,14 +41,64 @@ TYPED_TEST(TestCkTileGroupedGemm, SplitK)
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(256 + 512 * i);
|
||||
Ks.push_back(512 + 128 * i);
|
||||
Ms.push_back(64 + 64 * i);
|
||||
Ns.push_back(128 + 64 * i);
|
||||
Ks.push_back(64 + 32 * i);
|
||||
|
||||
stride_As.push_back(Ks[i]);
|
||||
stride_Bs.push_back(Ks[i]);
|
||||
stride_Cs.push_back(Ns[i]);
|
||||
stride_As.push_back(0);
|
||||
stride_Bs.push_back(0);
|
||||
stride_Cs.push_back(0);
|
||||
}
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count);
|
||||
}
|
||||
|
||||
// Verify that groups with M=0 are silently skipped (primary MoE scenario where some
|
||||
// experts receive zero tokens) and that non-zero groups produce correct results.
|
||||
TYPED_TEST(TEST_CKTILE_GGEMM_SUITE_NAME, ZeroM)
|
||||
{
|
||||
const int group_count = 8;
|
||||
const int kbatch = 1;
|
||||
|
||||
const std::vector<int> Ms = {256, 0, 256, 0, 256, 256, 0, 256};
|
||||
const std::vector<int> Ns = {256, 256, 256, 256, 256, 256, 256, 256};
|
||||
const std::vector<int> Ks = {512, 512, 512, 512, 512, 512, 512, 512};
|
||||
std::vector<int> stride_As(group_count, 0);
|
||||
std::vector<int> stride_Bs(group_count, 0);
|
||||
std::vector<int> stride_Cs(group_count, 0);
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count);
|
||||
}
|
||||
|
||||
// Verify that groups with K=0 produce all-zero output and that surrounding groups
|
||||
// with non-zero K are unaffected.
|
||||
TYPED_TEST(TEST_CKTILE_GGEMM_SUITE_NAME, ZeroK)
|
||||
{
|
||||
const int group_count = 8;
|
||||
const int kbatch = 1;
|
||||
|
||||
const std::vector<int> Ms = {256, 256, 256, 256, 256, 256, 256, 256};
|
||||
const std::vector<int> Ns = {256, 256, 256, 256, 256, 256, 256, 256};
|
||||
const std::vector<int> Ks = {512, 512, 512, 0, 0, 512, 512, 512};
|
||||
std::vector<int> stride_As(group_count, 0);
|
||||
std::vector<int> stride_Bs(group_count, 0);
|
||||
std::vector<int> stride_Cs(group_count, 0);
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count);
|
||||
}
|
||||
|
||||
// Verify that a mix of M=0, N=0, and K=0 groups all behave correctly together.
|
||||
TYPED_TEST(TEST_CKTILE_GGEMM_SUITE_NAME, ZeroMixed)
|
||||
{
|
||||
const int group_count = 8;
|
||||
const int kbatch = 1;
|
||||
|
||||
const std::vector<int> Ms = {256, 0, 256, 256, 512, 256, 0, 256};
|
||||
const std::vector<int> Ns = {256, 256, 512, 0, 256, 256, 256, 256};
|
||||
const std::vector<int> Ks = {512, 512, 512, 512, 512, 0, 512, 512};
|
||||
std::vector<int> stride_As(group_count, 0);
|
||||
std::vector<int> stride_Bs(group_count, 0);
|
||||
std::vector<int> stride_Cs(group_count, 0);
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count);
|
||||
}
|
||||
|
||||
@@ -32,14 +32,14 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
|
||||
struct GroupedGemKernelParam_Mfma
|
||||
{
|
||||
static const bool kPadM = false;
|
||||
static const bool kPadN = false;
|
||||
static const bool kPadK = false;
|
||||
static const bool kPadM = true;
|
||||
static const bool kPadN = true;
|
||||
static const bool kPadK = true;
|
||||
|
||||
static const int kBlockPerCu = 1;
|
||||
static const ck_tile::index_t M_Tile = 256;
|
||||
static const ck_tile::index_t N_Tile = 256;
|
||||
static const ck_tile::index_t K_Tile = 64;
|
||||
static const ck_tile::index_t M_Tile = 64;
|
||||
static const ck_tile::index_t N_Tile = 64;
|
||||
static const ck_tile::index_t K_Tile = 32;
|
||||
|
||||
static const ck_tile::index_t M_Warp = 2;
|
||||
static const ck_tile::index_t N_Warp = 2;
|
||||
@@ -52,9 +52,9 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
|
||||
struct GroupedGemKernelParam_Wmma : public GroupedGemKernelParam_Mfma
|
||||
{
|
||||
static const ck_tile::index_t M_Tile = 128;
|
||||
static const ck_tile::index_t N_Tile = 128;
|
||||
static const ck_tile::index_t K_Tile = 64;
|
||||
static const ck_tile::index_t M_Tile = 64;
|
||||
static const ck_tile::index_t N_Tile = 64;
|
||||
static const ck_tile::index_t K_Tile = 32;
|
||||
|
||||
static const ck_tile::index_t M_Warp_Tile = 16;
|
||||
static const ck_tile::index_t N_Warp_Tile = 16;
|
||||
@@ -131,14 +131,20 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
|
||||
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
// Use the filtered kargs (zero-dim groups are excluded by MakeKargs) to derive
|
||||
// the correct grid size and group count — not the raw gemm_descs vector.
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
if(kargs.empty())
|
||||
return;
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
const dim3 grids = dim3(kargs.back().block_end, 1, 1);
|
||||
|
||||
ck_tile::hip_check_error(
|
||||
hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::GemmTransKernelArg<>),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
@@ -155,7 +161,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
kargs.size()));
|
||||
}
|
||||
|
||||
template <typename GroupedGemKernelParam, typename ALayout, typename BLayout, typename CLayout>
|
||||
@@ -296,11 +302,13 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
if constexpr(std::is_same_v<decltype(layout),
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col;
|
||||
// Use stride 1, in case the dim equals to zero
|
||||
return std::max(col, std::size_t{1});
|
||||
}
|
||||
else
|
||||
{
|
||||
return row;
|
||||
// Use stride 1, in case the dim equals to zero
|
||||
return std::max(row, std::size_t{1});
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -332,7 +340,7 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
const ck_tile::index_t N = Ns[i];
|
||||
const ck_tile::index_t K = Ks[i];
|
||||
|
||||
stride_As[i] = f_get_default_stride(M, N, stride_As[i], ALayout{});
|
||||
stride_As[i] = f_get_default_stride(M, K, stride_As[i], ALayout{});
|
||||
stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], BLayout{});
|
||||
stride_Cs[i] = f_get_default_stride(M, N, stride_Cs[i], CLayout{});
|
||||
|
||||
@@ -442,17 +450,27 @@ class TestCkTileGroupedGemm : public ::testing::Test
|
||||
bool pass{true};
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
// Groups with M=0 or N=0 produce no output — skip validation.
|
||||
// K=0 groups do produce output (all zeros) and are validated normally.
|
||||
if(Ms[i] == 0 || Ns[i] == 0)
|
||||
continue;
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
f_host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{}));
|
||||
c_m_n_host_ref.SetZero();
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
const float max_accumulated_value = std::abs(static_cast<float>(*std::max_element(
|
||||
// Use max absolute value (not algebraic max) to calibrate atol.
|
||||
// The absolute threshold in calculate_rtol_atol scales with this value,
|
||||
// so using the algebraic max (which may be a small positive number when
|
||||
// most outputs are negative) would produce a near-zero atol. Near-zero
|
||||
// reference elements then have no tolerance headroom for the ~1 ULP
|
||||
// error introduced by SplitK atomicAdd accumulation.
|
||||
const float max_accumulated_value = std::accumulate(
|
||||
c_m_n_host_ref.mData.begin(),
|
||||
c_m_n_host_ref.mData.end(),
|
||||
[](CDataType a, CDataType b) {
|
||||
return std::abs(static_cast<float>(a)) < std::abs(static_cast<float>(b));
|
||||
})));
|
||||
0.0f,
|
||||
[](float acc, auto v) { return std::max(acc, std::abs(static_cast<float>(v))); });
|
||||
const auto rtol_atol = calculate_rtol_atol(Ks[i], kbatch, max_accumulated_value);
|
||||
pass &= ck_tile::check_err(c_m_n_tensors[i],
|
||||
c_m_n_host_ref,
|
||||
|
||||
Reference in New Issue
Block a user