[rocm-libraries] ROCm/rocm-libraries#4762 (commit 5598eb5)

Revert "[ck] Support VGPR estimate in
 GridwiseGemm_wmma_cshuffle_v3" (#4762)

Reverts ROCm/rocm-libraries#4638
unfortunately, this PR interfered with the PR#4299 and caused build
errors for gfx11:

In file included from
/rocm-libraries/projects/composablekernel/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp:7:
In file included from
/rocm-libraries/projects/composablekernel/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp:11:

/rocm-libraries/projects/composablekernel/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp:553:21:
error: no matching function for call to 'CheckValidity'
  553 |                 if(!GridwiseGemm::CheckValidity(
      |                     ^~~~~~~~~~~~~~~~~~~~~~~~~~~
This commit is contained in:
Illia Silin
2026-02-20 22:41:34 +00:00
committed by assistant-librarian[bot]
parent 6df27cfad6
commit 1915cdfcc2
4 changed files with 374 additions and 465 deletions

View File

@@ -33,10 +33,10 @@ using DeviceGemmV2Instance =
ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmDefault,
256,
128, 128, 64,
128, 256, 64,
8, 8,
16, 16,
4, 2,
4, 4,
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
1, 1, 8, true,
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,

View File

@@ -65,12 +65,12 @@ using DeviceOpInstance =
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
128,
32, 128, 64,
32, 128, 128,
8, 8,
16, 16,
2, 2,
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
1, 1, S<1, 16, 1, 8>, S<4, 4, 1>,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v1,

View File

@@ -1785,11 +1785,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
p_ds_grid_dummy[i] = nullptr;
StrideDs_dummy[i] = I0;
});
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); i++)
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
{
const index_t GemmM = arg.a_grid_desc_m_k_container_[i].GetLength(I0);
const index_t GemmN = arg.b_grid_desc_n_k_container_[i].GetLength(I0);
const index_t GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1);
const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I1);
const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_container_[i].GetLength(I1);
const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) *
arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2);
// Create gemm arguments with dummy values to check for validity
typename GridwiseGemmCTranspose::Argument gemm_arg{
std::array<const void*, 1>{nullptr}, // p_as_grid

View File

@@ -5,7 +5,6 @@
#include "ck/utility/env.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
@@ -569,82 +568,6 @@ struct GridwiseGemm_wmma_cshuffle_v3
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
template <bool IsGfx11>
static constexpr index_t GetEstimateVgprCount()
{
constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma);
constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma);
constexpr index_t WaveSize = BlockSize / (MWave * NWave);
// VGPR used in LDS loading and WMMA
constexpr index_t BaseInputVgprCount =
MPerBlock * KPerBlock / MWave / WaveSize * sizeof(ComputeTypeA) / sizeof(uint32_t) +
NPerBlock * KPerBlock / NWave / WaveSize * sizeof(ComputeTypeB) / sizeof(uint32_t);
// WMMA input is duplicated in GFX11
constexpr index_t InputVgprCount = IsGfx11 ? BaseInputVgprCount * 2 : BaseInputVgprCount;
// VGPR used in buffer load and LDS store
constexpr index_t TempVgprCount = BaseInputVgprCount / 2;
// VGPR used in Accumulator
constexpr index_t AccVgprCount =
MPerBlock * NPerBlock / BlockSize * sizeof(AccDataType) / sizeof(uint32_t);
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
return InputVgprCount + TempVgprCount + AccVgprCount;
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return InputVgprCount * 2 + TempVgprCount + AccVgprCount;
}
else
{
static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v1,
"Invalid pipeline version");
}
}
__device__ static bool constexpr IsValidCompilationParameter()
{
constexpr bool IsGfx11 = is_same_v<decltype(get_device_arch()), gfx11_t>;
constexpr auto EstimateVgprCount = GetEstimateVgprCount<IsGfx11>();
constexpr auto AvailableVgprCount = get_max_vgpr_count(get_device_arch());
if constexpr(EstimateVgprCount > (AvailableVgprCount + AvailableVgprCount / 4))
{
return false;
}
else
{
return true;
}
}
template <typename Argument>
__host__ static bool CheckValidity(const Argument& karg, bool allow_short_v3_pipe = false)
{
const auto availableVgprCount = []() {
if(ck::is_gfx12_supported())
{
return get_max_vgpr_count(gfx12_t{});
}
else if(ck::is_gfx11_supported())
{
return get_max_vgpr_count(gfx11_t{});
}
else
{
return get_max_vgpr_count(gfx_invalid_t{});
}
}();
const auto estimateVgprCount =
ck::is_gfx11_supported() ? GetEstimateVgprCount<true>() : GetEstimateVgprCount<false>();
if(estimateVgprCount > (availableVgprCount + availableVgprCount / 4))
{
return false;
}
return Base::template CheckValidity<Argument>(karg, allow_short_v3_pipe);
}
__device__ static index_t GetKBlockPerScale() { return 1; }
template <bool HasMainKBlockLoop,
@@ -744,61 +667,59 @@ struct GridwiseGemm_wmma_cshuffle_v3
const index_t A_k_id = 0,
const index_t B_k_id = 0)
{
if constexpr(IsValidCompilationParameter())
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t block_m_id =
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapMBlockIndex>{}]);
const index_t block_n_id =
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapNBlockIndex>{}]);
// BScale struct (Empty)
using Scale = typename BlockwiseGemmPipe::Empty;
auto a_scale_struct = Scale{};
auto b_scale_struct = Scale{};
const index_t num_k_block_per_scale = GetKBlockPerScale();
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(a_scale_struct),
decltype(b_scale_struct),
decltype(epilogue_args),
HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum>(p_as_grid,
p_bs_grid,
p_ds_grid,
p_e_grid,
p_shared,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
cde_element_op,
block_m_id,
block_n_id,
num_k_block_per_scale,
a_scale_struct,
b_scale_struct,
epilogue_args,
A_k_id,
B_k_id);
return;
}
const index_t block_m_id =
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapMBlockIndex>{}]);
const index_t block_n_id =
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapNBlockIndex>{}]);
// BScale struct (Empty)
using Scale = typename BlockwiseGemmPipe::Empty;
auto a_scale_struct = Scale{};
auto b_scale_struct = Scale{};
const index_t num_k_block_per_scale = GetKBlockPerScale();
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(a_scale_struct),
decltype(b_scale_struct),
decltype(epilogue_args),
HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum>(p_as_grid,
p_bs_grid,
p_ds_grid,
p_e_grid,
p_shared,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
a_element_op,
b_element_op,
cde_element_op,
block_m_id,
block_n_id,
num_k_block_per_scale,
a_scale_struct,
b_scale_struct,
epilogue_args,
A_k_id,
B_k_id);
}
template <bool HasMainKBlockLoop,
@@ -950,124 +871,117 @@ struct GridwiseGemm_wmma_cshuffle_v3
Argument& karg,
EpilogueArgument& epilogue_args)
{
if constexpr(IsValidCompilationParameter())
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch);
const index_t k_idx =
__builtin_amdgcn_readfirstlane((blockIdx.z - n_idx * karg.KBatch) * num_k_per_block);
// offset base pointer for each work-group
const long_index_t a_batch_offset =
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))
: amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))
: amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const long_index_t a_n_offset =
CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t b_n_offset =
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0;
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
AsGridPointer p_as_grid_;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
p_as_grid_(i) =
static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset + a_n_offset;
});
BsGridPointer p_bs_grid_;
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
p_bs_grid_(i) =
static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset;
});
DsGridPointer p_ds_grid_grp;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_batch_offset[i]; });
// Currently supporting one A and one B
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
[&](auto i) {
ignore = i;
return a_grid_desc_ak0_m_ak1;
},
Number<NumATensor>{});
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
[&](auto i) {
ignore = i;
return b_grid_desc_bk0_n_bk1;
},
Number<NumBTensor>{});
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch);
const index_t k_idx = __builtin_amdgcn_readfirstlane(
(blockIdx.z - n_idx * karg.KBatch) * num_k_per_block);
// offset base pointer for each work-group
const long_index_t a_batch_offset =
CTranspose
? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))
: amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
CTranspose
? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))
: amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const long_index_t a_n_offset =
CTranspose ? 0
: amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t b_n_offset =
CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx))
: 0;
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
AsGridPointer p_as_grid_;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
p_as_grid_(i) =
static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset + a_n_offset;
});
BsGridPointer p_bs_grid_;
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
p_bs_grid_(i) =
static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset;
});
DsGridPointer p_ds_grid_grp;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_batch_offset[i]; });
// Currently supporting one A and one B
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
[&](auto i) {
ignore = i;
return a_grid_desc_ak0_m_ak1;
},
Number<NumATensor>{});
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
[&](auto i) {
ignore = i;
return b_grid_desc_bk0_n_bk1;
},
Number<NumBTensor>{});
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
// AScale struct (Empty)
using AScale = typename BlockwiseGemmPipe::Empty;
auto a_scale_struct = AScale{};
// BScale struct (Empty)
using BScale = typename BlockwiseGemmPipe::Empty;
auto b_scale_struct = BScale{};
const index_t num_k_block_per_scale = GetKBlockPerScale();
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(a_scale_struct),
decltype(b_scale_struct),
decltype(epilogue_args),
HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum>(p_as_grid_,
p_bs_grid_,
p_ds_grid_grp,
karg.p_e_grid + e_batch_offset + e_n_offset,
p_shared,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
block_m_id,
block_n_id,
num_k_block_per_scale,
a_scale_struct,
b_scale_struct,
epilogue_args,
k_idx,
k_idx,
karg.KBatch);
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
// AScale struct (Empty)
using AScale = typename BlockwiseGemmPipe::Empty;
auto a_scale_struct = AScale{};
// BScale struct (Empty)
using BScale = typename BlockwiseGemmPipe::Empty;
auto b_scale_struct = BScale{};
const index_t num_k_block_per_scale = GetKBlockPerScale();
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(a_scale_struct),
decltype(b_scale_struct),
decltype(epilogue_args),
HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum>(p_as_grid_,
p_bs_grid_,
p_ds_grid_grp,
karg.p_e_grid + e_batch_offset + e_n_offset,
p_shared,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
block_m_id,
block_n_id,
num_k_block_per_scale,
a_scale_struct,
b_scale_struct,
epilogue_args,
k_idx,
k_idx,
karg.KBatch);
}
// Run method for convolution (grid descriptors are passed as arguments,
@@ -1091,106 +1005,103 @@ struct GridwiseGemm_wmma_cshuffle_v3
Argument& karg,
EpilogueArgument& epilogue_args)
{
if constexpr(IsValidCompilationParameter())
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
AsGridPointer p_as_grid_;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
p_as_grid_(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset;
});
BsGridPointer p_bs_grid_;
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
p_bs_grid_(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset;
});
const auto ds_grid_desc_m_n =
MakeDsGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideDs);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, karg.MBlock, karg.NBlock);
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
[&](auto i) {
ignore = i;
return a_grid_desc_ak0_m_ak1;
},
Number<NumATensor>{});
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
[&](auto i) {
ignore = i;
return b_grid_desc_bk0_n_bk1;
},
Number<NumBTensor>{});
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
AsGridPointer p_as_grid_;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
p_as_grid_(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset;
});
BsGridPointer p_bs_grid_;
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
p_bs_grid_(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset;
});
const auto ds_grid_desc_m_n =
MakeDsGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideDs);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
ds_grid_desc_m_n, karg.MBlock, karg.NBlock);
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
[&](auto i) {
ignore = i;
return a_grid_desc_ak0_m_ak1;
},
Number<NumATensor>{});
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
[&](auto i) {
ignore = i;
return b_grid_desc_bk0_n_bk1;
},
Number<NumBTensor>{});
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
// Scale structs (Empty)
using Scale = typename BlockwiseGemmPipe::Empty;
auto b_scale_struct = Scale{};
auto a_scale_struct = Scale{};
const index_t num_k_block_per_scale = GetKBlockPerScale();
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(a_scale_struct),
decltype(b_scale_struct),
decltype(epilogue_args),
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(p_as_grid_,
p_bs_grid_,
karg.p_ds_grid,
karg.p_e_grid + e_batch_offset,
p_shared,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_desc_mblock_mperblock_nblock_nperblock,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
block_m_id,
block_n_id,
num_k_block_per_scale,
a_scale_struct,
b_scale_struct,
epilogue_args,
k_idx,
k_idx,
karg.KBatch);
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
// Scale structs (Empty)
using Scale = typename BlockwiseGemmPipe::Empty;
auto b_scale_struct = Scale{};
auto a_scale_struct = Scale{};
const index_t num_k_block_per_scale = GetKBlockPerScale();
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(a_scale_struct),
decltype(b_scale_struct),
decltype(epilogue_args),
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(p_as_grid_,
p_bs_grid_,
karg.p_ds_grid,
karg.p_e_grid + e_batch_offset,
p_shared,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
c_grid_desc_mblock_mperblock_nblock_nperblock,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
block_m_id,
block_n_id,
num_k_block_per_scale,
a_scale_struct,
b_scale_struct,
epilogue_args,
k_idx,
k_idx,
karg.KBatch);
}
// Run method for convolution fwd (grid descriptors are passed as arguments,
@@ -1218,120 +1129,117 @@ struct GridwiseGemm_wmma_cshuffle_v3
Argument& karg,
EpilogueArgument& epilogue_args)
{
if constexpr(IsValidCompilationParameter())
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch);
// offset base pointer for each work-group
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t b_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetBPtrOffset(n_idx));
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
const auto ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
AsGridPointer p_as_grid_;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
p_as_grid_(i) =
static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset + a_n_offset;
});
BsGridPointer p_bs_grid_;
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
p_bs_grid_(i) =
static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset;
});
DsGridPointer p_ds_grid_grp;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid_grp(i) = static_cast<const DDataType_*>(karg.p_ds_grid[i]) +
ds_batch_offset[i] + ds_n_offset[i];
});
// Currently supporting one A and one B
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
[&](auto i) {
ignore = i;
return a_grid_desc_ak0_m_ak1;
},
Number<NumATensor>{});
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
[&](auto i) {
ignore = i;
return b_grid_desc_bk0_n_bk1;
},
Number<NumBTensor>{});
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch);
// offset base pointer for each work-group
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t b_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetBPtrOffset(n_idx));
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
const auto ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
AsGridPointer p_as_grid_;
static_for<0, NumATensor, 1>{}([&](auto i) {
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
p_as_grid_(i) =
static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset + a_n_offset;
});
BsGridPointer p_bs_grid_;
static_for<0, NumBTensor, 1>{}([&](auto i) {
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
p_bs_grid_(i) =
static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset;
});
DsGridPointer p_ds_grid_grp;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid_grp(i) = static_cast<const DDataType_*>(karg.p_ds_grid[i]) +
ds_batch_offset[i] + ds_n_offset[i];
});
// Currently supporting one A and one B
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
[&](auto i) {
ignore = i;
return a_grid_desc_ak0_m_ak1;
},
Number<NumATensor>{});
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
[&](auto i) {
ignore = i;
return b_grid_desc_bk0_n_bk1;
},
Number<NumBTensor>{});
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
// AScale struct (Empty)
using AScale = typename BlockwiseGemmPipe::Empty;
auto a_scale_struct = AScale{};
// BScale struct (Empty)
using BScale = typename BlockwiseGemmPipe::Empty;
auto b_scale_struct = BScale{};
const index_t num_k_block_per_scale = GetKBlockPerScale();
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(a_scale_struct),
decltype(b_scale_struct),
decltype(epilogue_args),
HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum>(p_as_grid_,
p_bs_grid_,
p_ds_grid_grp,
karg.p_e_grid + e_batch_offset + e_n_offset,
p_shared,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
block_m_id,
block_n_id,
num_k_block_per_scale,
a_scale_struct,
b_scale_struct,
epilogue_args);
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
// AScale struct (Empty)
using AScale = typename BlockwiseGemmPipe::Empty;
auto a_scale_struct = AScale{};
// BScale struct (Empty)
using BScale = typename BlockwiseGemmPipe::Empty;
auto b_scale_struct = BScale{};
const index_t num_k_block_per_scale = GetKBlockPerScale();
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(a_scale_struct),
decltype(b_scale_struct),
decltype(epilogue_args),
HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum>(p_as_grid_,
p_bs_grid_,
p_ds_grid_grp,
karg.p_e_grid + e_batch_offset + e_n_offset,
p_shared,
as_grid_desc_ak0_m_ak1,
bs_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
block_m_id,
block_n_id,
num_k_block_per_scale,
a_scale_struct,
b_scale_struct,
epilogue_args);
}
};