mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
[CK_TILE] Vector stores for C Column Layout
This commit is contained in:
@@ -350,7 +350,7 @@ CK_TILE_DEVICE void atomic_add<fp16x2_t>(fp16x2_t* p_dst, fp16x2_t const& x)
|
||||
new_.f162 = add_f16x2_t(cur_v.f162, x);
|
||||
new_v = new_.u32;
|
||||
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
} while(cur_v.u32 != old_v);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -361,6 +361,7 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
|
||||
(std::is_same<T, uint32_t>::value && (N == 1)) ||
|
||||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(std::is_same<T, double>::value && (N == 1 || N == 2)) ||
|
||||
(std::is_same<T, half_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
|
||||
(std::is_same<T, fp8_t>::value && (N == 4 || N == 8 || N == 16)) ||
|
||||
|
||||
@@ -631,7 +631,12 @@ struct buffer_view<address_space_enum::global,
|
||||
bool constexpr use_amd_buffer_addressing =
|
||||
std::is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
|
||||
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
|
||||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
|
||||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0)
|
||||
#if defined(__gfx94__)
|
||||
|| (std::is_same_v<remove_cvref_t<scalar_t>, bf16_t> && scalar_per_x_vector % 2 == 0);
|
||||
#else
|
||||
;
|
||||
#endif
|
||||
#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
|
||||
bool constexpr use_amd_buffer_addressing =
|
||||
std::is_same_v<remove_cvref_t<scalar_t>, int32_t>;
|
||||
|
||||
@@ -114,6 +114,14 @@ struct CShuffleEpilogue
|
||||
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
|
||||
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
|
||||
|
||||
static constexpr bool IsERowMajor =
|
||||
std::is_same_v<ELayout, tensor_layout::gemm::RowMajor> ? true : false;
|
||||
static constexpr bool IsEColMajor =
|
||||
std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor> ? true : false;
|
||||
|
||||
static_assert(std::disjunction_v<bool_constant<IsERowMajor>, bool_constant<IsEColMajor>>,
|
||||
"Unsupported ELayout!");
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
/**
|
||||
@@ -133,12 +141,12 @@ struct CShuffleEpilogue
|
||||
return VectorSizeC;
|
||||
}
|
||||
constexpr index_t max_vector_size = 16;
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(IsERowMajor)
|
||||
{
|
||||
return std::min(static_cast<int>(NPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(ODataType)));
|
||||
}
|
||||
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
|
||||
else if constexpr(IsEColMajor)
|
||||
{
|
||||
return std::min(static_cast<int>(MPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(ODataType)));
|
||||
@@ -160,12 +168,15 @@ struct CShuffleEpilogue
|
||||
constexpr index_t max_vector_size = 16;
|
||||
using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
static_assert(std::is_same_v<DiLayout, ELayout>, "ELayout is not equal to DiLayout!");
|
||||
constexpr bool IsDRowMajor = std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor> ? true : false;
|
||||
constexpr bool IsDColMajor = std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor> ? true : false;
|
||||
if constexpr(IsDRowMajor)
|
||||
{
|
||||
return std::min(static_cast<int>(NPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(DiDataType)));
|
||||
}
|
||||
else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
else if constexpr(IsDColMajor)
|
||||
{
|
||||
return std::min(static_cast<int>(MPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(DiDataType)));
|
||||
@@ -174,7 +185,6 @@ struct CShuffleEpilogue
|
||||
{
|
||||
static_assert(false, "Unsupported DLayout!");
|
||||
}
|
||||
return max_vector_size / sizeof(DiDataType);
|
||||
}
|
||||
/**
|
||||
* @brief Shuffle tile configuration parameters
|
||||
@@ -225,6 +235,22 @@ struct CShuffleEpilogue
|
||||
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
|
||||
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
|
||||
|
||||
static constexpr index_t NumYXdlPerWavePerShuffle =
|
||||
IsERowMajor ? NumMXdlPerWavePerShuffle : NumNXdlPerWavePerShuffle;
|
||||
static constexpr index_t NumXXdlPerWavePerShuffle =
|
||||
IsERowMajor ? NumNXdlPerWavePerShuffle : NumMXdlPerWavePerShuffle;
|
||||
|
||||
static constexpr index_t YPerIterationShuffle =
|
||||
IsERowMajor ? MPerIterationShuffle : NPerIterationShuffle;
|
||||
static constexpr index_t XPerIterationShuffle =
|
||||
IsERowMajor ? NPerIterationShuffle : MPerIterationShuffle;
|
||||
|
||||
static constexpr index_t YPerBlock = IsERowMajor ? kMPerBlock : kNPerBlock;
|
||||
static constexpr index_t XPerBlock = IsERowMajor ? kNPerBlock : kMPerBlock;
|
||||
|
||||
static constexpr index_t YWave = IsERowMajor ? MWave : NWave;
|
||||
static constexpr index_t XWave = IsERowMajor ? NWave : MWave;
|
||||
|
||||
using WG = WarpGemmDispatcher<ATypeToUse,
|
||||
BTypeToUse,
|
||||
AccDataType,
|
||||
@@ -243,32 +269,17 @@ struct CShuffleEpilogue
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
|
||||
{
|
||||
// N is contiguous dimension
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
make_tuple(number<NPerIterationShuffle>{}, number<1>{}));
|
||||
}
|
||||
// M is contiguous dimension
|
||||
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
make_tuple(number<1>{}, number<MPerIterationShuffle>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported ELayout!");
|
||||
}
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<YPerIterationShuffle>{}, number<XPerIterationShuffle>{}),
|
||||
make_tuple(number<XPerIterationShuffle>{}, number<1>{}));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
|
||||
{
|
||||
constexpr auto block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<NumNXdlPerWavePerShuffle, NWave>>,
|
||||
tuple<sequence<NumYXdlPerWavePerShuffle, YWave>,
|
||||
sequence<NumXXdlPerWavePerShuffle, XWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
@@ -585,24 +596,21 @@ struct CShuffleEpilogue
|
||||
|
||||
auto in_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
make_tuple(number<YPerIterationShuffle>{}, number<XPerIterationShuffle>{}),
|
||||
{0, 0},
|
||||
LdsTileDistr);
|
||||
|
||||
auto out_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
make_tuple(number<YPerIterationShuffle>{}, number<XPerIterationShuffle>{}),
|
||||
{0, 0});
|
||||
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
|
||||
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
|
||||
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
|
||||
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<kBlockSize,
|
||||
MPerIterationShuffle,
|
||||
NPerIterationShuffle,
|
||||
YPerIterationShuffle,
|
||||
XPerIterationShuffle,
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked,
|
||||
Problem::kNumWaveGroups>;
|
||||
|
||||
@@ -743,7 +743,7 @@ struct UniversalGemmKernel
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, kargs.stride_E),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{});
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user