mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
MNKO padding support on bmm+masking+scale+softmax+bmm+premute (#425)
* add lower triangle bmm * init code for tile skipping * functionality right with lower triangle mask * add decoder lower triangular mask calculation * use 7*13 group * fix n2 compute error * attention with lower triangle mask with tile skipping * add template to distinguish masking kernel * rename template and remove default template value * remove lower triangle gemm reference struct * add some comments on example * add 10 instance for masking bmm + scale + softmax + bmm + permute kernels * add test * add test file * add gtest for bmm masking scale softmax bmm permute * clang-format * fix compile error * check lef bottom corner for tile skipping * fix error: check left bottom corner for tile skipping * add k padding * add test and instance for MNK padding * passing a mask struct * fix instances * delete used comments * format Co-authored-by: danyao12 <yaodan@dc-smc-13.amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -35,6 +35,7 @@ template <typename GridwiseGemm,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2CTileMap,
|
||||
typename ComputeBasePtrOfStridedBatch,
|
||||
typename C0MatrixMask,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
@@ -57,7 +58,8 @@ __global__ void
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
|
||||
const C0MatrixMask c0_matrix_mask)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
@@ -88,7 +90,8 @@ __global__ void
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b1_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_ctile_map);
|
||||
block_2_ctile_map,
|
||||
c0_matrix_mask);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
@@ -106,6 +109,7 @@ __global__ void
|
||||
ignore = block_2_ctile_map;
|
||||
ignore = batch_count;
|
||||
ignore = compute_base_ptr_of_batch;
|
||||
ignore = c0_matrix_mask;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
@@ -168,6 +172,7 @@ template <typename ALayout,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
bool MaskOutUpperTriangle,
|
||||
LoopScheduler LoopSched = LoopScheduler::Default>
|
||||
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
: public DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
|
||||
@@ -194,9 +199,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
|
||||
|
||||
// FIXME: pad K
|
||||
static_assert(!matrix_padder.PadK, "KPadding is currently not supported");
|
||||
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
@@ -398,6 +400,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {}));
|
||||
using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {}));
|
||||
|
||||
// to track the points which need to be set to -inf on C0
|
||||
// Note: no need to reset M padding value, because they will not be stored out.
|
||||
struct C0MatrixMask
|
||||
{
|
||||
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
|
||||
|
||||
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
|
||||
|
||||
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
|
||||
{
|
||||
return n >= NRaw_;
|
||||
}
|
||||
|
||||
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
|
||||
{
|
||||
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
|
||||
}
|
||||
|
||||
private:
|
||||
// index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
};
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
@@ -498,7 +523,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
matrix_padder.PadN>;
|
||||
matrix_padder.PadN,
|
||||
MaskOutUpperTriangle>;
|
||||
|
||||
// Argument
|
||||
// FIXME: constness
|
||||
@@ -548,6 +574,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
batch_count_(Batch),
|
||||
compute_base_ptr_of_batch_{
|
||||
BatchStrideA, BatchStrideB, BatchStrideB1, c_grid_desc_g_m_n_},
|
||||
c0_matrix_mask_{NRaw},
|
||||
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw},
|
||||
c_extent_lowest_{c_gs_ms_gemm1ns_lengths.back()},
|
||||
c_stride_lowest_{c_gs_ms_gemm1ns_strides.back()}
|
||||
@@ -585,6 +612,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
index_t batch_count_;
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
|
||||
|
||||
// check C0 masking and padding
|
||||
C0MatrixMask c0_matrix_mask_;
|
||||
|
||||
// For robust IsSupportedArgument() check
|
||||
std::vector<index_t> raw_lengths_m_n_k_o_;
|
||||
index_t c_extent_lowest_;
|
||||
@@ -632,6 +662,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
ComputeBasePtrOfStridedBatch,
|
||||
C0MatrixMask,
|
||||
has_main_k_block_loop_>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
@@ -654,7 +685,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_,
|
||||
arg.batch_count_,
|
||||
arg.compute_base_ptr_of_batch_);
|
||||
arg.compute_base_ptr_of_batch_,
|
||||
arg.c0_matrix_mask_);
|
||||
};
|
||||
|
||||
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
|
||||
|
||||
@@ -35,6 +35,7 @@ template <typename GridwiseGemm,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2CTileMap,
|
||||
typename ComputeBasePtrOfStridedBatch,
|
||||
typename C0MatrixMask,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
@@ -57,7 +58,8 @@ __global__ void
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
|
||||
const C0MatrixMask c0_matrix_mask)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
@@ -88,7 +90,8 @@ __global__ void
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b1_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_ctile_map);
|
||||
block_2_ctile_map,
|
||||
c0_matrix_mask);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
@@ -106,6 +109,7 @@ __global__ void
|
||||
ignore = block_2_ctile_map;
|
||||
ignore = batch_count;
|
||||
ignore = compute_base_ptr_of_batch;
|
||||
ignore = c0_matrix_mask;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
@@ -177,6 +181,7 @@ template <typename ALayout,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
bool MaskOutUpperTriangle,
|
||||
LoopScheduler LoopSched = LoopScheduler::Default>
|
||||
struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
: public DeviceBatchedGemmSoftmaxGemm<ALayout,
|
||||
@@ -203,9 +208,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
|
||||
|
||||
// FIXME: pad K
|
||||
static_assert(!matrix_padder.PadK, "KPadding is currently not supported");
|
||||
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
@@ -313,6 +315,29 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
// to track the points which need to be set to -inf on C0
|
||||
// Note: no need to reset M padding value, because they will not be stored out.
|
||||
struct C0MatrixMask
|
||||
{
|
||||
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
|
||||
|
||||
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
|
||||
|
||||
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
|
||||
{
|
||||
return n >= NRaw_;
|
||||
}
|
||||
|
||||
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
|
||||
{
|
||||
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
|
||||
}
|
||||
|
||||
private:
|
||||
// index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
};
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
@@ -418,7 +443,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
matrix_padder.PadN>;
|
||||
matrix_padder.PadN,
|
||||
MaskOutUpperTriangle>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -463,6 +489,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
c_element_op_{c_element_op},
|
||||
batch_count_(Batch),
|
||||
compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC},
|
||||
c0_matrix_mask_{NRaw},
|
||||
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
|
||||
@@ -497,6 +524,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
index_t batch_count_;
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
|
||||
|
||||
// check C0 masking and padding
|
||||
C0MatrixMask c0_matrix_mask_;
|
||||
|
||||
// For robust IsSupportedArgument() check
|
||||
std::vector<index_t> raw_lengths_m_n_k_o_;
|
||||
};
|
||||
@@ -542,6 +572,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
ComputeBasePtrOfStridedBatch,
|
||||
C0MatrixMask,
|
||||
has_main_k_block_loop_>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
@@ -564,7 +595,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_,
|
||||
arg.batch_count_,
|
||||
arg.compute_base_ptr_of_batch_);
|
||||
arg.compute_base_ptr_of_batch_,
|
||||
arg.c0_matrix_mask_);
|
||||
};
|
||||
|
||||
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
|
||||
|
||||
@@ -98,7 +98,8 @@ __global__ void
|
||||
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
|
||||
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg_ptr[group_id].block_2_ctile_map_);
|
||||
arg_ptr[group_id].block_2_ctile_map_,
|
||||
arg_ptr[group_id].c0_matrix_mask_);
|
||||
#else
|
||||
ignore = group_kernel_args;
|
||||
ignore = group_count;
|
||||
@@ -169,6 +170,7 @@ template <typename ALayout,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
bool MaskOutUpperTriangle,
|
||||
LoopScheduler LoopSched = LoopScheduler::Default>
|
||||
struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
: public DeviceGroupedGemmSoftmaxGemmPermute<ALayout,
|
||||
@@ -209,9 +211,6 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
|
||||
|
||||
// FIXME: pad K
|
||||
static_assert(!matrix_padder.PadK, "KPadding is currently not supported");
|
||||
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
@@ -413,6 +412,29 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {}));
|
||||
using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {}));
|
||||
|
||||
// to track the points which need to be set to -inf on C0
|
||||
// Note: no need to reset M padding value, because they will not be stored out.
|
||||
struct C0MatrixMask
|
||||
{
|
||||
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
|
||||
|
||||
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
|
||||
|
||||
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
|
||||
{
|
||||
return n >= NRaw_;
|
||||
}
|
||||
|
||||
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
|
||||
{
|
||||
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
|
||||
}
|
||||
|
||||
private:
|
||||
// index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
};
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
@@ -513,7 +535,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
matrix_padder.PadN>;
|
||||
matrix_padder.PadN,
|
||||
MaskOutUpperTriangle>;
|
||||
|
||||
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
|
||||
|
||||
@@ -536,6 +559,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
index_t num_blocks_per_batch_;
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
|
||||
|
||||
// check C0 masking and padding
|
||||
C0MatrixMask c0_matrix_mask_;
|
||||
|
||||
// block-to-c-tile map
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
|
||||
@@ -623,6 +649,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
problem_desc_vec[i].BatchStrideB1,
|
||||
c_grid_desc_g_m_n);
|
||||
|
||||
// C0 mask
|
||||
const auto c0_matrix_mask = C0MatrixMask(problem_desc_vec[i].N);
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
|
||||
group_kernel_args_.push_back({p_a_grid,
|
||||
@@ -635,6 +664,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n),
|
||||
compute_base_ptr_of_batch,
|
||||
c0_matrix_mask,
|
||||
block_2_ctile_map,
|
||||
BlockStart,
|
||||
BlockEnd});
|
||||
|
||||
Reference in New Issue
Block a user