mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
Merge commit 'e77a7ca2bc65651b5e87a0127e0335733aca2f35' into develop
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -198,10 +198,10 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
|
||||
return static_cast<long_index_t>(g_idx) * BatchStrideE_;
|
||||
}
|
||||
|
||||
Array<long_index_t, NumATensor> BatchStrideA_;
|
||||
Array<long_index_t, NumBTensor> BatchStrideB_;
|
||||
Array<long_index_t, NumDTensor> BatchStrideDs_;
|
||||
long_index_t BatchStrideE_;
|
||||
Array<long_index_t, NumATensor> BatchStrideA_{};
|
||||
Array<long_index_t, NumBTensor> BatchStrideB_{};
|
||||
Array<long_index_t, NumDTensor> BatchStrideDs_{};
|
||||
long_index_t BatchStrideE_{};
|
||||
long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
|
||||
};
|
||||
|
||||
@@ -253,10 +253,10 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
|
||||
return static_cast<long_index_t>(g_idx) * BatchStrideE_;
|
||||
}
|
||||
|
||||
long_index_t BatchStrideA_;
|
||||
long_index_t BatchStrideB_;
|
||||
Array<long_index_t, NumDTensor> BatchStrideDs_;
|
||||
long_index_t BatchStrideE_;
|
||||
long_index_t BatchStrideA_{};
|
||||
long_index_t BatchStrideB_{};
|
||||
Array<long_index_t, NumDTensor> BatchStrideDs_{};
|
||||
long_index_t BatchStrideE_{};
|
||||
long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
|
||||
};
|
||||
|
||||
|
||||
@@ -894,6 +894,144 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
k_idx,
|
||||
karg.KBatch);
|
||||
}
|
||||
|
||||
// Run method for convolution fwd (grid descriptors are passed as arguments,
|
||||
// not generated internally)
|
||||
template <typename AGridDesc_AK0_M_K1,
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
typename ComputePtrOffsetOfN,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(void* p_shared,
|
||||
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const ComputePtrOffsetOfBatch& compute_ptr_offset_of_batch,
|
||||
const ComputePtrOffsetOfN& compute_ptr_offset_of_n,
|
||||
[[maybe_unused]] const index_t num_k_per_block,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch);
|
||||
// offset base pointer for each work-group
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
const long_index_t a_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
|
||||
const long_index_t b_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetBPtrOffset(n_idx));
|
||||
const long_index_t e_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
|
||||
|
||||
const auto ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
|
||||
|
||||
AsGridPointer p_as_grid_;
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
p_as_grid_(i) =
|
||||
static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset + a_n_offset;
|
||||
});
|
||||
|
||||
BsGridPointer p_bs_grid_;
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
p_bs_grid_(i) =
|
||||
static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset;
|
||||
});
|
||||
|
||||
DsGridPointer p_ds_grid_grp;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
p_ds_grid_grp(i) = static_cast<const DDataType_*>(karg.p_ds_grid[i]) +
|
||||
ds_batch_offset[i] + ds_n_offset[i];
|
||||
});
|
||||
|
||||
// Currently supporting one A and one B
|
||||
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
},
|
||||
Number<NumATensor>{});
|
||||
|
||||
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
},
|
||||
Number<NumBTensor>{});
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
|
||||
|
||||
// AScale struct (Empty)
|
||||
using AScale = typename BlockwiseGemmPipe::Empty;
|
||||
auto a_scale_struct = AScale{};
|
||||
|
||||
// BScale struct (Empty)
|
||||
using BScale = typename BlockwiseGemmPipe::Empty;
|
||||
auto b_scale_struct = BScale{};
|
||||
|
||||
const index_t num_k_block_per_scale = GetKBlockPerScale();
|
||||
|
||||
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(a_scale_struct),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_as_grid_,
|
||||
p_bs_grid_,
|
||||
p_ds_grid_grp,
|
||||
karg.p_e_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -620,8 +620,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
}
|
||||
|
||||
template <typename DsGridDesc>
|
||||
__device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
|
||||
__device__ __host__ static constexpr auto
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc& ds_grid_desc_m_n,
|
||||
index_t MBlock,
|
||||
index_t NBlock)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -737,7 +739,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
template <typename Argument>
|
||||
__host__ static constexpr bool CheckValidity(const Argument& karg)
|
||||
__host__ static constexpr bool CheckValidity(const Argument& karg,
|
||||
bool allow_short_v3_pipe = false)
|
||||
{
|
||||
static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
|
||||
(NPerBlock % (NPerWmma * NRepeat)) == 0,
|
||||
@@ -928,14 +931,17 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
{
|
||||
if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
if(!(allow_short_v3_pipe && BlkGemmPipelineVer == BlockGemmPipelineVersion::v3))
|
||||
{
|
||||
std::cout << "Pipeline validation failed: num_k_loop (" << num_k_loop
|
||||
<< ") <= PrefetchStages (" << BlockwiseGemmPipe::PrefetchStages
|
||||
<< ") for pipeline version != v1." << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__ << std::endl;
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Pipeline validation failed: num_k_loop (" << num_k_loop
|
||||
<< ") <= PrefetchStages (" << BlockwiseGemmPipe::PrefetchStages
|
||||
<< ") for pipeline version != v1." << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -286,8 +286,25 @@ struct ThreadwiseTensorSliceTransfer_v7r3
|
||||
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
using elm_vector_t = typename remove_cvref_t<decltype(elm_vectors[i])>::type;
|
||||
elm_vectors(i).template AsType<elm_vector_t>()(I0) =
|
||||
oob_val ? elm_vectors(i).template AsType<elm_vector_t>()[I0] : elm_vector_t{0};
|
||||
using scalar_t =
|
||||
remove_cvref_t<decltype(elm_vectors(i).template AsType<elm_vector_t>()[I0])>;
|
||||
|
||||
// This is a bit ugly but necessary to be able to compile f8 instances for grouped
|
||||
// convolution forward. For some reason for that specific type there is an ambiguity
|
||||
// in the type resolution for the ternary expression. I added an explicit cast to
|
||||
// disambiguate and only use it for f8 just in case it affects performance.
|
||||
if constexpr(std::is_same_v<scalar_t, ck::f8_ocp_t>)
|
||||
{
|
||||
elm_vectors(i).template AsType<elm_vector_t>()(I0) =
|
||||
oob_val ? elm_vector_t{elm_vectors(i).template AsType<elm_vector_t>()[I0]}
|
||||
: elm_vector_t{0};
|
||||
}
|
||||
else
|
||||
{
|
||||
elm_vectors(i).template AsType<elm_vector_t>()(I0) =
|
||||
oob_val ? elm_vectors(i).template AsType<elm_vector_t>()[I0]
|
||||
: elm_vector_t{0};
|
||||
}
|
||||
});
|
||||
|
||||
elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
|
||||
|
||||
Reference in New Issue
Block a user