mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
Merge branch 'develop' into kylasa_kdim_pr
This commit is contained in:
@@ -105,6 +105,16 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
list(APPEND gpu_list_tf32 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0)
|
||||
add_example_executable(example_gemm_xdl_lds_direct_load_fp32_tf32 gemm_xdl_lds_direct_load_fp32_tf32.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32_tf32)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp)
|
||||
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
|
||||
|
||||
|
||||
@@ -310,10 +310,14 @@ bool parse_cmd_args<ProblemSizeSplitK>(int argc,
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
template <typename DataType, typename ComputeDataType = DataType>
|
||||
inline __host__ __device__ constexpr double get_rtol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<ComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
@@ -351,10 +355,14 @@ inline __host__ __device__ constexpr double get_rtol()
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
template <typename DataType, typename ComputeDataType = DataType>
|
||||
inline __host__ __device__ constexpr double get_atol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<ComputeDataType, ck::tf32_t>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
|
||||
85
example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp
Normal file
85
example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp
Normal file
@@ -0,0 +1,85 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#define USING_DIRECT_LOADS 1
|
||||
#if USING_DIRECT_LOADS
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp"
|
||||
#else
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
|
||||
#endif
|
||||
|
||||
#define EXAMPLE_WITH_COMPUTE_DATATYPE
|
||||
|
||||
using F32 = float;
|
||||
|
||||
using ADataType = F32;
|
||||
using BDataType = F32;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CDataType = F32;
|
||||
using ComputeDataType = ck::tf32_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
#if USING_DIRECT_LOADS
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_LdsDirectLoad
|
||||
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer|
|
||||
// ######| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds|
|
||||
// ######| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler | pipeline ver | gemm type |
|
||||
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block|
|
||||
// ######| XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32,
|
||||
8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1,
|
||||
1, 1, S<1, 8, 1, 8>, 4, ck::LoopScheduler::Default, ck::PipelineVersion::v4, ComputeDataType>;
|
||||
// clang-format on
|
||||
#else
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 4>;
|
||||
// clang-format on
|
||||
#endif
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
ComputeDataType,
|
||||
ComputeDataType>;
|
||||
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
|
||||
#undef EXAMPLE_WITH_COMPUTE_DATATYPE
|
||||
@@ -4,6 +4,11 @@
|
||||
#pragma once
|
||||
#include "ck/library/utility/validation_common.hpp"
|
||||
|
||||
// use macro to minimize code change
|
||||
#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE
|
||||
using ComputeDataType = AccDataType;
|
||||
#endif
|
||||
|
||||
template <typename ProblemType>
|
||||
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
@@ -218,8 +223,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
pass &= ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
get_rtol<CDataType, ComputeDataType>(),
|
||||
get_atol<CDataType, ComputeDataType>());
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -249,8 +254,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
pass &= ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_device_ref_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
get_rtol<CDataType, ComputeDataType>(),
|
||||
get_atol<CDataType, ComputeDataType>());
|
||||
}
|
||||
|
||||
return pass == true;
|
||||
|
||||
@@ -19,4 +19,13 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
list(APPEND gpu_list_tf32 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp32_tf32 convnd_fwd_xdl_fp32_tf32.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
@@ -27,10 +27,14 @@ void print_helper_msg()
|
||||
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
template <typename DataType, typename GemmType = DataType>
|
||||
inline __host__ __device__ constexpr double get_rtol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::tf32_t>)
|
||||
{
|
||||
return 5e-3;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
@@ -68,10 +72,14 @@ inline __host__ __device__ constexpr double get_rtol()
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
template <typename DataType, typename GemmType = DataType>
|
||||
inline __host__ __device__ constexpr double get_atol()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, float>)
|
||||
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::tf32_t>)
|
||||
{
|
||||
return 1e-2;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
return 1e-3;
|
||||
}
|
||||
@@ -116,7 +124,8 @@ template <ck::index_t NDimSpatial,
|
||||
typename InElementOp,
|
||||
typename WeiElementOp,
|
||||
typename OutElementOp,
|
||||
typename DeviceConvNDFwdInstance>
|
||||
typename DeviceConvNDFwdInstance,
|
||||
typename ComputeDataType = OutDataType>
|
||||
bool run_grouped_conv_fwd(bool do_verification,
|
||||
int init_method,
|
||||
bool time_kernel,
|
||||
@@ -228,7 +237,11 @@ bool run_grouped_conv_fwd(bool do_verification,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>();
|
||||
OutElementOp,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
ComputeDataType>();
|
||||
|
||||
auto ref_invoker = ref_conv.MakeInvoker();
|
||||
auto ref_argument = ref_conv.MakeArgument(in,
|
||||
@@ -249,8 +262,8 @@ bool run_grouped_conv_fwd(bool do_verification,
|
||||
return ck::utils::check_err(out_device,
|
||||
out_host,
|
||||
"Error: incorrect results!",
|
||||
get_rtol<OutDataType>(),
|
||||
get_atol<OutDataType>());
|
||||
get_rtol<OutDataType, ComputeDataType>(),
|
||||
get_atol<OutDataType, ComputeDataType>());
|
||||
}
|
||||
|
||||
return true;
|
||||
|
||||
89
example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp
Normal file
89
example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp
Normal file
@@ -0,0 +1,89 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "convnd_fwd_common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
|
||||
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
|
||||
#define EXAMPLE_WITH_COMPUTE_DATATYPE
|
||||
|
||||
using InDataType = float;
|
||||
using WeiDataType = float;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using OutDataType = float;
|
||||
using ComputeDataType = ck::tf32_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvSpec =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
|
||||
using DeviceGroupedConvNDFwdInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
NDimSpatial,
|
||||
InLayout, // ALayout
|
||||
WeiLayout, // BLayout
|
||||
ck::Tuple<>, // DsLayout
|
||||
OutLayout, // ELayout
|
||||
InDataType, // ADataType
|
||||
WeiDataType, // BDataType
|
||||
AccDataType, // AccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
ck::Tuple<>, // DsDataType
|
||||
OutDataType, // EDataType
|
||||
InElementOp, // AElementwiseOperation
|
||||
WeiElementOp, // BElementwiseOperation
|
||||
OutElementOp, // CDEElementwiseOperation
|
||||
ConvSpec, // ConvForwardSpecialization
|
||||
GemmSpec, // GemmSpecialization
|
||||
1, // NumGemmKPrefetchStage
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
192, // NPerBlock
|
||||
16, // KPerBlock
|
||||
4, // AK1
|
||||
4, // BK1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
3, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
4, // ABlockTransferSrcScalarPerVector
|
||||
4, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
4, // BBlockTransferSrcScalarPerVector
|
||||
4, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 16, 1, 16>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
4, // CDEBlockTransferScalarPerVector_NPerBlock
|
||||
ComputeDataType, // AComputeDataType
|
||||
ComputeDataType, // BComputeDataType
|
||||
ck::LoopScheduler::Default, // LoopScheduler
|
||||
1 // NumGroupsToMerge
|
||||
>;
|
||||
|
||||
#include "run_convnd_fwd_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }
|
||||
|
||||
#undef EXAMPLE_WITH_COMPUTE_DATATYPE
|
||||
@@ -7,6 +7,8 @@
|
||||
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
|
||||
#define EXAMPLE_WITH_COMPUTE_DATATYPE
|
||||
|
||||
using InDataType = ck::f8_t;
|
||||
using WeiDataType = ck::f8_t;
|
||||
using AccDataType = float;
|
||||
@@ -87,3 +89,5 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
|
||||
}
|
||||
|
||||
#undef EXAMPLE_WITH_COMPUTE_DATATYPE
|
||||
|
||||
@@ -3,6 +3,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
// use macro to minimize code change
|
||||
#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE
|
||||
using ComputeDataType = AccDataType;
|
||||
#endif
|
||||
|
||||
bool run_convnd_fwd_example(int argc, char* argv[])
|
||||
{
|
||||
print_helper_msg();
|
||||
@@ -65,17 +70,17 @@ bool run_convnd_fwd_example(int argc, char* argv[])
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceGroupedConvNDFwdInstance<ndim_spatial_value, InLayout, WeiLayout, OutLayout>>(
|
||||
do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
DeviceGroupedConvNDFwdInstance<ndim_spatial_value, InLayout, WeiLayout, OutLayout>,
|
||||
ComputeDataType>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
};
|
||||
|
||||
namespace ctc = ck::tensor_layout::convolution;
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp)
|
||||
add_example_executable(example_gemm_wmma_quantization_int8 gemm_wmma_quantization_int8.cpp)
|
||||
add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp)
|
||||
add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp)
|
||||
|
||||
211
example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp
Normal file
211
example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp
Normal file
@@ -0,0 +1,211 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using ActivationOp = PassThrough;
|
||||
using CDEElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<ActivationOp>;
|
||||
|
||||
using ADataType = I8;
|
||||
using BDataType = I8;
|
||||
using AccDataType = I32;
|
||||
using CShuffleDataType = I32;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = I8;
|
||||
|
||||
using ALayout = Col;
|
||||
using BLayout = Row;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = Row;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
ActivationOp,
|
||||
ActivationOp,
|
||||
CDEElementOp,
|
||||
GemmDefault,
|
||||
256,
|
||||
128,
|
||||
128,
|
||||
64,
|
||||
8,
|
||||
8,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
2,
|
||||
S<4, 64, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
1,
|
||||
8,
|
||||
true,
|
||||
1,
|
||||
1,
|
||||
S<1, 32, 1, 8>,
|
||||
S<1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave,
|
||||
ck::BlockGemmPipelineVersion::v1,
|
||||
I8,
|
||||
I8>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, EDataType, float, PassThrough, PassThrough, CDEElementOp>;
|
||||
|
||||
int main(int /* argc */, char* /* argv */[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 1024;
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = N;
|
||||
ck::index_t StrideE = N;
|
||||
|
||||
float requant_scale = 0.03;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
|
||||
auto a_element_op = PassThrough{};
|
||||
auto b_element_op = PassThrough{};
|
||||
auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}};
|
||||
|
||||
// device GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
std::array<const void*, 0>{},
|
||||
static_cast<EDataType*>(e_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<ck::index_t, 0>{},
|
||||
StrideE,
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, e_m_n_host_result, a_element_op, b_element_op, cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -323,6 +323,31 @@ int main(int argc, char* argv[])
|
||||
|
||||
problem_size.Ms = {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0};
|
||||
|
||||
if(argc == 5)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
config.k_batch = std::stoi(argv[4]);
|
||||
}
|
||||
else if(argc == 6)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
config.k_batch = std::stoi(argv[4]);
|
||||
problem_size.group_count = std::stoi(argv[5]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4: k_batch (>0)\n");
|
||||
printf("arg5: group count (default=16)");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ns.push_back(768);
|
||||
@@ -333,21 +358,5 @@ int main(int argc, char* argv[])
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
}
|
||||
|
||||
if(argc == 5)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
config.k_batch = std::stoi(argv[4]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4: k_batch (>0)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
return !run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
|
||||
@@ -296,6 +296,32 @@ int main(int argc, char* argv[])
|
||||
|
||||
problem_size.group_count = 16;
|
||||
|
||||
if(argc == 5)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
config.k_batch = std::stoi(argv[4]);
|
||||
}
|
||||
else if(argc == 6)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
config.k_batch = std::stoi(argv[4]);
|
||||
problem_size.group_count = std::stoi(argv[5]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4: k_batch (> 0)\n");
|
||||
printf("arg5: group count (default=16)");
|
||||
|
||||
exit(0);
|
||||
}
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(128 + rand() % 128);
|
||||
@@ -307,21 +333,5 @@ int main(int argc, char* argv[])
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
}
|
||||
|
||||
if(argc == 5)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
config.k_batch = std::stoi(argv[4]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4: k_batch (> 0)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
return !run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
|
||||
@@ -297,6 +297,31 @@ int main(int argc, char* argv[])
|
||||
|
||||
problem_size.group_count = 16;
|
||||
|
||||
if(argc == 5)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
config.k_batch = std::stoi(argv[4]);
|
||||
}
|
||||
else if(argc == 6)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
config.k_batch = std::stoi(argv[4]);
|
||||
problem_size.group_count = std::stoi(argv[5]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4: k_batch (> 0)\n");
|
||||
printf("arg5: group count (default=16)");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(256 + 256 * i);
|
||||
@@ -308,21 +333,5 @@ int main(int argc, char* argv[])
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
}
|
||||
|
||||
if(argc == 5)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
config.k_batch = std::stoi(argv[4]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4: k_batch (> 0)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
return !run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
|
||||
@@ -66,6 +66,28 @@ int main(int argc, char* argv[])
|
||||
|
||||
problem_size.group_count = 16;
|
||||
|
||||
if(argc == 4)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 5)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
problem_size.group_count = std::stoi(argv[4]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4: group count (default=16)");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
for(int i = 0; i < problem_size.group_count; i++)
|
||||
{
|
||||
problem_size.Ms.push_back(256 + 256 * i);
|
||||
@@ -77,19 +99,5 @@ int main(int argc, char* argv[])
|
||||
problem_size.stride_Cs.push_back(problem_size.Ns[i]);
|
||||
}
|
||||
|
||||
if(argc == 4)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
return !run_grouped_gemm(problem_size, config);
|
||||
}
|
||||
|
||||
@@ -123,7 +123,9 @@ inline bool parse_cmd_args(int argc,
|
||||
|
||||
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
|
||||
conv_param = ck::utils::conv::parse_conv_param(
|
||||
num_dim_spatial, threshold_to_catch_partial_args, argv);
|
||||
num_dim_spatial,
|
||||
threshold_to_catch_partial_args + 1, // +1 because we already parsed num_dim_spatial
|
||||
argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -213,8 +213,20 @@ list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
|
||||
-Wno-undefined-func-template
|
||||
--save-temps
|
||||
)
|
||||
target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS})
|
||||
set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS)
|
||||
|
||||
check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32)
|
||||
if(HAS_DISABLE_PACKED_FP32)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
|
||||
-mllvm --amdgpu-disable-packed-fp32=1
|
||||
)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS
|
||||
-DCK_TILE_DISABLE_PACKED_FP32=1
|
||||
)
|
||||
endif()
|
||||
|
||||
target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS})
|
||||
target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS})
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
|
||||
@@ -131,4 +131,4 @@ TBD
|
||||
## FP8 experimental support
|
||||
As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+.
|
||||
|
||||
Currently we only support `-vlayout=c`( `hdim*seqlen` for V matrix) and `-squant=1`(static quantization) with `hdim=128` for fp8 now. Full feature support will come later.
|
||||
Currently we only support `-vlayout=r`( `seqlen*hdim` for V matrix) for fp8 and fp8bf16 now. Full feature support will come later.
|
||||
|
||||
@@ -7,7 +7,8 @@ FWD_DTYPE_MAP = {
|
||||
"bf16" : "FmhaFwdBf16",
|
||||
"fp8" : "FmhaFwdFp8",
|
||||
"fp8fp16": "FmhaFwdFp8Fp16",
|
||||
"fp8bf16": "FmhaFwdFp8Bf16"
|
||||
"fp8bf16": "FmhaFwdFp8Bf16",
|
||||
"fp8fp32": "FmhaFwdFp8Fp32"
|
||||
}
|
||||
|
||||
BWD_DTYPE_MAP = {
|
||||
|
||||
@@ -163,7 +163,7 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config&
|
||||
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
|
||||
const bool has_load_tr = ck_tile::is_load_tr_supported();
|
||||
|
||||
{F_dispatch}
|
||||
@@ -248,11 +248,11 @@ class FmhaFwdApiTrait:
|
||||
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_q % {self.bm0} == 0'
|
||||
else: assert False
|
||||
|
||||
|
||||
@property
|
||||
def seqtune(self) -> str:
|
||||
if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true
|
||||
else:
|
||||
else:
|
||||
return f'a.seqlen_q <= {self.bm0}'
|
||||
|
||||
@property
|
||||
@@ -351,7 +351,7 @@ class FmhaFwdPipeline:
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
|
||||
|
||||
if self.F_trload == 't' : n += '_trload'
|
||||
else: n += '_ntrload'
|
||||
|
||||
@@ -378,7 +378,7 @@ class FmhaFwdApiPool:
|
||||
"t": "has_load_tr",
|
||||
"f": "true"
|
||||
}
|
||||
|
||||
|
||||
per_tr_load =str()
|
||||
for tr_load in ["t", "f"]:
|
||||
per_dtypes=str()
|
||||
@@ -550,12 +550,16 @@ class KernelComponentFactory:
|
||||
(192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
elif dtype == 'fp8' or dtype == 'fp8bf16':
|
||||
return {
|
||||
(64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
(128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
}
|
||||
elif dtype == 'fp8fp32':
|
||||
return {
|
||||
(128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -567,9 +571,9 @@ class KernelComponentFactory:
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
squant = 'f'
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
@@ -589,11 +593,12 @@ class KernelComponentFactory:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't'))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
elif dtype in ['fp8', 'fp8bf16', 'fp8fp32']:
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
for logits, squant, mask, bias in itertools.product(["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
|
||||
elif dtype in ['fp8fp16', 'bf8']:
|
||||
# TODO
|
||||
None
|
||||
else:
|
||||
@@ -674,25 +679,34 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond = dtype in ['fp16', 'bf16', 'fp8bf16']
|
||||
cond &= mode == 'batch'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if dtype == 'fp8bf16':
|
||||
cond &= hdim == 128
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond = dtype in ['fp16', 'bf16', 'fp8bf16']
|
||||
cond &= mode == 'group'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if dtype == 'fp8bf16':
|
||||
cond &= hdim == 128
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond = dtype in ['fp16', 'bf16', 'fp8bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if dtype == 'fp8bf16':
|
||||
cond &= hdim == 128
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 888:
|
||||
cond = dtype in ['fp8', 'fp8bf16', 'fp8fp32']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= hdim == 128
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
|
||||
@@ -645,7 +645,6 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
return {
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -465,14 +465,14 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
|
||||
for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t"], ["f"]):
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# TODO
|
||||
None
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f'))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
|
||||
@@ -44,21 +44,15 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("scale_s",
|
||||
"0",
|
||||
"scale factor of S. 0 means equal to 1/sqrt(hdim).\n"
|
||||
"note when squant=1, this value will be modified by range_q/k")
|
||||
"note when squant=1, this value will be modified")
|
||||
.insert("logits_soft_cap", "0", "attention logits soft capping value.")
|
||||
.insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.")
|
||||
.insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.")
|
||||
.insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.")
|
||||
.insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.")
|
||||
.insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.")
|
||||
.insert("squant",
|
||||
"auto",
|
||||
"if using static quantization fusion or not. auto: fp8 will default use squant, "
|
||||
"other will not\n"
|
||||
"0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to "
|
||||
"P and O.\n"
|
||||
"calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, "
|
||||
"range_p, range_o")
|
||||
"calculate scale_s, scale_p, scale_o auto")
|
||||
.insert("iperm",
|
||||
"1",
|
||||
"permute input\n"
|
||||
@@ -89,7 +83,7 @@ auto create_args(int argc, char* argv[])
|
||||
"uf",
|
||||
"init method:\n ui or 0 - uniform random int\n ni - normalized random int"
|
||||
"\n uf or 1 - uniform random float\n nf - normalized random float"
|
||||
"\n tf or 2 - trig float\n uf:q or ufq or 3 - fp8 quantization")
|
||||
"\n tf or 2 - trig float\n")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
@@ -148,11 +142,6 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
|
||||
bool drop_prefs = arg_parser.get_bool("drop_prefs");
|
||||
std::string mask_str = arg_parser.get_str("mask");
|
||||
float range_q = arg_parser.get_float("range_q");
|
||||
float range_k = arg_parser.get_float("range_k");
|
||||
float range_v = arg_parser.get_float("range_v");
|
||||
float range_p = arg_parser.get_float("range_p");
|
||||
float range_o = arg_parser.get_float("range_o");
|
||||
bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved");
|
||||
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
@@ -201,11 +190,6 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
range_q,
|
||||
range_k,
|
||||
range_v,
|
||||
range_p,
|
||||
range_o,
|
||||
squant,
|
||||
is_rotary_interleaved,
|
||||
num_splits,
|
||||
@@ -237,6 +221,14 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return run<FmhaFwdFp8>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp8bf16")
|
||||
{
|
||||
return run<FmhaFwdFp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "fp8fp32")
|
||||
{
|
||||
return run<FmhaFwdFp8Fp32>(arg_parser) == fwd_result::success ? 0 : -2;
|
||||
}
|
||||
std::cerr << "Unsupported precision: " << data_type << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
@@ -45,18 +45,7 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
|
||||
"permute input\n"
|
||||
"if true, will be b*h*s*d, else b*s*h*d")
|
||||
.insert("operm", "0", "permute output")
|
||||
.insert("mask",
|
||||
"0",
|
||||
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
|
||||
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
|
||||
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
|
||||
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
|
||||
"'xt:window_size', xformer style masking from top-left, window_size negative is "
|
||||
"causal, positive is swa\n"
|
||||
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
|
||||
"causal, positive is swa\n"
|
||||
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
|
||||
"now)")
|
||||
.insert("causal", "0", "0: no mask, 1: causal mask")
|
||||
.insert("v", "1", "0:no verify, 1:verify")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
@@ -109,7 +98,16 @@ struct Problem
|
||||
softmax_scale = args.get_float("scale_s");
|
||||
if(softmax_scale == .0f)
|
||||
softmax_scale = 1.0 / ck_tile::sqrt(static_cast<float>(hdim));
|
||||
mask = mask_info::decode(args.get_str("mask"), seqlen_q, seqlen_k);
|
||||
|
||||
const auto is_causal = args.get_bool("causal");
|
||||
if(is_causal)
|
||||
{
|
||||
mask = mask_info::decode("b:-1,0", seqlen_q, seqlen_k);
|
||||
}
|
||||
else
|
||||
{
|
||||
mask = mask_info::decode("0", seqlen_q, seqlen_k);
|
||||
}
|
||||
|
||||
input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
|
||||
output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
|
||||
|
||||
@@ -41,6 +41,10 @@ struct FmhaFwdFp8Bf16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaFwdFp8Fp32
|
||||
{
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaFwdTypeConfig;
|
||||
|
||||
@@ -108,6 +112,38 @@ struct FmhaFwdTypeConfig<FmhaFwdBf8>
|
||||
using ODataType = ck_tile::bf8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdFp8Bf16>
|
||||
{
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using KDataType = ck_tile::fp8_t;
|
||||
using VDataType = ck_tile::fp8_t;
|
||||
using BiasDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaFwdTypeConfig<FmhaFwdFp8Fp32>
|
||||
{
|
||||
using QDataType = ck_tile::fp8_t;
|
||||
using KDataType = ck_tile::fp8_t;
|
||||
using VDataType = ck_tile::fp8_t;
|
||||
using BiasDataType = float;
|
||||
using RandValOutputDataType = uint8_t;
|
||||
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
|
||||
using SaccDataType = float; // data type for first gemm accumulation
|
||||
using SMPLComputeDataType = float; // data type for reduction, softmax
|
||||
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
|
||||
using OaccDataType = float; // data type for second gemm accumulation
|
||||
using ODataType = float;
|
||||
};
|
||||
|
||||
struct FmhaMasks
|
||||
{
|
||||
using NoMask = ck_tile::GenericAttentionMask<false>;
|
||||
|
||||
@@ -50,20 +50,30 @@ auto get_elimit<FmhaFwdBf16>(std::string /*init_method*/)
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<FmhaFwdFp8>(std::string init_method)
|
||||
auto get_elimit<FmhaFwdFp8>(std::string /*init_method*/)
|
||||
{
|
||||
if(init_method == "ui" || init_method == "ni")
|
||||
{
|
||||
unsigned max_rounding_point_distance = 0;
|
||||
double atol = 2e-3;
|
||||
return ck_tile::make_tuple(max_rounding_point_distance, atol);
|
||||
}
|
||||
else
|
||||
{
|
||||
unsigned max_rounding_point_distance = 1;
|
||||
double atol = 0.0625;
|
||||
return ck_tile::make_tuple(max_rounding_point_distance, atol);
|
||||
}
|
||||
using TypeConfig = FmhaFwdTypeConfig<FmhaFwdFp8>;
|
||||
using ODataType = typename TypeConfig::ODataType;
|
||||
float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
|
||||
double rtol = 0;
|
||||
double atol = 16 * (o_dtype_max > 240 ? 2 : 1);
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<FmhaFwdFp8Bf16>(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1.8e-1;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<FmhaFwdFp8Fp32>(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1.8e-1;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits)
|
||||
@@ -157,11 +167,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
uint64_t drop_offset,
|
||||
bool drop_prefs,
|
||||
std::string mask_str,
|
||||
float range_q,
|
||||
float range_k,
|
||||
float range_v,
|
||||
float range_p,
|
||||
float range_o,
|
||||
bool squant,
|
||||
bool is_rotary_interleaved,
|
||||
ck_tile::index_t num_splits,
|
||||
@@ -180,6 +185,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
return "fp8";
|
||||
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdBf8>)
|
||||
return "bf8";
|
||||
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16>)
|
||||
return "fp8bf16";
|
||||
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8Fp32>)
|
||||
return "fp8fp32";
|
||||
else
|
||||
static_assert(false);
|
||||
}();
|
||||
@@ -367,22 +376,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
using OaccDataType = typename TypeConfig::OaccDataType;
|
||||
using ODataType = typename TypeConfig::ODataType;
|
||||
|
||||
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
|
||||
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
|
||||
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
|
||||
float p_dtype_max = v_dtype_max; // assume p and v is the same type
|
||||
float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
|
||||
|
||||
float scale_p = 1.f;
|
||||
float scale_o = 1.f;
|
||||
|
||||
if(squant)
|
||||
{
|
||||
scale_s = scale_s * (range_q / q_dtype_max) * (range_k / k_dtype_max);
|
||||
scale_p = p_dtype_max / range_p;
|
||||
scale_o = (o_dtype_max / range_o) * (range_p / p_dtype_max) * (range_v / v_dtype_max);
|
||||
}
|
||||
|
||||
// accumulation numbers for performance evaluation
|
||||
std::size_t flop = 0, num_byte = 0;
|
||||
auto max_seqlen_q =
|
||||
@@ -528,7 +521,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::HostTensor<int32_t> cache_batch_idx_host(use_cache_batch_idx
|
||||
? std::array<ck_tile::index_t, 1>{batch}
|
||||
: std::array<ck_tile::index_t, 1>{1});
|
||||
|
||||
float max_o = 5.0;
|
||||
if(init_method == "ui" || init_method == "0")
|
||||
{
|
||||
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-3.f, 3.f, next_seed()}(q_host);
|
||||
@@ -576,32 +569,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::FillTrigValue<VDataType>{}(vnew_host);
|
||||
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
|
||||
}
|
||||
else if(init_method == "ufq" || init_method == "uf:q" || init_method == "3")
|
||||
{
|
||||
// suitable for fp8 quantization
|
||||
if(!squant)
|
||||
{
|
||||
std::cerr << "init method " << init_method << " can not be used without quantization"
|
||||
<< std::endl;
|
||||
return fwd_result::invalid_args;
|
||||
}
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.f, q_dtype_max, next_seed()}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, k_dtype_max, next_seed()}(k_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, k_dtype_max, next_seed()}(knew_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, v_dtype_max, next_seed()}(v_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, v_dtype_max, next_seed()}(vnew_host);
|
||||
|
||||
// bias_fp8 = qscale_bias * bias_fp32
|
||||
float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k);
|
||||
// Assume bias is in [0.f, 1.f] in original fp32
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, qscale_bias, next_seed()}(bias_host);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "Unknown value for init argument: " << init_method << std::endl;
|
||||
return fwd_result::invalid_args;
|
||||
}
|
||||
|
||||
if(bias.type == bias_enum::alibi)
|
||||
{
|
||||
auto slopes = ck_tile::get_alibi_slopes<SaccDataType>(nhead);
|
||||
@@ -625,8 +592,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
|
||||
@@ -650,10 +617,79 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes());
|
||||
|
||||
float scale_p = 1.f;
|
||||
float scale_o = 1.f;
|
||||
if(squant)
|
||||
{
|
||||
float q_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::max());
|
||||
float k_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::max());
|
||||
float v_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::max());
|
||||
float p_dtype_max = v_dtype_max; // assume p and v is the same type
|
||||
// Q tensor
|
||||
{
|
||||
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<QDataType>::min());
|
||||
q_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
if(val > max_value)
|
||||
max_value = val;
|
||||
});
|
||||
|
||||
float scale = q_dtype_max / max_value;
|
||||
|
||||
q_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
self(idx) = ck_tile::type_convert<QDataType>(val * scale);
|
||||
});
|
||||
scale_s = scale_s / scale;
|
||||
}
|
||||
|
||||
// K tensor
|
||||
{
|
||||
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<KDataType>::min());
|
||||
k_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
if(val > max_value)
|
||||
max_value = val;
|
||||
});
|
||||
float scale = k_dtype_max / max_value;
|
||||
k_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
self(idx) = ck_tile::type_convert<KDataType>(val * scale);
|
||||
});
|
||||
scale_s = scale_s / scale;
|
||||
}
|
||||
|
||||
// V tensor
|
||||
{
|
||||
float max_value = ck_tile::type_convert<float>(ck_tile::numeric<VDataType>::min());
|
||||
v_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
if(val > max_value)
|
||||
max_value = val;
|
||||
});
|
||||
|
||||
float scale = k_dtype_max / max_value;
|
||||
v_host.ForEach([&](auto& self, auto idx) {
|
||||
float val = ck_tile::type_convert<float>(self(idx));
|
||||
self(idx) = ck_tile::type_convert<VDataType>(val * scale);
|
||||
});
|
||||
|
||||
scale_o = (1.0 / p_dtype_max) / scale;
|
||||
}
|
||||
|
||||
scale_p = p_dtype_max;
|
||||
|
||||
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp8>)
|
||||
{
|
||||
float o_dtype_max = ck_tile::type_convert<float>(ck_tile::numeric<ODataType>::max());
|
||||
scale_o = scale_o * o_dtype_max / max_o;
|
||||
}
|
||||
}
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
knew_buf.ToDevice(knew_host.data());
|
||||
v_buf.ToDevice(v_host.data());
|
||||
knew_buf.ToDevice(knew_host.data());
|
||||
vnew_buf.ToDevice(vnew_host.data());
|
||||
bias_buf.ToDevice(bias_host.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
@@ -1103,7 +1139,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
lse_buf.FromDevice(lse_host.data());
|
||||
randval_buf.FromDevice(randval_host.data());
|
||||
|
||||
constexpr bool supports_squant = std::is_same_v<DataTypeConfig, FmhaFwdFp8>;
|
||||
constexpr bool supports_squant = std::is_same_v<DataTypeConfig, FmhaFwdFp8> ||
|
||||
std::is_same_v<DataTypeConfig, FmhaFwdFp8Bf16> ||
|
||||
std::is_same_v<DataTypeConfig, FmhaFwdFp8Fp32>;
|
||||
|
||||
auto p_compute_element_func = [&]() {
|
||||
if constexpr(supports_squant)
|
||||
@@ -1113,9 +1151,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}();
|
||||
|
||||
auto oacc_element_func = [&]() {
|
||||
if constexpr(supports_squant)
|
||||
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t> && supports_squant)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o});
|
||||
else if constexpr(supports_squant)
|
||||
return ck_tile::scales{scale_o};
|
||||
else
|
||||
return ck_tile::identity{};
|
||||
}();
|
||||
|
||||
@@ -34,7 +34,8 @@ struct fmha_fwd_v3_args
|
||||
|
||||
index_t window_size_left;
|
||||
index_t window_size_right;
|
||||
index_t mask_type;
|
||||
index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and
|
||||
// window_size_right == 0).
|
||||
|
||||
const void* q_ptr;
|
||||
index_t stride_q;
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
|
||||
|
||||
#include "fmha_fwd_v3.hpp"
|
||||
#include "mask.hpp"
|
||||
|
||||
#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \
|
||||
template <> \
|
||||
@@ -79,7 +80,7 @@ struct fmha_fwd_v3_kernel_traits
|
||||
-1 // kBlockPerCu
|
||||
>;
|
||||
|
||||
using fmha_mask = SimplifiedGenericAttentionMask<IsMasking>;
|
||||
using fmha_mask = GenericAttentionMask<IsMasking, /*IsLocal=*/false>;
|
||||
|
||||
using fmha_pipeline_problem =
|
||||
BlockFmhaFwdV3PipelineProblem<typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
|
||||
@@ -112,6 +113,22 @@ struct fmha_fwd_v3_kernel_traits
|
||||
template <typename Kernel>
|
||||
float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_config& config)
|
||||
{
|
||||
/// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly
|
||||
/// maximizes the kernel's performance.
|
||||
int remap_opt = 2;
|
||||
if(args.mask_type != static_cast<int>(mask_enum::no_mask) &&
|
||||
((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q)))
|
||||
{
|
||||
if(65536 <= args.seqlen_q)
|
||||
{
|
||||
remap_opt = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
remap_opt = 1;
|
||||
}
|
||||
}
|
||||
|
||||
auto kargs = Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
@@ -140,7 +157,8 @@ float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_confi
|
||||
args.batch_stride_o,
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.mask_type);
|
||||
args.mask_type,
|
||||
remap_opt);
|
||||
|
||||
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
@@ -8,22 +8,16 @@ for prec in "fp16" "bf16" ; do
|
||||
for hdim in 128 ; do
|
||||
for perm in 0 ; do
|
||||
|
||||
if [ $causal -eq 0 ]; then
|
||||
mask=0
|
||||
else
|
||||
mask=b:-1,0
|
||||
fi
|
||||
|
||||
$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
|
||||
$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1
|
||||
@@ -0,0 +1,2 @@
|
||||
tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1
|
||||
@@ -0,0 +1,31 @@
|
||||
tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1
|
||||
tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1
|
||||
@@ -0,0 +1,4 @@
|
||||
tile_example_fmha_fwd -prec=fp16 -mode=0 -b=2 -h=1 -d=128 -d_v=24 -s=3 -s_k=99 -bias=n -p_drop=0.0 -lse=0 -iperm=0 -operm=0 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1
|
||||
tile_example_fmha_fwd -prec=fp16 -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 -s_kpad=32 -bias=n -p_drop=0.0 -lse=0 -iperm=0 -operm=0 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1
|
||||
tile_example_fmha_fwd -prec=fp16 -mode=0 -b=2 -h=1 -d=128 -d_v=24 -s=3 -s_k=99 -bias=n -p_drop=0.0 -lse=0 -iperm=1 -operm=1 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1
|
||||
tile_example_fmha_fwd -prec=fp16 -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 -s_kpad=32 -bias=n -p_drop=0.0 -lse=0 -iperm=1 -operm=1 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1
|
||||
@@ -2,13 +2,35 @@
|
||||
# TODO: run this script from CK root or build directory
|
||||
set -euo pipefail
|
||||
|
||||
EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)"
|
||||
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
|
||||
EXE_NAME=tile_example_fmha_bwd
|
||||
EXE="$(find . -name $EXE_NAME -type f | head -n 1)"
|
||||
KNAME=1
|
||||
GPU_arch=$GPU_arch
|
||||
if [ -z "$GPU_arch" ] ; then
|
||||
GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}')
|
||||
fi
|
||||
|
||||
export CK_WARMUP=0
|
||||
export CK_REPEAT=1
|
||||
|
||||
CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_bwd_fails_$GPU_arch.txt"}
|
||||
rm -f $CURR_FAILS_FILE
|
||||
touch $CURR_FAILS_FILE
|
||||
KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/fmha_bwd_known_fails_$GPU_arch.txt"}
|
||||
|
||||
COMMON_ARGS='-v=1'
|
||||
|
||||
run_exe() {
|
||||
set +ex
|
||||
$EXE $@
|
||||
local ret=$?
|
||||
if [ $ret -ne 0 ] ; then
|
||||
echo "$EXE_NAME $*" >> $CURR_FAILS_FILE
|
||||
fi
|
||||
set -ex
|
||||
}
|
||||
|
||||
set -x
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for perm in 0 1 ; do
|
||||
@@ -19,12 +41,12 @@ for dbias in 0 ; do
|
||||
for p_drop in 0.0 0.2 ; do
|
||||
for deterministic in 0 ; do
|
||||
|
||||
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done
|
||||
done
|
||||
@@ -35,3 +57,24 @@ done
|
||||
done
|
||||
done
|
||||
set +x
|
||||
|
||||
new_fails_count=0
|
||||
known_fails_count=0
|
||||
if [ -f $KNOWN_FAILS_FILE ] ; then
|
||||
echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):"
|
||||
while IFS= read -r line; do
|
||||
if grep -Fxq "$line" $KNOWN_FAILS_FILE; then
|
||||
echo "Known fail: $line"
|
||||
known_fails_count=$(($known_fails_count + 1))
|
||||
else
|
||||
echo "New fail: $line"
|
||||
new_fails_count=$(($new_fails_count + 1))
|
||||
fi
|
||||
done < $CURR_FAILS_FILE
|
||||
else
|
||||
new_fails_count=$(wc -l < $CURR_FAILS_FILE)
|
||||
echo "No known fails file, all fails ($new_fails_count) are new:"
|
||||
cat $CURR_FAILS_FILE
|
||||
fi
|
||||
echo "New fails count: $new_fails_count; Known fails count: $known_fails_count"
|
||||
exit $(($new_fails_count != 0))
|
||||
|
||||
@@ -2,12 +2,23 @@
|
||||
# TODO: run this script from CK root or build directory
|
||||
set -euo pipefail
|
||||
|
||||
EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
|
||||
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
|
||||
EXE_NAME=tile_example_fmha_fwd
|
||||
EXE="$(find . -name $EXE_NAME -type f | head -n 1)"
|
||||
KNAME=1
|
||||
GPU_arch=$GPU_arch
|
||||
if [ -z "$GPU_arch" ] ; then
|
||||
GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}')
|
||||
fi
|
||||
|
||||
export CK_WARMUP=0
|
||||
export CK_REPEAT=1
|
||||
|
||||
CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_fwd_fails_$GPU_arch.txt"}
|
||||
rm -f $CURR_FAILS_FILE
|
||||
touch $CURR_FAILS_FILE
|
||||
KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/fmha_fwd_known_fails_$GPU_arch.txt"}
|
||||
|
||||
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
|
||||
# mode=0
|
||||
# export HIP_VISIBLE_DEVICES=4
|
||||
@@ -30,6 +41,16 @@ while getopts ":sa" opt; do
|
||||
esac
|
||||
done
|
||||
|
||||
run_exe() {
|
||||
set +ex
|
||||
$EXE $@
|
||||
local ret=$?
|
||||
if [ $ret -ne 0 ] ; then
|
||||
echo "$EXE_NAME $*" >> $CURR_FAILS_FILE
|
||||
fi
|
||||
set -ex
|
||||
}
|
||||
|
||||
run_fp16_bf16_tests() {
|
||||
local NUM_SPLITS="1"
|
||||
local PAGE_BLOCK_SIZE="0"
|
||||
@@ -52,16 +73,16 @@ run_fp16_bf16_tests() {
|
||||
for page_block_size in $PAGE_BLOCK_SIZE ; do
|
||||
for cache_batch_idx in $CACHE_BATCH_IDX ; do
|
||||
|
||||
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
# run_exe -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done ; done
|
||||
done ; done ; done ; done ; done
|
||||
@@ -73,7 +94,29 @@ run_fp8_tests() {
|
||||
for b in 1 2 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
|
||||
$EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=fp8 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done
|
||||
}
|
||||
|
||||
run_fp8bf16_tests() {
|
||||
for perm in 0 1 ; do
|
||||
for bias in "n" "e" "a" ; do
|
||||
for b in 1 2 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
|
||||
$EXE -prec=fp8bf16 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done
|
||||
}
|
||||
|
||||
run_fp8fp32_tests() {
|
||||
for perm in 0 1 ; do
|
||||
for bias in "n" "e" "a" ; do
|
||||
for b in 1 2 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
|
||||
$EXE -prec=fp8fp32 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done
|
||||
}
|
||||
@@ -88,7 +131,7 @@ run_fp16_appendkv_tests() {
|
||||
for page_block_size in 0 128 ; do
|
||||
for cache_batch_idx in 0 1 ; do
|
||||
|
||||
$EXE -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS
|
||||
run_exe -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done ; done
|
||||
done ; done ; done
|
||||
@@ -98,9 +141,32 @@ set -x
|
||||
|
||||
run_fp16_bf16_tests
|
||||
run_fp8_tests
|
||||
run_fp8bf16_tests
|
||||
run_fp8fp32_tests
|
||||
|
||||
if [ $TEST_APPENDKV -eq 1 ] ; then
|
||||
run_fp16_appendkv_tests
|
||||
fi
|
||||
|
||||
set +x
|
||||
|
||||
new_fails_count=0
|
||||
known_fails_count=0
|
||||
if [ -f $KNOWN_FAILS_FILE ] ; then
|
||||
echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):"
|
||||
while IFS= read -r line; do
|
||||
if grep -Fxq "$line" $KNOWN_FAILS_FILE; then
|
||||
echo "Known fail: $line"
|
||||
known_fails_count=$(($known_fails_count + 1))
|
||||
else
|
||||
echo "New fail: $line"
|
||||
new_fails_count=$(($new_fails_count + 1))
|
||||
fi
|
||||
done < $CURR_FAILS_FILE
|
||||
else
|
||||
new_fails_count=$(wc -l < $CURR_FAILS_FILE)
|
||||
echo "No known fails file, all fails ($new_fails_count) are new:"
|
||||
cat $CURR_FAILS_FILE
|
||||
fi
|
||||
echo "New fails count: $new_fails_count; Known fails count: $known_fails_count"
|
||||
exit $(($new_fails_count != 0))
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp)
|
||||
add_executable(tile_example_quant_grouped_gemm EXCLUDE_FROM_ALL quant_grouped_gemm.cpp)
|
||||
add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp)
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
target_compile_options(tile_example_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
@@ -356,6 +356,8 @@ int main(int argc, char* argv[])
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_gemm_example<GemmConfigComputeV4_Wmma>(argc, argv);
|
||||
#else
|
||||
return !run_grouped_gemm_example<GemmConfigComputeV4>(argc, argv);
|
||||
return !run_grouped_gemm_example<GemmConfigComputeV4>(argc, argv) ||
|
||||
!run_grouped_gemm_example<GemmConfigComputeV3_2>(argc, argv) ||
|
||||
!run_grouped_gemm_example<GemmConfigComputeV4_V2>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = false;
|
||||
static constexpr bool Persistent = false;
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
};
|
||||
|
||||
@@ -139,6 +139,29 @@ struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4_V2 : public GemmConfigBase
|
||||
{
|
||||
// Compute V4 only support Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleDecode : public GemmConfigBase
|
||||
{
|
||||
|
||||
1
example/ck_tile/22_gemm_multi_abd/CMakeLists.txt
Normal file
1
example/ck_tile/22_gemm_multi_abd/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_executable(tile_example_gemm_multi_abd_fp16 EXCLUDE_FROM_ALL gemm_multi_abd_fp16.cpp)
|
||||
35
example/ck_tile/22_gemm_multi_abd/README.md
Normal file
35
example/ck_tile/22_gemm_multi_abd/README.md
Normal file
@@ -0,0 +1,35 @@
|
||||
#Multiple ABD GEMM
|
||||
|
||||
This folder contains example for Multiple ABD GEMM using ck_tile tile-programming implementation.
|
||||
|
||||
## build
|
||||
```
|
||||
#in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \
|
||||
leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
#The basic pipeline method on the gemm calculation
|
||||
make tile_example_gemm_multi_abd_fp16 -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_gemm_multi_abd_fp16`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-m M dimensions - (Default: 3840)
|
||||
-n N dimensions - (Default: 4096)
|
||||
-k K dimensions - (Default: 4096)
|
||||
-as_layout Tensor A layout (default:R)
|
||||
-bs_layout Tensor B layout (default:C)
|
||||
-ds_layout Tensor D layout (default:R)
|
||||
-e_layout Tensor E layout (default:R)
|
||||
-stride_as Tensor A strides - (Default: 0)
|
||||
-stride_bs Tensor B strides - (Default: 0)
|
||||
-stride_e Tensor C strides - (Default: 0)
|
||||
-stride_ds Tensor D strides - (Default: 0)
|
||||
-validate 0. No validation, 1. Validation on GPU. (Default: 1)
|
||||
-warmup Number of iterations before benchmark the kernel. (Default: 10)
|
||||
-repeat Number of iterations to benchmark the kernel. (Default: 100)
|
||||
-kbatch kbatch for SplitK. (Default: 1)
|
||||
```
|
||||
184
example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp
Normal file
184
example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp
Normal file
@@ -0,0 +1,184 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_multi_abd_fp16.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename BElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
auto gemm_multi_abd(const gemm_multi_abd_kargs& args, const ck_tile::stream_config& s) -> float
|
||||
{
|
||||
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
|
||||
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
|
||||
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
|
||||
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr bool TransposeC = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, AsLayout, BsLayout, ELayout>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
AsLayout,
|
||||
BsLayout,
|
||||
ELayout,
|
||||
TransposeC>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<AsDataType, BsDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<AsDataType,
|
||||
BsDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
AElementWise,
|
||||
BElementWise>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernelMultiABD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", "
|
||||
<< grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", "
|
||||
<< blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_gemm_multi_abd_fp16_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_multiple_abd_gemm_example<GemmConfigV3_Wmma>(argc, argv);
|
||||
#else
|
||||
return !run_multiple_abd_gemm_example<GemmConfigV3>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
186
example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp
Normal file
186
example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp
Normal file
@@ -0,0 +1,186 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
|
||||
#ifndef CK_TILE_PIPELINE_DEFAULT
|
||||
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
|
||||
#endif
|
||||
|
||||
using A0DataType = ck_tile::half_t;
|
||||
using A1DataType = ck_tile::half_t;
|
||||
|
||||
using B0DataType = ck_tile::half_t;
|
||||
using B1DataType = ck_tile::half_t;
|
||||
|
||||
using D0DataType = ck_tile::half_t;
|
||||
using D1DataType = ck_tile::half_t;
|
||||
|
||||
using EDataType = ck_tile::half_t;
|
||||
|
||||
using AsDataType = ck_tile::tuple<A0DataType, A1DataType>;
|
||||
using BsDataType = ck_tile::tuple<B0DataType, B1DataType>;
|
||||
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
|
||||
|
||||
using AccDataType = float;
|
||||
|
||||
struct GemmConfigMemory
|
||||
{
|
||||
// Memory friendly for Interwave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
struct GemmConfigV3
|
||||
{
|
||||
// Compute friendly for Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
struct GemmConfigV4
|
||||
{
|
||||
// Compute friendly for Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
struct GemmConfigV3_Wmma
|
||||
{
|
||||
// Compute friendly for Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "4096", "k dimension")
|
||||
.insert("as_layout", "R", "As tensor data layout - Row by default")
|
||||
.insert("bs_layout", "C", "Bs tensor data layout - Col by default")
|
||||
.insert("ds_layout", "R", "Ds tensor data layout - Row by default")
|
||||
.insert("e_layout", "R", "E tensor data layout - Row by default")
|
||||
.insert("stride_as", "0", "Tensor A stride")
|
||||
.insert("stride_bs", "0", "Tensor B stride")
|
||||
.insert("stride_ds", "0", "Tensor Ds stride")
|
||||
.insert("stride_e", "0", "Tensor E stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on GPU")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("kbatch", "1", "kbatch for SplitK");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
using gemm_multi_abd_kargs =
|
||||
ck_tile::GemmMultiABDHostArgs<AsDataType::size(), BsDataType::size(), DsDataType::size()>;
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename AElementWise,
|
||||
typename BElementWise,
|
||||
typename CDEElementWise>
|
||||
float gemm_multi_abd(const gemm_multi_abd_kargs& kargs, const ck_tile::stream_config& s);
|
||||
@@ -0,0 +1,311 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename AsDataType,
|
||||
typename BsDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename EDataType,
|
||||
typename AsLayout,
|
||||
typename BsLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename BElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm_multi_abd(const std::array<const void*, AsDataType::size()>& as_m_k_dev_buf,
|
||||
const std::array<const void*, BsDataType::size()>& bs_k_n_dev_buf,
|
||||
const std::array<const void*, DsDataType::size()>& ds_m_n_dev_buf,
|
||||
void* e_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
const std::array<ck_tile::index_t, AsDataType::size()>& StrideAs,
|
||||
const std::array<ck_tile::index_t, BsDataType::size()>& StrideBs,
|
||||
const std::array<ck_tile::index_t, DsDataType::size()>& StrideDs,
|
||||
ck_tile::index_t StrideE,
|
||||
int n_warmup,
|
||||
int n_repeat,
|
||||
int k_batch)
|
||||
{
|
||||
gemm_multi_abd_kargs gemm_descs({as_m_k_dev_buf,
|
||||
bs_k_n_dev_buf,
|
||||
ds_m_n_dev_buf,
|
||||
e_m_n_dev_buf,
|
||||
k_batch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideDs,
|
||||
StrideE});
|
||||
|
||||
float ave_time = gemm_multi_abd<GemmConfig,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AElementWise,
|
||||
BElementWise,
|
||||
CDEElementWise>(
|
||||
gemm_descs, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::string op_name{"Gemm Multiple-ABD"};
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
flop += std::size_t(2) * M * N * K;
|
||||
|
||||
num_btype +=
|
||||
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Gemm Multiple-ABD kernel with:\n";
|
||||
std::cout << "M =" << M << " N =" << N << " K =" << K << "\n";
|
||||
std::cout << "StrideA = " << StrideAs[0] << " StrideB = " << StrideBs[0]
|
||||
<< " StrideE = " << StrideE << "\n";
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< "\n";
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename A0Layout,
|
||||
typename A1Layout,
|
||||
typename B0Layout,
|
||||
typename B1Layout,
|
||||
typename D0Layout,
|
||||
typename D1Layout,
|
||||
typename ELayout>
|
||||
int run_gemm_multi_abd_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const A0Layout a0_layout = A0Layout{},
|
||||
const A1Layout a1_layout = A1Layout{},
|
||||
const B0Layout b0_layout = B0Layout{},
|
||||
const B1Layout b1_layout = B1Layout{},
|
||||
const D0Layout d0_layout = D0Layout{},
|
||||
const D1Layout d1_layout = D1Layout{},
|
||||
const ELayout e_layout = ELayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
using AElementWiseFn = ck_tile::element_wise::AddScale;
|
||||
using BElementWiseFn = ck_tile::element_wise::AddScale;
|
||||
using CDEElementWiseFn = ck_tile::element_wise::MultiDMultiply;
|
||||
using AsLayout = ck_tile::tuple<A0Layout, A1Layout>;
|
||||
using BsLayout = ck_tile::tuple<B0Layout, B1Layout>;
|
||||
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t StrideA = arg_parser.get_int("stride_as");
|
||||
ck_tile::index_t StrideB = arg_parser.get_int("stride_bs");
|
||||
ck_tile::index_t StrideD = arg_parser.get_int("stride_ds");
|
||||
ck_tile::index_t StrideE = arg_parser.get_int("stride_e");
|
||||
|
||||
ck_tile::index_t StrideA0 = StrideA;
|
||||
ck_tile::index_t StrideA1 = StrideA;
|
||||
|
||||
ck_tile::index_t StrideB0 = StrideB;
|
||||
ck_tile::index_t StrideB1 = StrideB;
|
||||
|
||||
ck_tile::index_t StrideD0 = StrideD;
|
||||
ck_tile::index_t StrideD1 = StrideD;
|
||||
|
||||
const int n_warmup = arg_parser.get_int("warmup");
|
||||
const int n_repeat = arg_parser.get_int("repeat");
|
||||
const int k_batch = arg_parser.get_int("kbatch");
|
||||
|
||||
StrideA0 = get_default_stride(M, N, StrideA0, is_row_major(a1_layout));
|
||||
StrideA1 = get_default_stride(M, N, StrideA1, is_row_major(a1_layout));
|
||||
|
||||
StrideB0 = get_default_stride(K, N, StrideB0, is_row_major(b0_layout));
|
||||
StrideB1 = get_default_stride(K, N, StrideB1, is_row_major(b1_layout));
|
||||
|
||||
StrideD0 = get_default_stride(M, N, StrideD0, is_row_major(d0_layout));
|
||||
StrideD1 = get_default_stride(M, N, StrideD1, is_row_major(d1_layout));
|
||||
|
||||
StrideE = get_default_stride(M, N, StrideE, is_row_major(e_layout));
|
||||
|
||||
ck_tile::HostTensor<A0DataType> a0_m_k_tesnor(
|
||||
host_tensor_descriptor(M, K, StrideA0, is_row_major(a0_layout)));
|
||||
ck_tile::HostTensor<A1DataType> a1_m_k_tesnor(
|
||||
host_tensor_descriptor(M, K, StrideA1, is_row_major(a1_layout)));
|
||||
|
||||
ck_tile::HostTensor<B0DataType> b0_k_n_tensors(
|
||||
host_tensor_descriptor(K, N, StrideB0, is_row_major(b0_layout)));
|
||||
ck_tile::HostTensor<B1DataType> b1_k_n_tensors(
|
||||
host_tensor_descriptor(K, N, StrideB1, is_row_major(b1_layout)));
|
||||
|
||||
ck_tile::HostTensor<D0DataType> d0_m_n_tensors(
|
||||
host_tensor_descriptor(M, N, StrideD0, is_row_major(d0_layout)));
|
||||
ck_tile::HostTensor<D1DataType> d1_m_n_tensors(
|
||||
host_tensor_descriptor(M, N, StrideD1, is_row_major(d1_layout)));
|
||||
|
||||
ck_tile::HostTensor<EDataType> e_m_n_device_result(
|
||||
host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout)));
|
||||
|
||||
ck_tile::FillUniformDistribution<A0DataType>{-1.f, 1.f}(a0_m_k_tesnor);
|
||||
ck_tile::FillUniformDistribution<A1DataType>{-1.f, 1.f}(a1_m_k_tesnor);
|
||||
|
||||
ck_tile::FillUniformDistribution<B0DataType>{-1.f, 1.f}(b0_k_n_tensors);
|
||||
ck_tile::FillUniformDistribution<B1DataType>{-1.f, 1.f}(b1_k_n_tensors);
|
||||
|
||||
ck_tile::FillUniformDistribution<D0DataType>{-1.f, 1.f}(d0_m_n_tensors);
|
||||
ck_tile::FillUniformDistribution<D1DataType>{-1.f, 1.f}(d1_m_n_tensors);
|
||||
|
||||
ck_tile::DeviceMem a0_m_k_dev_buf(a0_m_k_tesnor.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem a1_m_k_dev_buf(a1_m_k_tesnor.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem b0_k_n_dev_buf(b0_k_n_tensors.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b1_k_n_dev_buf(b1_k_n_tensors.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes());
|
||||
|
||||
a0_m_k_dev_buf.ToDevice(a0_m_k_tesnor.mData.data());
|
||||
a1_m_k_dev_buf.ToDevice(a1_m_k_tesnor.mData.data());
|
||||
|
||||
b0_k_n_dev_buf.ToDevice(b0_k_n_tensors.mData.data());
|
||||
b1_k_n_dev_buf.ToDevice(b1_k_n_tensors.mData.data());
|
||||
|
||||
d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data());
|
||||
d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data());
|
||||
|
||||
e_m_n_dev_buf.SetZero();
|
||||
e_m_n_device_result.SetZero();
|
||||
|
||||
std::array<const void*, DsDataType::size()> as_ptr_buf = {a0_m_k_dev_buf.GetDeviceBuffer(),
|
||||
a1_m_k_dev_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<const void*, DsDataType::size()> bs_ptr_buf = {b0_k_n_dev_buf.GetDeviceBuffer(),
|
||||
b1_k_n_dev_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<const void*, DsDataType::size()> ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(),
|
||||
d1_m_n_dev_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck_tile::index_t, AsDataType::size()> strideAs = {StrideA0, StrideA1};
|
||||
std::array<ck_tile::index_t, BsDataType::size()> strideBs = {StrideB0, StrideB1};
|
||||
std::array<ck_tile::index_t, DsDataType::size()> strideDs = {StrideD0, StrideD1};
|
||||
|
||||
invoke_gemm_multi_abd<GemmConfig,
|
||||
AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AElementWiseFn,
|
||||
BElementWiseFn,
|
||||
CDEElementWiseFn>(as_ptr_buf,
|
||||
bs_ptr_buf,
|
||||
ds_ptr_buf,
|
||||
e_m_n_dev_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
strideAs,
|
||||
strideBs,
|
||||
strideDs,
|
||||
StrideE,
|
||||
n_warmup,
|
||||
n_repeat,
|
||||
k_batch);
|
||||
|
||||
e_m_n_dev_buf.FromDevice(e_m_n_device_result.data());
|
||||
|
||||
ck_tile::HostTensor<A0DataType> a_m_k_host_ref_element_result(
|
||||
host_tensor_descriptor(M, K, StrideA0, is_row_major(a0_layout)));
|
||||
ck_tile::HostTensor<B0DataType> b_k_n_host_ref_element_result(
|
||||
host_tensor_descriptor(K, N, StrideB0, is_row_major(b0_layout)));
|
||||
ck_tile::HostTensor<EDataType> e_m_n_host_ref(
|
||||
host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout)));
|
||||
a_m_k_host_ref_element_result.SetZero();
|
||||
b_k_n_host_ref_element_result.SetZero();
|
||||
e_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_gemm_multiple_abd<AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
AElementWiseFn,
|
||||
BElementWiseFn,
|
||||
CDEElementWiseFn>({a0_m_k_tesnor, a1_m_k_tesnor},
|
||||
{b0_k_n_tensors, b1_k_n_tensors},
|
||||
{d0_m_n_tensors, d1_m_n_tensors},
|
||||
a_m_k_host_ref_element_result,
|
||||
b_k_n_host_ref_element_result,
|
||||
e_m_n_host_ref);
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("v"))
|
||||
{
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(e_m_n_host_ref.mData.begin(), e_m_n_host_ref.mData.end());
|
||||
|
||||
const auto rtol_atol = calculate_rtol_atol(K, 1, max_accumulated_value);
|
||||
|
||||
pass &= ck_tile::check_err(e_m_n_device_result,
|
||||
e_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< std::endl;
|
||||
std::cout << "Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The CPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename GemmConfig>
|
||||
int run_multiple_abd_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string as_layout = arg_parser.get_str("as_layout");
|
||||
const std::string bs_layout = arg_parser.get_str("bs_layout");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(as_layout == "R" && bs_layout == "C")
|
||||
{
|
||||
return run_gemm_multi_abd_example_with_layouts<GemmConfig>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
}
|
||||
38
example/ck_tile/22_gemm_multi_abd/utils.hpp
Normal file
38
example/ck_tile/22_gemm_multi_abd/utils.hpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeTypeAB =
|
||||
std::conditional_t<sizeof(A0DataType) < sizeof(B0DataType), A0DataType, B0DataType>;
|
||||
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ComputeTypeAB) < sizeof(D0DataType), ComputeTypeAB, D0DataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
|
||||
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
@@ -21,6 +21,7 @@ add_subdirectory(18_flatmm)
|
||||
add_subdirectory(19_gemm_multi_d)
|
||||
add_subdirectory(20_grouped_convolution)
|
||||
add_subdirectory(21_elementwise)
|
||||
add_subdirectory(22_gemm_multi_abd)
|
||||
add_subdirectory(35_batched_transpose)
|
||||
add_subdirectory(38_block_scale_gemm)
|
||||
add_subdirectory(39_copy)
|
||||
|
||||
Reference in New Issue
Block a user