mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Minor splitk bugfix for gemms and conv (#3387)
* fix for splitk if splitk < grid * add different splitk implementation * minor bugfix for streamk gemm * Add test --------- Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>
This commit is contained in:
@@ -323,22 +323,38 @@ struct UniversalGemmKernel
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
|
||||
// This structure distributes work evenly among splitkk workgroups
|
||||
// It's based on a principle that if there is enough work to fill all workgroups,
|
||||
// then we can distribute the (K / K1) parts among k_batch workgroups in such a way
|
||||
// that each workgroup will be doing ceil((K / K1) / splitk) or ceil((K / K1) / splitk) - 1
|
||||
// and leave the potential tail for last(splitk - 1) indexed workgroup.
|
||||
__device__ SplitKBatchOffset(const KernelArgs& kargs, const index_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
|
||||
const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t num_all = amd_wave_read_first_lane(
|
||||
kargs.K / K1); // num of all loops not including potential tail
|
||||
index_t num_full = amd_wave_read_first_lane(num_all % kargs.k_batch);
|
||||
num_full = num_full == 0 ? kargs.k_batch : num_full;
|
||||
|
||||
const index_t num_full_iters =
|
||||
amd_wave_read_first_lane(std::max(integer_divide_ceil(num_all, kargs.k_batch), 1));
|
||||
const index_t full_k_read = num_full_iters * K1;
|
||||
const index_t partial_k_read = (num_full_iters - 1) * K1;
|
||||
|
||||
static_for<0, NumATensor, 1>{}([&](auto index) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, AiLayout>)
|
||||
{
|
||||
as_k_split_offset[index] = amd_wave_read_first_lane(k_id * KRead);
|
||||
as_k_split_offset[index] =
|
||||
amd_wave_read_first_lane(std::min(k_id, num_full) * full_k_read +
|
||||
std::max(k_id - num_full, 0) * partial_k_read);
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, AiLayout>)
|
||||
{
|
||||
as_k_split_offset[index] =
|
||||
amd_wave_read_first_lane(k_id * KRead * kargs.stride_As[index]);
|
||||
amd_wave_read_first_lane((std::min(k_id, num_full) * full_k_read +
|
||||
std::max(k_id - num_full, 0) * partial_k_read) *
|
||||
kargs.stride_As[index]);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -347,21 +363,30 @@ struct UniversalGemmKernel
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BiLayout>)
|
||||
{
|
||||
bs_k_split_offset[index] =
|
||||
amd_wave_read_first_lane(k_id * KRead * kargs.stride_Bs[index]);
|
||||
amd_wave_read_first_lane((std::min(k_id, num_full) * full_k_read +
|
||||
std::max(k_id - num_full, 0) * partial_k_read) *
|
||||
kargs.stride_Bs[index]);
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BiLayout>)
|
||||
{
|
||||
bs_k_split_offset[index] = amd_wave_read_first_lane(k_id * KRead);
|
||||
bs_k_split_offset[index] =
|
||||
amd_wave_read_first_lane(std::min(k_id, num_full) * full_k_read +
|
||||
std::max(k_id - num_full, 0) * partial_k_read);
|
||||
}
|
||||
});
|
||||
|
||||
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
|
||||
if(k_id == kargs.k_batch - 1)
|
||||
{
|
||||
splitted_k = amd_wave_read_first_lane(KRead);
|
||||
splitted_k = kargs.K - std::min(k_id, num_full) * full_k_read -
|
||||
std::max(k_id - num_full, 0) * partial_k_read;
|
||||
}
|
||||
else if(k_id < num_full)
|
||||
{
|
||||
splitted_k = full_k_read;
|
||||
}
|
||||
else
|
||||
{
|
||||
splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1));
|
||||
splitted_k = partial_k_read;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -385,6 +410,15 @@ struct UniversalGemmKernel
|
||||
}
|
||||
}
|
||||
|
||||
if(kargs.K < GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("KBatch is too large, part of GPU wouldn't be utilized!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA<true>()
|
||||
: GemmPipeline::template GetVectorSizeA<false>();
|
||||
bool AsTesnorIsValid = {true};
|
||||
|
||||
@@ -568,6 +568,15 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
}
|
||||
}
|
||||
|
||||
if(kargs.GemmK < TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("KBatch is too large, part of GPU wouldn't be utilized!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
|
||||
const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
|
||||
|
||||
|
||||
@@ -173,6 +173,11 @@ static GroupedConvBwdWeightHostArgs create_2d_host_args(index_t k_batch)
|
||||
return create_2d_host_args(2, 2, 8, 8, 3, 3, 7, 7, 1, 1, 1, 1, 1, 1, 1, 1, k_batch);
|
||||
}
|
||||
|
||||
static GroupedConvBwdWeightHostArgs create_large_2d_host_args(index_t k_batch)
|
||||
{
|
||||
return create_2d_host_args(2, 2, 8, 8, 3, 3, 70, 70, 1, 1, 1, 1, 1, 1, 1, 1, k_batch);
|
||||
}
|
||||
|
||||
class GroupedConvBwdWeightIsSupportedArgumentTest : public ::testing::Test
|
||||
{
|
||||
};
|
||||
@@ -227,6 +232,25 @@ TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, AtomicAddRequiresKBatchGreat
|
||||
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_2));
|
||||
}
|
||||
|
||||
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, K0KBatchLimitation)
|
||||
{
|
||||
using Kernel = typename BuildKernel<half_t,
|
||||
TestConvConfig,
|
||||
tensor_layout::convolution::NHWGC,
|
||||
tensor_layout::convolution::GKYXC,
|
||||
tensor_layout::convolution::NHWGK>::type;
|
||||
|
||||
// k_batch = 128 should pass
|
||||
auto host_args_kbatch_6 = create_2d_host_args(6);
|
||||
auto kargs_6 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_6);
|
||||
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_6));
|
||||
|
||||
// k_batch = 129 should fail for half_t output
|
||||
auto host_args_kbatch_7 = create_2d_host_args(7);
|
||||
auto kargs_7 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_7);
|
||||
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_7));
|
||||
}
|
||||
|
||||
TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, NonFloatDoubleOutputLimitsKBatch)
|
||||
{
|
||||
using Kernel = typename BuildKernel<half_t,
|
||||
@@ -236,13 +260,13 @@ TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, NonFloatDoubleOutputLimitsKB
|
||||
tensor_layout::convolution::NHWGK>::type;
|
||||
|
||||
// k_batch = 128 should pass
|
||||
auto host_args_kbatch_128 = create_2d_host_args(128);
|
||||
auto host_args_kbatch_128 = create_large_2d_host_args(128);
|
||||
auto kargs_128 =
|
||||
typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_128);
|
||||
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_128));
|
||||
|
||||
// k_batch = 129 should fail for half_t output
|
||||
auto host_args_kbatch_129 = create_2d_host_args(129);
|
||||
auto host_args_kbatch_129 = create_large_2d_host_args(129);
|
||||
auto kargs_129 =
|
||||
typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_129);
|
||||
EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_129));
|
||||
|
||||
Reference in New Issue
Block a user