mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK][CK TILE] Improve oob check (#4791)
## Motivation Improve OOB checks. Remove permutes which have been generated by thread buffer zero clear. at now in assembly there is only condmask instead of permute + condmask. Change number of KPack for generated instances ## Technical Details Remove permute instructions from assembly ## Test Plan test_grouped_convnd_fwd_tile ## Test Result passed ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: jakpiase <jakpia21@gmail.com>
This commit is contained in:
@@ -197,7 +197,21 @@ def parse_fwd_instances(instances, problem_name):
|
||||
dtype = get_dtype(problem_name)
|
||||
# TODO: Make it more flexible
|
||||
# k_per_xdl = f"ck_tile::get_k_warp_tile<{dtype}, {m_per_xdl}>()"
|
||||
k_per_xdl = 8 if dtype == "float" else 16
|
||||
if dtype == "float":
|
||||
if m_per_xdl == 32:
|
||||
if instance.find("BlkGemmPipelineVersion") == -1:
|
||||
k_per_xdl = 4
|
||||
else:
|
||||
# Increase for universal gemm
|
||||
k_per_xdl = 8
|
||||
else:
|
||||
k_per_xdl = 8
|
||||
else:
|
||||
if m_per_xdl == 32:
|
||||
k_per_xdl = 16
|
||||
else:
|
||||
k_per_xdl = 32
|
||||
k_per_xdl = min(k_per_xdl, k_per_block)
|
||||
|
||||
conv = ConvInstanceTemplateParams(
|
||||
spec,
|
||||
|
||||
@@ -2578,6 +2578,9 @@ CK_TILE_DEVICE void amd_buffer_atomic_max_impl(const thread_buffer<T, N> src_thr
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using has_type = typename T::type;
|
||||
|
||||
// buffer_load requires:
|
||||
// 1) p_src_wave must point to global memory space
|
||||
// 2) p_src_wave must be a wavewise pointer.
|
||||
@@ -2608,12 +2611,48 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
|
||||
return amd_buffer_load_impl<T, N, coherence>(
|
||||
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
||||
#else
|
||||
thread_buffer<T, N> tmp =
|
||||
amd_buffer_load_impl<T, N, coherence>(src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
if constexpr(oob_conditional_check)
|
||||
return src_thread_element_valid ? tmp : thread_buffer<T, N>{numeric<T>::zero()};
|
||||
else
|
||||
return tmp;
|
||||
{
|
||||
if(!src_thread_element_valid)
|
||||
{
|
||||
if constexpr(is_detected<has_type, T>::value)
|
||||
{
|
||||
// Use vector_t for not valid elements to avoid permute instructions.
|
||||
// Get raw type from structure
|
||||
using vector_t = typename T::type __attribute__((ext_vector_type(N)));
|
||||
if constexpr(sizeof(vector_t) != sizeof(typename T::type) * N)
|
||||
{
|
||||
// Not possible to use set_as
|
||||
return thread_buffer<T, N>{numeric<T>::zero()};
|
||||
}
|
||||
else
|
||||
{
|
||||
thread_buffer<T, N> tmp;
|
||||
tmp.template set_as<vector_t>(number<0>{},
|
||||
vector_t{numeric<typename T::type>::zero()});
|
||||
return tmp;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Use vector_t for not valid elements to avoid permute instructions.
|
||||
using vector_t = T __attribute__((ext_vector_type(N)));
|
||||
if constexpr(sizeof(vector_t) != sizeof(T) * N)
|
||||
{
|
||||
// Not possible to use set_as
|
||||
return thread_buffer<T, N>{numeric<T>::zero()};
|
||||
}
|
||||
else
|
||||
{
|
||||
thread_buffer<T, N> tmp;
|
||||
tmp.template set_as<vector_t>(number<0>{}, vector_t{numeric<T>::zero()});
|
||||
return tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return amd_buffer_load_impl<T, N, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -2637,13 +2676,47 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
|
||||
thread_buffer<T, N> tmp =
|
||||
amd_buffer_load_impl<T, N, coherence>(src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
|
||||
if constexpr(oob_conditional_check)
|
||||
return src_thread_element_valid ? tmp : thread_buffer<T, N>{customized_value};
|
||||
else
|
||||
return tmp;
|
||||
{
|
||||
if(!src_thread_element_valid)
|
||||
{
|
||||
if constexpr(is_detected<has_type, T>::value)
|
||||
{
|
||||
// Use vector_t for not valid elements to avoid permute instructions.
|
||||
// Get raw type from structure
|
||||
using vector_t = typename T::type __attribute__((ext_vector_type(N)));
|
||||
if constexpr(sizeof(vector_t) != sizeof(typename T::type) * N)
|
||||
{
|
||||
// Not possible to use set_as
|
||||
return thread_buffer<T, N>{customized_value};
|
||||
}
|
||||
else
|
||||
{
|
||||
thread_buffer<T, N> tmp;
|
||||
tmp.template set_as<vector_t>(number<0>{}, vector_t{customized_value});
|
||||
return tmp;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Use vector_t for not valid elements to avoid permute instructions.
|
||||
using vector_t = T __attribute__((ext_vector_type(N)));
|
||||
if constexpr(sizeof(vector_t) != sizeof(T) * N)
|
||||
{
|
||||
// Not possible to use set_as
|
||||
return thread_buffer<T, N>{customized_value};
|
||||
}
|
||||
else
|
||||
{
|
||||
thread_buffer<T, N> tmp;
|
||||
tmp.template set_as<vector_t>(number<0>{}, vector_t{customized_value});
|
||||
return tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return amd_buffer_load_impl<T, N, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
|
||||
@@ -2404,6 +2404,9 @@ CK_TILE_DEVICE void amd_buffer_atomic_max_impl(const thread_buffer<T, N> src_thr
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
using has_type = typename T::type;
|
||||
|
||||
// buffer_load requires:
|
||||
// 1) p_src_wave must point to global memory space
|
||||
// 2) p_src_wave must be a wavewise pointer.
|
||||
@@ -2436,21 +2439,46 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
|
||||
#else
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
if(src_thread_element_valid)
|
||||
if(!src_thread_element_valid)
|
||||
{
|
||||
return amd_buffer_load_impl<T, N, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
return thread_buffer<T, N>{numeric<T>::zero()};
|
||||
if constexpr(is_detected<has_type, T>::value)
|
||||
{
|
||||
// Use vector_t for not valid elements to avoid permute instructions.
|
||||
// Get raw type from structure
|
||||
using vector_t = typename T::type __attribute__((ext_vector_type(N)));
|
||||
if constexpr(sizeof(vector_t) != sizeof(typename T::type) * N)
|
||||
{
|
||||
// Not possible to use set_as
|
||||
return thread_buffer<T, N>{numeric<T>::zero()};
|
||||
}
|
||||
else
|
||||
{
|
||||
thread_buffer<T, N> tmp;
|
||||
tmp.template set_as<vector_t>(number<0>{},
|
||||
vector_t{numeric<typename T::type>::zero()});
|
||||
return tmp;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Use vector_t for not valid elements to avoid permute instructions.
|
||||
using vector_t = T __attribute__((ext_vector_type(N)));
|
||||
if constexpr(sizeof(vector_t) != sizeof(T) * N)
|
||||
{
|
||||
// Not possible to use set_as
|
||||
return thread_buffer<T, N>{numeric<T>::zero()};
|
||||
}
|
||||
else
|
||||
{
|
||||
thread_buffer<T, N> tmp;
|
||||
tmp.template set_as<vector_t>(number<0>{}, vector_t{numeric<T>::zero()});
|
||||
return tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return amd_buffer_load_impl<T, N, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
}
|
||||
return amd_buffer_load_impl<T, N, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -2474,13 +2502,47 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave,
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
|
||||
|
||||
thread_buffer<T, N> tmp =
|
||||
amd_buffer_load_impl<T, N, coherence>(src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
|
||||
if constexpr(oob_conditional_check)
|
||||
return src_thread_element_valid ? tmp : thread_buffer<T, N>{customized_value};
|
||||
else
|
||||
return tmp;
|
||||
{
|
||||
if(!src_thread_element_valid)
|
||||
{
|
||||
if constexpr(is_detected<has_type, T>::value)
|
||||
{
|
||||
// Use vector_t for not valid elements to avoid permute instructions.
|
||||
// Get raw type from structure
|
||||
using vector_t = typename T::type __attribute__((ext_vector_type(N)));
|
||||
if constexpr(sizeof(vector_t) != sizeof(typename T::type) * N)
|
||||
{
|
||||
// Not possible to use set_as
|
||||
return thread_buffer<T, N>{customized_value};
|
||||
}
|
||||
else
|
||||
{
|
||||
thread_buffer<T, N> tmp;
|
||||
tmp.template set_as<vector_t>(number<0>{}, vector_t{customized_value});
|
||||
return tmp;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Use vector_t for not valid elements to avoid permute instructions.
|
||||
using vector_t = T __attribute__((ext_vector_type(N)));
|
||||
if constexpr(sizeof(vector_t) != sizeof(T) * N)
|
||||
{
|
||||
// Not possible to use set_as
|
||||
return thread_buffer<T, N>{customized_value};
|
||||
}
|
||||
else
|
||||
{
|
||||
thread_buffer<T, N> tmp;
|
||||
tmp.template set_as<vector_t>(number<0>{}, vector_t{customized_value});
|
||||
return tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return amd_buffer_load_impl<T, N, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
|
||||
@@ -35,6 +35,12 @@ using WarpGemmMfmaF32F32F32M32N32K8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK
|
||||
4,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF32F32F32M32N32K4 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplF32F32F32M32N32K2<WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
|
||||
|
||||
@@ -36,6 +36,7 @@ struct Dispatcher;
|
||||
template<> struct Dispatcher<float, float, float, 16, 16, 4, false> { using Type = WarpGemmMfmaF32F32F32M16N16K4; };
|
||||
template<> struct Dispatcher<float, float, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF32F32F32M16N16K16<>; };
|
||||
template<> struct Dispatcher<float, float, float, 16, 16, 8, false> { using Type = WarpGemmMfmaF32F32F32M16N16K8<>; };
|
||||
template<> struct Dispatcher<float, float, float, 32, 32, 4, false> { using Type = WarpGemmMfmaF32F32F32M32N32K4<>; };
|
||||
template<> struct Dispatcher<float, float, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF32F32F32M32N32K8<>; };
|
||||
template<> struct Dispatcher<float, float, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; };
|
||||
// fp16
|
||||
|
||||
@@ -96,7 +96,7 @@ int call_profiler(const ckt::Args<SIGNATURE>& args, bool time_kernel)
|
||||
std::string op_name;
|
||||
bool valid;
|
||||
std::tie(valid, avg_time, op_name) = ckp::run_grouped_conv_forward_tile_algs(
|
||||
args, inputs.get(), outputs.get(), ck_tile::stream_config{nullptr, time_kernel});
|
||||
args, inputs.get(), outputs.get(), ck_tile::stream_config{nullptr, time_kernel, 0, 5, 50});
|
||||
if(time_kernel)
|
||||
{
|
||||
std::cout << "Best configuration parameters:" << "\nname: " << op_name
|
||||
|
||||
Reference in New Issue
Block a user