mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
Merging the gfx12 code into public repo. (#1362)
[ROCm/composable_kernel commit: 941d1f7ce0]
This commit is contained in:
@@ -23,45 +23,45 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
|
||||
< ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
< ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
GemmDefault,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
GemmDefault,
|
||||
1, // Prefetch stage
|
||||
128, // BlockSize
|
||||
64, // MPerBlock
|
||||
128, // NPerBlock
|
||||
64, // KPerBlock
|
||||
8, // K1
|
||||
2, // K1
|
||||
16, // MPerWmma
|
||||
16, // NPerWmma
|
||||
2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave
|
||||
4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave
|
||||
S<4, 32, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 32, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 32, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
true,
|
||||
S<4, 32, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
true,
|
||||
1, // C shuffle (M Repeat) Per store
|
||||
1, // C shuffle (N Repeat) Per store
|
||||
S<1, 32, 1, 4>,
|
||||
S<1, 32, 1, 4>,
|
||||
8>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -159,7 +159,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
break;
|
||||
case 4:
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
|
||||
break;
|
||||
case 5:
|
||||
|
||||
@@ -24,4 +24,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
@@ -83,14 +83,14 @@ using DeviceOpInstanceKKNN =
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
true,
|
||||
false,
|
||||
S<4, 32, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
true,
|
||||
false,
|
||||
1,
|
||||
1,
|
||||
S<1, 64, 1, 2>,
|
||||
|
||||
@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
|
||||
#define CK_MHA_USE_WAVE_1
|
||||
#define CK_MHA_USE_WAVE_2
|
||||
#define CK_MHA_USE_WAVE_4
|
||||
#define CK_MHA_USE_WAVE_8
|
||||
//#define CK_MHA_USE_WAVE_8
|
||||
using DeviceMHAFactory =
|
||||
std::tuple<
|
||||
#ifdef CK_MHA_USE_WAVE_1
|
||||
@@ -277,10 +277,10 @@ using DeviceMHAFactory =
|
||||
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 64, 1, 2>, 8,
|
||||
MaskingSpec>,
|
||||
MaskingSpec>
|
||||
#endif
|
||||
#ifdef CK_MHA_USE_WAVE_8
|
||||
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
|
||||
,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
|
||||
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
|
||||
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
|
||||
@@ -71,7 +71,7 @@ static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecial
|
||||
#define CK_MHA_USE_WAVE_1
|
||||
#define CK_MHA_USE_WAVE_2
|
||||
#define CK_MHA_USE_WAVE_4
|
||||
#define CK_MHA_USE_WAVE_8
|
||||
//#define CK_MHA_USE_WAVE_8
|
||||
using DeviceMHAFactory =
|
||||
std::tuple<
|
||||
#ifdef CK_MHA_USE_WAVE_1
|
||||
@@ -277,10 +277,10 @@ using DeviceMHAFactory =
|
||||
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 64, 1, 2>, 8,
|
||||
MaskingSpec>,
|
||||
MaskingSpec>
|
||||
#endif
|
||||
#ifdef CK_MHA_USE_WAVE_8
|
||||
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
|
||||
,ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle<
|
||||
NumDimG, NumDimM, NumDimN, NumDimK, NumDimO,
|
||||
ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
|
||||
@@ -67,7 +67,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
endforeach()
|
||||
#Do not build any WMMA examples if gfx11 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
|
||||
if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
|
||||
message("removing wmma example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
@@ -154,7 +154,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
endforeach()
|
||||
#Do not build any WMMA examples if gfx11 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
|
||||
if(NOT GPU_TARGETS MATCHES "gfx11" AND NOT GPU_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
|
||||
message("removing wmma example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
|
||||
Reference in New Issue
Block a user