mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
improved zeroing (#1221)
This commit is contained in:
@@ -23,8 +23,8 @@ add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16)
|
||||
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8)
|
||||
|
||||
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8)
|
||||
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16_fp8 grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16_fp8)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
|
||||
|
||||
@@ -36,7 +36,7 @@ using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = F32;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
@@ -55,7 +55,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
|
||||
struct ProblemSize final
|
||||
@@ -298,9 +298,9 @@ int main(int argc, char* argv[])
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(256 + 256 * i);
|
||||
problem_size.Ns.push_back(256);
|
||||
problem_size.Ks.push_back(128);
|
||||
problem_size.Ms.push_back(128 + rand() % 128);
|
||||
problem_size.Ns.push_back(1024);
|
||||
problem_size.Ks.push_back(1024);
|
||||
|
||||
problem_size.stride_As.push_back(problem_size.Ks[i]);
|
||||
problem_size.stride_Bs.push_back(problem_size.Ks[i]);
|
||||
|
||||
@@ -35,7 +35,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using ADataType = F16;
|
||||
using BDataType = F8;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = F16;
|
||||
|
||||
@@ -56,7 +56,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
// clang-format on
|
||||
|
||||
struct ProblemSize final
|
||||
@@ -23,6 +23,7 @@ namespace device {
|
||||
template <typename GridwiseGemm,
|
||||
typename GemmDesc,
|
||||
GemmSpecialization GemmSpec,
|
||||
bool Zeroing,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
@@ -106,33 +107,63 @@ __global__ void
|
||||
const auto block_2_etile_map =
|
||||
GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off);
|
||||
|
||||
auto barrier_count_finished =
|
||||
barrier_count + group_id * barrier_size_grp + id_local % mn_blocks;
|
||||
if constexpr(Zeroing)
|
||||
{
|
||||
auto barrier_count_finished =
|
||||
barrier_count + group_id * barrier_size_grp + id_local % mn_blocks;
|
||||
GridwiseGemm::template RunWithZeroing<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout>(gemm_desc_ptr[group_id].p_a_grid,
|
||||
gemm_desc_ptr[group_id].p_b_grid,
|
||||
p_ds_grid_,
|
||||
gemm_desc_ptr[group_id].p_e_grid,
|
||||
p_shared,
|
||||
barrier_count_finished,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
block_2_etile_map);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout>(gemm_desc_ptr[group_id].p_a_grid,
|
||||
gemm_desc_ptr[group_id].p_b_grid,
|
||||
p_ds_grid_,
|
||||
gemm_desc_ptr[group_id].p_e_grid,
|
||||
p_shared,
|
||||
barrier_count_finished,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
block_2_etile_map);
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout>(gemm_desc_ptr[group_id].p_a_grid,
|
||||
gemm_desc_ptr[group_id].p_b_grid,
|
||||
p_ds_grid_,
|
||||
gemm_desc_ptr[group_id].p_e_grid,
|
||||
p_shared,
|
||||
nullptr,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
block_2_etile_map);
|
||||
}
|
||||
|
||||
id_off += grid_size_grp;
|
||||
id_local += grid_size_grp;
|
||||
@@ -193,8 +224,11 @@ template <typename ALayout,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename ComputeType = ADataType,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
typename ComputeType = ADataType,
|
||||
typename ALDSType = ComputeType,
|
||||
typename BLDSType = ComputeType>
|
||||
struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
@@ -215,11 +249,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
using AComputeType = ComputeType;
|
||||
using BComputeType = ComputeType;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle<
|
||||
ADataType, // TODO: distinguish A/B datatype
|
||||
BDataType,
|
||||
ComputeType,
|
||||
AComputeType,
|
||||
BComputeType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
@@ -258,7 +296,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
LoopSched,
|
||||
PipelineVer,
|
||||
ALDSType,
|
||||
BLDSType>;
|
||||
|
||||
template <typename UnderlyingBlockToCTileMap>
|
||||
struct OffsettedBlockToCTileMapMLoops
|
||||
@@ -613,45 +654,85 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
float ave_time = 0;
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) {
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm,
|
||||
GroupedGemmKernelArgument<NumDTensor>,
|
||||
GemmSpec,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
DsDataType,
|
||||
Block2ETileMap,
|
||||
GroupedGemmBlock2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
e_global_memory_operation_,
|
||||
has_main_k_block_loop_>;
|
||||
if(arg.k_batch_ == 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm,
|
||||
GroupedGemmKernelArgument<NumDTensor>,
|
||||
GemmSpec,
|
||||
false,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
DsDataType,
|
||||
Block2ETileMap,
|
||||
GroupedGemmBlock2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
e_global_memory_operation_,
|
||||
has_main_k_block_loop_>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
|
||||
reinterpret_cast<uint32_t*>(arg.p_workspace_),
|
||||
arg.barrier_size_grp_,
|
||||
arg.gemm_desc_kernel_arg_.size(),
|
||||
arg.grid_size_grp_,
|
||||
arg.k_batch_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_);
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
|
||||
nullptr,
|
||||
arg.barrier_size_grp_,
|
||||
arg.gemm_desc_kernel_arg_.size(),
|
||||
arg.grid_size_grp_,
|
||||
arg.k_batch_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm,
|
||||
GroupedGemmKernelArgument<NumDTensor>,
|
||||
GemmSpec,
|
||||
true,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
DsDataType,
|
||||
Block2ETileMap,
|
||||
GroupedGemmBlock2ETileMap,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
e_global_memory_operation_,
|
||||
has_main_k_block_loop_>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
|
||||
reinterpret_cast<uint32_t*>(arg.p_workspace_),
|
||||
arg.barrier_size_grp_,
|
||||
arg.gemm_desc_kernel_arg_.size(),
|
||||
arg.grid_size_grp_,
|
||||
arg.k_batch_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd;
|
||||
constexpr auto Set = InMemoryDataOperationEnum::Set;
|
||||
|
||||
// For bf16 datatype only kbatch = 1 scenario is supported. This condition is enforced
|
||||
// in IsSupportedArgument function
|
||||
// For bf16 datatype only kbatch = 1 scenario is supported. This condition is
|
||||
// enforced in IsSupportedArgument function
|
||||
if constexpr(std::is_same<ADataType, ck::bhalf_t>::value)
|
||||
{
|
||||
if(has_main_k_block_loop)
|
||||
@@ -719,12 +800,12 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
|
||||
|
||||
bool supported = true;
|
||||
|
||||
// If we use padding we do not support vector loads for dimensions not divisible by vector
|
||||
// load size.
|
||||
// If we use padding we do not support vector loads for dimensions not divisible by
|
||||
// vector load size.
|
||||
if constexpr(GemmSpec != GemmSpecialization::Default)
|
||||
{
|
||||
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout,
|
||||
// thus we have to adapt it to the {M,K} or {N,K} layout.
|
||||
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1}
|
||||
// layout, thus we have to adapt it to the {M,K} or {N,K} layout.
|
||||
const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
|
||||
const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
|
||||
|
||||
|
||||
@@ -31,7 +31,8 @@ namespace ck {
|
||||
// D0, D1, ... and E have the same layout
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename ComputeType,
|
||||
typename AComputeType,
|
||||
typename BComputeType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
@@ -71,7 +72,9 @@ template <typename ADataType,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched,
|
||||
PipelineVersion PipelineVer = PipelineVersion::v1>
|
||||
PipelineVersion PipelineVer,
|
||||
typename ALDSType,
|
||||
typename BLDSType>
|
||||
struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
@@ -186,8 +189,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
constexpr auto c_block_size =
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
|
||||
sizeof(ComputeType),
|
||||
return math::max(a_block_space_size_aligned * sizeof(ALDSType) +
|
||||
b_block_space_size_aligned * sizeof(BLDSType),
|
||||
c_block_size * sizeof(CShuffleDataType));
|
||||
}
|
||||
|
||||
@@ -455,6 +458,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
index_t NumDTensor_,
|
||||
typename DsDataType_,
|
||||
bool Zeroing,
|
||||
typename AGridDesc_KBatch_AK0_M_AK1,
|
||||
typename BGridDesc_KBatch_BK0_N_BK1,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -530,7 +534,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
ComputeType,
|
||||
ALDSType,
|
||||
decltype(a_grid_desc_kbatch_ak0_m_ak1),
|
||||
decltype(a_block_desc_kbatch_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -561,7 +565,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
ComputeType,
|
||||
BLDSType,
|
||||
decltype(b_grid_desc_kbatch_bk0_n_bk1),
|
||||
decltype(b_block_desc_kbatch_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
@@ -597,12 +601,12 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
// sanity check
|
||||
constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1, BK1),
|
||||
MfmaSelector<ComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
MfmaSelector<AComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
ComputeType,
|
||||
ComputeType,
|
||||
ALDSType,
|
||||
BLDSType,
|
||||
AccDataType,
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
@@ -611,62 +615,65 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
LoopSched>();
|
||||
LoopSched,
|
||||
AComputeType,
|
||||
BComputeType>();
|
||||
|
||||
#if 1
|
||||
if(block_work_idx[I0] == 0)
|
||||
if constexpr(Zeroing)
|
||||
{
|
||||
const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock;
|
||||
const index_t numNThreads = NPerBlock / nThreadSize;
|
||||
const index_t numMThreads = BlockSize / numNThreads;
|
||||
const index_t mThreadSize = MPerBlock / numMThreads;
|
||||
|
||||
const index_t m_tid = get_thread_local_1d_id() / numNThreads;
|
||||
const index_t n_tid = get_thread_local_1d_id() % numNThreads;
|
||||
|
||||
auto c_thread_desc_mblock_mperblock_nblock_nperblock =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<mThreadSize>{}, I1, Number<nThreadSize>{}));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
EDataType,
|
||||
c_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(),
|
||||
true>
|
||||
e_thread_zero_buf;
|
||||
|
||||
auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
EDataType,
|
||||
EDataType,
|
||||
decltype(c_thread_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<1, mThreadSize, 1, nThreadSize>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_work_idx[I1],
|
||||
m_tid * mThreadSize,
|
||||
block_work_idx[I2],
|
||||
n_tid * nThreadSize),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
e_thread_zero_buf,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
if(block_work_idx[I0] == 0)
|
||||
{
|
||||
atomicAdd(barrier_count_finished, 1);
|
||||
const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock;
|
||||
const index_t numNThreads = NPerBlock / nThreadSize;
|
||||
const index_t numMThreads = BlockSize / numNThreads;
|
||||
const index_t mThreadSize = MPerBlock / numMThreads;
|
||||
|
||||
const index_t m_tid = get_thread_local_1d_id() / numNThreads;
|
||||
const index_t n_tid = get_thread_local_1d_id() % numNThreads;
|
||||
|
||||
auto c_thread_desc_mblock_mperblock_nblock_nperblock =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1, Number<mThreadSize>{}, I1, Number<nThreadSize>{}));
|
||||
|
||||
StaticBuffer<AddressSpaceEnum::Vgpr,
|
||||
EDataType,
|
||||
c_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(),
|
||||
true>
|
||||
e_thread_zero_buf;
|
||||
|
||||
auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3<
|
||||
EDataType,
|
||||
EDataType,
|
||||
decltype(c_thread_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<1, mThreadSize, 1, nThreadSize>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
1,
|
||||
true>{e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_work_idx[I1],
|
||||
m_tid * mThreadSize,
|
||||
block_work_idx[I2],
|
||||
n_tid * nThreadSize),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
e_thread_zero_buf,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_buf);
|
||||
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
atomicAdd(barrier_count_finished, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
|
||||
|
||||
@@ -675,10 +682,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ComputeType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
static_cast<ALDSType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<ComputeType*>(p_shared) + a_block_space_size_aligned,
|
||||
static_cast<BLDSType*>(p_shared) + a_block_space_size_aligned,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0);
|
||||
@@ -711,13 +718,15 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
|
||||
// shuffle C and write out
|
||||
{
|
||||
if(threadIdx.x == 0)
|
||||
if constexpr(Zeroing)
|
||||
{
|
||||
while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {}
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {}
|
||||
}
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
|
||||
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
|
||||
"wrong!");
|
||||
@@ -951,18 +960,131 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
}
|
||||
});
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
if constexpr(Zeroing)
|
||||
{
|
||||
index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1);
|
||||
|
||||
if(k_id_finished_t == KBatch)
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
*barrier_count_finished = 0;
|
||||
index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1);
|
||||
|
||||
if(k_id_finished_t == KBatch)
|
||||
{
|
||||
*barrier_count_finished = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename Block2ETileMap>
|
||||
__device__ static void RunWithZeroing(const void* __restrict__ p_a_grid_,
|
||||
const void* __restrict__ p_b_grid_,
|
||||
DsGridPointer p_ds_grid,
|
||||
void* __restrict__ p_e_grid_,
|
||||
void* __restrict__ p_shared,
|
||||
uint32_t* barrier_count_finished,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t StrideA,
|
||||
const index_t StrideB,
|
||||
const std::array<index_t, NumDTensor> StrideDs,
|
||||
const index_t StrideE,
|
||||
const index_t KBatch,
|
||||
const Block2ETileMap& block_2_etile_map)
|
||||
{
|
||||
const auto p_a_grid = reinterpret_cast<const ADataType*>(p_a_grid_);
|
||||
const auto p_b_grid = reinterpret_cast<const BDataType*>(p_b_grid_);
|
||||
const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
|
||||
|
||||
using DsGridDesc_M_N =
|
||||
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N<DsLayout, GemmSpec>({}, {}, {}))>;
|
||||
|
||||
DsGridDesc_M_N ds_grid_desc_m_n;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
|
||||
|
||||
ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout, GemmSpec>(M, N, StrideDs[j]);
|
||||
});
|
||||
|
||||
const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
const auto a_grid_desc_kbatch_ak0_m_ak1 =
|
||||
MakeAGridDescriptor_KBatch_AK0_M_AK1<ALayout, GemmSpec>(M, K, StrideA, KBatch);
|
||||
|
||||
const auto b_grid_desc_kbatch_bk0_n_bk1 =
|
||||
MakeBGridDescriptor_KBatch_BK0_N_BK1<BLayout, GemmSpec>(K, N, StrideB, KBatch);
|
||||
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock(j) =
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]);
|
||||
});
|
||||
|
||||
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
|
||||
if(kbatch_id == KBatch - 1)
|
||||
{
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, NumDTensor, DsDataType, true>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
barrier_count_finished,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_kbatch_ak0_m_ak1,
|
||||
b_grid_desc_kbatch_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
}
|
||||
else
|
||||
{
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, 0, Tuple<>, true>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
barrier_count_finished,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
a_grid_desc_kbatch_ak0_m_ak1,
|
||||
b_grid_desc_kbatch_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
@@ -976,7 +1098,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
DsGridPointer p_ds_grid,
|
||||
void* __restrict__ p_e_grid_,
|
||||
void* __restrict__ p_shared,
|
||||
uint32_t* barrier_count_finished,
|
||||
uint32_t*,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op,
|
||||
@@ -1028,49 +1150,22 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
|
||||
if(kbatch_id == KBatch - 1)
|
||||
{
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, NumDTensor, DsDataType>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
barrier_count_finished,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_kbatch_ak0_m_ak1,
|
||||
b_grid_desc_kbatch_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
}
|
||||
else
|
||||
{
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, 0, Tuple<>>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
barrier_count_finished,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
a_grid_desc_kbatch_ak0_m_ak1,
|
||||
b_grid_desc_kbatch_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
}
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, NumDTensor, DsDataType, false>(
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
nullptr,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_kbatch_ak0_m_ak1,
|
||||
b_grid_desc_kbatch_bk0_n_bk1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_etile_map);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user