mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] Minor splitk bugfix for gemms and conv (#3387)
* fix for splitk if splitk < grid * add different splitk implementation * minor bugfix for streamk gemm * Add test --------- Co-authored-by: Bartlomiej Kocot <barkocot@amd.com>
This commit is contained in:
@@ -323,22 +323,38 @@ struct UniversalGemmKernel
|
||||
|
||||
struct SplitKBatchOffset
|
||||
{
|
||||
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
|
||||
// This structure distributes work evenly among splitkk workgroups
|
||||
// It's based on a principle that if there is enough work to fill all workgroups,
|
||||
// then we can distribute the (K / K1) parts among k_batch workgroups in such a way
|
||||
// that each workgroup will be doing ceil((K / K1) / splitk) or ceil((K / K1) / splitk) - 1
|
||||
// and leave the potential tail for last(splitk - 1) indexed workgroup.
|
||||
__device__ SplitKBatchOffset(const KernelArgs& kargs, const index_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
|
||||
const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t num_all = amd_wave_read_first_lane(
|
||||
kargs.K / K1); // num of all loops not including potential tail
|
||||
index_t num_full = amd_wave_read_first_lane(num_all % kargs.k_batch);
|
||||
num_full = num_full == 0 ? kargs.k_batch : num_full;
|
||||
|
||||
const index_t num_full_iters =
|
||||
amd_wave_read_first_lane(std::max(integer_divide_ceil(num_all, kargs.k_batch), 1));
|
||||
const index_t full_k_read = num_full_iters * K1;
|
||||
const index_t partial_k_read = (num_full_iters - 1) * K1;
|
||||
|
||||
static_for<0, NumATensor, 1>{}([&](auto index) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, AiLayout>)
|
||||
{
|
||||
as_k_split_offset[index] = amd_wave_read_first_lane(k_id * KRead);
|
||||
as_k_split_offset[index] =
|
||||
amd_wave_read_first_lane(std::min(k_id, num_full) * full_k_read +
|
||||
std::max(k_id - num_full, 0) * partial_k_read);
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, AiLayout>)
|
||||
{
|
||||
as_k_split_offset[index] =
|
||||
amd_wave_read_first_lane(k_id * KRead * kargs.stride_As[index]);
|
||||
amd_wave_read_first_lane((std::min(k_id, num_full) * full_k_read +
|
||||
std::max(k_id - num_full, 0) * partial_k_read) *
|
||||
kargs.stride_As[index]);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -347,21 +363,30 @@ struct UniversalGemmKernel
|
||||
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BiLayout>)
|
||||
{
|
||||
bs_k_split_offset[index] =
|
||||
amd_wave_read_first_lane(k_id * KRead * kargs.stride_Bs[index]);
|
||||
amd_wave_read_first_lane((std::min(k_id, num_full) * full_k_read +
|
||||
std::max(k_id - num_full, 0) * partial_k_read) *
|
||||
kargs.stride_Bs[index]);
|
||||
}
|
||||
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BiLayout>)
|
||||
{
|
||||
bs_k_split_offset[index] = amd_wave_read_first_lane(k_id * KRead);
|
||||
bs_k_split_offset[index] =
|
||||
amd_wave_read_first_lane(std::min(k_id, num_full) * full_k_read +
|
||||
std::max(k_id - num_full, 0) * partial_k_read);
|
||||
}
|
||||
});
|
||||
|
||||
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
|
||||
if(k_id == kargs.k_batch - 1)
|
||||
{
|
||||
splitted_k = amd_wave_read_first_lane(KRead);
|
||||
splitted_k = kargs.K - std::min(k_id, num_full) * full_k_read -
|
||||
std::max(k_id - num_full, 0) * partial_k_read;
|
||||
}
|
||||
else if(k_id < num_full)
|
||||
{
|
||||
splitted_k = full_k_read;
|
||||
}
|
||||
else
|
||||
{
|
||||
splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1));
|
||||
splitted_k = partial_k_read;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -385,6 +410,15 @@ struct UniversalGemmKernel
|
||||
}
|
||||
}
|
||||
|
||||
if(kargs.K < GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}) * kargs.k_batch)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("KBatch is too large, part of GPU wouldn't be utilized!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA<true>()
|
||||
: GemmPipeline::template GetVectorSizeA<false>();
|
||||
bool AsTesnorIsValid = {true};
|
||||
|
||||
Reference in New Issue
Block a user