Merge commit 'e77a7ca2bc65651b5e87a0127e0335733aca2f35' into develop

This commit is contained in:
assistant-librarian[bot]
2025-12-18 21:13:07 +00:00
parent cef729b554
commit 87b8b502e6
93 changed files with 8182 additions and 127 deletions

View File

@@ -198,10 +198,10 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
return static_cast<long_index_t>(g_idx) * BatchStrideE_;
}
Array<long_index_t, NumATensor> BatchStrideA_;
Array<long_index_t, NumBTensor> BatchStrideB_;
Array<long_index_t, NumDTensor> BatchStrideDs_;
long_index_t BatchStrideE_;
Array<long_index_t, NumATensor> BatchStrideA_{};
Array<long_index_t, NumBTensor> BatchStrideB_{};
Array<long_index_t, NumDTensor> BatchStrideDs_{};
long_index_t BatchStrideE_{};
long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
};
@@ -253,10 +253,10 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
return static_cast<long_index_t>(g_idx) * BatchStrideE_;
}
long_index_t BatchStrideA_;
long_index_t BatchStrideB_;
Array<long_index_t, NumDTensor> BatchStrideDs_;
long_index_t BatchStrideE_;
long_index_t BatchStrideA_{};
long_index_t BatchStrideB_{};
Array<long_index_t, NumDTensor> BatchStrideDs_{};
long_index_t BatchStrideE_{};
long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
};

View File

@@ -894,6 +894,144 @@ struct GridwiseGemm_wmma_cshuffle_v3
k_idx,
karg.KBatch);
}
// Run method for convolution fwd (grid descriptors are passed as arguments,
// not generated internally)
template <typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch,
typename ComputePtrOffsetOfN,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum,
typename EpilogueArgument>
__device__ static void Run(void* p_shared,
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch& compute_ptr_offset_of_batch,
const ComputePtrOffsetOfN& compute_ptr_offset_of_n,
[[maybe_unused]] const index_t num_k_per_block,
Argument& karg,
EpilogueArgument& epilogue_args)
{
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch);
// offset base pointer for each work-group
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t b_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetBPtrOffset(n_idx));
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
const auto ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
AsGridPointer p_as_grid_;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
p_as_grid_(i) =
static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset + a_n_offset;
});
BsGridPointer p_bs_grid_;
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
p_bs_grid_(i) =
static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset;
});
DsGridPointer p_ds_grid_grp;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid_grp(i) = static_cast<const DDataType_*>(karg.p_ds_grid[i]) +
ds_batch_offset[i] + ds_n_offset[i];
});
// Currently supporting one A and one B
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
[&](auto i) {
ignore = i;
return a_grid_desc_ak0_m_ak1;
},
Number<NumATensor>{});
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
[&](auto i) {
ignore = i;
return b_grid_desc_bk0_n_bk1;
},
Number<NumBTensor>{});
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
// AScale struct (Empty)
using AScale = typename BlockwiseGemmPipe::Empty;
auto a_scale_struct = AScale{};
// BScale struct (Empty)
using BScale = typename BlockwiseGemmPipe::Empty;
auto b_scale_struct = BScale{};
const index_t num_k_block_per_scale = GetKBlockPerScale();
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(a_scale_struct),
decltype(b_scale_struct),
decltype(epilogue_args),
HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum>(p_as_grid_,
p_bs_grid_,
p_ds_grid_grp,
karg.p_e_grid + e_batch_offset + e_n_offset,
p_shared,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
block_m_id,
block_n_id,
num_k_block_per_scale,
a_scale_struct,
b_scale_struct,
epilogue_args);
}
};
} // namespace ck

View File

@@ -620,8 +620,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
}
template <typename DsGridDesc>
__device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
__device__ __host__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc& ds_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
{
return generate_tuple(
[&](auto i) {
@@ -737,7 +739,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename Argument>
__host__ static constexpr bool CheckValidity(const Argument& karg)
__host__ static constexpr bool CheckValidity(const Argument& karg,
bool allow_short_v3_pipe = false)
{
static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
(NPerBlock % (NPerWmma * NRepeat)) == 0,
@@ -928,14 +931,17 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
{
if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
if(!(allow_short_v3_pipe && BlkGemmPipelineVer == BlockGemmPipelineVersion::v3))
{
std::cout << "Pipeline validation failed: num_k_loop (" << num_k_loop
<< ") <= PrefetchStages (" << BlockwiseGemmPipe::PrefetchStages
<< ") for pipeline version != v1." << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Pipeline validation failed: num_k_loop (" << num_k_loop
<< ") <= PrefetchStages (" << BlockwiseGemmPipe::PrefetchStages
<< ") for pipeline version != v1." << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
}
return false;
}
return false;
}
}

View File

@@ -286,8 +286,25 @@ struct ThreadwiseTensorSliceTransfer_v7r3
static_for<0, nDst, 1>{}([&](auto i) {
using elm_vector_t = typename remove_cvref_t<decltype(elm_vectors[i])>::type;
elm_vectors(i).template AsType<elm_vector_t>()(I0) =
oob_val ? elm_vectors(i).template AsType<elm_vector_t>()[I0] : elm_vector_t{0};
using scalar_t =
remove_cvref_t<decltype(elm_vectors(i).template AsType<elm_vector_t>()[I0])>;
// This is a bit ugly but necessary to be able to compile f8 instances for grouped
// convolution forward. For some reason for that specific type there is an ambiguity
// in the type resolution for the ternary expression. I added an explicit cast to
// disambiguate and only use it for f8 just in case it affects performance.
if constexpr(std::is_same_v<scalar_t, ck::f8_ocp_t>)
{
elm_vectors(i).template AsType<elm_vector_t>()(I0) =
oob_val ? elm_vector_t{elm_vectors(i).template AsType<elm_vector_t>()[I0]}
: elm_vector_t{0};
}
else
{
elm_vectors(i).template AsType<elm_vector_t>()(I0) =
oob_val ? elm_vectors(i).template AsType<elm_vector_t>()[I0]
: elm_vector_t{0};
}
});
elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;