mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Merge some updates for ck_tile headers (#3342)
* fix some issues from internal branch * update cshuffle_epilogue * update cshuffle_epilogue * update cshuffle * update warp_gemm
This commit is contained in:
@@ -423,7 +423,7 @@ struct UniversalGemmKernel
|
||||
|
||||
const auto vectorSizeA = is_wave32() ? GemmPipeline::template GetVectorSizeA<true>()
|
||||
: GemmPipeline::template GetVectorSizeA<false>();
|
||||
bool AsTesnorIsValid = {true};
|
||||
bool AsTensorIsValid = {true};
|
||||
static_for<0, NumATensor, 1>{}([&](auto index) {
|
||||
using AiLayout = remove_cvref_t<std::tuple_element_t<index.value, AsLayout>>;
|
||||
if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -437,15 +437,27 @@ struct UniversalGemmKernel
|
||||
"Can't support K that is not a multiple of k_batch * KPerBlock "
|
||||
"without padding!");
|
||||
}
|
||||
AsTesnorIsValid = false;
|
||||
AsTensorIsValid = false;
|
||||
}
|
||||
if(kargs.K % vectorSizeA != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
const auto remainder = kargs.K % vectorSizeA;
|
||||
constexpr ck_tile::index_t APackedSize =
|
||||
ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
const auto remainder_in_bytes = remainder * sizeof(ADataType) / APackedSize;
|
||||
// oob can support to dword level
|
||||
if(remainder_in_bytes % 4 == 0)
|
||||
{
|
||||
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
|
||||
AsTensorIsValid = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!");
|
||||
}
|
||||
AsTensorIsValid = false;
|
||||
}
|
||||
AsTesnorIsValid = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -457,20 +469,33 @@ struct UniversalGemmKernel
|
||||
CK_TILE_ERROR(
|
||||
"Can't support M that is not a multiple of MPerBlock without padding!");
|
||||
}
|
||||
AsTesnorIsValid = false;
|
||||
AsTensorIsValid = false;
|
||||
}
|
||||
if(kargs.M % vectorSizeA != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
const auto remainder = kargs.M % vectorSizeA;
|
||||
constexpr ck_tile::index_t APackedSize =
|
||||
ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
const auto remainder_in_bytes = remainder * sizeof(ADataType) / APackedSize;
|
||||
// oob can support to dword level
|
||||
if(remainder_in_bytes % 4 == 0)
|
||||
{
|
||||
CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
|
||||
|
||||
AsTensorIsValid = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!");
|
||||
}
|
||||
AsTensorIsValid = false;
|
||||
}
|
||||
AsTesnorIsValid = false;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
bool BsTesnorIsValid = {true};
|
||||
bool BsTensorIsValid = {true};
|
||||
const auto vectorSizeB = is_wave32() ? GemmPipeline::template GetVectorSizeB<true>()
|
||||
: GemmPipeline::template GetVectorSizeB<false>();
|
||||
static_for<0, NumBTensor, 1>{}([&](auto index) {
|
||||
@@ -484,47 +509,72 @@ struct UniversalGemmKernel
|
||||
CK_TILE_ERROR(
|
||||
"Can't support N that is not a multiple of NPerBlock without padding!");
|
||||
}
|
||||
BsTesnorIsValid = false;
|
||||
BsTensorIsValid = false;
|
||||
}
|
||||
if(kargs.N % vectorSizeB != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
const auto remainder = kargs.N % vectorSizeB;
|
||||
constexpr ck_tile::index_t BPackedSize =
|
||||
ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
const auto remainder_in_bytes = remainder * sizeof(BDataType) / BPackedSize;
|
||||
// oob can support to dword level
|
||||
if(remainder_in_bytes % 4 == 0)
|
||||
{
|
||||
CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
|
||||
BsTensorIsValid = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!");
|
||||
}
|
||||
BsTensorIsValid = false;
|
||||
}
|
||||
BsTesnorIsValid = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
|
||||
GemmPipeline::kPadK == false)
|
||||
else
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 &&
|
||||
GemmPipeline::kPadK == false)
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Can't support K that is not a multiple of k_batch * KPerBlock "
|
||||
"without padding!");
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Can't support K that is not a multiple of k_batch * KPerBlock "
|
||||
"without padding!");
|
||||
}
|
||||
BsTensorIsValid = false;
|
||||
}
|
||||
BsTesnorIsValid = false;
|
||||
}
|
||||
if(kargs.K % vectorSizeB != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
if(kargs.K % vectorSizeB != 0)
|
||||
{
|
||||
CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!");
|
||||
const auto remainder = kargs.K % vectorSizeB;
|
||||
constexpr ck_tile::index_t BPackedSize =
|
||||
ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
const auto remainder_in_bytes = remainder * sizeof(BDataType) / BPackedSize;
|
||||
// oob can support to dword level
|
||||
if(remainder_in_bytes % 4 == 0)
|
||||
{
|
||||
BsTensorIsValid = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"K is not a multiple of vector load size for B tensor!");
|
||||
}
|
||||
BsTensorIsValid = false;
|
||||
}
|
||||
}
|
||||
BsTesnorIsValid = false;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
bool DTesnorIsValid = {true};
|
||||
bool DTensorIsValid = {true};
|
||||
static_for<0, NumDTensor, 1>{}([&](auto index) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
|
||||
if(std::is_same_v<DiLayout, CLayout> == false)
|
||||
{
|
||||
DTesnorIsValid = false;
|
||||
DTensorIsValid = false;
|
||||
}
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -535,7 +585,7 @@ struct UniversalGemmKernel
|
||||
CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
|
||||
"NPerBlock without padding!");
|
||||
}
|
||||
DTesnorIsValid = false;
|
||||
DTensorIsValid = false;
|
||||
}
|
||||
if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
|
||||
{
|
||||
@@ -543,7 +593,7 @@ struct UniversalGemmKernel
|
||||
{
|
||||
CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
|
||||
}
|
||||
DTesnorIsValid = false;
|
||||
DTensorIsValid = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -555,7 +605,7 @@ struct UniversalGemmKernel
|
||||
CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
|
||||
"MPerBlock without padding!");
|
||||
}
|
||||
DTesnorIsValid = false;
|
||||
DTensorIsValid = false;
|
||||
}
|
||||
if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
|
||||
{
|
||||
@@ -563,7 +613,7 @@ struct UniversalGemmKernel
|
||||
{
|
||||
CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
|
||||
}
|
||||
DTesnorIsValid = false;
|
||||
DTensorIsValid = false;
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -608,7 +658,7 @@ struct UniversalGemmKernel
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid;
|
||||
return AsTensorIsValid && BsTensorIsValid && DTensorIsValid;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
|
||||
Reference in New Issue
Block a user