mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
MNKO padding support on bmm+masking+scale+softmax+bmm+premute (#425)
* add lower triangle bmm * init code for tile skipping * functionality right with lower triangle mask * add decoder lower triangular mask calculation * use 7*13 group * fix n2 compute error * attention with lower triangle mask with tile skipping * add template to distinguish masking kernel * rename template and remove default template value * remove lower triangle gemm reference struct * add some comments on example * add 10 instance for masking bmm + scale + softmax + bmm + permute kernels * add test * add test file * add gtest for bmm masking scale softmax bmm permute * clang-format * fix compile error * check lef bottom corner for tile skipping * fix error: check left bottom corner for tile skipping * add k padding * add test and instance for MNK padding * passing a mask struct * fix instances * delete used comments * format Co-authored-by: danyao12 <yaodan@dc-smc-13.amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -35,6 +35,7 @@ template <typename GridwiseGemm,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2CTileMap,
|
||||
typename ComputeBasePtrOfStridedBatch,
|
||||
typename C0MatrixMask,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
@@ -57,7 +58,8 @@ __global__ void
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
|
||||
const C0MatrixMask c0_matrix_mask)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
@@ -88,7 +90,8 @@ __global__ void
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b1_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_ctile_map);
|
||||
block_2_ctile_map,
|
||||
c0_matrix_mask);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
@@ -106,6 +109,7 @@ __global__ void
|
||||
ignore = block_2_ctile_map;
|
||||
ignore = batch_count;
|
||||
ignore = compute_base_ptr_of_batch;
|
||||
ignore = c0_matrix_mask;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
@@ -168,6 +172,7 @@ template <typename ALayout,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
bool MaskOutUpperTriangle,
|
||||
LoopScheduler LoopSched = LoopScheduler::Default>
|
||||
struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
: public DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
|
||||
@@ -194,9 +199,6 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
|
||||
|
||||
// FIXME: pad K
|
||||
static_assert(!matrix_padder.PadK, "KPadding is currently not supported");
|
||||
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
@@ -398,6 +400,29 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {}));
|
||||
using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {}));
|
||||
|
||||
// to track the points which need to be set to -inf on C0
|
||||
// Note: no need to reset M padding value, because they will not be stored out.
|
||||
struct C0MatrixMask
|
||||
{
|
||||
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
|
||||
|
||||
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
|
||||
|
||||
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
|
||||
{
|
||||
return n >= NRaw_;
|
||||
}
|
||||
|
||||
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
|
||||
{
|
||||
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
|
||||
}
|
||||
|
||||
private:
|
||||
// index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
};
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
@@ -498,7 +523,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
matrix_padder.PadN>;
|
||||
matrix_padder.PadN,
|
||||
MaskOutUpperTriangle>;
|
||||
|
||||
// Argument
|
||||
// FIXME: constness
|
||||
@@ -548,6 +574,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
batch_count_(Batch),
|
||||
compute_base_ptr_of_batch_{
|
||||
BatchStrideA, BatchStrideB, BatchStrideB1, c_grid_desc_g_m_n_},
|
||||
c0_matrix_mask_{NRaw},
|
||||
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw},
|
||||
c_extent_lowest_{c_gs_ms_gemm1ns_lengths.back()},
|
||||
c_stride_lowest_{c_gs_ms_gemm1ns_strides.back()}
|
||||
@@ -585,6 +612,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
index_t batch_count_;
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
|
||||
|
||||
// check C0 masking and padding
|
||||
C0MatrixMask c0_matrix_mask_;
|
||||
|
||||
// For robust IsSupportedArgument() check
|
||||
std::vector<index_t> raw_lengths_m_n_k_o_;
|
||||
index_t c_extent_lowest_;
|
||||
@@ -632,6 +662,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
ComputeBasePtrOfStridedBatch,
|
||||
C0MatrixMask,
|
||||
has_main_k_block_loop_>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
@@ -654,7 +685,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_,
|
||||
arg.batch_count_,
|
||||
arg.compute_base_ptr_of_batch_);
|
||||
arg.compute_base_ptr_of_batch_,
|
||||
arg.c0_matrix_mask_);
|
||||
};
|
||||
|
||||
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
|
||||
|
||||
@@ -35,6 +35,7 @@ template <typename GridwiseGemm,
|
||||
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2CTileMap,
|
||||
typename ComputeBasePtrOfStridedBatch,
|
||||
typename C0MatrixMask,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
@@ -57,7 +58,8 @@ __global__ void
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch,
|
||||
const C0MatrixMask c0_matrix_mask)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
@@ -88,7 +90,8 @@ __global__ void
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b1_grid_desc_bk0_n_bk1,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_ctile_map);
|
||||
block_2_ctile_map,
|
||||
c0_matrix_mask);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
@@ -106,6 +109,7 @@ __global__ void
|
||||
ignore = block_2_ctile_map;
|
||||
ignore = batch_count;
|
||||
ignore = compute_base_ptr_of_batch;
|
||||
ignore = c0_matrix_mask;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
@@ -177,6 +181,7 @@ template <typename ALayout,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
bool MaskOutUpperTriangle,
|
||||
LoopScheduler LoopSched = LoopScheduler::Default>
|
||||
struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
: public DeviceBatchedGemmSoftmaxGemm<ALayout,
|
||||
@@ -203,9 +208,6 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
|
||||
|
||||
// FIXME: pad K
|
||||
static_assert(!matrix_padder.PadK, "KPadding is currently not supported");
|
||||
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
@@ -313,6 +315,29 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
// to track the points which need to be set to -inf on C0
|
||||
// Note: no need to reset M padding value, because they will not be stored out.
|
||||
struct C0MatrixMask
|
||||
{
|
||||
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
|
||||
|
||||
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
|
||||
|
||||
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
|
||||
{
|
||||
return n >= NRaw_;
|
||||
}
|
||||
|
||||
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
|
||||
{
|
||||
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
|
||||
}
|
||||
|
||||
private:
|
||||
// index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
};
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
@@ -418,7 +443,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
matrix_padder.PadN>;
|
||||
matrix_padder.PadN,
|
||||
MaskOutUpperTriangle>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -463,6 +489,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
c_element_op_{c_element_op},
|
||||
batch_count_(Batch),
|
||||
compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC},
|
||||
c0_matrix_mask_{NRaw},
|
||||
raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw}
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
|
||||
@@ -497,6 +524,9 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
index_t batch_count_;
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
|
||||
|
||||
// check C0 masking and padding
|
||||
C0MatrixMask c0_matrix_mask_;
|
||||
|
||||
// For robust IsSupportedArgument() check
|
||||
std::vector<index_t> raw_lengths_m_n_k_o_;
|
||||
};
|
||||
@@ -542,6 +572,7 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2CTileMap,
|
||||
ComputeBasePtrOfStridedBatch,
|
||||
C0MatrixMask,
|
||||
has_main_k_block_loop_>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
@@ -564,7 +595,8 @@ struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_,
|
||||
arg.batch_count_,
|
||||
arg.compute_base_ptr_of_batch_);
|
||||
arg.compute_base_ptr_of_batch_,
|
||||
arg.c0_matrix_mask_);
|
||||
};
|
||||
|
||||
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
|
||||
|
||||
@@ -98,7 +98,8 @@ __global__ void
|
||||
arg_ptr[group_id].b_grid_desc_bk0_n_bk1_,
|
||||
arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_,
|
||||
arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg_ptr[group_id].block_2_ctile_map_);
|
||||
arg_ptr[group_id].block_2_ctile_map_,
|
||||
arg_ptr[group_id].c0_matrix_mask_);
|
||||
#else
|
||||
ignore = group_kernel_args;
|
||||
ignore = group_count;
|
||||
@@ -169,6 +170,7 @@ template <typename ALayout,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
bool MaskOutUpperTriangle,
|
||||
LoopScheduler LoopSched = LoopScheduler::Default>
|
||||
struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
: public DeviceGroupedGemmSoftmaxGemmPermute<ALayout,
|
||||
@@ -209,9 +211,6 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
GemmGemmPadder<GemmSpec, index_t, index_t, index_t, index_t>{
|
||||
MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock};
|
||||
|
||||
// FIXME: pad K
|
||||
static_assert(!matrix_padder.PadK, "KPadding is currently not supported");
|
||||
|
||||
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
|
||||
{
|
||||
const auto a_grid_desc_mraw_kraw = [&]() {
|
||||
@@ -413,6 +412,29 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {}));
|
||||
using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {}));
|
||||
|
||||
// to track the points which need to be set to -inf on C0
|
||||
// Note: no need to reset M padding value, because they will not be stored out.
|
||||
struct C0MatrixMask
|
||||
{
|
||||
C0MatrixMask(index_t NRaw) : NRaw_(NRaw) {}
|
||||
|
||||
__host__ __device__ bool IsUpperTriangle(index_t m, index_t n) const { return n > m; }
|
||||
|
||||
__host__ __device__ bool IsNOutOfBound(/*index_t m, */ index_t n) const
|
||||
{
|
||||
return n >= NRaw_;
|
||||
}
|
||||
|
||||
__host__ __device__ bool IsMaskedElement(index_t m, index_t n) const
|
||||
{
|
||||
return IsUpperTriangle(m, n) || IsNOutOfBound(n);
|
||||
}
|
||||
|
||||
private:
|
||||
// index_t MRaw_;
|
||||
index_t NRaw_;
|
||||
};
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
@@ -513,7 +535,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched,
|
||||
matrix_padder.PadN>;
|
||||
matrix_padder.PadN,
|
||||
MaskOutUpperTriangle>;
|
||||
|
||||
using Block2CTileMap = OffsettedBlockToCTileMap<typename GridwiseGemm::DefaultBlock2CTileMap>;
|
||||
|
||||
@@ -536,6 +559,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
index_t num_blocks_per_batch_;
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
|
||||
|
||||
// check C0 masking and padding
|
||||
C0MatrixMask c0_matrix_mask_;
|
||||
|
||||
// block-to-c-tile map
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
|
||||
@@ -623,6 +649,9 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
problem_desc_vec[i].BatchStrideB1,
|
||||
c_grid_desc_g_m_n);
|
||||
|
||||
// C0 mask
|
||||
const auto c0_matrix_mask = C0MatrixMask(problem_desc_vec[i].N);
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
|
||||
group_kernel_args_.push_back({p_a_grid,
|
||||
@@ -635,6 +664,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n),
|
||||
compute_base_ptr_of_batch,
|
||||
c0_matrix_mask,
|
||||
block_2_ctile_map,
|
||||
BlockStart,
|
||||
BlockEnd});
|
||||
|
||||
@@ -76,7 +76,8 @@ template <typename FloatAB,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched,
|
||||
bool PadN>
|
||||
bool PadN,
|
||||
bool MaskOutUpperTriangle>
|
||||
struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
{
|
||||
static_assert(LoopSched == LoopScheduler::Default,
|
||||
@@ -97,6 +98,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
|
||||
static constexpr auto AK1 = Number<AK1Value>{};
|
||||
static constexpr auto BK1 = Number<BK1Value>{};
|
||||
|
||||
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
|
||||
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
|
||||
|
||||
// Gemm1
|
||||
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
|
||||
static constexpr auto B1K1 = Number<B1K1Value>{};
|
||||
@@ -361,7 +366,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
}
|
||||
};
|
||||
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap>
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap, typename C0MatrixMask>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
const FloatAB* __restrict__ p_b1_grid,
|
||||
@@ -377,22 +382,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
const Block2CTileMap& block_2_ctile_map,
|
||||
const C0MatrixMask& c0_matrix_mask)
|
||||
{
|
||||
const auto a_grid_buf =
|
||||
conditional_expr<PadN>(make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid,
|
||||
a_grid_desc_ak0_m_ak1.GetElementSpaceSize(),
|
||||
NumericLimits<FloatAB>::QuietNaN()),
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()));
|
||||
const auto b_grid_buf =
|
||||
conditional_expr<PadN>(make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid,
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize(),
|
||||
NumericLimits<FloatAB>::QuietNaN()),
|
||||
make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()));
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
@@ -749,10 +745,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
running_max = NumericLimits<FloatGemmAcc>::Lowest();
|
||||
running_max_new = NumericLimits<FloatGemmAcc>::Lowest();
|
||||
|
||||
// decoder lower triangular mask
|
||||
const auto thread_cluster_idx = threadid_to_m_n_thread_cluster_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_n_cluster_id = thread_cluster_idx[I1];
|
||||
const index_t MPerRepeat = MPerBlock / MXdlPerWave;
|
||||
const index_t NPerRepeat = NPerBlock / NXdlPerWave;
|
||||
const index_t mstart = m_block_data_idx_on_grid + thread_m_cluster_id;
|
||||
|
||||
// gemm1 K loop
|
||||
index_t gemm1_k_block_outer_index = 0;
|
||||
do
|
||||
{
|
||||
if constexpr(MaskOutUpperTriangle)
|
||||
{
|
||||
auto gemm0_n_block_idx =
|
||||
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
|
||||
if(c0_matrix_mask.IsUpperTriangle(m_block_data_idx_on_grid, gemm0_n_block_idx) &&
|
||||
c0_matrix_mask.IsUpperTriangle(m_block_data_idx_on_grid + MPerBlock - 1,
|
||||
gemm0_n_block_idx))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// gemm0
|
||||
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
|
||||
a_block_desc_ak0_m_ak1,
|
||||
@@ -770,16 +786,63 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
acc_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
|
||||
// Acc0 elementwise Op
|
||||
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
|
||||
static_for<0, acc_thread_buf.Size(), 1>{}(
|
||||
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
|
||||
#else
|
||||
static_for<0, acc_thread_buf.Size(), 1>{}([&](auto i) {
|
||||
ElementOpPredicatedResetNaNToMinusInf<PadN>{}.Run(
|
||||
acc_thread_buf(i), acc_element_op, acc_thread_buf[i]);
|
||||
});
|
||||
#endif
|
||||
// do MNK padding or upper triangular masking
|
||||
if constexpr(MaskOutUpperTriangle || PadN)
|
||||
{
|
||||
const index_t nstart = gemm1_k_block_outer_index * NPerBlock;
|
||||
|
||||
static_for<0, m0, 1>{}([&](auto m0_i) {
|
||||
const index_t m_global = mstart + m0_i * MPerRepeat;
|
||||
const index_t acc_idx_m0 = m0_i * n0 * n2 * n4;
|
||||
static_for<0, n0, 1>{}([&](auto n0_i) {
|
||||
// constexpr auto nrepeat_i = n0_i * NPerRepeat;
|
||||
// const index_t nstartxdl = nstart + nrepeat_i;
|
||||
const index_t nstartxdl = nstart + n0_i * NPerRepeat;
|
||||
const index_t acc_idx_n0 = acc_idx_m0 + n0_i * n2 * n4;
|
||||
static_for<0, n2, 1>{}([&](auto n2_i) {
|
||||
const index_t nstartgroup =
|
||||
nstartxdl + thread_n_cluster_id * n4 + n2_i * AccN3 * n4;
|
||||
const index_t acc_idx_n2 = acc_idx_n0 + n2_i * n4;
|
||||
static_for<0, n4, 1>{}([&](auto n4_i) {
|
||||
const index_t n_global = nstartgroup + n4_i;
|
||||
const auto acc_offset = Number<acc_idx_n2 + n4_i>{};
|
||||
if constexpr(MaskOutUpperTriangle)
|
||||
{
|
||||
if(c0_matrix_mask.IsMaskedElement(m_global, n_global))
|
||||
{
|
||||
acc_thread_buf(acc_offset) =
|
||||
-ck::NumericLimits<float>::Infinity();
|
||||
}
|
||||
else
|
||||
{
|
||||
acc_element_op(acc_thread_buf(acc_offset),
|
||||
acc_thread_buf[acc_offset]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// ignore m_global;
|
||||
if(c0_matrix_mask.IsNOutOfBound(n_global))
|
||||
{
|
||||
acc_thread_buf(acc_offset) =
|
||||
-ck::NumericLimits<float>::Infinity();
|
||||
}
|
||||
else
|
||||
{
|
||||
acc_element_op(acc_thread_buf(acc_offset),
|
||||
acc_thread_buf[acc_offset]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, acc_thread_buf.Size(), 1>{}(
|
||||
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
|
||||
}
|
||||
|
||||
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
|
||||
|
||||
|
||||
Reference in New Issue
Block a user