Ck/joye/revert oob check (#5640)

## Motivation

fix ck_tile's oob check. 


## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
joyeamd
2026-03-20 20:30:08 +08:00
committed by GitHub
parent 005f9fc582
commit 1cc5380ee9
2 changed files with 27 additions and 91 deletions

View File

@@ -447,23 +447,11 @@ struct UniversalGemmKernel
}
if(kargs.K % vectorSizeA != 0)
{
const auto remainder = kargs.K % vectorSizeA;
constexpr ck_tile::index_t APackedSize =
ck_tile::numeric_traits<ADataType>::PackedSize;
const auto remainder_in_bytes = remainder * sizeof(ADataType) / APackedSize;
// oob can support to dword level
if(remainder_in_bytes % 4 == 0)
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
AsTensorIsValid = true;
}
else
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
}
AsTensorIsValid = false;
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
}
AsTensorIsValid = false;
}
}
else
@@ -479,24 +467,11 @@ struct UniversalGemmKernel
}
if(kargs.M % vectorSizeA != 0)
{
const auto remainder = kargs.M % vectorSizeA;
constexpr ck_tile::index_t APackedSize =
ck_tile::numeric_traits<ADataType>::PackedSize;
const auto remainder_in_bytes = remainder * sizeof(ADataType) / APackedSize;
// oob can support to dword level
if(remainder_in_bytes % 4 == 0)
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
AsTensorIsValid = true;
}
else
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
}
AsTensorIsValid = false;
CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
}
AsTensorIsValid = false;
}
}
});
@@ -519,58 +494,33 @@ struct UniversalGemmKernel
}
if(kargs.N % vectorSizeB != 0)
{
const auto remainder = kargs.N % vectorSizeB;
constexpr ck_tile::index_t BPackedSize =
ck_tile::numeric_traits<BDataType>::PackedSize;
const auto remainder_in_bytes = remainder * sizeof(BDataType) / BPackedSize;
// oob can support to dword level
if(remainder_in_bytes % 4 == 0)
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
BsTensorIsValid = true;
}
else
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
}
BsTensorIsValid = false;
CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
}
BsTensorIsValid = false;
}
else
}
else
{
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
GemmPipeline::kPadK == false)
{
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
GemmPipeline::kPadK == false)
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Can't support K that is not a multiple of k_batch * KPerBlock "
"without padding!");
}
BsTensorIsValid = false;
CK_TILE_ERROR(
"Can't support K that is not a multiple of k_batch * KPerBlock "
"without padding!");
}
if(kargs.K % vectorSizeB != 0)
BsTensorIsValid = false;
}
if(kargs.K % vectorSizeB != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
const auto remainder = kargs.K % vectorSizeB;
constexpr ck_tile::index_t BPackedSize =
ck_tile::numeric_traits<BDataType>::PackedSize;
const auto remainder_in_bytes = remainder * sizeof(BDataType) / BPackedSize;
// oob can support to dword level
if(remainder_in_bytes % 4 == 0)
{
BsTensorIsValid = true;
}
else
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"K is not a multiple of vector load size for B tensor!");
}
BsTensorIsValid = false;
}
CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
}
BsTensorIsValid = false;
}
}
});

View File

@@ -31,14 +31,7 @@ TYPED_TEST(TEST_SUITE_NAME, SmallM)
if constexpr(std::is_same_v<typename TestFixture::ALayout,
ck_tile::tensor_layout::gemm::ColumnMajor>)
{
if(M * sizeof(typename TestFixture::ADataType) % 4 == 0) // oob fit dword
{
this->Run(M, N, K);
}
else
{
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
}
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
}
else
{
@@ -91,14 +84,7 @@ TYPED_TEST(TEST_SUITE_NAME, MidLargeM)
}
else
{
if(M * sizeof(typename TestFixture::ADataType) % 4 == 0) // oob fit dword
{
this->Run(M, N, K);
}
else
{
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
}
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
}
}
else