mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
estimate vgpr
This commit is contained in:
@@ -33,10 +33,10 @@ using DeviceGemmV2Instance =
|
||||
ADataType, BDataType, DsDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmDefault,
|
||||
256,
|
||||
128, 256, 64,
|
||||
128, 128, 64,
|
||||
8, 8,
|
||||
16, 16,
|
||||
4, 4,
|
||||
4, 2,
|
||||
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 1, 8, true,
|
||||
S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
|
||||
@@ -81,13 +81,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
|
||||
GemmSpec,
|
||||
256,
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
64,
|
||||
8,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
2,
|
||||
2,
|
||||
S<8, 32, 1>,
|
||||
S<1, 0, 2>,
|
||||
@@ -104,7 +104,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
|
||||
8,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
S<1, 32, 1, 8>,
|
||||
S<8, 8, 8>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
|
||||
@@ -106,7 +106,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
|
||||
S<1, 32, 1, 8>,
|
||||
S<8, 8, 8>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v3>;
|
||||
ck::BlockGemmPipelineVersion::v1>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
|
||||
@@ -107,7 +107,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Wmm
|
||||
S<1, 32, 1, 8>,
|
||||
S<8, 8, 8>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v3>;
|
||||
ck::BlockGemmPipelineVersion::v1>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
|
||||
@@ -65,12 +65,12 @@ using DeviceOpInstance =
|
||||
A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
128,
|
||||
32, 128, 128,
|
||||
32, 128, 64,
|
||||
8, 8,
|
||||
16, 16,
|
||||
2, 2,
|
||||
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
|
||||
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
|
||||
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
|
||||
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
|
||||
1, 1, S<1, 16, 1, 8>, S<4, 4, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v1,
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
@@ -565,6 +566,87 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
|
||||
|
||||
template <bool IsGfx11>
|
||||
static constexpr index_t GetEstimateVgprCount()
|
||||
{
|
||||
constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma);
|
||||
constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma);
|
||||
constexpr index_t WaveSize = BlockSize / (MWave * NWave);
|
||||
|
||||
// VGPR used in LDS loading and WMMA
|
||||
constexpr index_t BaseInputVgprCount =
|
||||
MPerBlock * KPerBlock / MWave / WaveSize * sizeof(ComputeTypeA) / sizeof(uint32_t) +
|
||||
NPerBlock * KPerBlock / NWave / WaveSize * sizeof(ComputeTypeB) / sizeof(uint32_t);
|
||||
// WMMA input is duplicated in GFX11
|
||||
constexpr index_t InputVgprCount = IsGfx11 ? BaseInputVgprCount * 2 : BaseInputVgprCount;
|
||||
// VGPR used in Accumulator
|
||||
constexpr index_t AccVgprCount =
|
||||
MPerBlock * NPerBlock / BlockSize * sizeof(AccDataType) / sizeof(uint32_t);
|
||||
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
return InputVgprCount + AccVgprCount;
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
return InputVgprCount * 2 + AccVgprCount;
|
||||
}
|
||||
else
|
||||
{
|
||||
// invalid pipeline version
|
||||
static_assert(0);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static bool constexpr IsValidCompilationParameter()
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
constexpr bool IsGfx11 = true;
|
||||
#else
|
||||
constexpr bool IsGfx11 = false;
|
||||
#endif
|
||||
constexpr auto EstimateVgprCount = GetEstimateVgprCount<IsGfx11>();
|
||||
#if defined(__gfx11__) || defined(__gfx12__)
|
||||
constexpr index_t AvailableVgprCount = 256;
|
||||
#else
|
||||
constexpr index_t AvailableVgprCount = 512;
|
||||
#endif
|
||||
if constexpr(EstimateVgprCount > (AvailableVgprCount + AvailableVgprCount / 4))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Argument>
|
||||
__host__ static bool CheckValidity(const Argument& karg, bool allow_short_v3_pipe = false)
|
||||
{
|
||||
const auto availableVgprCount = []() {
|
||||
if(ck::is_gfx12_supported())
|
||||
{
|
||||
return 256;
|
||||
}
|
||||
else if(ck::is_gfx11_supported())
|
||||
{
|
||||
return 256;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 512;
|
||||
}
|
||||
}();
|
||||
const auto estimateVgprCount =
|
||||
ck::is_gfx11_supported() ? GetEstimateVgprCount<true>() : GetEstimateVgprCount<false>();
|
||||
if(estimateVgprCount > (availableVgprCount + availableVgprCount / 4))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return Base::template CheckValidity<Argument>(karg, allow_short_v3_pipe);
|
||||
}
|
||||
__device__ static index_t GetKBlockPerScale() { return 1; }
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
@@ -588,74 +670,81 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
const index_t A_k_id = 0,
|
||||
const index_t B_k_id = 0)
|
||||
{
|
||||
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
|
||||
const index_t K_b = IsBPreShuffled ? problem.Kt : problem.K;
|
||||
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
|
||||
K_b, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
|
||||
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
|
||||
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
|
||||
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
if constexpr(IsValidCompilationParameter())
|
||||
{
|
||||
return;
|
||||
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(problem.M,
|
||||
problem.MPadded,
|
||||
problem.K,
|
||||
problem.KPadded,
|
||||
problem.StrideAs,
|
||||
problem.AK0);
|
||||
const index_t K_b = IsBPreShuffled ? problem.Kt : problem.K;
|
||||
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
|
||||
K_b, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
|
||||
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
|
||||
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideE);
|
||||
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const index_t block_m_id =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapMBlockIndex>{}]);
|
||||
const index_t block_n_id =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapNBlockIndex>{}]);
|
||||
|
||||
// BScale struct (Empty)
|
||||
using Scale = typename BlockwiseGemmPipe::Empty;
|
||||
auto a_scale_struct = Scale{};
|
||||
auto b_scale_struct = Scale{};
|
||||
|
||||
const index_t num_k_block_per_scale = GetKBlockPerScale();
|
||||
|
||||
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(a_scale_struct),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_as_grid,
|
||||
p_bs_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args,
|
||||
A_k_id,
|
||||
B_k_id);
|
||||
}
|
||||
|
||||
const index_t block_m_id =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapMBlockIndex>{}]);
|
||||
const index_t block_n_id =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapNBlockIndex>{}]);
|
||||
|
||||
// BScale struct (Empty)
|
||||
using Scale = typename BlockwiseGemmPipe::Empty;
|
||||
auto a_scale_struct = Scale{};
|
||||
auto b_scale_struct = Scale{};
|
||||
|
||||
const index_t num_k_block_per_scale = GetKBlockPerScale();
|
||||
|
||||
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(a_scale_struct),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_as_grid,
|
||||
p_bs_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args,
|
||||
A_k_id,
|
||||
B_k_id);
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
@@ -937,103 +1026,106 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
|
||||
AsGridPointer p_as_grid_;
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
p_as_grid_(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset;
|
||||
});
|
||||
|
||||
BsGridPointer p_bs_grid_;
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
p_bs_grid_(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset;
|
||||
});
|
||||
|
||||
const auto ds_grid_desc_m_n =
|
||||
MakeDsGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideDs);
|
||||
|
||||
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n, karg.MBlock, karg.NBlock);
|
||||
|
||||
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
},
|
||||
Number<NumATensor>{});
|
||||
|
||||
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
},
|
||||
Number<NumBTensor>{});
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
if constexpr(IsValidCompilationParameter())
|
||||
{
|
||||
return;
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
|
||||
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
|
||||
AsGridPointer p_as_grid_;
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
p_as_grid_(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset;
|
||||
});
|
||||
|
||||
BsGridPointer p_bs_grid_;
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
p_bs_grid_(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset;
|
||||
});
|
||||
|
||||
const auto ds_grid_desc_m_n =
|
||||
MakeDsGridDescriptor_M_N(karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideDs);
|
||||
|
||||
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n, karg.MBlock, karg.NBlock);
|
||||
|
||||
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
},
|
||||
Number<NumATensor>{});
|
||||
|
||||
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
},
|
||||
Number<NumBTensor>{});
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
|
||||
|
||||
// Scale structs (Empty)
|
||||
using Scale = typename BlockwiseGemmPipe::Empty;
|
||||
auto b_scale_struct = Scale{};
|
||||
auto a_scale_struct = Scale{};
|
||||
|
||||
const index_t num_k_block_per_scale = GetKBlockPerScale();
|
||||
|
||||
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(a_scale_struct),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(p_as_grid_,
|
||||
p_bs_grid_,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + e_batch_offset,
|
||||
p_shared,
|
||||
as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args,
|
||||
k_idx,
|
||||
k_idx,
|
||||
karg.KBatch);
|
||||
}
|
||||
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
|
||||
|
||||
// Scale structs (Empty)
|
||||
using Scale = typename BlockwiseGemmPipe::Empty;
|
||||
auto b_scale_struct = Scale{};
|
||||
auto a_scale_struct = Scale{};
|
||||
|
||||
const index_t num_k_block_per_scale = GetKBlockPerScale();
|
||||
|
||||
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(a_scale_struct),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum>(p_as_grid_,
|
||||
p_bs_grid_,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + e_batch_offset,
|
||||
p_shared,
|
||||
as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args,
|
||||
k_idx,
|
||||
k_idx,
|
||||
karg.KBatch);
|
||||
}
|
||||
|
||||
// Run method for convolution fwd (grid descriptors are passed as arguments,
|
||||
@@ -1061,117 +1153,120 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch);
|
||||
// offset base pointer for each work-group
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
const long_index_t a_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
|
||||
const long_index_t b_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetBPtrOffset(n_idx));
|
||||
const long_index_t e_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
|
||||
|
||||
const auto ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
|
||||
|
||||
AsGridPointer p_as_grid_;
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
p_as_grid_(i) =
|
||||
static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset + a_n_offset;
|
||||
});
|
||||
|
||||
BsGridPointer p_bs_grid_;
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
p_bs_grid_(i) =
|
||||
static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset;
|
||||
});
|
||||
|
||||
DsGridPointer p_ds_grid_grp;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
p_ds_grid_grp(i) = static_cast<const DDataType_*>(karg.p_ds_grid[i]) +
|
||||
ds_batch_offset[i] + ds_n_offset[i];
|
||||
});
|
||||
|
||||
// Currently supporting one A and one B
|
||||
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
},
|
||||
Number<NumATensor>{});
|
||||
|
||||
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
},
|
||||
Number<NumBTensor>{});
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
if constexpr(IsValidCompilationParameter())
|
||||
{
|
||||
return;
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / karg.KBatch);
|
||||
// offset base pointer for each work-group
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
const long_index_t a_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
|
||||
const long_index_t b_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetBPtrOffset(n_idx));
|
||||
const long_index_t e_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
|
||||
|
||||
const auto ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
|
||||
|
||||
AsGridPointer p_as_grid_;
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
using ADataType_ = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
|
||||
p_as_grid_(i) =
|
||||
static_cast<const ADataType_*>(karg.p_as_grid[i]) + a_batch_offset + a_n_offset;
|
||||
});
|
||||
|
||||
BsGridPointer p_bs_grid_;
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
using BDataType_ = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
|
||||
p_bs_grid_(i) =
|
||||
static_cast<const BDataType_*>(karg.p_bs_grid[i]) + b_batch_offset + b_n_offset;
|
||||
});
|
||||
|
||||
DsGridPointer p_ds_grid_grp;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
p_ds_grid_grp(i) = static_cast<const DDataType_*>(karg.p_ds_grid[i]) +
|
||||
ds_batch_offset[i] + ds_n_offset[i];
|
||||
});
|
||||
|
||||
// Currently supporting one A and one B
|
||||
const auto as_grid_desc_ak0_m_ak1 = generate_tuple(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
},
|
||||
Number<NumATensor>{});
|
||||
|
||||
const auto bs_grid_desc_bk0_n_bk1 = generate_tuple(
|
||||
[&](auto i) {
|
||||
ignore = i;
|
||||
return b_grid_desc_bk0_n_bk1;
|
||||
},
|
||||
Number<NumBTensor>{});
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4};
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
|
||||
|
||||
// AScale struct (Empty)
|
||||
using AScale = typename BlockwiseGemmPipe::Empty;
|
||||
auto a_scale_struct = AScale{};
|
||||
|
||||
// BScale struct (Empty)
|
||||
using BScale = typename BlockwiseGemmPipe::Empty;
|
||||
auto b_scale_struct = BScale{};
|
||||
|
||||
const index_t num_k_block_per_scale = GetKBlockPerScale();
|
||||
|
||||
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(a_scale_struct),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_as_grid_,
|
||||
p_bs_grid_,
|
||||
p_ds_grid_grp,
|
||||
karg.p_e_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
|
||||
|
||||
// AScale struct (Empty)
|
||||
using AScale = typename BlockwiseGemmPipe::Empty;
|
||||
auto a_scale_struct = AScale{};
|
||||
|
||||
// BScale struct (Empty)
|
||||
using BScale = typename BlockwiseGemmPipe::Empty;
|
||||
auto b_scale_struct = BScale{};
|
||||
|
||||
const index_t num_k_block_per_scale = GetKBlockPerScale();
|
||||
|
||||
Base::template Run<decltype(as_grid_desc_ak0_m_ak1),
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(a_scale_struct),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum>(p_as_grid_,
|
||||
p_bs_grid_,
|
||||
p_ds_grid_grp,
|
||||
karg.p_e_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user