mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
Merge commit 'dcd33a6ecc30e18cc8491ed03926ab5ac8b6f1c3' into develop
This commit is contained in:
@@ -481,13 +481,10 @@ struct CShuffleEpilogue
|
||||
auto sm_tile = make_static_distributed_tensor<SMType>(dram_tile_distribution);
|
||||
auto sn_tile = make_static_distributed_tensor<SNType>(dram_tile_distribution);
|
||||
|
||||
// Build windows only if scales are provided
|
||||
// Build windows only if non-scalar scales are provided
|
||||
auto scale_m_window = [&]() {
|
||||
if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
static_assert(
|
||||
IsLoadableTile<decltype(make_tile_window(scale_m, dram_tile_distribution))>,
|
||||
"ScaleM must be a loadable tile");
|
||||
return make_tile_window(scale_m, dram_tile_distribution);
|
||||
}
|
||||
else
|
||||
@@ -498,9 +495,6 @@ struct CShuffleEpilogue
|
||||
auto scale_n_window = [&]() {
|
||||
if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
static_assert(
|
||||
IsLoadableTile<decltype(make_tile_window(scale_n, dram_tile_distribution))>,
|
||||
"ScaleN must be a loadable tile");
|
||||
return make_tile_window(scale_n, dram_tile_distribution);
|
||||
}
|
||||
else
|
||||
@@ -515,8 +509,8 @@ struct CShuffleEpilogue
|
||||
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
|
||||
// If scales provided, load them with identical distribution
|
||||
if constexpr(has_scales && IsLoadableTile<ScaleM> && IsLoadableTile<ScaleN>)
|
||||
// If non-scalar scales provided, load them with identical distribution
|
||||
if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
sm_tile = load_tile(scale_m_window); // row scales in permuted layout
|
||||
sn_tile = load_tile(scale_n_window); // col scales in permuted layout
|
||||
@@ -535,7 +529,7 @@ struct CShuffleEpilogue
|
||||
{
|
||||
v = static_cast<AccDataType>(v * scale_m * scale_n);
|
||||
}
|
||||
else if constexpr(has_scales)
|
||||
else if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
// same linear index mapping on the permuted distribution
|
||||
const auto s_m = static_cast<float>(sm_tile.get_thread_buffer()[out_idx]);
|
||||
@@ -636,9 +630,6 @@ struct CShuffleEpilogue
|
||||
}
|
||||
else if constexpr(has_scales)
|
||||
{
|
||||
static_assert(
|
||||
IsLoadableTile<decltype(make_tile_window(scale_m, dram_tile_distribution))>,
|
||||
"ScaleM must be a loadable tile");
|
||||
return make_tile_window(scale_m, lds_tile.get_tile_distribution());
|
||||
}
|
||||
else
|
||||
@@ -653,9 +644,6 @@ struct CShuffleEpilogue
|
||||
}
|
||||
else if constexpr(has_scales)
|
||||
{
|
||||
static_assert(
|
||||
IsLoadableTile<decltype(make_tile_window(scale_n, dram_tile_distribution))>,
|
||||
"ScaleN must be a loadable tile");
|
||||
return make_tile_window(scale_n, lds_tile.get_tile_distribution());
|
||||
}
|
||||
else
|
||||
|
||||
@@ -132,6 +132,10 @@ struct GemmKernelMultiABD
|
||||
static constexpr index_t NumBTensor = BsDataType::size();
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
using DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>;
|
||||
|
||||
CK_TILE_HOST static auto GetName() -> const std::string
|
||||
{
|
||||
return UniversalGemmKernel::GetName();
|
||||
@@ -181,6 +185,14 @@ struct GemmKernelMultiABD
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// Currently MultiABD kernel doesn't support F8 data type
|
||||
if(ck_tile::get_device_name() == "gfx950" &&
|
||||
(std::is_same<ck_tile::fp8_t, ADataType>::value ||
|
||||
std::is_same<ck_tile::fp8_t, BDataType>::value ||
|
||||
std::is_same<ck_tile::fp8_t, DDataType>::value))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return UniversalGemmKernel::IsSupportedArgument(kargs);
|
||||
}
|
||||
|
||||
@@ -530,7 +530,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -542,7 +543,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
@@ -553,7 +554,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
@@ -577,7 +578,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -596,7 +598,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
@@ -607,7 +609,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major)
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
@@ -619,7 +621,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -100,7 +100,7 @@ struct GemmPipelineProblemBase
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
if constexpr(std::is_same_v<AsLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
|
||||
@@ -118,7 +118,7 @@ struct GemmPipelineProblemBase
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
if constexpr(std::is_same_v<BsLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
|
||||
|
||||
Reference in New Issue
Block a user