mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[rocm-libraries] ROCm/rocm-libraries#5842 (commit 04c5690)
[CK][CK Tile] Force padding for atomic_add bf16 C tensor (#5842) ## Motivation Force padding for atomic_add bf16 C tensor to avoid memfaults. ## Technical Details - add global atomic add for bf16 and enable them - add padding for atomic add bf16 due to the lack of oob - remove padding for not continous dims in conv for other cases - minor bwd data conv fixes ## Test Plan test_grouped_conv_*_tile ## Test Result pending ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
66dc81d530
commit
ef4ff4667d
@@ -18,6 +18,10 @@
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_coherence.hpp"
|
||||
|
||||
#define HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN \
|
||||
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2f16) && \
|
||||
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16)
|
||||
|
||||
// This attribute gives a hint to the compiler that a branch is likely to be taken.
|
||||
// Then, the compiler should remove if possible the associated s_cbranch_execz branch that would
|
||||
// have been generated.
|
||||
@@ -2317,6 +2321,34 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_DEVICE void
|
||||
amd_global_atomic_add_impl([[maybe_unused]] const thread_buffer<T, N>& src_thread_data,
|
||||
[[maybe_unused]] T* addr)
|
||||
{
|
||||
static_assert((std::is_same<T, ck_tile::bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, ck_tile::fp16_t>::value && (N == 2 || N == 4 || N == 8)),
|
||||
"wrong! not implemented");
|
||||
|
||||
#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
|
||||
if constexpr(__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16) &&
|
||||
std::is_same<T, ck_tile::bf16_t>::value)
|
||||
{
|
||||
static_for<0, N / 2, 1>{}([&](auto i) {
|
||||
__builtin_amdgcn_global_atomic_fadd_v2bf16(
|
||||
bit_cast<ck_tile::bf16x2_t*>(addr) + i,
|
||||
src_thread_data.template get_as<ck_tile::bf16x2_t>()[i]);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Not supported!");
|
||||
}
|
||||
#else
|
||||
static_assert(false, "Not supported!");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_thread_data,
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
@@ -2325,8 +2357,11 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
|
||||
{
|
||||
static_assert((std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
|
||||
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4))
|
||||
#if defined(__gfx950__)
|
||||
|| (std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8))
|
||||
#endif
|
||||
,
|
||||
"wrong! not implemented");
|
||||
|
||||
if constexpr(std::is_same<T, float>::value)
|
||||
@@ -2931,16 +2966,27 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
|
||||
const bool dst_thread_element_valid,
|
||||
const index_t dst_element_space_size)
|
||||
{
|
||||
const int32x4_t dst_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
||||
#if defined(__gfx942__)
|
||||
if constexpr(std::is_same<T, bf16_t>::value)
|
||||
{
|
||||
if(dst_thread_element_valid)
|
||||
{
|
||||
amd_global_atomic_add_impl<T, N>(src_thread_data,
|
||||
p_dst_wave + dst_thread_element_offset);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
#endif
|
||||
const int32x4_t dst_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
||||
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
|
||||
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
|
||||
|
||||
amd_buffer_atomic_add_impl<T, N>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
amd_buffer_atomic_add_impl<T, N>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
#else
|
||||
if(dst_thread_element_valid)
|
||||
{
|
||||
@@ -2948,6 +2994,9 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
|
||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
}
|
||||
#endif
|
||||
#if defined(__gfx942__)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
|
||||
@@ -18,6 +18,10 @@
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
#include "ck_tile/core/arch/amd_buffer_coherence.hpp"
|
||||
|
||||
#define HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN \
|
||||
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2f16) && \
|
||||
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16)
|
||||
|
||||
using as3_uint32_ptr = uint32_t __attribute__((address_space(3)))*;
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -2143,6 +2147,33 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_DEVICE void
|
||||
amd_global_atomic_add_impl([[maybe_unused]] const thread_buffer<T, N>& src_thread_data,
|
||||
[[maybe_unused]] T* addr)
|
||||
{
|
||||
static_assert((std::is_same<T, ck_tile::bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, ck_tile::fp16_t>::value && (N == 2 || N == 4 || N == 8)),
|
||||
"wrong! not implemented");
|
||||
|
||||
#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
|
||||
if constexpr(std::is_same<T, ck_tile::bf16_t>::value)
|
||||
{
|
||||
static_for<0, N / 2, 1>{}([&](auto i) {
|
||||
__builtin_amdgcn_global_atomic_fadd_v2bf16(
|
||||
bit_cast<ck_tile::bf16x2_t*>(addr) + i,
|
||||
src_thread_data.template get_as<ck_tile::bf16x2_t>()[i]);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Not supported!");
|
||||
}
|
||||
#else
|
||||
static_assert(false, "Not supported!");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_thread_data,
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
@@ -2151,8 +2182,11 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
|
||||
{
|
||||
static_assert((std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
|
||||
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4))
|
||||
#if defined(__gfx950__)
|
||||
|| (std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8))
|
||||
#endif
|
||||
,
|
||||
"wrong! not implemented");
|
||||
|
||||
if constexpr(std::is_same<T, float>::value)
|
||||
@@ -2759,16 +2793,28 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
|
||||
const bool dst_thread_element_valid,
|
||||
const index_t dst_element_space_size)
|
||||
{
|
||||
const int32x4_t dst_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
|
||||
#if defined(__gfx942__)
|
||||
if constexpr(std::is_same<T, bf16_t>::value)
|
||||
{
|
||||
if(dst_thread_element_valid)
|
||||
{
|
||||
amd_global_atomic_add_impl<T, N>(src_thread_data,
|
||||
p_dst_wave + dst_thread_element_offset);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
#endif
|
||||
const int32x4_t dst_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
||||
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
||||
|
||||
#if CK_TILE_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
|
||||
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
|
||||
|
||||
amd_buffer_atomic_add_impl<T, N>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
amd_buffer_atomic_add_impl<T, N>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
#else
|
||||
if(dst_thread_element_valid)
|
||||
{
|
||||
@@ -2776,6 +2822,9 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
|
||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
}
|
||||
#endif
|
||||
#if defined(__gfx942__)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
|
||||
@@ -630,7 +630,7 @@ struct buffer_view<address_space_enum::global,
|
||||
std::is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
|
||||
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
|
||||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0)
|
||||
#if defined(__gfx950__) // only gfx950 support atomic_pk_add_bf16
|
||||
#if defined(__gfx942__) || defined(__gfx950__) // only gfx942 and gfx950 support atomic_pk_add_bf16
|
||||
||
|
||||
(std::is_same_v<remove_cvref_t<scalar_t>, bfloat16_t> && scalar_per_x_vector % 2 == 0)
|
||||
#endif
|
||||
@@ -642,7 +642,7 @@ struct buffer_view<address_space_enum::global,
|
||||
bool constexpr use_amd_buffer_addressing =
|
||||
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
|
||||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0)
|
||||
#if defined(__gfx950__) // only gfx950 support atomic_pk_add_bf16
|
||||
#if defined(__gfx942__) || defined(__gfx950__) // only gfx942 and gfx950 support atomic_pk_add_bf16
|
||||
||
|
||||
(std::is_same_v<remove_cvref_t<scalar_t>, bfloat16_t> && scalar_per_x_vector % 2 == 0)
|
||||
#endif
|
||||
|
||||
@@ -1021,6 +1021,11 @@ struct UniversalGemmKernel
|
||||
const auto& e_tensor_view =
|
||||
make_tensor_view<address_space_enum::global, DstInMemOp>(e_ptr, e_desc);
|
||||
|
||||
// For bf16_t and atomic_add global_atomic_add is used instead of buffer_atomic_add
|
||||
// Add padding for not contiguous dim due to the lack of OOB check
|
||||
constexpr bool pad_not_contiguous_dim =
|
||||
std::is_same_v<EDataType, bf16_t> && DstInMemOp == memory_operation_enum::atomic_add;
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& e_pad_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -1028,14 +1033,14 @@ struct UniversalGemmKernel
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
sequence<pad_not_contiguous_dim, GemmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<GemmPipeline::kPadM, false>{});
|
||||
sequence<GemmPipeline::kPadM, pad_not_contiguous_dim>{});
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -531,11 +531,11 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
|
||||
using InDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using OutDataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using WeiDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
|
||||
using OutDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
using InDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using GroupedConvBwdDataKernelArgsSpecialized =
|
||||
GroupedConvBwdDataKernelArgs<GroupedConvTraitsType_, TilePartitioner>;
|
||||
@@ -561,7 +561,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
|
||||
// clang-format off
|
||||
return concat('_', "grouped_convolution_backward_data",
|
||||
gemm_prec_str<InDataType, WeiDataType>(),
|
||||
gemm_prec_str<OutDataType, WeiDataType>(),
|
||||
InLayout::name,
|
||||
WeiLayout::name,
|
||||
OutLayout::name,
|
||||
@@ -632,7 +632,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
const auto& a_pad_view = pad_tensor_view(
|
||||
a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
sequence<false, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
auto a_block_window = make_tile_window(
|
||||
@@ -644,7 +644,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeBBlockWindow(const InDataType* b_ptr,
|
||||
MakeBBlockWindow(const WeiDataType* b_ptr,
|
||||
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
|
||||
const index_t group_id,
|
||||
const index_t i_n,
|
||||
@@ -658,7 +658,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
const auto& b_pad_view = pad_tensor_view(
|
||||
b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
sequence<false, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
auto b_block_window = make_tile_window(
|
||||
@@ -681,14 +681,14 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
[&](auto i) {
|
||||
// Step 1: Create tensor view for D
|
||||
const auto& d_tensor_view = make_tensor_view<address_space_enum::global>(
|
||||
static_cast<const OutDataType*>(ds_ptr[i]), kargs.c_grid_descs_m_n[group_id]);
|
||||
static_cast<const InDataType*>(ds_ptr[i]), kargs.c_grid_descs_m_n[group_id]);
|
||||
|
||||
// Step 2: Create padded view
|
||||
const auto& d_pad_view =
|
||||
pad_tensor_view(d_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
sequence<false, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(d_pad_view,
|
||||
@@ -703,7 +703,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeCBlockWindow(WeiDataType* c_ptr,
|
||||
MakeCBlockWindow(InDataType* c_ptr,
|
||||
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
|
||||
const index_t group_id,
|
||||
const index_t i_m,
|
||||
@@ -713,11 +713,20 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
const auto& c_tensor_view = make_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr, kargs.c_grid_descs_m_n[group_id]);
|
||||
|
||||
// For bf16_t and atomic_add global_atomic_add is used instead of buffer_atomic_add
|
||||
// Add padding for not contiguous dim due to the lack of OOB check
|
||||
// Not needed from gfx950.
|
||||
#if defined(__gfx950__)
|
||||
constexpr bool pad_not_contiguous_dim = false;
|
||||
#else
|
||||
constexpr bool pad_not_contiguous_dim =
|
||||
std::is_same_v<InDataType, bf16_t> && DstInMemOp == memory_operation_enum::atomic_add;
|
||||
#endif
|
||||
// Step 2: Create padded view
|
||||
const auto& c_pad_view = pad_tensor_view(
|
||||
c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
sequence<pad_not_contiguous_dim, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
auto c_block_window = make_tile_window(
|
||||
@@ -739,7 +748,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
}
|
||||
}
|
||||
if constexpr(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value)
|
||||
is_any_of<InDataType, fp16_t, bf16_t>::value)
|
||||
{
|
||||
if(kargs.k_batch != 1)
|
||||
{
|
||||
@@ -862,133 +871,6 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
return true;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const OutDataType* a_ptr,
|
||||
const InDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
WeiDataType* c_ptr,
|
||||
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
|
||||
const index_t group_id)
|
||||
{
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
kargs.a_grid_descs_m_k[group_id]); // A: out
|
||||
}();
|
||||
|
||||
const auto& b_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
kargs.b_grid_descs_n_k[group_id]); // B: weight
|
||||
}();
|
||||
|
||||
const auto& c_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
c_ptr, kargs.c_grid_descs_m_n[group_id]);
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
|
||||
"Not supported!");
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_descs_m_n[group_id]);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
const auto& b_pad_view = [&]() {
|
||||
const auto& b_tensor_view = views.at(I1);
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = views.at(I2);
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
const auto& c_pad_view = [&]() {
|
||||
const auto& c_tensor_view = views.at(I3);
|
||||
return pad_tensor_view(c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
|
||||
const index_t i_m,
|
||||
const index_t i_n,
|
||||
const index_t i_k)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& b_pad_view = views.at(I1);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& c_pad_view = views.at(I3);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, i_k});
|
||||
}();
|
||||
|
||||
const auto& b_block_window = [&]() {
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_k, i_n});
|
||||
}();
|
||||
|
||||
const auto ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
@@ -1002,9 +884,9 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
*
|
||||
*/
|
||||
CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr,
|
||||
const InDataType* b_ptr,
|
||||
const WeiDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
WeiDataType* c_ptr,
|
||||
InDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const GroupedConvBwdDataKernelArgsSpecialized& kargs,
|
||||
const index_t splitted_k,
|
||||
@@ -1044,7 +926,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
else
|
||||
{
|
||||
if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
is_any_of<InDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
auto c_block_window = MakeCBlockWindow<memory_operation_enum::atomic_add>(
|
||||
c_ptr, kargs, group_id, block_idx_m, block_idx_n);
|
||||
|
||||
@@ -869,10 +869,19 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
const auto& c_tensor_view =
|
||||
make_tensor_view<address_space_enum::global, DstInMemOp>(c_ptr, kargs.c_grid_desc_m_n);
|
||||
|
||||
// For bf16_t and atomic_add global_atomic_add is used instead of buffer_atomic_add
|
||||
// Add padding for not contiguous dim due to the lack of OOB check
|
||||
// Not needed from gfx950.
|
||||
#if defined(__gfx950__)
|
||||
constexpr bool pad_not_contiguous_dim = false;
|
||||
#else
|
||||
constexpr bool pad_not_contiguous_dim =
|
||||
std::is_same_v<WeiDataType, bf16_t> && DstInMemOp == memory_operation_enum::atomic_add;
|
||||
#endif
|
||||
const auto& c_pad_view = pad_tensor_view(
|
||||
c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
sequence<pad_not_contiguous_dim, true>{});
|
||||
|
||||
return make_tile_window(
|
||||
c_pad_view,
|
||||
@@ -905,7 +914,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
sequence<false, true>{});
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
@@ -933,7 +942,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{} * kargs.k_batch,
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
sequence<false, true>{});
|
||||
|
||||
return make_tile_window(
|
||||
b_pad_view,
|
||||
@@ -955,7 +964,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{} * kargs.k_batch,
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
sequence<false, true>{});
|
||||
|
||||
return make_tile_window(
|
||||
a_pad_view,
|
||||
|
||||
@@ -898,7 +898,7 @@ struct GroupedConvolutionForwardKernel
|
||||
pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
sequence<false, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(a_pad_view,
|
||||
@@ -924,7 +924,7 @@ struct GroupedConvolutionForwardKernel
|
||||
pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
sequence<false, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(a_pad_view,
|
||||
@@ -945,7 +945,7 @@ struct GroupedConvolutionForwardKernel
|
||||
const auto& b_pad_view = pad_tensor_view(
|
||||
b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
sequence<false, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(
|
||||
@@ -981,7 +981,7 @@ struct GroupedConvolutionForwardKernel
|
||||
return pad_tensor_view(ds_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
sequence<false, true>{});
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
@@ -1006,11 +1006,20 @@ struct GroupedConvolutionForwardKernel
|
||||
const auto& c_tensor_view =
|
||||
make_tensor_view<address_space_enum::global, DstInMemOp>(c_ptr, c_desc);
|
||||
|
||||
// For bf16_t and atomic_add global_atomic_add is used instead of buffer_atomic_add
|
||||
// Add padding for not contiguous dim due to the lack of OOB check
|
||||
// Not needed from gfx950.
|
||||
#if defined(__gfx950__)
|
||||
constexpr bool pad_not_contiguous_dim = false;
|
||||
#else
|
||||
constexpr bool pad_not_contiguous_dim =
|
||||
std::is_same_v<OutDataType, bf16_t> && DstInMemOp == memory_operation_enum::atomic_add;
|
||||
#endif
|
||||
// Step 2: Create padded view
|
||||
const auto& c_pad_view = pad_tensor_view(
|
||||
c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
sequence<pad_not_contiguous_dim, true>{});
|
||||
|
||||
// Step 3: Create tile window
|
||||
return make_tile_window(
|
||||
|
||||
Reference in New Issue
Block a user