[rocm-libraries] ROCm/rocm-libraries#4791 (commit 6cc17c6)

[CK][CK TILE] Improve oob check

## Motivation

Improve OOB checks. Remove permutes which have been generated by thread
buffer zero clear. at now in assembly there is only condmask instead of
permute + condmask.

Change number of KPack for generated instances

## Technical Details

Remove permute instructions from assembly

## Test Plan

test_grouped_convnd_fwd_tile

## Test Result

passed

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Bartłomiej Kocot
2026-02-24 21:41:44 +00:00
committed by assistant-librarian[bot]
parent f3f4d7d842
commit 1a2c0d835a
6 changed files with 187 additions and 31 deletions

View File

@@ -2578,6 +2578,9 @@ CK_TILE_DEVICE void amd_buffer_atomic_max_impl(const thread_buffer<T, N> src_thr
}
}
template <typename T>
using has_type = typename T::type;
// buffer_load requires:
// 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer.
@@ -2608,12 +2611,48 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
return amd_buffer_load_impl<T, N, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
#else
thread_buffer<T, N> tmp =
amd_buffer_load_impl<T, N, coherence>(src_wave_buffer_resource, src_thread_addr_offset, 0);
if constexpr(oob_conditional_check)
return src_thread_element_valid ? tmp : thread_buffer<T, N>{numeric<T>::zero()};
else
return tmp;
{
if(!src_thread_element_valid)
{
if constexpr(is_detected<has_type, T>::value)
{
// Use vector_t for not valid elements to avoid permute instructions.
// Get raw type from structure
using vector_t = typename T::type __attribute__((ext_vector_type(N)));
if constexpr(sizeof(vector_t) != sizeof(typename T::type) * N)
{
// Not possible to use set_as
return thread_buffer<T, N>{numeric<T>::zero()};
}
else
{
thread_buffer<T, N> tmp;
tmp.template set_as<vector_t>(number<0>{},
vector_t{numeric<typename T::type>::zero()});
return tmp;
}
}
else
{
// Use vector_t for not valid elements to avoid permute instructions.
using vector_t = T __attribute__((ext_vector_type(N)));
if constexpr(sizeof(vector_t) != sizeof(T) * N)
{
// Not possible to use set_as
return thread_buffer<T, N>{numeric<T>::zero()};
}
else
{
thread_buffer<T, N> tmp;
tmp.template set_as<vector_t>(number<0>{}, vector_t{numeric<T>::zero()});
return tmp;
}
}
}
}
return amd_buffer_load_impl<T, N, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
#endif
}
@@ -2637,13 +2676,47 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
thread_buffer<T, N> tmp =
amd_buffer_load_impl<T, N, coherence>(src_wave_buffer_resource, src_thread_addr_offset, 0);
if constexpr(oob_conditional_check)
return src_thread_element_valid ? tmp : thread_buffer<T, N>{customized_value};
else
return tmp;
{
if(!src_thread_element_valid)
{
if constexpr(is_detected<has_type, T>::value)
{
// Use vector_t for not valid elements to avoid permute instructions.
// Get raw type from structure
using vector_t = typename T::type __attribute__((ext_vector_type(N)));
if constexpr(sizeof(vector_t) != sizeof(typename T::type) * N)
{
// Not possible to use set_as
return thread_buffer<T, N>{customized_value};
}
else
{
thread_buffer<T, N> tmp;
tmp.template set_as<vector_t>(number<0>{}, vector_t{customized_value});
return tmp;
}
}
else
{
// Use vector_t for not valid elements to avoid permute instructions.
using vector_t = T __attribute__((ext_vector_type(N)));
if constexpr(sizeof(vector_t) != sizeof(T) * N)
{
// Not possible to use set_as
return thread_buffer<T, N>{customized_value};
}
else
{
thread_buffer<T, N> tmp;
tmp.template set_as<vector_t>(number<0>{}, vector_t{customized_value});
return tmp;
}
}
}
}
return amd_buffer_load_impl<T, N, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
}
template <typename T,

View File

@@ -2404,6 +2404,9 @@ CK_TILE_DEVICE void amd_buffer_atomic_max_impl(const thread_buffer<T, N> src_thr
}
}
template <typename T>
using has_type = typename T::type;
// buffer_load requires:
// 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer.
@@ -2436,21 +2439,46 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#else
if constexpr(oob_conditional_check)
{
if(src_thread_element_valid)
if(!src_thread_element_valid)
{
return amd_buffer_load_impl<T, N, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
}
else
{
return thread_buffer<T, N>{numeric<T>::zero()};
if constexpr(is_detected<has_type, T>::value)
{
// Use vector_t for not valid elements to avoid permute instructions.
// Get raw type from structure
using vector_t = typename T::type __attribute__((ext_vector_type(N)));
if constexpr(sizeof(vector_t) != sizeof(typename T::type) * N)
{
// Not possible to use set_as
return thread_buffer<T, N>{numeric<T>::zero()};
}
else
{
thread_buffer<T, N> tmp;
tmp.template set_as<vector_t>(number<0>{},
vector_t{numeric<typename T::type>::zero()});
return tmp;
}
}
else
{
// Use vector_t for not valid elements to avoid permute instructions.
using vector_t = T __attribute__((ext_vector_type(N)));
if constexpr(sizeof(vector_t) != sizeof(T) * N)
{
// Not possible to use set_as
return thread_buffer<T, N>{numeric<T>::zero()};
}
else
{
thread_buffer<T, N> tmp;
tmp.template set_as<vector_t>(number<0>{}, vector_t{numeric<T>::zero()});
return tmp;
}
}
}
}
else
{
return amd_buffer_load_impl<T, N, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
}
return amd_buffer_load_impl<T, N, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
#endif
}
@@ -2474,13 +2502,47 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
thread_buffer<T, N> tmp =
amd_buffer_load_impl<T, N, coherence>(src_wave_buffer_resource, src_thread_addr_offset, 0);
if constexpr(oob_conditional_check)
return src_thread_element_valid ? tmp : thread_buffer<T, N>{customized_value};
else
return tmp;
{
if(!src_thread_element_valid)
{
if constexpr(is_detected<has_type, T>::value)
{
// Use vector_t for not valid elements to avoid permute instructions.
// Get raw type from structure
using vector_t = typename T::type __attribute__((ext_vector_type(N)));
if constexpr(sizeof(vector_t) != sizeof(typename T::type) * N)
{
// Not possible to use set_as
return thread_buffer<T, N>{customized_value};
}
else
{
thread_buffer<T, N> tmp;
tmp.template set_as<vector_t>(number<0>{}, vector_t{customized_value});
return tmp;
}
}
else
{
// Use vector_t for not valid elements to avoid permute instructions.
using vector_t = T __attribute__((ext_vector_type(N)));
if constexpr(sizeof(vector_t) != sizeof(T) * N)
{
// Not possible to use set_as
return thread_buffer<T, N>{customized_value};
}
else
{
thread_buffer<T, N> tmp;
tmp.template set_as<vector_t>(number<0>{}, vector_t{customized_value});
return tmp;
}
}
}
}
return amd_buffer_load_impl<T, N, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
}
template <typename T,

View File

@@ -35,6 +35,12 @@ using WarpGemmMfmaF32F32F32M32N32K8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK
4,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF32F32F32M32N32K4 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF32F32F32M32N32K2<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<

View File

@@ -36,6 +36,7 @@ struct Dispatcher;
template<> struct Dispatcher<float, float, float, 16, 16, 4, false> { using Type = WarpGemmMfmaF32F32F32M16N16K4; };
template<> struct Dispatcher<float, float, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF32F32F32M16N16K16<>; };
template<> struct Dispatcher<float, float, float, 16, 16, 8, false> { using Type = WarpGemmMfmaF32F32F32M16N16K8<>; };
template<> struct Dispatcher<float, float, float, 32, 32, 4, false> { using Type = WarpGemmMfmaF32F32F32M32N32K4<>; };
template<> struct Dispatcher<float, float, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF32F32F32M32N32K8<>; };
template<> struct Dispatcher<float, float, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; };
// fp16