mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Merge branch 'develop' into tianxing/unified-attention
This commit is contained in:
@@ -42,6 +42,8 @@ option(ENABLE_CLANG_CPP_CHECKS "Enables clang tidy, cppcheck" ON)
|
||||
option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF)
|
||||
option(CK_EXPERIMENTAL_BUILDER "Enable experimental builder" OFF)
|
||||
option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
|
||||
option(FORCE_DISABLE_XDL "Skip compiling XDL specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF)
|
||||
option(FORCE_DISABLE_WMMA "Skip compiling WMMA specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF)
|
||||
|
||||
if(CK_EXPERIMENTAL_BUILDER)
|
||||
add_definitions(-DCK_EXPERIMENTAL_BUILDER)
|
||||
@@ -232,12 +234,12 @@ message(STATUS "Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}"
|
||||
# Cache SUPPORTED_GPU_TARGETS for debug
|
||||
set(SUPPORTED_GPU_TARGETS "${SUPPORTED_GPU_TARGETS}" CACHE STRING "List of supported GPU targets")
|
||||
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx11|gfx12" AND NOT FORCE_DISABLE_XDL)
|
||||
message(STATUS "Enabling XDL instances")
|
||||
add_definitions(-DCK_USE_XDL)
|
||||
set(CK_USE_XDL "ON")
|
||||
endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95")
|
||||
if ((SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND NOT FORCE_DISABLE_XDL)
|
||||
message(STATUS "Enabling XDL FP8 gemms on native architectures")
|
||||
add_definitions(-DCK_USE_GFX94)
|
||||
set(CK_USE_GFX94 "ON")
|
||||
@@ -250,7 +252,7 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx10")
|
||||
add_definitions(-DCK_GFX1030_SUPPORT)
|
||||
endif()
|
||||
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
if ((SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") AND NOT FORCE_DISABLE_WMMA)
|
||||
message(STATUS "Enabling WMMA instances")
|
||||
add_definitions(-DCK_USE_WMMA)
|
||||
set(CK_USE_WMMA "ON")
|
||||
@@ -260,7 +262,7 @@ endif()
|
||||
# define the macro with the current value (0 or 1)
|
||||
add_definitions(-DCK_TILE_USE_WMMA=${CK_TILE_USE_WMMA})
|
||||
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" AND NOT FORCE_DISABLE_WMMA)
|
||||
message(STATUS "Enabling WMMA FP8 gemms on native architectures")
|
||||
add_definitions(-DCK_USE_WMMA_FP8)
|
||||
set(CK_USE_WMMA_FP8 "ON")
|
||||
|
||||
@@ -37,6 +37,13 @@ if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
|
||||
endif()
|
||||
|
||||
add_custom_target(example_grouped_gemm_wmma)
|
||||
add_example_executable(example_grouped_gemm_wmma_splitk_fp16 grouped_gemm_wmma_splitk_fp16.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_fp16)
|
||||
|
||||
add_example_executable(example_grouped_gemm_wmma_splitk_bf16 grouped_gemm_wmma_splitk_bf16.cpp)
|
||||
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_bf16)
|
||||
|
||||
list(APPEND gpu_list_tf32 gfx942 gfx950)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
|
||||
72
example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp
Normal file
72
example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/utility/ignore.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::hip_check_error;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
using ::ck::Tensor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = BF16;
|
||||
using BDataType = BF16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = BF16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
#define EXAMPLE_USE_SPLITK
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
|
||||
71
example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp
Normal file
71
example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp
Normal file
@@ -0,0 +1,71 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/utility/ignore.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::hip_check_error;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
using ::ck::Tensor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using EDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3
|
||||
// clang-format off
|
||||
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
#define EXAMPLE_USE_SPLITK
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
|
||||
@@ -19,6 +19,10 @@ struct ProblemSize final
|
||||
std::vector<ck::index_t> stride_Cs;
|
||||
|
||||
ck::index_t group_count;
|
||||
|
||||
#if defined(EXAMPLE_USE_SPLITK)
|
||||
ck::index_t k_batch;
|
||||
#endif
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
@@ -177,6 +181,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
|
||||
auto argument = gemm.MakeArgument(
|
||||
p_a, p_b, p_Ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
#if defined(EXAMPLE_USE_SPLITK)
|
||||
gemm.SetKBatchSize(&argument, problem_size.k_batch);
|
||||
#endif
|
||||
|
||||
std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument);
|
||||
std::size_t kargs_size = gemm.GetDeviceKernelArgSize(&argument);
|
||||
std::size_t hargs_size = gemm.GetHostKernelArgSize(&argument);
|
||||
@@ -285,12 +293,15 @@ bool run_grouped_gemm_example(int argc, char* argv[])
|
||||
ExecutionConfig config;
|
||||
|
||||
problem_size.group_count = 16;
|
||||
#if defined(EXAMPLE_USE_SPLITK)
|
||||
problem_size.k_batch = 1;
|
||||
#endif
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default cases
|
||||
}
|
||||
else if(argc == 4 || argc == 6)
|
||||
else if(argc == 4 || argc == 6 || argc == 7)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
@@ -300,6 +311,13 @@ bool run_grouped_gemm_example(int argc, char* argv[])
|
||||
config.async_hargs = std::stoi(argv[4]);
|
||||
problem_size.group_count = std::stoi(argv[5]);
|
||||
}
|
||||
|
||||
#if defined(EXAMPLE_USE_SPLITK)
|
||||
if(argc == 7)
|
||||
{
|
||||
problem_size.k_batch = std::stoi(argv[6]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -307,7 +325,10 @@ bool run_grouped_gemm_example(int argc, char* argv[])
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg4: async hargs (0=n0, 1=yes)\n");
|
||||
printf("arg5: group count (default=16)");
|
||||
printf("arg5: group count (default=16)\n");
|
||||
#if defined(EXAMPLE_USE_SPLITK)
|
||||
printf("arg6: k-batch count (default=1)\n");
|
||||
#endif
|
||||
exit(1);
|
||||
}
|
||||
|
||||
|
||||
@@ -158,7 +158,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert(
|
||||
"mx_prec", "fp4xfp4", "data type for activation and weight, support: fp6xfp6, fp8xfp8")
|
||||
"mx_prec", "fp4xfp4", "data type for activation and weight, support: fp4xfp4, fp8xfp8")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
|
||||
@@ -75,7 +75,7 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
HasHotLoop,
|
||||
TailNum>;
|
||||
|
||||
using MXFlatmmPipeline = ck_tile::MXF4FlatmmPipelineAGmemBGmemCRegV1<MXPipelineProblem>;
|
||||
using MXFlatmmPipeline = ck_tile::MXFlatmmPipelineAGmemBGmemCRegV1<MXPipelineProblem>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<FlatmmShape,
|
||||
|
||||
@@ -74,7 +74,7 @@ User need to select correct mapping of config for each quant mode:
|
||||
|:--------|:-----:|:-----:|-------|
|
||||
| For selecting AQuant | aquant | gemm_aquant_quantgrouped.cpp| GemmConfigQuantDecode |
|
||||
| For selecting AQuant with Preshuffle quant | aquant | gemm_aquant_quantgrouped_preshufflequant.cpp | GemmConfigPreshuffleQuantDecode |
|
||||
| For selecting BQuant | bquant | gemm_bquant_quantgrouped_<prec_type>.cpp| GemmConfigQuantDecode (or) GemmConfigBQuantPrefill |
|
||||
| For selecting BQuant | bquant | gemm_bquant_quantgrouped_<prec_type>.cpp| GemmConfigQuantDecode (or) GemmConfigQuantPrefill |
|
||||
| For selecting BQuant with Preshuffle quant | bquant | gemm_bquant_quantgrouped_preshufflequant.cpp| GemmConfigPreshuffleQuantDecode (or) GemmConfigPreshuffleBQuantPrefill |
|
||||
| For selecting PreShuffle B with BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb.cpp| GemmConfigPreshuffleB_BQuant_Decode (or) GemmConfigPreshuffleB_BQuant_Prefill
|
||||
| For selecting PreShuffle B with preshuffle BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp |GemmConfigPreshuffleB_PreshuffleBQuant_Decode (or) GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
|
||||
|
||||
@@ -6,6 +6,10 @@
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigQuantDecode<T>;
|
||||
|
||||
// GemmConfigQuantPrefill is also supported for aquant grouped quantization
|
||||
// template <typename T>
|
||||
// using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
|
||||
void aquant_quantgrouped_instance_factory(
|
||||
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
|
||||
{
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigBQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigBQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigBQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "run_gemm_quant_example.inc"
|
||||
|
||||
template <typename T>
|
||||
using GemmConfig = GemmConfigBQuantPrefill<T>;
|
||||
using GemmConfig = GemmConfigQuantPrefill<T>;
|
||||
|
||||
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
|
||||
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \
|
||||
|
||||
@@ -221,7 +221,7 @@ struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigBQuantPrefill : public GemmConfigBase
|
||||
struct GemmConfigQuantPrefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
@@ -237,13 +237,13 @@ struct GemmConfigBQuantPrefill : public GemmConfigBase
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigBQuantPrefill<PrecType>
|
||||
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill<PrecType>
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigBQuantPrefill_Wmma : public GemmConfigBQuantPrefill<PrecType>
|
||||
struct GemmConfigBQuantPrefill_Wmma : public GemmConfigQuantPrefill<PrecType>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
|
||||
@@ -35,27 +35,41 @@ cmake
|
||||
..
|
||||
```
|
||||
|
||||
## Building and testing
|
||||
## Building and Testing
|
||||
|
||||
During development, all CK Builder tests can be built with command
|
||||
The builder test suite is organized into two main categories:
|
||||
|
||||
### Smoke Tests (Fast Unit Tests)
|
||||
Quick unit tests that verify the builder's internal logic without compiling GPU kernels. These complete in under 1 second total and are suitable for frequent execution during development.
|
||||
|
||||
```sh
|
||||
ninja test_ckb_all
|
||||
ninja smoke-builder
|
||||
```
|
||||
|
||||
To execute all tests, run
|
||||
### Regression Tests (Integration Tests)
|
||||
Integration tests that compile actual GPU kernels to verify that the builder generates valid, compilable code. These are more expensive than smoke tests (can take minutes to compile) but cover more fuctionality.
|
||||
|
||||
```sh
|
||||
ls bin/test_ckb_* | xargs -n1 sh -c
|
||||
ninja regression-builder
|
||||
```
|
||||
|
||||
Some tests involve building old CK convolution factories, which will take a long time.
|
||||
Hence, one might want to build only single test targets. For example
|
||||
### Running All Tests
|
||||
To build and run the complete test suite:
|
||||
|
||||
```sh
|
||||
ninja check-builder
|
||||
```
|
||||
|
||||
### Building Individual Tests
|
||||
To build and run a specific test:
|
||||
|
||||
```sh
|
||||
ninja test_ckb_conv_builder && bin/test_ckb_conv_builder
|
||||
```
|
||||
|
||||
When adding new tests, please follow the convention where the CMake build target starts with a prefix `test_ckb`.
|
||||
This allows us to filter out the CK Builder tests from the set full CK repository tests.
|
||||
Also, the `test_ckb_all` target that builds all CK Builder tests relies on having the `test_ckb` prefix on the CMake build targets.
|
||||
### Test Organization
|
||||
- **Smoke tests**: Fast feedback during active development
|
||||
- **Regression tests**: Thorough validation before submitting changes
|
||||
- **Factory tests**: Expensive tests that build all MIOpen kernels (included in regression tests)
|
||||
|
||||
When adding new tests, please follow the convention where the CMake build target starts with a prefix `test_ckb`. This allows filtering of CK Builder tests from the full CK repository test suite.
|
||||
|
||||
@@ -98,7 +98,7 @@ struct ConvDescription
|
||||
f.writeLine(2, "Weights elementwise operation: ", signature.weight_element_op);
|
||||
f.writeLast(2, "Output elementwise operation: ", signature.output_element_op);
|
||||
|
||||
f.writeLine(1, "Algorithm");
|
||||
f.writeLast(1, "Algorithm");
|
||||
// Compute Block section
|
||||
f.writeLine(2, "Thread block size: ", algorithm.thread_block_size);
|
||||
f.writeLine(2,
|
||||
@@ -123,7 +123,7 @@ struct ConvDescription
|
||||
algorithm.warp_gemm.n_iter);
|
||||
|
||||
// Memory Access section
|
||||
f.writeLine(2, "Memory access:");
|
||||
f.writeLast(2, "Memory access:");
|
||||
|
||||
f.writeLine(3, "A Tile transfer: ");
|
||||
f.writeLine(4,
|
||||
@@ -219,8 +219,6 @@ struct ConvDescription
|
||||
f.writeLast(4,
|
||||
"Vector access (GMEM write) instruction size: ",
|
||||
algorithm.c_tile_transfer.scalar_per_vector);
|
||||
f.writeLast(2);
|
||||
f.writeLast(1);
|
||||
return f.getString();
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,48 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
################################################################################
|
||||
# CK Builder Test Suite
|
||||
################################################################################
|
||||
#
|
||||
# This file defines the test suite for the Composable Kernel (CK) Builder,
|
||||
# which is responsible for generating optimized GPU kernels for convolution
|
||||
# operations.
|
||||
#
|
||||
# TESTING PHILOSOPHY:
|
||||
# -------------------
|
||||
# Tests are organized into two main categories:
|
||||
#
|
||||
# 1. SMOKE TESTS (fast, < 1 second total)
|
||||
# - Unit tests that verify the builder's internal logic
|
||||
# - Do NOT compile GPU kernels (fast compilation)
|
||||
# - Run these frequently during development for quick feedback
|
||||
# - Target: `ninja smoke-builder`
|
||||
#
|
||||
# 2. REGRESSION TESTS (slower, may take minutes)
|
||||
# - Integration tests that compile and verify actual GPU kernels
|
||||
# - Ensure the builder generates valid, compilable code
|
||||
# - Include expensive "factory tests" that build all MIOpen kernels
|
||||
# - Run these before submitting changes
|
||||
# - Target: `ninja regression-builder`
|
||||
#
|
||||
# QUICK START:
|
||||
# ------------
|
||||
# - During development: ninja smoke-builder
|
||||
# - Before submitting: ninja regression-builder
|
||||
# - Run everything: ninja check-builder
|
||||
# - Build specific test: ninja test_ckb_conv_builder && bin/test_ckb_conv_builder
|
||||
#
|
||||
################################################################################
|
||||
|
||||
include(gtest)
|
||||
|
||||
################################################################################
|
||||
# Helper Functions
|
||||
################################################################################
|
||||
|
||||
# Helper function to create a gtest executable with common properties
|
||||
# All builder tests share the same compilation settings and dependencies
|
||||
function(add_ck_builder_test test_name)
|
||||
add_executable(${test_name} ${ARGN} testing_utils.cpp)
|
||||
target_compile_features(${test_name} PRIVATE cxx_std_20)
|
||||
@@ -19,17 +58,51 @@ function(add_ck_builder_test test_name)
|
||||
target_link_libraries(${test_name} PRIVATE GTest::gtest_main GTest::gmock)
|
||||
endfunction()
|
||||
|
||||
# The test_ckb_conv_builder target has all the unit tests (each test should run < 10 ms)
|
||||
# Factory tests attempt to build all the kernels needed by MIOpen.
|
||||
# These are only for regression testing and development; the builds are too
|
||||
# expensive for regular use in CI.
|
||||
function(add_ck_factory_test test_name)
|
||||
add_ck_builder_test(${test_name} ${ARGN})
|
||||
target_link_libraries(${test_name} PRIVATE composablekernels::device_conv_operations)
|
||||
endfunction()
|
||||
|
||||
################################################################################
|
||||
# SMOKE TESTS - Fast Unit Tests (No Kernel Compilation)
|
||||
################################################################################
|
||||
# These tests verify the builder's internal logic without compiling GPU kernels.
|
||||
# They should complete in under 10ms each and are suitable for frequent execution
|
||||
# during development.
|
||||
add_ck_builder_test(test_ckb_conv_builder
|
||||
test_conv_builder.cpp
|
||||
test_fwd_instance_traits.cpp
|
||||
test_bwd_weight_instance_traits.cpp
|
||||
test_bwd_data_instance_traits.cpp
|
||||
test_instance_traits_util.cpp)
|
||||
test_instance_traits_util.cpp
|
||||
)
|
||||
|
||||
# Tests the inline diff utility used for comparing strings in tests assertions
|
||||
add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp)
|
||||
|
||||
add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp)
|
||||
# Tests convolution trait selection and configuration
|
||||
add_ck_builder_test(test_ckb_conv_traits
|
||||
conv/test_conv_traits.cpp)
|
||||
|
||||
# Tests convolution problem description and parameter handling
|
||||
add_ck_builder_test(test_ckb_conv_description
|
||||
test_conv_description.cpp)
|
||||
|
||||
################################################################################
|
||||
# REGRESSION TESTS - Integration Tests (With Kernel Compilation)
|
||||
################################################################################
|
||||
# These tests compile actual GPU kernels to verify the builder generates valid,
|
||||
# compilable code. They are more expensive but catch real-world issues.
|
||||
|
||||
# Testing the virtual GetInstanceString methods requires kernel compilation.
|
||||
|
||||
# Verifies that GetInstanceString() methods produce valid kernel code.
|
||||
# Tests various convolution types:
|
||||
# - Group convolution (v3, standard, large tensor, WMMA, DL variants)
|
||||
# - Backward weight group convolution (XDL)
|
||||
# Requires kernel compilation to validate the generated strings.
|
||||
add_ck_builder_test(test_ckb_get_instance_string
|
||||
test_get_instance_string_fwd_grp_conv_v3.cpp
|
||||
test_get_instance_string_fwd_grp_conv.cpp
|
||||
@@ -38,8 +111,8 @@ add_ck_builder_test(test_ckb_get_instance_string
|
||||
test_get_instance_string_fwd_grp_conv_dl.cpp
|
||||
test_get_instance_string_bwd_weight_grp_conv_xdl.cpp)
|
||||
|
||||
# Testing the fwd convolution builder requires kernel compilation.
|
||||
# To enable parallel compilation, the individual tests are split into separate files.
|
||||
# Tests the forward convolution builder across multiple data types and dimensions.
|
||||
# Individual tests are split into separate files to enable parallel compilation.
|
||||
add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/test_ckb_conv_fwd_1d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_1d_bf16.cpp
|
||||
@@ -55,15 +128,21 @@ add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/test_ckb_conv_fwd_3d_fp32.cpp
|
||||
)
|
||||
|
||||
# Factory tests attempt to build all the kernels need by MIOpen.
|
||||
# This is only for regression testing and development, the builds are too expensive for regular use in CI.
|
||||
function(add_ck_factory_test test_name)
|
||||
add_ck_builder_test(${test_name} ${ARGN})
|
||||
target_link_libraries(${test_name} PRIVATE composablekernels::device_conv_operations)
|
||||
endfunction()
|
||||
|
||||
# TODO: add these tests back in once we have CI working across all GPU architectures.
|
||||
################################################################################
|
||||
# FACTORY TESTS - Expensive Regression Tests (Full MIOpen Kernel Set)
|
||||
################################################################################
|
||||
# These tests attempt to build ALL kernels needed by MIOpen for various
|
||||
# convolution operations. They are extremely expensive (minutes to compile)
|
||||
# and are intended for deep regression testing and development only.
|
||||
# NOT suitable for regular CI runs.
|
||||
#
|
||||
# Many tests are commented out pending CI support across all GPU architectures.
|
||||
|
||||
# Tests the testing utilities themselves
|
||||
add_ck_factory_test(test_ckb_testing_utils test_testing_utils.cpp)
|
||||
|
||||
# TODO: Re-enable these tests once we have CI working across all GPU architectures.
|
||||
# add_ck_factory_test(test_ckb_factory_grouped_convolution_forward test_ck_factory_grouped_convolution_forward.cpp)
|
||||
# add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_clamp test_ck_factory_grouped_convolution_forward_clamp.cpp)
|
||||
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_convscale test_ck_factory_grouped_convolution_forward_convscale.cpp)
|
||||
@@ -75,22 +154,30 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_ab tes
|
||||
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp)
|
||||
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp)
|
||||
|
||||
add_ck_builder_test(test_ckb_conv_traits
|
||||
conv/test_conv_traits.cpp)
|
||||
################################################################################
|
||||
# CTest Integration - Register Tests and Assign Labels
|
||||
################################################################################
|
||||
# Tests are registered with CTest and labeled for selective execution:
|
||||
# - BUILDER_SMOKE: Fast unit tests for frequent development cycles
|
||||
# - BUILDER_REGRESSION: Slower integration tests for pre-submission validation
|
||||
|
||||
add_ck_builder_test(test_ckb_conv_description
|
||||
test_conv_description.cpp)
|
||||
|
||||
# Register tests with CTest and assign labels
|
||||
include(CTest)
|
||||
|
||||
# Smoke test: fast-compiling unit test
|
||||
add_test(NAME test_ckb_conv_builder COMMAND test_ckb_conv_builder)
|
||||
set_tests_properties(test_ckb_conv_builder PROPERTIES LABELS "BUILDER_SMOKE")
|
||||
|
||||
# Regression tests: all other tests that require kernel compilation
|
||||
set(CKB_REGRESSION_TESTS
|
||||
# Register all smoke tests (fast unit tests, no kernel compilation)
|
||||
set(CKB_SMOKE_TESTS
|
||||
test_ckb_conv_builder
|
||||
test_ckb_inline_diff
|
||||
test_ckb_conv_traits
|
||||
test_ckb_conv_description
|
||||
)
|
||||
|
||||
foreach(test_target ${CKB_SMOKE_TESTS})
|
||||
add_test(NAME ${test_target} COMMAND ${test_target})
|
||||
set_tests_properties(${test_target} PROPERTIES LABELS "BUILDER_SMOKE")
|
||||
endforeach()
|
||||
|
||||
# Register all regression tests (integration tests with kernel compilation)
|
||||
set(CKB_REGRESSION_TESTS
|
||||
test_ckb_get_instance_string
|
||||
test_ckb_build_fwd_instances
|
||||
test_ckb_testing_utils
|
||||
@@ -98,8 +185,6 @@ set(CKB_REGRESSION_TESTS
|
||||
test_ckb_factory_grouped_convolution_forward_scaleadd_ab
|
||||
test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu
|
||||
test_ckb_factory_grouped_convolution_forward_dynamic_op
|
||||
test_ckb_conv_traits
|
||||
test_ckb_conv_description
|
||||
)
|
||||
|
||||
foreach(test_target ${CKB_REGRESSION_TESTS})
|
||||
@@ -107,18 +192,31 @@ foreach(test_target ${CKB_REGRESSION_TESTS})
|
||||
set_tests_properties(${test_target} PROPERTIES LABELS "BUILDER_REGRESSION")
|
||||
endforeach()
|
||||
|
||||
# Helper target to build all regression tests
|
||||
################################################################################
|
||||
# Custom Build Targets - Convenient Test Execution
|
||||
################################################################################
|
||||
# These targets provide convenient ways to build and run different test suites:
|
||||
# - smoke-builder: Quick sanity check during development
|
||||
# - regression-builder: Thorough validation before submitting changes
|
||||
# - check-builder: Complete test suite execution
|
||||
|
||||
# Helper target to build all smoke tests (without running them)
|
||||
add_custom_target(build-smoke-builder DEPENDS ${CKB_SMOKE_TESTS})
|
||||
|
||||
# Helper target to build all regression tests (without running them)
|
||||
add_custom_target(build-regression-builder DEPENDS ${CKB_REGRESSION_TESTS})
|
||||
|
||||
# Target to run only smoke tests (builds only test_ckb_conv_builder)
|
||||
# Target to run only smoke tests (builds and runs all smoke test executables)
|
||||
# Use this for quick feedback during active development
|
||||
add_custom_target(smoke-builder
|
||||
COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "BUILDER_SMOKE"
|
||||
DEPENDS test_ckb_conv_builder
|
||||
DEPENDS build-smoke-builder
|
||||
USES_TERMINAL
|
||||
COMMENT "Running experimental builder smoke tests..."
|
||||
)
|
||||
|
||||
# Target to run only regression tests (builds all regression test executables)
|
||||
# Target to run only regression tests (builds and runs all regression test executables)
|
||||
# Use this before submitting changes to catch integration issues
|
||||
add_custom_target(regression-builder
|
||||
COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "BUILDER_REGRESSION"
|
||||
DEPENDS build-regression-builder
|
||||
@@ -126,15 +224,20 @@ add_custom_target(regression-builder
|
||||
COMMENT "Running experimental builder regression tests..."
|
||||
)
|
||||
|
||||
# Target to run all builder tests (builds all test executables)
|
||||
# Target to run all builder tests (builds and runs all test executables)
|
||||
# Use this for comprehensive validation
|
||||
add_custom_target(check-builder
|
||||
COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -R "^test_ckb"
|
||||
DEPENDS test_ckb_conv_builder build-regression-builder
|
||||
DEPENDS build-smoke-builder build-regression-builder
|
||||
USES_TERMINAL
|
||||
COMMENT "Running all experimental builder tests..."
|
||||
)
|
||||
|
||||
# Print summary of test organization
|
||||
################################################################################
|
||||
# Build Summary
|
||||
################################################################################
|
||||
|
||||
# Print summary of test organization for developer reference
|
||||
message(STATUS "CK Builder test organization:")
|
||||
message(STATUS " Smoke test: test_ckb_conv_builder")
|
||||
message(STATUS " Smoke tests: ${CKB_SMOKE_TESTS}")
|
||||
message(STATUS " Regression tests: ${CKB_REGRESSION_TESTS}")
|
||||
|
||||
@@ -127,41 +127,39 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
|
||||
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
|
||||
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
|
||||
"│ └─ Output elementwise operation: PASS_THROUGH\n"
|
||||
"├─ Algorithm\n"
|
||||
"│ ├─ Thread block size: 256\n"
|
||||
"│ ├─ Data tile size: 256×256×32\n"
|
||||
"│ ├─ Gemm padding: DEFAULT\n"
|
||||
"│ ├─ Convolution specialization: DEFAULT\n"
|
||||
"│ ├─ Pipeline version: V4\n"
|
||||
"│ ├─ Pipeline scheduler: INTRAWAVE\n"
|
||||
"│ ├─ Warp Gemm parameters: \n"
|
||||
"│ │ ├─ subtile size: 16×16\n"
|
||||
"│ │ └─ Number of warp gemm iterations: 4×4\n"
|
||||
"│ ├─ Memory access:\n"
|
||||
"│ │ ├─ A Tile transfer: \n"
|
||||
"│ │ │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
"│ │ │ ├─ The innermost K subdimension size: 8\n"
|
||||
"│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
"│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
"│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
"│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
"│ │ │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
"│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
"│ │ ├─ B Tile transfer: \n"
|
||||
"│ │ │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
"│ │ │ ├─ The innermost K subdimension size: 8\n"
|
||||
"│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
"│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
"│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
"│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
"│ │ │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
"│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
"│ │ └─ C Tile transfer: \n"
|
||||
"│ │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
|
||||
"│ │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
|
||||
"│ │ └─ Vector access (GMEM write) instruction size: 8\n"
|
||||
"│ └─ \n"
|
||||
"└─ "));
|
||||
"└─ Algorithm\n"
|
||||
" ├─ Thread block size: 256\n"
|
||||
" ├─ Data tile size: 256×256×32\n"
|
||||
" ├─ Gemm padding: DEFAULT\n"
|
||||
" ├─ Convolution specialization: DEFAULT\n"
|
||||
" ├─ Pipeline version: V4\n"
|
||||
" ├─ Pipeline scheduler: INTRAWAVE\n"
|
||||
" ├─ Warp Gemm parameters: \n"
|
||||
" │ ├─ subtile size: 16×16\n"
|
||||
" │ └─ Number of warp gemm iterations: 4×4\n"
|
||||
" └─ Memory access:\n"
|
||||
" ├─ A Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" ├─ B Tile transfer: \n"
|
||||
" │ ├─ Tile dimensions: 4×256×8×\n"
|
||||
" │ ├─ The innermost K subdimension size: 8\n"
|
||||
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
|
||||
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
|
||||
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
|
||||
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
|
||||
" │ ├─ Vector access (LDS write) instruction size: 8\n"
|
||||
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
|
||||
" └─ C Tile transfer: \n"
|
||||
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
|
||||
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
|
||||
" └─ Vector access (GMEM write) instruction size: 8"));
|
||||
}
|
||||
|
||||
// NOTE: BackwardDataInstanceHasDetailedDescription test is disabled because ConvFactory
|
||||
|
||||
@@ -199,7 +199,7 @@ struct BaseArgument
|
||||
BaseArgument(const BaseArgument&) = default;
|
||||
BaseArgument& operator=(const BaseArgument&) = default;
|
||||
|
||||
virtual ~BaseArgument() {}
|
||||
virtual __host__ __device__ ~BaseArgument() {}
|
||||
|
||||
void* p_workspace_ = nullptr;
|
||||
};
|
||||
|
||||
@@ -0,0 +1,827 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/env.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename GemmDesc,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename Block2CTileMap,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_grouped_gemm_wmma_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
const index_t group_count)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
const index_t block_id = get_block_1d_id();
|
||||
const auto gemm_desc_ptr =
|
||||
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
|
||||
// Binary search lookup to find which group this block is part of
|
||||
index_t left = 0;
|
||||
index_t right = group_count;
|
||||
index_t group_id = index_t((left + right) / 2);
|
||||
while((!(block_id >= gemm_desc_ptr[group_id].block_start_ &&
|
||||
block_id < gemm_desc_ptr[group_id].block_end_)) &&
|
||||
left <= right)
|
||||
{
|
||||
if(block_id < gemm_desc_ptr[group_id].block_start_)
|
||||
{
|
||||
right = group_id;
|
||||
}
|
||||
else
|
||||
{
|
||||
left = group_id;
|
||||
}
|
||||
group_id = index_t((left + right) / 2);
|
||||
}
|
||||
|
||||
// NOTE: Local copy of the arg struct since SplitKBatchOffset verifies and modifies K index
|
||||
// and thus needs a non-const reference. It's also not feasible to store this in global
|
||||
// memory as different threads would be writing different K values to the same arg struct
|
||||
auto karg = gemm_desc_ptr[group_id].karg_;
|
||||
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
|
||||
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<c_data_type, ck::half_t> ||
|
||||
std::is_same_v<c_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
const auto& block_2_ctile_map = gemm_desc_ptr[group_id].block_2_ctile_map_;
|
||||
|
||||
// Tile index first dimension is the K batch
|
||||
auto tile_index =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
auto splitk_batch_offset =
|
||||
typename GridwiseGemm::SplitKBatchOffset(karg, tile_index[Number<0>{}]);
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop,
|
||||
CGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
Block2CTileMap,
|
||||
typename GridwiseGemm::EpilogueCShuffle,
|
||||
1, // Block2CTileMap MBlock index
|
||||
2 // Block2CTileMap NBlock index
|
||||
>(static_cast<void*>(p_shared),
|
||||
splitk_batch_offset,
|
||||
karg,
|
||||
block_2_ctile_map,
|
||||
epilogue_args);
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
ignore = gemm_descs_const;
|
||||
ignore = group_count;
|
||||
#endif // end of if(defined(__gfx11__) || defined(__gfx12__))
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t NumGemmKPrefetchStage,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t MPerWmma,
|
||||
ck::index_t NPerWmma,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = EDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static_assert(KPerBlock % AK1 == 0);
|
||||
static constexpr index_t K0PerBlock = KPerBlock / AK1;
|
||||
|
||||
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
Tuple<ADataType>,
|
||||
Tuple<BDataType>,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
false, // PermuteA not supported by DeviceBatchedGemm base class.
|
||||
false>; // PermuteB not supported by DeviceBatchedGemm base class.
|
||||
|
||||
using CGridDesc_M_N =
|
||||
remove_cvref_t<decltype(GridwiseGemm::template MakeDEGridDescriptor_M_N<ELayout>(
|
||||
1, 1, 1, 1, 1))>;
|
||||
using Block2ETileMapKSplit =
|
||||
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
|
||||
// Block2CTileMap configuration parameter.
|
||||
static constexpr index_t B2E_M01 = 8;
|
||||
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
|
||||
using KernelArgument = typename GridwiseGemm::Argument;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
template <typename KernelArgument_>
|
||||
struct GemmTransKernelArgBase
|
||||
{
|
||||
KernelArgument_ karg_;
|
||||
GroupedGemmBlock2ETileMap block_2_ctile_map_;
|
||||
index_t block_start_, block_end_;
|
||||
|
||||
GemmTransKernelArgBase() = default;
|
||||
GemmTransKernelArgBase(KernelArgument_&& karg,
|
||||
GroupedGemmBlock2ETileMap&& b2c_map,
|
||||
index_t block_start,
|
||||
index_t block_end)
|
||||
: karg_{karg},
|
||||
block_2_ctile_map_{b2c_map},
|
||||
block_start_{block_start},
|
||||
block_end_{block_end}
|
||||
{
|
||||
}
|
||||
};
|
||||
using GemmTransKernelArg = GemmTransKernelArgBase<KernelArgument>;
|
||||
|
||||
static constexpr index_t DefaultKBatch = 1;
|
||||
|
||||
static constexpr bool CalculateHasMainKBlockLoop(const KernelArgument& karg)
|
||||
{
|
||||
index_t k_grain = karg.KBatch * KPerBlock;
|
||||
index_t K_split = (karg.K + k_grain - 1) / karg.KBatch;
|
||||
return GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
}
|
||||
|
||||
// Argument
|
||||
// TODO: Add A/B/CDE element op?
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
|
||||
Argument(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs)
|
||||
: Argument(p_As, p_Bs, p_Es, gemm_descs, DefaultKBatch)
|
||||
{
|
||||
// TODO: use occupancy api to calculate appropriate batch size.
|
||||
}
|
||||
|
||||
Argument(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
index_t kbatch)
|
||||
: K_BATCH{kbatch}, gemm_kernel_host_args_{nullptr}
|
||||
{
|
||||
grid_size_ = 0;
|
||||
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
|
||||
|
||||
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
|
||||
group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
|
||||
group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
|
||||
{
|
||||
throw std::runtime_error("wrong! group_count_ != p_As/b/c.size");
|
||||
}
|
||||
|
||||
gemm_kernel_args_.reserve(group_count_);
|
||||
|
||||
skipped_group_count_ = 0;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
|
||||
{
|
||||
const index_t M = gemm_descs[i].M_;
|
||||
const index_t N = gemm_descs[i].N_;
|
||||
const index_t K = gemm_descs[i].K_;
|
||||
|
||||
if(M == 0)
|
||||
{
|
||||
skipped_group_count_++;
|
||||
continue;
|
||||
}
|
||||
|
||||
const index_t stride_a = gemm_descs[i].stride_A_;
|
||||
const index_t stride_b = gemm_descs[i].stride_B_;
|
||||
const index_t stride_c = gemm_descs[i].stride_C_;
|
||||
|
||||
const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
|
||||
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
|
||||
|
||||
const auto c_grid_desc_m_n =
|
||||
GridwiseGemm::template MakeDEGridDescriptor_M_N<ELayout>(
|
||||
M, m_padded, N, n_padded, stride_c);
|
||||
|
||||
const auto local_b2c_tile_map =
|
||||
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
|
||||
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
|
||||
|
||||
const index_t block_start = grid_size_;
|
||||
const index_t block_end = grid_size_ + grid_size_grp;
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
|
||||
// block-to-e-tile map
|
||||
auto grouped_block_2_ctile_map =
|
||||
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
|
||||
|
||||
auto karg = KernelArgument(std::array<const void*, 1>{p_As[i]},
|
||||
std::array<const void*, 1>{p_Bs[i]},
|
||||
std::array<const void*, 0>{}, // p_ds_grid_
|
||||
type_convert<EDataType*>(p_Es[i]),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
std::array<index_t, 1>{stride_a},
|
||||
std::array<index_t, 1>{stride_b},
|
||||
std::array<index_t, 0>{}, // StrideDs_
|
||||
stride_c,
|
||||
K_BATCH,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
false);
|
||||
|
||||
gemm_kernel_args_.emplace_back(
|
||||
std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Recalculate group grid size for all gemms and update B2C maps.
|
||||
*
|
||||
* @param[in] kbatch The new splitK parameter value.
|
||||
*/
|
||||
void UpdateKBatch(index_t kbatch)
|
||||
{
|
||||
K_BATCH = kbatch;
|
||||
grid_size_ = 0;
|
||||
|
||||
for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
auto& karg = gemm_kernel_args_[i].karg_;
|
||||
|
||||
const index_t k_read = GridwiseGemm::CalculateKRead(karg.K, K_BATCH);
|
||||
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
|
||||
const index_t ak0_padded = GridwiseGemm::CalculateAK0Padded(karg.K, K_BATCH);
|
||||
const index_t bk0_padded = GridwiseGemm::CalculateBK0Padded(karg.K, K_BATCH);
|
||||
|
||||
const auto c_grid_desc_m_n =
|
||||
GridwiseGemm::template MakeDEGridDescriptor_M_N<ELayout>(
|
||||
karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideE);
|
||||
|
||||
const auto local_b2c_tile_map =
|
||||
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
|
||||
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
|
||||
|
||||
const index_t block_start = grid_size_;
|
||||
const index_t block_end = grid_size_ + grid_size_grp;
|
||||
|
||||
grid_size_ += grid_size_grp;
|
||||
|
||||
// block-to-e-tile map
|
||||
auto grouped_block_2_ctile_map =
|
||||
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
|
||||
|
||||
karg.KRead = k_read;
|
||||
karg.KPadded = k_padded;
|
||||
karg.AK0 = ak0_padded;
|
||||
karg.BK0 = bk0_padded;
|
||||
karg.KBatch = K_BATCH;
|
||||
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
|
||||
gemm_kernel_args_[i].block_start_ = block_start;
|
||||
gemm_kernel_args_[i].block_end_ = block_end;
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
index_t K_BATCH;
|
||||
index_t group_count_;
|
||||
index_t skipped_group_count_;
|
||||
|
||||
std::vector<GemmTransKernelArg> gemm_kernel_args_;
|
||||
void* gemm_kernel_host_args_;
|
||||
index_t grid_size_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg,
|
||||
const StreamConfig& stream_config = StreamConfig{},
|
||||
hipStream_t cpy_stream = nullptr,
|
||||
hipEvent_t cpy_event = nullptr)
|
||||
{
|
||||
using GemmTransKernelArg_ = GemmTransKernelArgBase<typename GridwiseGemm::Argument>;
|
||||
static_assert(sizeof(GemmTransKernelArg_) == sizeof(GemmTransKernelArg));
|
||||
|
||||
bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.KBatch > 1;
|
||||
bool all_have_main_k0_block_loop =
|
||||
CalculateHasMainKBlockLoop(arg.gemm_kernel_args_[0].karg_);
|
||||
|
||||
bool not_all_have_main_k0_block_loop_same = false;
|
||||
bool not_all_have_kbatch_value_same = false;
|
||||
|
||||
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
const auto& karg = reinterpret_cast<const typename GridwiseGemm::Argument&>(
|
||||
arg.gemm_kernel_args_[i].karg_);
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
karg.Print();
|
||||
}
|
||||
|
||||
auto kbatch = karg.KBatch;
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(karg))
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
not_all_have_main_k0_block_loop_same |=
|
||||
all_have_main_k0_block_loop xor CalculateHasMainKBlockLoop(karg);
|
||||
not_all_have_kbatch_value_same |= all_have_kbatch_gt_one xor (kbatch > 1);
|
||||
}
|
||||
|
||||
if(not_all_have_main_k0_block_loop_same)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__;
|
||||
// throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
if(not_all_have_kbatch_value_same)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Not all gemms have same kbatch value (=1 or >1)! " << " in " << __FILE__
|
||||
<< ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
// If the user provides copy stream and copy event, we assume that they're also
|
||||
// responsible for providing allocated host memory (eg. pinned) which
|
||||
// would be used to copy kernel arguments to the device.
|
||||
if(cpy_stream && cpy_event)
|
||||
{
|
||||
if(arg.gemm_kernel_host_args_ == nullptr)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "No memory has been allocated for gemm kernel host args "
|
||||
<< "when providing the copy stream and copy event! In " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
hip_check_error(hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_kernel_host_args_,
|
||||
arg.group_count_ * sizeof(GemmTransKernelArg_),
|
||||
hipMemcpyHostToDevice,
|
||||
cpy_stream));
|
||||
hip_check_error(hipEventRecord(cpy_event, cpy_stream));
|
||||
hip_check_error(hipEventSynchronize(cpy_event));
|
||||
}
|
||||
else // In this case CK owns memory allocated on host.
|
||||
{
|
||||
|
||||
hip_check_error(
|
||||
hipMemcpyAsync(arg.p_workspace_,
|
||||
arg.gemm_kernel_args_.data(),
|
||||
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg_),
|
||||
hipMemcpyHostToDevice,
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(all_have_kbatch_gt_one)
|
||||
{
|
||||
for(const auto& trans_arg : arg.gemm_kernel_args_)
|
||||
{
|
||||
const auto& karg = trans_arg.karg_;
|
||||
hip_check_error(hipMemsetAsync(karg.p_e_grid,
|
||||
0,
|
||||
karg.M * karg.N * sizeof(EDataType),
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
}
|
||||
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(arg.grid_size_),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
cast_pointer_to_constant_address_space(arg.p_workspace_),
|
||||
arg.gemm_kernel_args_.size());
|
||||
};
|
||||
|
||||
// NOTE: If at least one gemm problem has a main k0 block loop, we include it for all
|
||||
if(all_have_main_k0_block_loop || not_all_have_main_k0_block_loop_same)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(all_have_kbatch_gt_one)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_wmma_splitk<GridwiseGemm,
|
||||
GemmTransKernelArg_,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
GroupedGemmBlock2ETileMap>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_wmma_splitk<GridwiseGemm,
|
||||
GemmTransKernelArg_,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
GroupedGemmBlock2ETileMap>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(all_have_kbatch_gt_one)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_wmma_splitk<GridwiseGemm,
|
||||
GemmTransKernelArg_,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
GroupedGemmBlock2ETileMap>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_grouped_gemm_wmma_splitk<GridwiseGemm,
|
||||
GemmTransKernelArg_,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
GroupedGemmBlock2ETileMap>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(std::is_same_v<EDataType, ck::half_t> ||
|
||||
std::is_same_v<EDataType, ck::bhalf_t>)
|
||||
{
|
||||
if(arg.K_BATCH > 1 && ck::is_gfx11_supported())
|
||||
{
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
|
||||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
|
||||
arg.skipped_group_count_) != arg.group_count_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "The group count is not equal to sum of skipped groups "
|
||||
"and kernel args size!"
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool supported = true;
|
||||
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
|
||||
{
|
||||
const auto& a = arg.gemm_kernel_args_[i].karg_;
|
||||
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
|
||||
|
||||
if(not group_arg_valid)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "[" << __func__ << "] group id: " << i
|
||||
<< " has invalid GridwiseGemm settings!" << std::endl;
|
||||
a.Print();
|
||||
}
|
||||
}
|
||||
supported = supported && group_arg_valid;
|
||||
}
|
||||
return supported;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>&,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc> gemm_descs,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation)
|
||||
{
|
||||
return Argument{p_As, p_Bs, p_Es, gemm_descs};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(std::vector<const void*>& p_As,
|
||||
std::vector<const void*>& p_Bs,
|
||||
std::vector<std::array<const void*, NumDTensor>>&,
|
||||
std::vector<void*>& p_Es,
|
||||
std::vector<GemmDesc>& gemm_descs,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_As, p_Bs, p_Es, gemm_descs);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedGemm_WmmaSplitK"
|
||||
<< "<"
|
||||
<< std::string(ALayout::name)[0] << ","
|
||||
<< std::string(BLayout::name)[0] << ","
|
||||
<< std::string(ELayout::name)[0] << ","
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< MPerWmma << ", "
|
||||
<< NPerWmma << ", "
|
||||
<< MRepeat << ", "
|
||||
<< NRepeat << ", "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< CShuffleMRepeatPerShuffle << ", "
|
||||
<< CShuffleNRepeatPerShuffle << ", "
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer]
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
auto p_arg_ = dynamic_cast<const Argument*>(p_arg);
|
||||
if(p_arg_)
|
||||
{
|
||||
return p_arg_->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg);
|
||||
}
|
||||
else
|
||||
throw std::runtime_error("The argument pointer is not an object of "
|
||||
"DeviceGroupedGemm_Wmma_CShuffleV3::Argument structure!");
|
||||
}
|
||||
|
||||
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
return GetWorkSpaceSize(p_arg);
|
||||
}
|
||||
|
||||
size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); }
|
||||
|
||||
// TODO: deperecation notice.
|
||||
static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
|
||||
|
||||
// polymorphic
|
||||
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
|
||||
{
|
||||
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
|
||||
if(p_arg_)
|
||||
{
|
||||
p_arg_->UpdateKBatch(kbatch);
|
||||
}
|
||||
else
|
||||
throw std::runtime_error("The argument pointer is not an object of "
|
||||
"DeviceGroupedGemm_Wmma_CShuffleV3::Argument structure!");
|
||||
}
|
||||
|
||||
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
|
||||
{
|
||||
return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------------------------
|
||||
/// @brief Sets the host kernel arguments pointer and copies that data on the host side.
|
||||
/// This function can be utilised to use pinned memory for the host args and
|
||||
/// achieve fully async data copy.
|
||||
///
|
||||
/// @param p_arg The pointer to the Argument we're going to update.
|
||||
/// @param[in] p_host_kernel_args The pointer to the host memory where the kernel
|
||||
/// arguments will be copied
|
||||
///
|
||||
void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const
|
||||
{
|
||||
Argument* pArg_ = dynamic_cast<Argument*>(p_arg);
|
||||
if(!pArg_)
|
||||
{
|
||||
throw std::runtime_error("Failed to cast argument pointer!");
|
||||
}
|
||||
|
||||
pArg_->gemm_kernel_host_args_ = p_host_kernel_args;
|
||||
std::copy(pArg_->gemm_kernel_args_.begin(),
|
||||
pArg_->gemm_kernel_args_.end(),
|
||||
static_cast<GemmTransKernelArg*>(pArg_->gemm_kernel_host_args_));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -470,9 +470,9 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
DsGridPointer p_ds_grid;
|
||||
EDataType* p_e_grid;
|
||||
|
||||
const AElementwiseOperation a_element_op;
|
||||
const BElementwiseOperation b_element_op;
|
||||
const CDEElementwiseOperation cde_element_op;
|
||||
AElementwiseOperation a_element_op;
|
||||
BElementwiseOperation b_element_op;
|
||||
CDEElementwiseOperation cde_element_op;
|
||||
|
||||
// TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
|
||||
bool is_reduce;
|
||||
@@ -555,13 +555,17 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
typename Block2CTileMap,
|
||||
typename EpilogueArgument,
|
||||
int BlockMapMBlockIndex = 0,
|
||||
int BlockMapNBlockIndex = 1>
|
||||
__device__ static void Run(AsGridPointer& p_as_grid,
|
||||
BsGridPointer& p_bs_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
void* p_shared,
|
||||
const Problem& problem,
|
||||
const Block2CTileMap& block_2_ctile_map,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
@@ -582,9 +586,6 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
|
||||
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
|
||||
@@ -596,8 +597,10 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
return;
|
||||
}
|
||||
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
|
||||
const index_t block_m_id =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapMBlockIndex>{}]);
|
||||
const index_t block_n_id =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapNBlockIndex>{}]);
|
||||
|
||||
// BScale struct (Empty)
|
||||
using BScale = typename BlockwiseGemmPipe::Empty;
|
||||
@@ -632,15 +635,51 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(AsGridPointer& p_as_grid,
|
||||
BsGridPointer& p_bs_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
void* p_shared,
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
Run<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
Block2CTileMap,
|
||||
EpilogueArgument>(p_as_grid,
|
||||
p_bs_grid,
|
||||
p_ds_grid,
|
||||
p_e_grid,
|
||||
p_shared,
|
||||
problem,
|
||||
DefaultBlock2CTileMap(problem),
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
// Wrapper function to have __global__ function in common
|
||||
// between gemm_universal, b_scale, ab_scale, etc.
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
typename Block2CTileMap,
|
||||
typename EpilogueArgument,
|
||||
int BlockMapMBlockIndex = 0,
|
||||
int BlockMapNBlockIndex = 1>
|
||||
__device__ static void Run(void* p_shared,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
Argument& karg,
|
||||
const Block2CTileMap& block_2_ctile_map,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
// shift A matrices pointer for splitk
|
||||
@@ -659,17 +698,47 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
splitk_batch_offset.b_k_split_offset[i];
|
||||
});
|
||||
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_as_grid_splitk,
|
||||
p_bs_grid_splitk,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
epilogue_args);
|
||||
Run<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
Block2CTileMap,
|
||||
EpilogueArgument,
|
||||
BlockMapMBlockIndex,
|
||||
BlockMapNBlockIndex>(p_as_grid_splitk,
|
||||
p_bs_grid_splitk,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
block_2_ctile_map,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
epilogue_args);
|
||||
}
|
||||
|
||||
// Wrapper function to have __global__ function in common
|
||||
// between gemm_universal, b_scale, ab_scale, etc.
|
||||
template <bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
TailNumber TailNum,
|
||||
typename EpilogueArgument>
|
||||
__device__ static void Run(void* p_shared,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args)
|
||||
{
|
||||
Run<HasMainKBlockLoop,
|
||||
EGlobalMemoryDataOperation,
|
||||
TailNum,
|
||||
Block2CTileMap,
|
||||
EpilogueArgument>(
|
||||
p_shared, splitk_batch_offset, karg, DefaultBlock2CTileMap(karg), epilogue_args);
|
||||
}
|
||||
|
||||
__device__ static auto DefaultBlock2CTileMap(const Problem& problem)
|
||||
{
|
||||
return Block2CTileMap{problem.M, problem.N, 4};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -729,6 +729,13 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
|
||||
if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Arg K value too low for combination of AK1/BK1/KBatch. AK1: "
|
||||
<< AK1Number << ", BK1: " << BK1Number << ", KBatch: " << karg.KBatch
|
||||
<< ", K: " << karg.K << " " << __FILE__ << ":" << __LINE__
|
||||
<< ", in function: " << __func__ << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -293,6 +293,15 @@ struct tile_window_with_static_distribution
|
||||
0, dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename offset_t>
|
||||
CK_TILE_DEVICE constexpr auto get_load_offset(offset_t = {}) const
|
||||
{
|
||||
constexpr auto bottom_tensor_idx_off = to_multi_index(offset_t{});
|
||||
const auto bottom_tensor_coord_off = make_tensor_coordinate(
|
||||
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_idx_off);
|
||||
return amd_wave_read_first_lane(bottom_tensor_coord_off.get_offset());
|
||||
}
|
||||
|
||||
template <typename DataType,
|
||||
typename StaticTileDistribution,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
@@ -316,12 +325,7 @@ struct tile_window_with_static_distribution
|
||||
else if constexpr(is_constant_v<offset_t>)
|
||||
return offset_t::value;
|
||||
else
|
||||
{
|
||||
auto bottom_tensor_idx_off = to_multi_index(offset_t{});
|
||||
auto bottom_tensor_coord_off = make_tensor_coordinate(
|
||||
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_idx_off);
|
||||
return bottom_tensor_coord_off.get_offset();
|
||||
}
|
||||
return get_load_offset(offset_t{});
|
||||
}();
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
|
||||
@@ -46,8 +46,8 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
static constexpr index_t flatKPerWarp = get_warp_size() * ContinuousKPerThread;
|
||||
};
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = MXF4FlatmmPipelineAgBgCrPolicy>
|
||||
struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>
|
||||
template <typename Problem, typename PipelinePolicy = MXFlatmmPipelineAgBgCrPolicy>
|
||||
struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>
|
||||
{
|
||||
using Underlying = FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>;
|
||||
|
||||
@@ -470,17 +470,39 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
return PipelinePolicy::template MakeADramTileDistribution<Problem>();
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
CK_TILE_DEVICE auto operator()(Args&&... args) const
|
||||
{
|
||||
auto c_warp_tensors = Run_(std::forward<Args>(args)...);
|
||||
|
||||
// Block GEMM Acc register tile
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
auto c_block_tile = BlockFlatmm{}.MakeCBlockTile();
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensors(mIter)(nIter).get_thread_buffer());
|
||||
});
|
||||
});
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename ScaleADramBlockWindowTmp,
|
||||
typename ScaleBDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_copy_dram_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const ScaleADramBlockWindowTmp& scale_a_window,
|
||||
const ScaleBDramBlockWindowTmp& scale_b_window,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem_ping,
|
||||
void* __restrict__ p_smem_pong) const
|
||||
CK_TILE_DEVICE auto Run_(const ADramBlockWindowTmp& a_copy_dram_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const ScaleADramBlockWindowTmp& scale_a_window,
|
||||
const ScaleBDramBlockWindowTmp& scale_b_window,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem_ping,
|
||||
void* __restrict__ p_smem_pong) const
|
||||
{
|
||||
#ifndef __gfx950__
|
||||
static_assert(false, "Only gfx950 is supported for MXFP4 flatmm pipeline now.");
|
||||
@@ -497,19 +519,14 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// constexpr auto MIter_2nd_last = max(0, MIterPerWarp - 2);
|
||||
static_assert(NWarp == 4);
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
auto a_dram_window =
|
||||
make_tile_window(PipelinePolicy::template MakeMXFP4_AAsyncLoadDramDescriptor<Problem>(
|
||||
make_tile_window(PipelinePolicy::template MakeMX_AAsyncLoadDramDescriptor<Problem>(
|
||||
a_copy_dram_window_tmp.get_bottom_tensor_view()),
|
||||
a_copy_dram_window_tmp.get_window_lengths(),
|
||||
a_copy_dram_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeMXFP4_ADramTileDistribution<Problem>());
|
||||
PipelinePolicy::template MakeMX_ADramTileDistribution<Problem>());
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -518,7 +535,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
|
||||
|
||||
constexpr auto a_lds_block_desc =
|
||||
PipelinePolicy::template MakeMXFP4_ALdsBlockDescriptor<Problem>();
|
||||
PipelinePolicy::template MakeMX_ALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block_ping =
|
||||
make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
|
||||
@@ -535,39 +552,34 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeMXF4_ALDS_TileDistribution<Problem>());
|
||||
PipelinePolicy::template MakeMX_ALDS_TileDistribution<Problem>());
|
||||
auto a_warp_window_pong =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeMXF4_ALDS_TileDistribution<Problem>());
|
||||
|
||||
// Block GEMM
|
||||
auto block_flatmm = BlockFlatmm();
|
||||
// Acc register tile
|
||||
auto c_block_tile = block_flatmm.MakeCBlockTile();
|
||||
PipelinePolicy::template MakeMX_ALDS_TileDistribution<Problem>());
|
||||
|
||||
// B flat DRAM window for load
|
||||
|
||||
// pingpong buffer for B
|
||||
auto b_flat_dram_windows = generate_tuple(
|
||||
auto b_flat_dram_window =
|
||||
make_tile_window(b_flat_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeMX_BFlatDramTileDistribution<Problem>());
|
||||
auto b_flat_dram_offsets = generate_tuple(
|
||||
[&](auto nIter) {
|
||||
constexpr auto packed_n_idx = nIter / number<NXdlPack>{};
|
||||
constexpr auto packed_n_rank = nIter % number<NXdlPack>{};
|
||||
auto window_i = make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution<Problem>());
|
||||
move_tile_window(
|
||||
window_i,
|
||||
{number<packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank>{},
|
||||
number<0>{}});
|
||||
return window_i;
|
||||
return b_flat_dram_window.get_load_offset(
|
||||
tuple<number<packed_n_idx * NXdlPack * NFlatPerBlockPerIter>,
|
||||
number<0>>{}) +
|
||||
b_flat_dram_window.get_load_offset(
|
||||
tuple<number<packed_n_rank>, number<0>>{});
|
||||
},
|
||||
number<NIterPerWarp>{});
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_windows(I0))), KIterPerWarp>,
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_ping, b_warp_tensor_pong;
|
||||
|
||||
@@ -576,41 +588,37 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
scale_a_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<MWarp * WG::kM>{}, number<64 / WG::kM>{}),
|
||||
scale_a_window.get_window_origin(),
|
||||
PipelinePolicy::template MakeMXFP4_ScaleA_FlatDramTileDistribution<Problem>());
|
||||
PipelinePolicy::template MakeMX_ScaleA_FlatDramTileDistribution<Problem>());
|
||||
const auto scale_a_dram_step_m = amd_wave_read_first_lane(
|
||||
scale_a_dram_window.get_load_offset(tuple<number<MWarp * WG::kM>, number<0>>{}));
|
||||
const auto scale_a_dram_step_k = amd_wave_read_first_lane(
|
||||
scale_a_dram_window.get_load_offset(tuple<number<0>, number<64 / WG::kM>>{}));
|
||||
|
||||
auto scale_b_dram_window = make_tile_window(
|
||||
scale_b_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<NWarp * WG::kN>{}, number<64 / WG::kN>{}),
|
||||
scale_b_window.get_window_origin(),
|
||||
PipelinePolicy::template MakeMXFP4_ScaleB_DramTileDistribution<Problem>());
|
||||
PipelinePolicy::template MakeMX_ScaleB_DramTileDistribution<Problem>());
|
||||
const auto scale_b_dram_step_n = amd_wave_read_first_lane(
|
||||
scale_b_dram_window.get_load_offset(tuple<number<NWarp * WG::kN>, number<0>>{}));
|
||||
const auto scale_b_dram_step_k = amd_wave_read_first_lane(
|
||||
scale_b_dram_window.get_load_offset(tuple<number<0>, number<64 / WG::kN>>{}));
|
||||
|
||||
constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack;
|
||||
constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack;
|
||||
constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack;
|
||||
|
||||
// ping pong buffer for scale A
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(scale_a_dram_window), KIterPerWarp / KXdlPack>,
|
||||
MIterPerWarp / MXdlPack>
|
||||
scale_a_dram_windows;
|
||||
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_a_dram_window)),
|
||||
KIterPerWarp / KXdlPack>,
|
||||
MIterPerWarp / MXdlPack>
|
||||
scale_a_tile_tensor_ping;
|
||||
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_a_dram_window)),
|
||||
KIterPerWarp / KXdlPack>,
|
||||
MIterPerWarp / MXdlPack>
|
||||
scale_a_tile_tensor_pong;
|
||||
statically_indexed_array<decltype(load_tile(scale_a_dram_window)), KPackIterPerWarp>,
|
||||
MPackIterPerWarp>
|
||||
scale_a_tile_tensor_ping, scale_a_tile_tensor_pong;
|
||||
|
||||
// ping pong buffer for scale B
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(scale_b_dram_window), KIterPerWarp / KXdlPack>,
|
||||
NIterPerWarp / NXdlPack>
|
||||
scale_b_dram_windows;
|
||||
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_b_dram_window)),
|
||||
KIterPerWarp / KXdlPack>,
|
||||
NIterPerWarp / NXdlPack>
|
||||
scale_b_tile_tensor_ping;
|
||||
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_b_dram_window)),
|
||||
KIterPerWarp / KXdlPack>,
|
||||
NIterPerWarp / NXdlPack>
|
||||
scale_b_tile_tensor_pong;
|
||||
statically_indexed_array<decltype(load_tile(scale_b_dram_window)), KPackIterPerWarp>,
|
||||
NPackIterPerWarp>
|
||||
scale_b_tile_tensor_ping, scale_b_tile_tensor_pong;
|
||||
|
||||
auto async_load_tile_ = [](auto lds, auto dram) {
|
||||
async_load_tile(lds, dram, number<-1>{}, true_type{}, false_type{});
|
||||
@@ -625,35 +633,31 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
|
||||
b_flat_dram_window, b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
|
||||
});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_windows(nIter), {0, KIterPerWarp * KFlatPerBlockPerIter});
|
||||
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
|
||||
});
|
||||
|
||||
// prefetch Scale A
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
|
||||
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) =
|
||||
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
|
||||
mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k);
|
||||
});
|
||||
});
|
||||
// move Scale A window to next K
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
// prefetch Scale B
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
|
||||
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
|
||||
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) =
|
||||
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k);
|
||||
});
|
||||
});
|
||||
// move Scale B window to next K
|
||||
@@ -667,7 +671,12 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock});
|
||||
}
|
||||
// initialize C
|
||||
clear_tile(c_block_tile);
|
||||
statically_indexed_array<statically_indexed_array<CWarpTensor, NIterPerWarp>, MIterPerWarp>
|
||||
c_warp_tensors;
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}(
|
||||
[&](auto nIter) { clear_tile(c_warp_tensors(mIter)(nIter)); });
|
||||
});
|
||||
|
||||
statically_indexed_array<decltype(load_tile(a_warp_window_pong)), m_preload> a_warp_tensor;
|
||||
|
||||
@@ -688,40 +697,37 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
|
||||
|
||||
// move B window to next flat K
|
||||
if constexpr(kIter == KIterPerWarp - 1)
|
||||
move_tile_window(b_flat_dram_windows(nIter),
|
||||
{0, BlockGemmShape::flatKPerBlock});
|
||||
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
|
||||
});
|
||||
});
|
||||
|
||||
// prefetch Scale A and Scale B (2i+1)
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
|
||||
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
|
||||
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) =
|
||||
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k);
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
|
||||
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
|
||||
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) =
|
||||
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k);
|
||||
});
|
||||
});
|
||||
|
||||
// GEMM 2i
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
|
||||
@@ -729,39 +735,22 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<m_iter, n_iter>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0],
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0]);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<m_iter, n_iter>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
constexpr auto addr =
|
||||
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
(nIter_pack == NPackIterPerWarp - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
@@ -802,81 +791,60 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
|
||||
|
||||
// move B window to next flat K
|
||||
if constexpr(kIter == KIterPerWarp - 1)
|
||||
move_tile_window(b_flat_dram_windows(nIter),
|
||||
{0, BlockGemmShape::flatKPerBlock});
|
||||
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
|
||||
});
|
||||
});
|
||||
|
||||
// prefetch Scale A and Scale B (2i+2)
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
|
||||
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
|
||||
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) =
|
||||
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k);
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
|
||||
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
|
||||
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) =
|
||||
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k);
|
||||
});
|
||||
});
|
||||
|
||||
// GEMM 2i+1
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
|
||||
constexpr auto m_iter = mIter_pack * MXdlPack + imxdl;
|
||||
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{}),
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0]); // scale B
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
constexpr auto addr =
|
||||
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
(nIter_pack == NPackIterPerWarp - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
@@ -928,78 +896,54 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_windows(nIter),
|
||||
make_tuple(number<0>{}, number<kIter * KFlatPerBlockPerIter>{}));
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
|
||||
});
|
||||
});
|
||||
|
||||
// prefetch Scale A and Scale B (2i+1)
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
|
||||
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
|
||||
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
|
||||
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) =
|
||||
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k);
|
||||
});
|
||||
});
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
|
||||
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
|
||||
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
|
||||
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) =
|
||||
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k);
|
||||
});
|
||||
});
|
||||
|
||||
// GEMM loopK-1
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
|
||||
constexpr auto m_iter = mIter_pack * MXdlPack + imxdl;
|
||||
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0]); // scale B
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
constexpr auto addr =
|
||||
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
(nIter_pack == NPackIterPerWarp - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
@@ -1028,50 +972,32 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
Last2ndHotLoopScheduler();
|
||||
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
|
||||
constexpr auto m_iter = mIter_pack * MXdlPack + imxdl;
|
||||
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{}),
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0]); // scale B
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
constexpr auto addr =
|
||||
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
(nIter_pack == NPackIterPerWarp - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
@@ -1089,50 +1015,32 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
|
||||
constexpr auto m_iter = mIter_pack * MXdlPack + imxdl;
|
||||
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0]); // scale B
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
constexpr auto addr =
|
||||
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
(nIter_pack == NPackIterPerWarp - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
@@ -1151,7 +1059,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
{
|
||||
static_assert(false, "Wrong TailNum");
|
||||
}
|
||||
return c_block_tile;
|
||||
return c_warp_tensors;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
{
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
@@ -58,7 +58,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
|
||||
template <typename Problem, typename TensorView>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
MakeMXFP4_AAsyncLoadDramDescriptor(const TensorView& naive_view)
|
||||
MakeMX_AAsyncLoadDramDescriptor(const TensorView& naive_view)
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
@@ -107,7 +107,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeMXFP4_ADramTileDistribution()
|
||||
CK_TILE_DEVICE static constexpr auto MakeMX_ADramTileDistribution()
|
||||
{
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
@@ -140,7 +140,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor()
|
||||
CK_TILE_DEVICE static constexpr auto MakeMX_ALdsBlockDescriptor()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
@@ -218,7 +218,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXF4_ALDS_TileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDS_TileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
@@ -255,7 +255,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_BFlatDramTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
@@ -298,7 +298,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_DramTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
@@ -335,7 +335,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_DramTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_DramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
@@ -372,7 +372,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_FlatDramTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_FlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
@@ -394,7 +394,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_FlatDramTileDistribution()
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_FlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
@@ -420,8 +420,8 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
return sizeof(ADataType) *
|
||||
MakeMXFP4_ALdsBlockDescriptor<Problem>().get_element_space_size() / APackedSize;
|
||||
return sizeof(ADataType) * MakeMX_ALdsBlockDescriptor<Problem>().get_element_space_size() /
|
||||
APackedSize;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -23,7 +23,8 @@ struct BaseGemmPipelineAgBgCrCompV4
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
constexpr index_t HotLoopGlobalReads = 2;
|
||||
return num_loop >= (HotLoopGlobalReads + PrefetchStages);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
|
||||
@@ -373,8 +373,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
{
|
||||
// Need to multiply aquant with accumulated C
|
||||
//
|
||||
// The accumulated C tile has the standard distribution. For example
|
||||
// lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
|
||||
// The accumulated C tile has the standard distribution. For example, a
|
||||
// 32x32 C lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
|
||||
// [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
|
||||
// [26,0], [27,0].
|
||||
//
|
||||
@@ -388,35 +388,31 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
//
|
||||
// These scales can be obtained using __builtin_amdgcn_ds_bpermute.
|
||||
|
||||
// MIters per warp
|
||||
constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM;
|
||||
|
||||
// Reg block offset based on mIter
|
||||
constexpr index_t reg_block_offset =
|
||||
((mIter / mIters_per_warp) * Traits::AQPerBlock);
|
||||
|
||||
constexpr index_t lane_base_offset =
|
||||
(mIter % mIters_per_warp) * WarpGemm::kM;
|
||||
|
||||
// Scale tensor offset along K
|
||||
constexpr index_t src_reg_offset = reg_block_offset + kQScale;
|
||||
// Directly index into thread buffer corresponding to
|
||||
// desired row coefficient
|
||||
// Each thread stores AQPerBlock scale values per M iteration.
|
||||
constexpr index_t reg_block_offset = mIter * Traits::AQPerBlock;
|
||||
constexpr index_t src_reg_offset = reg_block_offset + kQScale;
|
||||
auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset];
|
||||
|
||||
constexpr uint32_t kTileRows = (get_warp_size() == 64) ? 4 : 8;
|
||||
;
|
||||
constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows;
|
||||
constexpr uint32_t reg_offset_for_row_data = c_row * WarpGemm::kCMLane;
|
||||
// Multiply by 4 because output is stored in tiles of 4
|
||||
// x CNLane
|
||||
constexpr uint32_t row_base =
|
||||
((reg_offset_for_row_data / kTiledCMsPerWarp) * kTiledCMsPerWarp) +
|
||||
((reg_offset_for_row_data % kTiledCMsPerWarp) / WarpGemm::kCMLane);
|
||||
// Divide M dimension of C Warp tile into groups of
|
||||
// (WarpGemm::kCMLane * WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane)
|
||||
// m_base_offset_of_c_row indicates which group the current c_row belongs
|
||||
// to.
|
||||
constexpr index_t m_base_offset_of_c_row =
|
||||
(c_row / WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane) *
|
||||
(WarpGemm::kCMLane * WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane);
|
||||
|
||||
// M offset of each thread within its group (see comment above)
|
||||
index_t m_base_offset_of_lane =
|
||||
(get_lane_id() / WarpGemm::kN *
|
||||
WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane);
|
||||
|
||||
// M offset wrt. c_row in the subgroup of kCM1PerLane
|
||||
constexpr index_t m_offset_of_c_row =
|
||||
c_row & (WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane - 1);
|
||||
|
||||
// Lane index to source scale from
|
||||
uint32_t src_lane_idx =
|
||||
lane_base_offset + row_base + (__lane_id() / WarpGemm::kN * kTileRows);
|
||||
m_base_offset_of_c_row + m_base_offset_of_lane + m_offset_of_c_row;
|
||||
|
||||
return exchange_quant_value_across_lanes(scale_reg, src_lane_idx);
|
||||
}
|
||||
|
||||
@@ -94,21 +94,20 @@ struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding
|
||||
// # of elements per thread
|
||||
constexpr index_t X = XPerTile;
|
||||
|
||||
constexpr index_t Y0 = 1;
|
||||
constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1;
|
||||
constexpr index_t Y2 = MWarps;
|
||||
constexpr index_t Y3 = WarpGemm::kM;
|
||||
static_assert(Y3 >= WarpGemm::kM,
|
||||
constexpr index_t YR = 1;
|
||||
constexpr index_t Y0 = MIterPerWarp ? MIterPerWarp : 1;
|
||||
constexpr index_t Y1 = MWarps;
|
||||
constexpr index_t Y2 = WarpGemm::kM;
|
||||
static_assert(Y2 >= WarpGemm::kM,
|
||||
"Scales for all rows must be available within the warp.");
|
||||
static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile,
|
||||
"Y0, Y1, Y2, Y3 must cover the blocktile along Y.");
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<NWarps>,
|
||||
tuple<sequence<Y0, Y1, Y2, Y3>, sequence<X>>,
|
||||
tuple<sequence<1, 0>, sequence<1, 1>>,
|
||||
tuple<sequence<2, 0>, sequence<0, 3>>,
|
||||
tile_distribution_encoding<sequence<NWarps, YR>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X>>,
|
||||
tuple<sequence<1, 0>, sequence<0, 1>>,
|
||||
tuple<sequence<1, 0>, sequence<1, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 0>>{});
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -15,6 +15,142 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Col,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Row,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
|
||||
Row,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
|
||||
Col,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif // CK_ENABLE_FP16
|
||||
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) && defined(__gfx12__)
|
||||
void add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Row,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F8,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Row,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F8,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#if defined(CK_ENABLE_BF16)
|
||||
void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Col,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Row,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
|
||||
Row,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
|
||||
Col,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif // CK_ENABLE_BF16
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
|
||||
@@ -409,6 +545,81 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
#if defined(CK_USE_WMMA)
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<EDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_ENABLE_FP16
|
||||
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) && defined(__gfx12__)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
|
||||
is_same_v<EDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<EDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#if defined(CK_ENABLE_BF16)
|
||||
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
|
||||
is_same_v<EDataType, bhalf_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_ENABLE_BF16
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
|
||||
@@ -0,0 +1,205 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/utility/loop_scheduler.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AccDataType = F32;
|
||||
using DsDataType = Empty_Tuple;
|
||||
|
||||
using DsLayout = Empty_Tuple;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = PassThrough;
|
||||
|
||||
static constexpr auto PipelineV1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto PipelineV3 = BlockGemmPipelineVersion::v3;
|
||||
static constexpr auto IntrawaveScheduler = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto InterwaveScheduler = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto GemmMNKPadding = device::GemmSpecialization::MNKPadding;
|
||||
static constexpr auto GemmDefault = device::GemmSpecialization::Default;
|
||||
|
||||
// Instances for 2 byte datatypes in CRR layout with ADataType = BDataType = EDataType
|
||||
template <typename T,
|
||||
device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
enable_if_t<sizeof(T) == 2, bool> = false>
|
||||
using device_grouped_gemm_wmma_universal_km_kn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang`-format on
|
||||
>;
|
||||
|
||||
// Instances for 2 byte datatypes in CCR layout with ADataType = BDataType = EDataType
|
||||
template <typename T,
|
||||
device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
enable_if_t<sizeof(T) == 2, bool> = false>
|
||||
using device_grouped_gemm_wmma_universal_km_nk_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// Instances for 2 byte datatypes in RRR layout with ADataType = BDataType = EDataType
|
||||
template <typename T,
|
||||
device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
enable_if_t<sizeof(T) == 2, bool> = false>
|
||||
using device_grouped_gemm_wmma_universal_mk_kn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// Instances for 2 byte datatypes in RCR layout with ADataType = BDataType = EDataType
|
||||
template <typename T,
|
||||
device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
enable_if_t<sizeof(T) == 2, bool> = false>
|
||||
using device_grouped_gemm_wmma_universal_mk_nk_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
// Helper function to add a list of layout instances with specific A/B/E datatypes for all supported
|
||||
// padding/scheduler/pipeline version combinations
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
template <device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer>
|
||||
typename LayoutInstances,
|
||||
typename ADataType, // NOTE: type parameters as last so that they can be inferred from the
|
||||
typename BDataType, // vector argument
|
||||
typename EDataType>
|
||||
void add_device_grouped_gemm_wmma_universal_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
LayoutInstances<GemmDefault, IntrawaveScheduler, PipelineV1>{});
|
||||
add_device_operation_instances(instances,
|
||||
LayoutInstances<GemmDefault, InterwaveScheduler, PipelineV1>{});
|
||||
add_device_operation_instances(instances,
|
||||
LayoutInstances<GemmDefault, IntrawaveScheduler, PipelineV3>{});
|
||||
add_device_operation_instances(
|
||||
instances, LayoutInstances<GemmMNKPadding, IntrawaveScheduler, PipelineV1>{});
|
||||
add_device_operation_instances(
|
||||
instances, LayoutInstances<GemmMNKPadding, InterwaveScheduler, PipelineV1>{});
|
||||
add_device_operation_instances(
|
||||
instances, LayoutInstances<GemmMNKPadding, IntrawaveScheduler, PipelineV3>{});
|
||||
}
|
||||
|
||||
// Helper function to add a list of layout instances for instances with matching A/B/E data types
|
||||
// for all supported padding/scheduler/pipeline version combinations
|
||||
template <typename T,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
template <typename T2,
|
||||
device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer>
|
||||
typename LayoutInstances>
|
||||
void add_device_grouped_gemm_wmma_universal_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
T,
|
||||
T,
|
||||
DsDataType,
|
||||
T,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, LayoutInstances<T, GemmDefault, IntrawaveScheduler, PipelineV1>{});
|
||||
add_device_operation_instances(
|
||||
instances, LayoutInstances<T, GemmDefault, InterwaveScheduler, PipelineV1>{});
|
||||
add_device_operation_instances(
|
||||
instances, LayoutInstances<T, GemmDefault, IntrawaveScheduler, PipelineV3>{});
|
||||
add_device_operation_instances(
|
||||
instances, LayoutInstances<T, GemmMNKPadding, IntrawaveScheduler, PipelineV1>{});
|
||||
add_device_operation_instances(
|
||||
instances, LayoutInstances<T, GemmMNKPadding, InterwaveScheduler, PipelineV1>{});
|
||||
add_device_operation_instances(
|
||||
instances, LayoutInstances<T, GemmMNKPadding, IntrawaveScheduler, PipelineV3>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -57,7 +57,7 @@ function(add_instance_library INSTANCE_NAME)
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
# Do not build XDL instances if gfx9 targets are not on the target list
|
||||
if(NOT INST_TARGETS MATCHES "gfx9" AND NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "_xdl")
|
||||
if(((NOT INST_TARGETS MATCHES "gfx9" AND NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12") OR FORCE_DISABLE_XDL) AND source_name MATCHES "_xdl")
|
||||
message(DEBUG "removing xdl instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
@@ -67,7 +67,7 @@ function(add_instance_library INSTANCE_NAME)
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
# Do not build WMMA instances if gfx11 targets are not on the target list
|
||||
if(NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "_wmma")
|
||||
if(((NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12") OR FORCE_DISABLE_WMMA) AND source_name MATCHES "_wmma")
|
||||
message(DEBUG "removing wmma instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
@@ -88,7 +88,7 @@ function(add_instance_library INSTANCE_NAME)
|
||||
endif()
|
||||
endif()
|
||||
# Do not build WMMA gemm_universal_f8 for any targets except gfx12+
|
||||
if(NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "gemm_wmma_universal" AND source_name MATCHES "_f8_")
|
||||
if((NOT INST_TARGETS MATCHES "gfx12" OR FORCE_DISABLE_WMMA) AND source_name MATCHES "gemm_wmma_universal" AND source_name MATCHES "_f8_")
|
||||
message(DEBUG "removing gemm_universal_f8 instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
@@ -274,7 +274,7 @@ FOREACH(subdir_path ${dir_list})
|
||||
message(DEBUG "Found only dl instances, but DL_KERNELS is not set. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12"))
|
||||
if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12" OR FORCE_DISABLE_XDL))
|
||||
message(DEBUG "Found only xdl instances, but gfx9 is not on the targets list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
@@ -282,7 +282,7 @@ FOREACH(subdir_path ${dir_list})
|
||||
message(DEBUG "Found only MX instances, but gfx950 is not on the targets list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12"))
|
||||
if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (((NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12")) OR FORCE_DISABLE_WMMA))
|
||||
message(DEBUG "Found only wmma instances, but gfx11 is not on the targets list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
@@ -290,7 +290,7 @@ FOREACH(subdir_path ${dir_list})
|
||||
message(DEBUG "Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12"))
|
||||
if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND ((NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12") OR (FORCE_DISABLE_XDL AND FORCE_DISABLE_WMMA)))
|
||||
message(DEBUG "Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
@@ -333,20 +333,22 @@ FOREACH(subdir_path ${dir_list})
|
||||
if((add_inst EQUAL 1))
|
||||
get_filename_component(target_dir ${subdir_path} NAME)
|
||||
add_subdirectory(${target_dir})
|
||||
if("${cmake_instance}" MATCHES "gemm")
|
||||
list(APPEND CK_DEVICE_GEMM_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
|
||||
elseif("${cmake_instance}" MATCHES "conv")
|
||||
list(APPEND CK_DEVICE_CONV_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
|
||||
elseif("${cmake_instance}" MATCHES "mha")
|
||||
list(APPEND CK_DEVICE_MHA_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
|
||||
elseif("${cmake_instance}" MATCHES "contr")
|
||||
list(APPEND CK_DEVICE_CONTRACTION_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
|
||||
elseif("${cmake_instance}" MATCHES "reduce")
|
||||
list(APPEND CK_DEVICE_REDUCTION_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
|
||||
else()
|
||||
list(APPEND CK_DEVICE_OTHER_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
|
||||
endif()
|
||||
message(DEBUG "add_instance_directory ${subdir_path}")
|
||||
if (TARGET device_${target_dir}_instance)
|
||||
if("${cmake_instance}" MATCHES "gemm")
|
||||
list(APPEND CK_DEVICE_GEMM_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
|
||||
elseif("${cmake_instance}" MATCHES "conv")
|
||||
list(APPEND CK_DEVICE_CONV_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
|
||||
elseif("${cmake_instance}" MATCHES "mha")
|
||||
list(APPEND CK_DEVICE_MHA_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
|
||||
elseif("${cmake_instance}" MATCHES "contr")
|
||||
list(APPEND CK_DEVICE_CONTRACTION_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
|
||||
elseif("${cmake_instance}" MATCHES "reduce")
|
||||
list(APPEND CK_DEVICE_REDUCTION_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
|
||||
else()
|
||||
list(APPEND CK_DEVICE_OTHER_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
|
||||
endif()
|
||||
message(DEBUG "add_instance_directory ${subdir_path}")
|
||||
endif()
|
||||
else()
|
||||
message(DEBUG "skip_instance_directory ${subdir_path}")
|
||||
endif()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
add_instance_library(device_grouped_gemm_instance
|
||||
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
@@ -36,4 +36,17 @@ add_instance_library(device_grouped_gemm_instance
|
||||
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_bf16_bf16_mk_nk_mn_instance.cpp
|
||||
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_nk_mn_instance.cpp
|
||||
|
||||
device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instance.cpp
|
||||
|
||||
device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
|
||||
device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instance.cpp
|
||||
device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instance.cpp
|
||||
device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
|
||||
Row,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
BF16,
|
||||
BF16,
|
||||
DsDataType,
|
||||
BF16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
BF16,
|
||||
Col,
|
||||
Row,
|
||||
device_grouped_gemm_wmma_universal_km_kn_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
|
||||
Col,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
BF16,
|
||||
BF16,
|
||||
DsDataType,
|
||||
BF16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
BF16,
|
||||
Col,
|
||||
Col,
|
||||
device_grouped_gemm_wmma_universal_km_nk_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
DsDataType,
|
||||
BF16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
BF16,
|
||||
Row,
|
||||
Row,
|
||||
device_grouped_gemm_wmma_universal_mk_kn_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Col,
|
||||
DsLayout,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
DsDataType,
|
||||
BF16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
BF16,
|
||||
Row,
|
||||
Col,
|
||||
device_grouped_gemm_wmma_universal_mk_nk_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
|
||||
Row,
|
||||
DsLayout,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
F16,
|
||||
Col,
|
||||
Row,
|
||||
device_grouped_gemm_wmma_universal_km_kn_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,37 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
|
||||
Col,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
F16,
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
{
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
F16,
|
||||
Col,
|
||||
Col,
|
||||
device_grouped_gemm_wmma_universal_km_nk_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
F16,
|
||||
Row,
|
||||
Row,
|
||||
device_grouped_gemm_wmma_universal_mk_kn_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,38 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Col,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
F16,
|
||||
F16,
|
||||
DsDataType,
|
||||
F16,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
F16,
|
||||
Row,
|
||||
Col,
|
||||
device_grouped_gemm_wmma_universal_mk_nk_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,57 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using ADataType = F16;
|
||||
using BDataType = F8;
|
||||
using EDataType = F16;
|
||||
|
||||
template <device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer>
|
||||
using device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
Row,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
Row,
|
||||
Row,
|
||||
device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,57 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using ADataType = F8;
|
||||
using BDataType = F16;
|
||||
using EDataType = F16;
|
||||
|
||||
template <device::GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer>
|
||||
using device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
|
||||
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
|
||||
Row,
|
||||
DsLayout,
|
||||
Row,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>>>& instances)
|
||||
{
|
||||
|
||||
add_device_grouped_gemm_wmma_universal_instances<
|
||||
Row,
|
||||
Row,
|
||||
device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -42,10 +42,11 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideCs,
|
||||
const std::vector<int>& kbatches = {},
|
||||
int n_warmup = 1,
|
||||
int n_iter = 10,
|
||||
int instance_index = -1)
|
||||
const std::vector<int>& kbatches = {},
|
||||
int n_warmup = 1,
|
||||
int n_iter = 10,
|
||||
int instance_index = -1,
|
||||
bool fail_if_no_supported_instance = false)
|
||||
{
|
||||
bool pass = true;
|
||||
// TODO: Fixme - we do not pass compute data type here but need it
|
||||
@@ -225,6 +226,7 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
}
|
||||
}
|
||||
// profile device GEMM instances
|
||||
int instances_supporting_all_batch_sizes = 0;
|
||||
for(auto& gemm_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr =
|
||||
@@ -268,6 +270,7 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
kbatch_list = kbatches;
|
||||
}
|
||||
|
||||
bool all_batch_sizes_supported = true;
|
||||
for(std::size_t j = 0; j < kbatch_list.size(); j++)
|
||||
{
|
||||
auto kbatch_curr = kbatch_list[j];
|
||||
@@ -367,10 +370,30 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
}
|
||||
else
|
||||
{
|
||||
all_batch_sizes_supported = false;
|
||||
std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// If all batch sizes were supported by this instance, the instance can be marked as
|
||||
// 'supported' for this problem
|
||||
if(all_batch_sizes_supported)
|
||||
{
|
||||
++instances_supporting_all_batch_sizes;
|
||||
}
|
||||
}
|
||||
|
||||
// Warn if not a single instance was supported
|
||||
if(instances_supporting_all_batch_sizes == 0)
|
||||
{
|
||||
std::cout << "Warning! No instance found that supported all of the batch sizes."
|
||||
<< std::endl;
|
||||
|
||||
if(fail_if_no_supported_instance)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
@@ -384,6 +407,7 @@ bool profile_grouped_gemm_impl(int do_verification,
|
||||
std::cout << "grouped_gemm_instance (" << instance_index << "/" << num_kernel << "): Passed"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
|
||||
@@ -53,6 +53,13 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<false>();
|
||||
};
|
||||
|
||||
struct GemmConfigPrefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128;
|
||||
};
|
||||
|
||||
struct GemmConfigPreshuffleQuant : public GemmConfigBase
|
||||
{
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
|
||||
@@ -39,6 +39,12 @@ using AQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
|
||||
// PreshuffleQuant = false && TransposeC = false && Prefill
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
|
||||
|
||||
// PreshuffleQuant = false && TransposeC = true
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
|
||||
|
||||
@@ -3,10 +3,15 @@
|
||||
|
||||
add_custom_target(test_grouped_gemm)
|
||||
|
||||
add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance)
|
||||
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk)
|
||||
# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary
|
||||
# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link
|
||||
# the instance library if there's no instances present for the current arch.
|
||||
if (CK_USE_XDL OR CK_USE_WMMA)
|
||||
add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance)
|
||||
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp)
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "test_grouped_gemm_util.hpp"
|
||||
#include "test_grouped_gemm_interface_xdl.hpp"
|
||||
|
||||
class TestGGemmSplitKInterface_MKNKMN : public ::testing::Test
|
||||
{
|
||||
|
||||
205
test/grouped_gemm/test_grouped_gemm_interface_xdl.hpp
Normal file
205
test/grouped_gemm/test_grouped_gemm_interface_xdl.hpp
Normal file
@@ -0,0 +1,205 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/stream_config.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "profiler/profile_grouped_gemm_impl.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace test {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock>
|
||||
struct DeviceGroupedGemmSplitkInstanceWrapper
|
||||
{
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using PassThrough = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using EmptyTuple = ck::Tuple<>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
template <ck::index_t N>
|
||||
using I = ck::Number<N>;
|
||||
|
||||
using ABlockTransferThreadClusterArrageOrder =
|
||||
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
|
||||
using ABlockTransferSrcAccessOrder =
|
||||
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
|
||||
using ABlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<ALayout, Row>, I<3>, I<2>>;
|
||||
using ABlockTransferDstScalarPerVector_K1 =
|
||||
std::conditional_t<std::is_same_v<ALayout, Row>, I<8>, I<2>>;
|
||||
using ABlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<1>, I<0>>;
|
||||
|
||||
using BBlockTransferThreadClusterArrageOrder =
|
||||
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
|
||||
using BBlockTransferSrcAccessOrder =
|
||||
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
|
||||
using BBlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<BLayout, Row>, I<2>, I<3>>;
|
||||
using BBlockTransferDstScalarPerVector_K1 =
|
||||
std::conditional_t<std::is_same_v<ALayout, Row>, I<2>, I<8>>;
|
||||
using BBlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<0>, I<1>>;
|
||||
|
||||
using DeviceGroupedGemmSplitKInstance =
|
||||
tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle<
|
||||
ALayout,
|
||||
BLayout,
|
||||
EmptyTuple,
|
||||
ELayout,
|
||||
F16,
|
||||
F16,
|
||||
F32,
|
||||
F16,
|
||||
EmptyTuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
GemmSpec,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
KPerBlock,
|
||||
K1,
|
||||
K1,
|
||||
16,
|
||||
16,
|
||||
8,
|
||||
4,
|
||||
S<1, 4, 16, 1>,
|
||||
ABlockTransferThreadClusterArrageOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim::value,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1::value,
|
||||
ABlockLdsAddExtraM::value,
|
||||
S<1, 4, 16, 1>,
|
||||
BBlockTransferThreadClusterArrageOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim::value,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1::value,
|
||||
BBlockLdsAddExtraM::value,
|
||||
1,
|
||||
1,
|
||||
S<1, 16, 1, 8>,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock>;
|
||||
|
||||
bool IsSupported(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideCs,
|
||||
int kbatch = 1) const
|
||||
{
|
||||
std::size_t n_groups = Ms.size();
|
||||
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
|
||||
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
|
||||
<< "The number of groups is not consistent!";
|
||||
|
||||
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
|
||||
|
||||
for(std::size_t i = 0; i < n_groups; ++i)
|
||||
{
|
||||
gemm_descs.push_back(tensor_operation::device::GemmDesc{
|
||||
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
|
||||
}
|
||||
|
||||
std::vector<const void*> p_As(n_groups, nullptr);
|
||||
std::vector<const void*> p_Bs(n_groups, nullptr);
|
||||
std::vector<void*> p_Cs(n_groups, nullptr);
|
||||
auto p_Ds = std::vector<std::array<const void*, 0>>{};
|
||||
|
||||
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
|
||||
auto argument = ggemm_instance.MakeArgument(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
if(kbatch > 1)
|
||||
{
|
||||
ggemm_instance.SetKBatchSize(&argument, kbatch);
|
||||
}
|
||||
|
||||
return ggemm_instance.IsSupportedArgument(argument);
|
||||
}
|
||||
|
||||
float Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideCs,
|
||||
int kbatch = 1) const
|
||||
{
|
||||
std::size_t n_groups = Ms.size();
|
||||
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
|
||||
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
|
||||
<< "The number of groups is not consistent!";
|
||||
|
||||
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
|
||||
|
||||
for(std::size_t i = 0; i < n_groups; ++i)
|
||||
{
|
||||
gemm_descs.push_back(tensor_operation::device::GemmDesc{
|
||||
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
|
||||
}
|
||||
|
||||
std::vector<const void*> p_As(n_groups, nullptr);
|
||||
std::vector<const void*> p_Bs(n_groups, nullptr);
|
||||
std::vector<void*> p_Cs(n_groups, nullptr);
|
||||
auto p_Ds = std::vector<std::array<const void*, 0>>{};
|
||||
|
||||
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
|
||||
auto argument = ggemm_instance.MakeArgument(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
if(kbatch > 1)
|
||||
{
|
||||
ggemm_instance.SetKBatchSize(&argument, kbatch);
|
||||
}
|
||||
if(kbatch > 1 && ck::is_gfx11_supported())
|
||||
{
|
||||
EXPECT_FALSE(ggemm_instance.IsSupportedArgument(argument));
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument));
|
||||
auto invoker = ggemm_instance.MakeInvoker();
|
||||
DeviceMem dev_gemm_kargs(ggemm_instance.GetDeviceKernelArgSize(&argument));
|
||||
ggemm_instance.SetDeviceKernelArgs(&argument, dev_gemm_kargs.GetDeviceBuffer());
|
||||
return invoker.Run(argument, StreamConfig{nullptr, false});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace ck
|
||||
@@ -24,21 +24,48 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
template <typename Tuple>
|
||||
class TestGroupedGemm : public ck::test::TestGroupedGemm<Tuple>
|
||||
{
|
||||
public:
|
||||
void SetUp() override
|
||||
{
|
||||
ck::test::TestGroupedGemm<Tuple>::SetUp();
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
// The old XDL tests didn't fail if instances were not supported, so we want to keep that
|
||||
// behaviour When compiling WMMA instances and WMMA is supported, then we'll fail if a
|
||||
// specific case is not supported
|
||||
this->fail_if_no_supported_instances_ =
|
||||
ck::is_gfx11_supported() || ck::is_gfx12_supported();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
// WWMA only. No reason to not have it for XDL, but the instance was not defined and it was not in the original test.
|
||||
std::tuple< Col, Col, Row, BF16, BF16, BF16>,
|
||||
#endif
|
||||
|
||||
#if defined(CK_USE_XDL) && defined(__gfx9__)
|
||||
// XDL only at the moment, instances for WMMA not defined
|
||||
std::tuple< Row, Row, Row, BF16, I8, BF16>,
|
||||
std::tuple< Row, Col, Row, BF16, I8, BF16>,
|
||||
#endif
|
||||
|
||||
#if (defined(CK_USE_XDL) && (defined(__gfx9__) || defined(__gfx12__))) || (defined(CK_USE_WMMA) && defined(__gfx12__))
|
||||
std::tuple< Row, Row, Row, F8, F16, F16>,
|
||||
std::tuple< Row, Row, Row, F16, F8, F16>,
|
||||
#endif
|
||||
|
||||
std::tuple< Row, Row, Row, F16, F16, F16>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F16>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F16>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F16>,
|
||||
|
||||
std::tuple< Row, Row, Row, BF16, BF16, BF16>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, BF16>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, BF16>,
|
||||
std::tuple< Row, Row, Row, BF16, I8, BF16>,
|
||||
std::tuple< Row, Col, Row, BF16, I8, BF16>,
|
||||
std::tuple< Row, Row, Row, F16, F8, F16>,
|
||||
std::tuple< Row, Row, Row, F8, F16, F16>
|
||||
std::tuple< Col, Row, Row, BF16, BF16, BF16>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
@@ -65,6 +65,13 @@ TYPED_TEST(TestGroupedGemm, MNKPadded)
|
||||
|
||||
TYPED_TEST(TestGroupedGemm, TestLargeKBatch)
|
||||
{
|
||||
// gfx11 does not support split-K due to missing atomic add for fp16/bf16
|
||||
// Technically, we could still run the tests for fp32, but we currently don't have instances for
|
||||
// it so we disable it entirely
|
||||
if(ck::is_gfx11_supported())
|
||||
GTEST_SKIP() << "Split-K not supported for FP16/BF16 on GFX11 due to missing atomic add "
|
||||
"instructions";
|
||||
|
||||
const std::vector<int> Ms{188, 210};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 4096;
|
||||
|
||||
@@ -11,16 +11,7 @@
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/stream_config.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/number.hpp"
|
||||
#include "profiler/profile_grouped_gemm_impl.hpp"
|
||||
|
||||
extern ck::index_t param_mask;
|
||||
@@ -41,7 +32,7 @@ std::string serialize_range(const Range& range)
|
||||
return std::string(str.begin(), str.end() - 2);
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
template <typename Tuple, bool FailIfNoSupportedInstances = false>
|
||||
class TestGroupedGemm : public testing::Test
|
||||
{
|
||||
protected:
|
||||
@@ -62,9 +53,26 @@ class TestGroupedGemm : public testing::Test
|
||||
static constexpr bool bench_ = false; // measure kernel performance
|
||||
static constexpr int n_warmup_ = 0;
|
||||
static constexpr int n_iter_ = 1;
|
||||
|
||||
bool fail_if_no_supported_instances_ = FailIfNoSupportedInstances;
|
||||
std::vector<int> k_batches_;
|
||||
|
||||
void SetUp() override { k_batches_ = {1, 2, 3, 5, 8}; }
|
||||
void SetUp() override
|
||||
{
|
||||
constexpr bool require_16bit_atomic_add =
|
||||
std::is_same_v<EDataType, ck::half_t> || std::is_same_v<EDataType, ck::bhalf_t>;
|
||||
if(require_16bit_atomic_add && ck::is_gfx11_supported())
|
||||
{
|
||||
// gfx11 does not support split-K due to missing atomic add for fp16/bf16
|
||||
// Technically, we could still use split-K for fp32, but we currently don't have
|
||||
// instances for it so we disable it entirely
|
||||
k_batches_ = {1};
|
||||
}
|
||||
else
|
||||
{
|
||||
k_batches_ = {1, 2, 3, 5, 8};
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename Layout>
|
||||
@@ -132,204 +140,31 @@ class TestGroupedGemm : public testing::Test
|
||||
const std::vector<int>& StrideCs,
|
||||
const std::vector<int>& kbatches)
|
||||
{
|
||||
bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType,
|
||||
BDataType,
|
||||
EDataType,
|
||||
float,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout>(verify_,
|
||||
init_method_,
|
||||
log_,
|
||||
bench_,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs,
|
||||
kbatches,
|
||||
n_warmup_,
|
||||
n_iter_,
|
||||
instance_index);
|
||||
bool pass =
|
||||
ck::profiler::profile_grouped_gemm_impl<ADataType,
|
||||
BDataType,
|
||||
EDataType,
|
||||
float,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout>(verify_,
|
||||
init_method_,
|
||||
log_,
|
||||
bench_,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideCs,
|
||||
kbatches,
|
||||
n_warmup_,
|
||||
n_iter_,
|
||||
instance_index,
|
||||
fail_if_no_supported_instances_);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
ck::index_t KPerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock>
|
||||
struct DeviceGroupedGemmSplitkInstanceWrapper
|
||||
{
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using PassThrough = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using EmptyTuple = ck::Tuple<>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
template <ck::index_t N>
|
||||
using I = ck::Number<N>;
|
||||
|
||||
using ABlockTransferThreadClusterArrageOrder =
|
||||
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
|
||||
using ABlockTransferSrcAccessOrder =
|
||||
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
|
||||
using ABlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<ALayout, Row>, I<3>, I<2>>;
|
||||
using ABlockTransferDstScalarPerVector_K1 =
|
||||
std::conditional_t<std::is_same_v<ALayout, Row>, I<8>, I<2>>;
|
||||
using ABlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<1>, I<0>>;
|
||||
|
||||
using BBlockTransferThreadClusterArrageOrder =
|
||||
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
|
||||
using BBlockTransferSrcAccessOrder =
|
||||
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
|
||||
using BBlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<BLayout, Row>, I<2>, I<3>>;
|
||||
using BBlockTransferDstScalarPerVector_K1 =
|
||||
std::conditional_t<std::is_same_v<ALayout, Row>, I<2>, I<8>>;
|
||||
using BBlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<0>, I<1>>;
|
||||
|
||||
using DeviceGroupedGemmSplitKInstance =
|
||||
tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle<
|
||||
ALayout,
|
||||
BLayout,
|
||||
EmptyTuple,
|
||||
ELayout,
|
||||
F16,
|
||||
F16,
|
||||
F32,
|
||||
F16,
|
||||
EmptyTuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
GemmSpec,
|
||||
1,
|
||||
128,
|
||||
128,
|
||||
128,
|
||||
KPerBlock,
|
||||
K1,
|
||||
K1,
|
||||
16,
|
||||
16,
|
||||
8,
|
||||
4,
|
||||
S<1, 4, 16, 1>,
|
||||
ABlockTransferThreadClusterArrageOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim::value,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1::value,
|
||||
ABlockLdsAddExtraM::value,
|
||||
S<1, 4, 16, 1>,
|
||||
BBlockTransferThreadClusterArrageOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim::value,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1::value,
|
||||
BBlockLdsAddExtraM::value,
|
||||
1,
|
||||
1,
|
||||
S<1, 16, 1, 8>,
|
||||
CDEBlockTransferScalarPerVector_NPerBlock>;
|
||||
|
||||
bool IsSupported(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideCs,
|
||||
int kbatch = 1) const
|
||||
{
|
||||
std::size_t n_groups = Ms.size();
|
||||
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
|
||||
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
|
||||
<< "The number of groups is not consistent!";
|
||||
|
||||
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
|
||||
|
||||
for(std::size_t i = 0; i < n_groups; ++i)
|
||||
{
|
||||
gemm_descs.push_back(tensor_operation::device::GemmDesc{
|
||||
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
|
||||
}
|
||||
|
||||
std::vector<const void*> p_As(n_groups, nullptr);
|
||||
std::vector<const void*> p_Bs(n_groups, nullptr);
|
||||
std::vector<void*> p_Cs(n_groups, nullptr);
|
||||
auto p_Ds = std::vector<std::array<const void*, 0>>{};
|
||||
|
||||
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
|
||||
auto argument = ggemm_instance.MakeArgument(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
if(kbatch > 1)
|
||||
{
|
||||
ggemm_instance.SetKBatchSize(&argument, kbatch);
|
||||
}
|
||||
|
||||
return ggemm_instance.IsSupportedArgument(argument);
|
||||
}
|
||||
|
||||
float Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideCs,
|
||||
int kbatch = 1) const
|
||||
{
|
||||
std::size_t n_groups = Ms.size();
|
||||
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
|
||||
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
|
||||
<< "The number of groups is not consistent!";
|
||||
|
||||
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
|
||||
|
||||
for(std::size_t i = 0; i < n_groups; ++i)
|
||||
{
|
||||
gemm_descs.push_back(tensor_operation::device::GemmDesc{
|
||||
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
|
||||
}
|
||||
|
||||
std::vector<const void*> p_As(n_groups, nullptr);
|
||||
std::vector<const void*> p_Bs(n_groups, nullptr);
|
||||
std::vector<void*> p_Cs(n_groups, nullptr);
|
||||
auto p_Ds = std::vector<std::array<const void*, 0>>{};
|
||||
|
||||
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
|
||||
auto argument = ggemm_instance.MakeArgument(
|
||||
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
if(kbatch > 1)
|
||||
{
|
||||
ggemm_instance.SetKBatchSize(&argument, kbatch);
|
||||
}
|
||||
if(kbatch > 1 && ck::is_gfx11_supported())
|
||||
{
|
||||
EXPECT_FALSE(ggemm_instance.IsSupportedArgument(argument));
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument));
|
||||
auto invoker = ggemm_instance.MakeInvoker();
|
||||
DeviceMem dev_gemm_kargs(ggemm_instance.GetDeviceKernelArgSize(&argument));
|
||||
ggemm_instance.SetDeviceKernelArgs(&argument, dev_gemm_kargs.GetDeviceBuffer());
|
||||
return invoker.Run(argument, StreamConfig{nullptr, false});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user