mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Merging the gfx12 code into public repo. (#1362)
This commit is contained in:
@@ -371,12 +371,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
|
||||
if constexpr(B0EnableLds)
|
||||
{
|
||||
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
|
||||
constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
|
||||
constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto B_KRow = I2;
|
||||
#else
|
||||
constexpr auto B_KRow = I1;
|
||||
#endif
|
||||
return transform_tensor_descriptor(
|
||||
B0BlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})),
|
||||
make_pass_through_transform(Number<B_K1>{})),
|
||||
@@ -428,12 +432,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
|
||||
if constexpr(B1EnableLds)
|
||||
{
|
||||
// BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
|
||||
constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2);
|
||||
constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto B_LRow = I2;
|
||||
#else
|
||||
constexpr auto B_LRow = I1;
|
||||
#endif
|
||||
return transform_tensor_descriptor(
|
||||
B1BlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_L0>{}, B_LRow)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_L0 / B_LRow>{}, B_LRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
|
||||
make_pass_through_transform(Number<B_L1>{})),
|
||||
|
||||
@@ -50,7 +50,7 @@ __global__ void
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
@@ -302,12 +302,16 @@ struct GridwiseFpAintBGemm_Wmma
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
|
||||
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
||||
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
||||
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
||||
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto A_KRow = I2;
|
||||
#else
|
||||
constexpr auto A_KRow = I1;
|
||||
#endif
|
||||
return transform_tensor_descriptor(
|
||||
ABlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
|
||||
make_pass_through_transform(Number<A_K1>{})),
|
||||
@@ -360,12 +364,16 @@ struct GridwiseFpAintBGemm_Wmma
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
|
||||
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
|
||||
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto B_KRow = I2;
|
||||
#else
|
||||
constexpr auto B_KRow = I1;
|
||||
#endif
|
||||
return transform_tensor_descriptor(
|
||||
BBlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
|
||||
make_pass_through_transform(Number<B_K1>{})),
|
||||
|
||||
@@ -54,7 +54,7 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
@@ -147,7 +147,7 @@ __global__ void
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const Block2CTileMap block_2_etile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
// printf("entry kernel launch");
|
||||
__shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size];
|
||||
|
||||
@@ -237,7 +237,7 @@ __global__ void
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
__shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size];
|
||||
|
||||
GridwiseOp::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
@@ -375,8 +375,9 @@ struct GridwiseGemmMultipleD_Wmma
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto A_KRow = I2;
|
||||
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
|
||||
constexpr auto K0PerWmma = WmmaK / 2 / K1;
|
||||
constexpr auto K0PerWmma = WmmaK / A_KRow / K1;
|
||||
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KWmmaPerblock>{},
|
||||
@@ -422,8 +423,9 @@ struct GridwiseGemmMultipleD_Wmma
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto B_KRow = I2;
|
||||
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
|
||||
constexpr auto K0PerWmma = WmmaK / 2 / K1;
|
||||
constexpr auto K0PerWmma = WmmaK / B_KRow / K1;
|
||||
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KWmmaPerblock>{},
|
||||
@@ -495,12 +497,16 @@ struct GridwiseGemmMultipleD_Wmma
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
|
||||
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
||||
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
||||
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
||||
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto A_KRow = I2;
|
||||
#else
|
||||
constexpr auto A_KRow = I1;
|
||||
#endif
|
||||
return transform_tensor_descriptor(
|
||||
ABlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
|
||||
make_pass_through_transform(Number<A_K1>{})),
|
||||
@@ -534,12 +540,16 @@ struct GridwiseGemmMultipleD_Wmma
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
|
||||
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
|
||||
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto B_KRow = I2;
|
||||
#else
|
||||
constexpr auto B_KRow = I1;
|
||||
#endif
|
||||
return transform_tensor_descriptor(
|
||||
BBlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
|
||||
make_pass_through_transform(Number<B_K1>{})),
|
||||
@@ -571,15 +581,12 @@ struct GridwiseGemmMultipleD_Wmma
|
||||
// *Caution Here repeat is shuffle repeat
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
|
||||
{
|
||||
constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma);
|
||||
constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma);
|
||||
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<CShuffleMRepeatPerShuffle * MWave * MPerWmma>{},
|
||||
Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
|
||||
I1,
|
||||
Number<CShuffleNRepeatPerShuffle * NWave * NPerWmma>{}));
|
||||
Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
|
||||
|
||||
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
|
||||
}
|
||||
@@ -799,8 +806,9 @@ struct GridwiseGemmMultipleD_Wmma
|
||||
const auto M = e_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = e_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const auto MBlock = M / MPerBlock;
|
||||
const auto NBlock = N / NPerBlock;
|
||||
const auto MBlock = M / MPerBlock;
|
||||
const auto NBlock = N / NPerBlock;
|
||||
|
||||
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
e_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
|
||||
|
||||
@@ -45,7 +45,7 @@ __global__ void
|
||||
const CElementwiseOperation c_element_op,
|
||||
const Block2CTileMap block_2_ctile_map)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
|
||||
@@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto A_KRow = I2;
|
||||
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
|
||||
constexpr auto K0PerWmma = WmmaK / 2 / K1;
|
||||
constexpr auto K0PerWmma = WmmaK / A_KRow / K1;
|
||||
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KWmmaPerblock>{},
|
||||
@@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
constexpr auto B_KRow = I2;
|
||||
constexpr auto KWmmaPerblock = KPerBlock / WmmaK;
|
||||
constexpr auto K0PerWmma = WmmaK / 2 / K1;
|
||||
constexpr auto K0PerWmma = WmmaK / B_KRow / K1;
|
||||
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<KWmmaPerblock>{},
|
||||
@@ -290,12 +293,17 @@ struct GridwiseGemm_Wmma
|
||||
if constexpr(AEnableLds)
|
||||
{
|
||||
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
|
||||
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
||||
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
||||
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
||||
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto A_KRow = I2;
|
||||
#else
|
||||
constexpr auto A_KRow = I1;
|
||||
#endif
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
ABlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0>{}, A_KRow)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
|
||||
make_pass_through_transform(Number<A_K1>{})),
|
||||
@@ -348,12 +356,16 @@ struct GridwiseGemm_Wmma
|
||||
if constexpr(BEnableLds)
|
||||
{
|
||||
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
|
||||
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
|
||||
constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto B_KRow = I2;
|
||||
#else
|
||||
constexpr auto B_KRow = I1;
|
||||
#endif
|
||||
return transform_tensor_descriptor(
|
||||
BBlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
|
||||
make_pass_through_transform(Number<B_K1>{})),
|
||||
@@ -522,12 +534,6 @@ struct GridwiseGemm_Wmma
|
||||
c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
CGridDesc_M_N{}))>;
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
|
||||
|
||||
struct SharedMemTrait
|
||||
{
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
@@ -559,6 +565,12 @@ struct GridwiseGemm_Wmma
|
||||
b_block_space_size_aligned * sizeof(BDataType));
|
||||
};
|
||||
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
CGridDesc_M_N{}))>;
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
|
||||
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
|
||||
@@ -35,8 +35,9 @@ __global__ void
|
||||
const Block2ETileMap block_2_tile_map,
|
||||
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
|
||||
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
|
||||
defined(__gfx12__))
|
||||
GridwiseTensorRearrangeKernel::Run(in_grid_desc,
|
||||
p_in_global,
|
||||
out_grid_desc,
|
||||
|
||||
Reference in New Issue
Block a user