Merge branch 'develop' into tianxing/unified-attention

This commit is contained in:
Tianxing Wu
2025-12-02 10:58:31 +00:00
52 changed files with 3258 additions and 1102 deletions

View File

@@ -42,6 +42,8 @@ option(ENABLE_CLANG_CPP_CHECKS "Enables clang tidy, cppcheck" ON)
option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF)
option(CK_EXPERIMENTAL_BUILDER "Enable experimental builder" OFF)
option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
option(FORCE_DISABLE_XDL "Skip compiling XDL specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF)
option(FORCE_DISABLE_WMMA "Skip compiling WMMA specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF)
if(CK_EXPERIMENTAL_BUILDER)
add_definitions(-DCK_EXPERIMENTAL_BUILDER)
@@ -232,12 +234,12 @@ message(STATUS "Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}"
# Cache SUPPORTED_GPU_TARGETS for debug
set(SUPPORTED_GPU_TARGETS "${SUPPORTED_GPU_TARGETS}" CACHE STRING "List of supported GPU targets")
if (SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
if (SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx11|gfx12" AND NOT FORCE_DISABLE_XDL)
message(STATUS "Enabling XDL instances")
add_definitions(-DCK_USE_XDL)
set(CK_USE_XDL "ON")
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95")
if ((SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND NOT FORCE_DISABLE_XDL)
message(STATUS "Enabling XDL FP8 gemms on native architectures")
add_definitions(-DCK_USE_GFX94)
set(CK_USE_GFX94 "ON")
@@ -250,7 +252,7 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx10")
add_definitions(-DCK_GFX1030_SUPPORT)
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
if ((SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") AND NOT FORCE_DISABLE_WMMA)
message(STATUS "Enabling WMMA instances")
add_definitions(-DCK_USE_WMMA)
set(CK_USE_WMMA "ON")
@@ -260,7 +262,7 @@ endif()
# define the macro with the current value (0 or 1)
add_definitions(-DCK_TILE_USE_WMMA=${CK_TILE_USE_WMMA})
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12")
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" AND NOT FORCE_DISABLE_WMMA)
message(STATUS "Enabling WMMA FP8 gemms on native architectures")
add_definitions(-DCK_USE_WMMA_FP8)
set(CK_USE_WMMA_FP8 "ON")

View File

@@ -37,6 +37,13 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
endif()
add_custom_target(example_grouped_gemm_wmma)
add_example_executable(example_grouped_gemm_wmma_splitk_fp16 grouped_gemm_wmma_splitk_fp16.cpp)
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_fp16)
add_example_executable(example_grouped_gemm_wmma_splitk_bf16 grouped_gemm_wmma_splitk_bf16.cpp)
add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_bf16)
list(APPEND gpu_list_tf32 gfx942 gfx950)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)

View File

@@ -0,0 +1,72 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <tuple>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
using ::ck::DeviceMem;
using ::ck::hip_check_error;
using ::ck::HostTensorDescriptor;
using ::ck::Tensor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = BF16;
using BDataType = BF16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = BF16;
using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
// clang-format on
#define EXAMPLE_USE_SPLITK
#include "run_grouped_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }

View File

@@ -0,0 +1,71 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <tuple>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
using ::ck::DeviceMem;
using ::ck::hip_check_error;
using ::ck::HostTensorDescriptor;
using ::ck::Tensor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F16;
using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>;
// clang-format on
#define EXAMPLE_USE_SPLITK
#include "run_grouped_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }

View File

@@ -19,6 +19,10 @@ struct ProblemSize final
std::vector<ck::index_t> stride_Cs;
ck::index_t group_count;
#if defined(EXAMPLE_USE_SPLITK)
ck::index_t k_batch;
#endif
};
struct ExecutionConfig final
@@ -177,6 +181,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
auto argument = gemm.MakeArgument(
p_a, p_b, p_Ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op);
#if defined(EXAMPLE_USE_SPLITK)
gemm.SetKBatchSize(&argument, problem_size.k_batch);
#endif
std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument);
std::size_t kargs_size = gemm.GetDeviceKernelArgSize(&argument);
std::size_t hargs_size = gemm.GetHostKernelArgSize(&argument);
@@ -285,12 +293,15 @@ bool run_grouped_gemm_example(int argc, char* argv[])
ExecutionConfig config;
problem_size.group_count = 16;
#if defined(EXAMPLE_USE_SPLITK)
problem_size.k_batch = 1;
#endif
if(argc == 1)
{
// use default cases
}
else if(argc == 4 || argc == 6)
else if(argc == 4 || argc == 6 || argc == 7)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
@@ -300,6 +311,13 @@ bool run_grouped_gemm_example(int argc, char* argv[])
config.async_hargs = std::stoi(argv[4]);
problem_size.group_count = std::stoi(argv[5]);
}
#if defined(EXAMPLE_USE_SPLITK)
if(argc == 7)
{
problem_size.k_batch = std::stoi(argv[6]);
}
#endif
}
else
{
@@ -307,7 +325,10 @@ bool run_grouped_gemm_example(int argc, char* argv[])
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=n0, 1=yes)\n");
printf("arg4: async hargs (0=n0, 1=yes)\n");
printf("arg5: group count (default=16)");
printf("arg5: group count (default=16)\n");
#if defined(EXAMPLE_USE_SPLITK)
printf("arg6: k-batch count (default=1)\n");
#endif
exit(1);
}

View File

@@ -158,7 +158,7 @@ auto create_args(int argc, char* argv[])
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert(
"mx_prec", "fp4xfp4", "data type for activation and weight, support: fp6xfp6, fp8xfp8")
"mx_prec", "fp4xfp4", "data type for activation and weight, support: fp4xfp4, fp8xfp8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")

View File

@@ -75,7 +75,7 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
HasHotLoop,
TailNum>;
using MXFlatmmPipeline = ck_tile::MXF4FlatmmPipelineAGmemBGmemCRegV1<MXPipelineProblem>;
using MXFlatmmPipeline = ck_tile::MXFlatmmPipelineAGmemBGmemCRegV1<MXPipelineProblem>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<FlatmmShape,

View File

@@ -74,7 +74,7 @@ User need to select correct mapping of config for each quant mode:
|:--------|:-----:|:-----:|-------|
| For selecting AQuant | aquant | gemm_aquant_quantgrouped.cpp| GemmConfigQuantDecode |
| For selecting AQuant with Preshuffle quant | aquant | gemm_aquant_quantgrouped_preshufflequant.cpp | GemmConfigPreshuffleQuantDecode |
| For selecting BQuant | bquant | gemm_bquant_quantgrouped_<prec_type>.cpp| GemmConfigQuantDecode (or) GemmConfigBQuantPrefill |
| For selecting BQuant | bquant | gemm_bquant_quantgrouped_<prec_type>.cpp| GemmConfigQuantDecode (or) GemmConfigQuantPrefill |
| For selecting BQuant with Preshuffle quant | bquant | gemm_bquant_quantgrouped_preshufflequant.cpp| GemmConfigPreshuffleQuantDecode (or) GemmConfigPreshuffleBQuantPrefill |
| For selecting PreShuffle B with BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb.cpp| GemmConfigPreshuffleB_BQuant_Decode (or) GemmConfigPreshuffleB_BQuant_Prefill
| For selecting PreShuffle B with preshuffle BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp |GemmConfigPreshuffleB_PreshuffleBQuant_Decode (or) GemmConfigPreshuffleB_PreshuffleBQuant_Prefill

View File

@@ -6,6 +6,10 @@
template <typename T>
using GemmConfig = GemmConfigQuantDecode<T>;
// GemmConfigQuantPrefill is also supported for aquant grouped quantization
// template <typename T>
// using GemmConfig = GemmConfigQuantPrefill<T>;
void aquant_quantgrouped_instance_factory(
std::unordered_map<size_t, std::function<int(const ck_tile::ArgParser&)>>& lut)
{

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigBQuantPrefill<T>;
using GemmConfig = GemmConfigQuantPrefill<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigBQuantPrefill<T>;
using GemmConfig = GemmConfigQuantPrefill<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, \

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigBQuantPrefill<T>;
using GemmConfig = GemmConfigQuantPrefill<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \

View File

@@ -4,7 +4,7 @@
#include "run_gemm_quant_example.inc"
template <typename T>
using GemmConfig = GemmConfigBQuantPrefill<T>;
using GemmConfig = GemmConfigQuantPrefill<T>;
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, \

View File

@@ -221,7 +221,7 @@ struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill
};
template <typename PrecType>
struct GemmConfigBQuantPrefill : public GemmConfigBase
struct GemmConfigQuantPrefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
@@ -237,13 +237,13 @@ struct GemmConfigBQuantPrefill : public GemmConfigBase
};
template <typename PrecType>
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigBQuantPrefill<PrecType>
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill<PrecType>
{
static constexpr bool PreshuffleQuant = true;
};
template <typename PrecType>
struct GemmConfigBQuantPrefill_Wmma : public GemmConfigBQuantPrefill<PrecType>
struct GemmConfigBQuantPrefill_Wmma : public GemmConfigQuantPrefill<PrecType>
{
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;

View File

@@ -35,27 +35,41 @@ cmake
..
```
## Building and testing
## Building and Testing
During development, all CK Builder tests can be built with command
The builder test suite is organized into two main categories:
### Smoke Tests (Fast Unit Tests)
Quick unit tests that verify the builder's internal logic without compiling GPU kernels. These complete in under 1 second total and are suitable for frequent execution during development.
```sh
ninja test_ckb_all
ninja smoke-builder
```
To execute all tests, run
### Regression Tests (Integration Tests)
Integration tests that compile actual GPU kernels to verify that the builder generates valid, compilable code. These are more expensive than smoke tests (can take minutes to compile) but cover more fuctionality.
```sh
ls bin/test_ckb_* | xargs -n1 sh -c
ninja regression-builder
```
Some tests involve building old CK convolution factories, which will take a long time.
Hence, one might want to build only single test targets. For example
### Running All Tests
To build and run the complete test suite:
```sh
ninja check-builder
```
### Building Individual Tests
To build and run a specific test:
```sh
ninja test_ckb_conv_builder && bin/test_ckb_conv_builder
```
When adding new tests, please follow the convention where the CMake build target starts with a prefix `test_ckb`.
This allows us to filter out the CK Builder tests from the set full CK repository tests.
Also, the `test_ckb_all` target that builds all CK Builder tests relies on having the `test_ckb` prefix on the CMake build targets.
### Test Organization
- **Smoke tests**: Fast feedback during active development
- **Regression tests**: Thorough validation before submitting changes
- **Factory tests**: Expensive tests that build all MIOpen kernels (included in regression tests)
When adding new tests, please follow the convention where the CMake build target starts with a prefix `test_ckb`. This allows filtering of CK Builder tests from the full CK repository test suite.

View File

@@ -98,7 +98,7 @@ struct ConvDescription
f.writeLine(2, "Weights elementwise operation: ", signature.weight_element_op);
f.writeLast(2, "Output elementwise operation: ", signature.output_element_op);
f.writeLine(1, "Algorithm");
f.writeLast(1, "Algorithm");
// Compute Block section
f.writeLine(2, "Thread block size: ", algorithm.thread_block_size);
f.writeLine(2,
@@ -123,7 +123,7 @@ struct ConvDescription
algorithm.warp_gemm.n_iter);
// Memory Access section
f.writeLine(2, "Memory access:");
f.writeLast(2, "Memory access:");
f.writeLine(3, "A Tile transfer: ");
f.writeLine(4,
@@ -219,8 +219,6 @@ struct ConvDescription
f.writeLast(4,
"Vector access (GMEM write) instruction size: ",
algorithm.c_tile_transfer.scalar_per_vector);
f.writeLast(2);
f.writeLast(1);
return f.getString();
}

View File

@@ -1,9 +1,48 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
################################################################################
# CK Builder Test Suite
################################################################################
#
# This file defines the test suite for the Composable Kernel (CK) Builder,
# which is responsible for generating optimized GPU kernels for convolution
# operations.
#
# TESTING PHILOSOPHY:
# -------------------
# Tests are organized into two main categories:
#
# 1. SMOKE TESTS (fast, < 1 second total)
# - Unit tests that verify the builder's internal logic
# - Do NOT compile GPU kernels (fast compilation)
# - Run these frequently during development for quick feedback
# - Target: `ninja smoke-builder`
#
# 2. REGRESSION TESTS (slower, may take minutes)
# - Integration tests that compile and verify actual GPU kernels
# - Ensure the builder generates valid, compilable code
# - Include expensive "factory tests" that build all MIOpen kernels
# - Run these before submitting changes
# - Target: `ninja regression-builder`
#
# QUICK START:
# ------------
# - During development: ninja smoke-builder
# - Before submitting: ninja regression-builder
# - Run everything: ninja check-builder
# - Build specific test: ninja test_ckb_conv_builder && bin/test_ckb_conv_builder
#
################################################################################
include(gtest)
################################################################################
# Helper Functions
################################################################################
# Helper function to create a gtest executable with common properties
# All builder tests share the same compilation settings and dependencies
function(add_ck_builder_test test_name)
add_executable(${test_name} ${ARGN} testing_utils.cpp)
target_compile_features(${test_name} PRIVATE cxx_std_20)
@@ -19,17 +58,51 @@ function(add_ck_builder_test test_name)
target_link_libraries(${test_name} PRIVATE GTest::gtest_main GTest::gmock)
endfunction()
# The test_ckb_conv_builder target has all the unit tests (each test should run < 10 ms)
# Factory tests attempt to build all the kernels needed by MIOpen.
# These are only for regression testing and development; the builds are too
# expensive for regular use in CI.
function(add_ck_factory_test test_name)
add_ck_builder_test(${test_name} ${ARGN})
target_link_libraries(${test_name} PRIVATE composablekernels::device_conv_operations)
endfunction()
################################################################################
# SMOKE TESTS - Fast Unit Tests (No Kernel Compilation)
################################################################################
# These tests verify the builder's internal logic without compiling GPU kernels.
# They should complete in under 10ms each and are suitable for frequent execution
# during development.
add_ck_builder_test(test_ckb_conv_builder
test_conv_builder.cpp
test_fwd_instance_traits.cpp
test_bwd_weight_instance_traits.cpp
test_bwd_data_instance_traits.cpp
test_instance_traits_util.cpp)
test_instance_traits_util.cpp
)
# Tests the inline diff utility used for comparing strings in tests assertions
add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp)
add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp)
# Tests convolution trait selection and configuration
add_ck_builder_test(test_ckb_conv_traits
conv/test_conv_traits.cpp)
# Tests convolution problem description and parameter handling
add_ck_builder_test(test_ckb_conv_description
test_conv_description.cpp)
################################################################################
# REGRESSION TESTS - Integration Tests (With Kernel Compilation)
################################################################################
# These tests compile actual GPU kernels to verify the builder generates valid,
# compilable code. They are more expensive but catch real-world issues.
# Testing the virtual GetInstanceString methods requires kernel compilation.
# Verifies that GetInstanceString() methods produce valid kernel code.
# Tests various convolution types:
# - Group convolution (v3, standard, large tensor, WMMA, DL variants)
# - Backward weight group convolution (XDL)
# Requires kernel compilation to validate the generated strings.
add_ck_builder_test(test_ckb_get_instance_string
test_get_instance_string_fwd_grp_conv_v3.cpp
test_get_instance_string_fwd_grp_conv.cpp
@@ -38,8 +111,8 @@ add_ck_builder_test(test_ckb_get_instance_string
test_get_instance_string_fwd_grp_conv_dl.cpp
test_get_instance_string_bwd_weight_grp_conv_xdl.cpp)
# Testing the fwd convolution builder requires kernel compilation.
# To enable parallel compilation, the individual tests are split into separate files.
# Tests the forward convolution builder across multiple data types and dimensions.
# Individual tests are split into separate files to enable parallel compilation.
add_ck_builder_test(test_ckb_build_fwd_instances
conv/test_ckb_conv_fwd_1d_fp16.cpp
conv/test_ckb_conv_fwd_1d_bf16.cpp
@@ -55,15 +128,21 @@ add_ck_builder_test(test_ckb_build_fwd_instances
conv/test_ckb_conv_fwd_3d_fp32.cpp
)
# Factory tests attempt to build all the kernels need by MIOpen.
# This is only for regression testing and development, the builds are too expensive for regular use in CI.
function(add_ck_factory_test test_name)
add_ck_builder_test(${test_name} ${ARGN})
target_link_libraries(${test_name} PRIVATE composablekernels::device_conv_operations)
endfunction()
# TODO: add these tests back in once we have CI working across all GPU architectures.
################################################################################
# FACTORY TESTS - Expensive Regression Tests (Full MIOpen Kernel Set)
################################################################################
# These tests attempt to build ALL kernels needed by MIOpen for various
# convolution operations. They are extremely expensive (minutes to compile)
# and are intended for deep regression testing and development only.
# NOT suitable for regular CI runs.
#
# Many tests are commented out pending CI support across all GPU architectures.
# Tests the testing utilities themselves
add_ck_factory_test(test_ckb_testing_utils test_testing_utils.cpp)
# TODO: Re-enable these tests once we have CI working across all GPU architectures.
# add_ck_factory_test(test_ckb_factory_grouped_convolution_forward test_ck_factory_grouped_convolution_forward.cpp)
# add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_clamp test_ck_factory_grouped_convolution_forward_clamp.cpp)
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_convscale test_ck_factory_grouped_convolution_forward_convscale.cpp)
@@ -75,22 +154,30 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_ab tes
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp)
add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp)
add_ck_builder_test(test_ckb_conv_traits
conv/test_conv_traits.cpp)
################################################################################
# CTest Integration - Register Tests and Assign Labels
################################################################################
# Tests are registered with CTest and labeled for selective execution:
# - BUILDER_SMOKE: Fast unit tests for frequent development cycles
# - BUILDER_REGRESSION: Slower integration tests for pre-submission validation
add_ck_builder_test(test_ckb_conv_description
test_conv_description.cpp)
# Register tests with CTest and assign labels
include(CTest)
# Smoke test: fast-compiling unit test
add_test(NAME test_ckb_conv_builder COMMAND test_ckb_conv_builder)
set_tests_properties(test_ckb_conv_builder PROPERTIES LABELS "BUILDER_SMOKE")
# Regression tests: all other tests that require kernel compilation
set(CKB_REGRESSION_TESTS
# Register all smoke tests (fast unit tests, no kernel compilation)
set(CKB_SMOKE_TESTS
test_ckb_conv_builder
test_ckb_inline_diff
test_ckb_conv_traits
test_ckb_conv_description
)
foreach(test_target ${CKB_SMOKE_TESTS})
add_test(NAME ${test_target} COMMAND ${test_target})
set_tests_properties(${test_target} PROPERTIES LABELS "BUILDER_SMOKE")
endforeach()
# Register all regression tests (integration tests with kernel compilation)
set(CKB_REGRESSION_TESTS
test_ckb_get_instance_string
test_ckb_build_fwd_instances
test_ckb_testing_utils
@@ -98,8 +185,6 @@ set(CKB_REGRESSION_TESTS
test_ckb_factory_grouped_convolution_forward_scaleadd_ab
test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu
test_ckb_factory_grouped_convolution_forward_dynamic_op
test_ckb_conv_traits
test_ckb_conv_description
)
foreach(test_target ${CKB_REGRESSION_TESTS})
@@ -107,18 +192,31 @@ foreach(test_target ${CKB_REGRESSION_TESTS})
set_tests_properties(${test_target} PROPERTIES LABELS "BUILDER_REGRESSION")
endforeach()
# Helper target to build all regression tests
################################################################################
# Custom Build Targets - Convenient Test Execution
################################################################################
# These targets provide convenient ways to build and run different test suites:
# - smoke-builder: Quick sanity check during development
# - regression-builder: Thorough validation before submitting changes
# - check-builder: Complete test suite execution
# Helper target to build all smoke tests (without running them)
add_custom_target(build-smoke-builder DEPENDS ${CKB_SMOKE_TESTS})
# Helper target to build all regression tests (without running them)
add_custom_target(build-regression-builder DEPENDS ${CKB_REGRESSION_TESTS})
# Target to run only smoke tests (builds only test_ckb_conv_builder)
# Target to run only smoke tests (builds and runs all smoke test executables)
# Use this for quick feedback during active development
add_custom_target(smoke-builder
COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "BUILDER_SMOKE"
DEPENDS test_ckb_conv_builder
DEPENDS build-smoke-builder
USES_TERMINAL
COMMENT "Running experimental builder smoke tests..."
)
# Target to run only regression tests (builds all regression test executables)
# Target to run only regression tests (builds and runs all regression test executables)
# Use this before submitting changes to catch integration issues
add_custom_target(regression-builder
COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "BUILDER_REGRESSION"
DEPENDS build-regression-builder
@@ -126,15 +224,20 @@ add_custom_target(regression-builder
COMMENT "Running experimental builder regression tests..."
)
# Target to run all builder tests (builds all test executables)
# Target to run all builder tests (builds and runs all test executables)
# Use this for comprehensive validation
add_custom_target(check-builder
COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -R "^test_ckb"
DEPENDS test_ckb_conv_builder build-regression-builder
DEPENDS build-smoke-builder build-regression-builder
USES_TERMINAL
COMMENT "Running all experimental builder tests..."
)
# Print summary of test organization
################################################################################
# Build Summary
################################################################################
# Print summary of test organization for developer reference
message(STATUS "CK Builder test organization:")
message(STATUS " Smoke test: test_ckb_conv_builder")
message(STATUS " Smoke tests: ${CKB_SMOKE_TESTS}")
message(STATUS " Regression tests: ${CKB_REGRESSION_TESTS}")

View File

@@ -127,41 +127,39 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription)
"│ ├─ Input elementwise operation: PASS_THROUGH\n"
"│ ├─ Weights elementwise operation: PASS_THROUGH\n"
"│ └─ Output elementwise operation: PASS_THROUGH\n"
"─ Algorithm\n"
" ├─ Thread block size: 256\n"
" ├─ Data tile size: 256×256×32\n"
" ├─ Gemm padding: DEFAULT\n"
" ├─ Convolution specialization: DEFAULT\n"
" ├─ Pipeline version: V4\n"
" ├─ Pipeline scheduler: INTRAWAVE\n"
" ├─ Warp Gemm parameters: \n"
" │ ├─ subtile size: 16×16\n"
" │ └─ Number of warp gemm iterations: 4×4\n"
" ─ Memory access:\n"
" ├─ A Tile transfer: \n"
" │ ├─ Tile dimensions: 4×256×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
" │ ├─ Vector access (LDS write) instruction size: 8\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
" ├─ B Tile transfer: \n"
" │ ├─ Tile dimensions: 4×256×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
" │ ├─ Vector access (LDS write) instruction size: 8\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
" └─ C Tile transfer: \n"
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
" └─ Vector access (GMEM write) instruction size: 8\n"
"│ └─ \n"
"└─ "));
"─ Algorithm\n"
" ├─ Thread block size: 256\n"
" ├─ Data tile size: 256×256×32\n"
" ├─ Gemm padding: DEFAULT\n"
" ├─ Convolution specialization: DEFAULT\n"
" ├─ Pipeline version: V4\n"
" ├─ Pipeline scheduler: INTRAWAVE\n"
" ├─ Warp Gemm parameters: \n"
" │ ├─ subtile size: 16×16\n"
" │ └─ Number of warp gemm iterations: 4×4\n"
" ─ Memory access:\n"
" ├─ A Tile transfer: \n"
" │ ├─ Tile dimensions: 4×256×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
" │ ├─ Vector access (LDS write) instruction size: 8\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
" ├─ B Tile transfer: \n"
" │ ├─ Tile dimensions: 4×256×8×\n"
" │ ├─ The innermost K subdimension size: 8\n"
" │ ├─ Spatial thread distribution over the data tile: 0×1×2\n"
" │ ├─ The order of accessing data tile axes: 0×1×2\n"
" │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n"
" │ ├─ Vector access (GMEM read) instruction size: 8\n"
" │ ├─ Vector access (LDS write) instruction size: 8\n"
" │ └─ LDS data layout padding (to prevent bank conflicts): 8\n"
" └─ C Tile transfer: \n"
" ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n"
" ├─ Spatial thread distribution used to store data: 1×32×1×8\n"
" └─ Vector access (GMEM write) instruction size: 8"));
}
// NOTE: BackwardDataInstanceHasDetailedDescription test is disabled because ConvFactory

View File

@@ -199,7 +199,7 @@ struct BaseArgument
BaseArgument(const BaseArgument&) = default;
BaseArgument& operator=(const BaseArgument&) = default;
virtual ~BaseArgument() {}
virtual __host__ __device__ ~BaseArgument() {}
void* p_workspace_ = nullptr;
};

View File

@@ -0,0 +1,827 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <sstream>
#include "ck/ck.hpp"
#include "ck/utility/env.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename GemmDesc,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename Block2CTileMap,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_grouped_gemm_wmma_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count)
{
#if(defined(__gfx11__) || defined(__gfx12__))
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
__shared__ char p_shared[LDS_size];
const index_t block_id = get_block_1d_id();
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
// Binary search lookup to find which group this block is part of
index_t left = 0;
index_t right = group_count;
index_t group_id = index_t((left + right) / 2);
while((!(block_id >= gemm_desc_ptr[group_id].block_start_ &&
block_id < gemm_desc_ptr[group_id].block_end_)) &&
left <= right)
{
if(block_id < gemm_desc_ptr[group_id].block_start_)
{
right = group_id;
}
else
{
left = group_id;
}
group_id = index_t((left + right) / 2);
}
// NOTE: Local copy of the arg struct since SplitKBatchOffset verifies and modifies K index
// and thus needs a non-const reference. It's also not feasible to store this in global
// memory as different threads would be writing different K values to the same arg struct
auto karg = gemm_desc_ptr[group_id].karg_;
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<c_data_type, ck::half_t> ||
std::is_same_v<c_data_type, ck::bhalf_t>)))
{
#endif
const auto& block_2_ctile_map = gemm_desc_ptr[group_id].block_2_ctile_map_;
// Tile index first dimension is the K batch
auto tile_index =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, tile_index[Number<0>{}]);
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
GridwiseGemm::template Run<HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum,
Block2CTileMap,
typename GridwiseGemm::EpilogueCShuffle,
1, // Block2CTileMap MBlock index
2 // Block2CTileMap NBlock index
>(static_cast<void*>(p_shared),
splitk_batch_offset,
karg,
block_2_ctile_map,
epilogue_args);
#if defined(__gfx11__)
}
#endif
#else
ignore = gemm_descs_const;
ignore = group_count;
#endif // end of if(defined(__gfx11__) || defined(__gfx12__))
}
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
ck::index_t NumGemmKPrefetchStage,
ck::index_t BlockSize,
ck::index_t MPerBlock,
ck::index_t NPerBlock,
ck::index_t KPerBlock,
ck::index_t AK1,
ck::index_t BK1,
ck::index_t MPerWmma,
ck::index_t NPerWmma,
ck::index_t MRepeat,
ck::index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
ck::index_t ABlockTransferSrcVectorDim,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
ck::index_t BBlockTransferSrcVectorDim,
ck::index_t BBlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = EDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static_assert(KPerBlock % AK1 == 0);
static constexpr index_t K0PerBlock = KPerBlock / AK1;
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
ALayout,
BLayout,
DsLayout,
ELayout,
Tuple<ADataType>,
Tuple<BDataType>,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<CDEBlockTransferScalarPerVector_NPerBlock>,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
false, // PermuteA not supported by DeviceBatchedGemm base class.
false>; // PermuteB not supported by DeviceBatchedGemm base class.
using CGridDesc_M_N =
remove_cvref_t<decltype(GridwiseGemm::template MakeDEGridDescriptor_M_N<ELayout>(
1, 1, 1, 1, 1))>;
using Block2ETileMapKSplit =
BlockToCTileMap_KSplit_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>;
// Block2CTileMap configuration parameter.
static constexpr index_t B2E_M01 = 8;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
using KernelArgument = typename GridwiseGemm::Argument;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
template <typename KernelArgument_>
struct GemmTransKernelArgBase
{
KernelArgument_ karg_;
GroupedGemmBlock2ETileMap block_2_ctile_map_;
index_t block_start_, block_end_;
GemmTransKernelArgBase() = default;
GemmTransKernelArgBase(KernelArgument_&& karg,
GroupedGemmBlock2ETileMap&& b2c_map,
index_t block_start,
index_t block_end)
: karg_{karg},
block_2_ctile_map_{b2c_map},
block_start_{block_start},
block_end_{block_end}
{
}
};
using GemmTransKernelArg = GemmTransKernelArgBase<KernelArgument>;
static constexpr index_t DefaultKBatch = 1;
static constexpr bool CalculateHasMainKBlockLoop(const KernelArgument& karg)
{
index_t k_grain = karg.KBatch * KPerBlock;
index_t K_split = (karg.K + k_grain - 1) / karg.KBatch;
return GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
}
// Argument
// TODO: Add A/B/CDE element op?
struct Argument : public BaseArgument
{
Argument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs)
: Argument(p_As, p_Bs, p_Es, gemm_descs, DefaultKBatch)
{
// TODO: use occupancy api to calculate appropriate batch size.
}
Argument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
index_t kbatch)
: K_BATCH{kbatch}, gemm_kernel_host_args_{nullptr}
{
grid_size_ = 0;
group_count_ = ck::type_convert<ck::index_t>(gemm_descs.size());
if(!(group_count_ == ck::type_convert<ck::index_t>(p_As.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Bs.size()) &&
group_count_ == ck::type_convert<ck::index_t>(p_Es.size())))
{
throw std::runtime_error("wrong! group_count_ != p_As/b/c.size");
}
gemm_kernel_args_.reserve(group_count_);
skipped_group_count_ = 0;
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
{
const index_t M = gemm_descs[i].M_;
const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_;
if(M == 0)
{
skipped_group_count_++;
continue;
}
const index_t stride_a = gemm_descs[i].stride_A_;
const index_t stride_b = gemm_descs[i].stride_B_;
const index_t stride_c = gemm_descs[i].stride_C_;
const index_t m_padded = GridwiseGemm::CalculateMPadded(M);
const index_t n_padded = GridwiseGemm::CalculateNPadded(N);
const auto c_grid_desc_m_n =
GridwiseGemm::template MakeDEGridDescriptor_M_N<ELayout>(
M, m_padded, N, n_padded, stride_c);
const auto local_b2c_tile_map =
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t block_start = grid_size_;
const index_t block_end = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
// block-to-e-tile map
auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
auto karg = KernelArgument(std::array<const void*, 1>{p_As[i]},
std::array<const void*, 1>{p_Bs[i]},
std::array<const void*, 0>{}, // p_ds_grid_
type_convert<EDataType*>(p_Es[i]),
M,
N,
K,
std::array<index_t, 1>{stride_a},
std::array<index_t, 1>{stride_b},
std::array<index_t, 0>{}, // StrideDs_
stride_c,
K_BATCH,
PassThrough{},
PassThrough{},
PassThrough{},
false);
gemm_kernel_args_.emplace_back(
std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end);
}
}
/**
* @brief Recalculate group grid size for all gemms and update B2C maps.
*
* @param[in] kbatch The new splitK parameter value.
*/
void UpdateKBatch(index_t kbatch)
{
K_BATCH = kbatch;
grid_size_ = 0;
for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i)
{
auto& karg = gemm_kernel_args_[i].karg_;
const index_t k_read = GridwiseGemm::CalculateKRead(karg.K, K_BATCH);
const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH);
const index_t ak0_padded = GridwiseGemm::CalculateAK0Padded(karg.K, K_BATCH);
const index_t bk0_padded = GridwiseGemm::CalculateBK0Padded(karg.K, K_BATCH);
const auto c_grid_desc_m_n =
GridwiseGemm::template MakeDEGridDescriptor_M_N<ELayout>(
karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideE);
const auto local_b2c_tile_map =
Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH};
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n);
const index_t block_start = grid_size_;
const index_t block_end = grid_size_ + grid_size_grp;
grid_size_ += grid_size_grp;
// block-to-e-tile map
auto grouped_block_2_ctile_map =
GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start);
karg.KRead = k_read;
karg.KPadded = k_padded;
karg.AK0 = ak0_padded;
karg.BK0 = bk0_padded;
karg.KBatch = K_BATCH;
gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map;
gemm_kernel_args_[i].block_start_ = block_start;
gemm_kernel_args_[i].block_end_ = block_end;
}
}
// private:
index_t K_BATCH;
index_t group_count_;
index_t skipped_group_count_;
std::vector<GemmTransKernelArg> gemm_kernel_args_;
void* gemm_kernel_host_args_;
index_t grid_size_;
};
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg,
const StreamConfig& stream_config = StreamConfig{},
hipStream_t cpy_stream = nullptr,
hipEvent_t cpy_event = nullptr)
{
using GemmTransKernelArg_ = GemmTransKernelArgBase<typename GridwiseGemm::Argument>;
static_assert(sizeof(GemmTransKernelArg_) == sizeof(GemmTransKernelArg));
bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.KBatch > 1;
bool all_have_main_k0_block_loop =
CalculateHasMainKBlockLoop(arg.gemm_kernel_args_[0].karg_);
bool not_all_have_main_k0_block_loop_same = false;
bool not_all_have_kbatch_value_same = false;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
const auto& karg = reinterpret_cast<const typename GridwiseGemm::Argument&>(
arg.gemm_kernel_args_[i].karg_);
if(stream_config.log_level_ > 0)
{
karg.Print();
}
auto kbatch = karg.KBatch;
if(!GridwiseGemm::CheckValidity(karg))
{
std::ostringstream err;
err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
not_all_have_main_k0_block_loop_same |=
all_have_main_k0_block_loop xor CalculateHasMainKBlockLoop(karg);
not_all_have_kbatch_value_same |= all_have_kbatch_gt_one xor (kbatch > 1);
}
if(not_all_have_main_k0_block_loop_same)
{
std::ostringstream err;
err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
// throw std::runtime_error(err.str());
}
if(not_all_have_kbatch_value_same)
{
std::ostringstream err;
err << "Not all gemms have same kbatch value (=1 or >1)! " << " in " << __FILE__
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
// If the user provides copy stream and copy event, we assume that they're also
// responsible for providing allocated host memory (eg. pinned) which
// would be used to copy kernel arguments to the device.
if(cpy_stream && cpy_event)
{
if(arg.gemm_kernel_host_args_ == nullptr)
{
std::ostringstream err;
err << "No memory has been allocated for gemm kernel host args "
<< "when providing the copy stream and copy event! In " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
hip_check_error(hipMemcpyAsync(arg.p_workspace_,
arg.gemm_kernel_host_args_,
arg.group_count_ * sizeof(GemmTransKernelArg_),
hipMemcpyHostToDevice,
cpy_stream));
hip_check_error(hipEventRecord(cpy_event, cpy_stream));
hip_check_error(hipEventSynchronize(cpy_event));
}
else // In this case CK owns memory allocated on host.
{
hip_check_error(
hipMemcpyAsync(arg.p_workspace_,
arg.gemm_kernel_args_.data(),
arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg_),
hipMemcpyHostToDevice,
stream_config.stream_id_));
}
float ave_time = 0;
const auto Run = [&](const auto& kernel) {
if(all_have_kbatch_gt_one)
{
for(const auto& trans_arg : arg.gemm_kernel_args_)
{
const auto& karg = trans_arg.karg_;
hip_check_error(hipMemsetAsync(karg.p_e_grid,
0,
karg.M * karg.N * sizeof(EDataType),
stream_config.stream_id_));
}
}
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_kernel_args_.size());
};
// NOTE: If at least one gemm problem has a main k0 block loop, we include it for all
if(all_have_main_k0_block_loop || not_all_have_main_k0_block_loop_same)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(all_have_kbatch_gt_one)
{
const auto kernel =
kernel_grouped_gemm_wmma_splitk<GridwiseGemm,
GemmTransKernelArg_,
true,
InMemoryDataOperationEnum::AtomicAdd,
GroupedGemmBlock2ETileMap>;
Run(kernel);
}
else
{
const auto kernel =
kernel_grouped_gemm_wmma_splitk<GridwiseGemm,
GemmTransKernelArg_,
true,
InMemoryDataOperationEnum::Set,
GroupedGemmBlock2ETileMap>;
Run(kernel);
}
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(all_have_kbatch_gt_one)
{
const auto kernel =
kernel_grouped_gemm_wmma_splitk<GridwiseGemm,
GemmTransKernelArg_,
false,
InMemoryDataOperationEnum::AtomicAdd,
GroupedGemmBlock2ETileMap>;
Run(kernel);
}
else
{
const auto kernel =
kernel_grouped_gemm_wmma_splitk<GridwiseGemm,
GemmTransKernelArg_,
false,
InMemoryDataOperationEnum::Set,
GroupedGemmBlock2ETileMap>;
Run(kernel);
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
{
return false;
}
if constexpr(std::is_same_v<EDataType, ck::half_t> ||
std::is_same_v<EDataType, ck::bhalf_t>)
{
if(arg.K_BATCH > 1 && ck::is_gfx11_supported())
{
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
return false;
}
}
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
{
if(ck::is_gfx11_supported())
{
return false;
}
}
if((ck::type_convert<ck::index_t>(arg.gemm_kernel_args_.size()) +
arg.skipped_group_count_) != arg.group_count_)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "The group count is not equal to sum of skipped groups "
"and kernel args size!"
<< std::endl;
}
return false;
}
bool supported = true;
for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i)
{
const auto& a = arg.gemm_kernel_args_[i].karg_;
bool group_arg_valid = GridwiseGemm::CheckValidity(a);
if(not group_arg_valid)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[" << __func__ << "] group id: " << i
<< " has invalid GridwiseGemm settings!" << std::endl;
a.Print();
}
}
supported = supported && group_arg_valid;
}
return supported;
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>&,
std::vector<void*>& p_Es,
std::vector<GemmDesc> gemm_descs,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation)
{
return Argument{p_As, p_Bs, p_Es, gemm_descs};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_As,
std::vector<const void*>& p_Bs,
std::vector<std::array<const void*, NumDTensor>>&,
std::vector<void*>& p_Es,
std::vector<GemmDesc>& gemm_descs,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation) override
{
return std::make_unique<Argument>(p_As, p_Bs, p_Es, gemm_descs);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGroupedGemm_WmmaSplitK"
<< "<"
<< std::string(ALayout::name)[0] << ","
<< std::string(BLayout::name)[0] << ","
<< std::string(ELayout::name)[0] << ","
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerWmma << ", "
<< NPerWmma << ", "
<< MRepeat << ", "
<< NRepeat << ", "
<< ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", "
<< CShuffleMRepeatPerShuffle << ", "
<< CShuffleNRepeatPerShuffle << ", "
<< getGemmSpecializationString(GemmSpec) << ", "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer]
<< ">";
// clang-format on
return str.str();
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto p_arg_ = dynamic_cast<const Argument*>(p_arg);
if(p_arg_)
{
return p_arg_->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg);
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Wmma_CShuffleV3::Argument structure!");
}
size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override
{
return GetWorkSpaceSize(p_arg);
}
size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); }
// TODO: deperecation notice.
static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); }
// polymorphic
void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override
{
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
if(p_arg_)
{
p_arg_->UpdateKBatch(kbatch);
}
else
throw std::runtime_error("The argument pointer is not an object of "
"DeviceGroupedGemm_Wmma_CShuffleV3::Argument structure!");
}
void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override
{
return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args);
}
//----------------------------------------------------------------------------------------------
/// @brief Sets the host kernel arguments pointer and copies that data on the host side.
/// This function can be utilised to use pinned memory for the host args and
/// achieve fully async data copy.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_host_kernel_args The pointer to the host memory where the kernel
/// arguments will be copied
///
void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const
{
Argument* pArg_ = dynamic_cast<Argument*>(p_arg);
if(!pArg_)
{
throw std::runtime_error("Failed to cast argument pointer!");
}
pArg_->gemm_kernel_host_args_ = p_host_kernel_args;
std::copy(pArg_->gemm_kernel_args_.begin(),
pArg_->gemm_kernel_args_.end(),
static_cast<GemmTransKernelArg*>(pArg_->gemm_kernel_host_args_));
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -470,9 +470,9 @@ struct GridwiseGemm_wmma_cshuffle_v3
DsGridPointer p_ds_grid;
EDataType* p_e_grid;
const AElementwiseOperation a_element_op;
const BElementwiseOperation b_element_op;
const CDEElementwiseOperation cde_element_op;
AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op;
CDEElementwiseOperation cde_element_op;
// TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd
bool is_reduce;
@@ -555,13 +555,17 @@ struct GridwiseGemm_wmma_cshuffle_v3
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum,
typename EpilogueArgument>
typename Block2CTileMap,
typename EpilogueArgument,
int BlockMapMBlockIndex = 0,
int BlockMapNBlockIndex = 1>
__device__ static void Run(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
const Problem& problem,
const Block2CTileMap& block_2_ctile_map,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
@@ -582,9 +586,6 @@ struct GridwiseGemm_wmma_cshuffle_v3
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
@@ -596,8 +597,10 @@ struct GridwiseGemm_wmma_cshuffle_v3
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
const index_t block_m_id =
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapMBlockIndex>{}]);
const index_t block_n_id =
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapNBlockIndex>{}]);
// BScale struct (Empty)
using BScale = typename BlockwiseGemmPipe::Empty;
@@ -632,15 +635,51 @@ struct GridwiseGemm_wmma_cshuffle_v3
epilogue_args);
}
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum,
typename EpilogueArgument>
__device__ static void Run(AsGridPointer& p_as_grid,
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
EDataType* p_e_grid,
void* p_shared,
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
EpilogueArgument& epilogue_args)
{
Run<HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum,
Block2CTileMap,
EpilogueArgument>(p_as_grid,
p_bs_grid,
p_ds_grid,
p_e_grid,
p_shared,
problem,
DefaultBlock2CTileMap(problem),
a_element_op,
b_element_op,
cde_element_op,
epilogue_args);
}
// Wrapper function to have __global__ function in common
// between gemm_universal, b_scale, ab_scale, etc.
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum,
typename EpilogueArgument>
typename Block2CTileMap,
typename EpilogueArgument,
int BlockMapMBlockIndex = 0,
int BlockMapNBlockIndex = 1>
__device__ static void Run(void* p_shared,
const SplitKBatchOffset& splitk_batch_offset,
Argument& karg,
const Block2CTileMap& block_2_ctile_map,
EpilogueArgument& epilogue_args)
{
// shift A matrices pointer for splitk
@@ -659,17 +698,47 @@ struct GridwiseGemm_wmma_cshuffle_v3
splitk_batch_offset.b_k_split_offset[i];
});
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_as_grid_splitk,
p_bs_grid_splitk,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
epilogue_args);
Run<HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum,
Block2CTileMap,
EpilogueArgument,
BlockMapMBlockIndex,
BlockMapNBlockIndex>(p_as_grid_splitk,
p_bs_grid_splitk,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
p_shared,
karg,
block_2_ctile_map,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
epilogue_args);
}
// Wrapper function to have __global__ function in common
// between gemm_universal, b_scale, ab_scale, etc.
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
TailNumber TailNum,
typename EpilogueArgument>
__device__ static void Run(void* p_shared,
const SplitKBatchOffset& splitk_batch_offset,
Argument& karg,
EpilogueArgument& epilogue_args)
{
Run<HasMainKBlockLoop,
EGlobalMemoryDataOperation,
TailNum,
Block2CTileMap,
EpilogueArgument>(
p_shared, splitk_batch_offset, karg, DefaultBlock2CTileMap(karg), epilogue_args);
}
__device__ static auto DefaultBlock2CTileMap(const Problem& problem)
{
return Block2CTileMap{problem.M, problem.N, 4};
}
};

View File

@@ -729,6 +729,13 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Arg K value too low for combination of AK1/BK1/KBatch. AK1: "
<< AK1Number << ", BK1: " << BK1Number << ", KBatch: " << karg.KBatch
<< ", K: " << karg.K << " " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__ << std::endl;
}
return false;
}
}

View File

@@ -293,6 +293,15 @@ struct tile_window_with_static_distribution
0, dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
}
template <typename offset_t>
CK_TILE_DEVICE constexpr auto get_load_offset(offset_t = {}) const
{
constexpr auto bottom_tensor_idx_off = to_multi_index(offset_t{});
const auto bottom_tensor_coord_off = make_tensor_coordinate(
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_idx_off);
return amd_wave_read_first_lane(bottom_tensor_coord_off.get_offset());
}
template <typename DataType,
typename StaticTileDistribution,
index_t i_access_unsupport_ = -1,
@@ -316,12 +325,7 @@ struct tile_window_with_static_distribution
else if constexpr(is_constant_v<offset_t>)
return offset_t::value;
else
{
auto bottom_tensor_idx_off = to_multi_index(offset_t{});
auto bottom_tensor_coord_off = make_tensor_coordinate(
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_idx_off);
return bottom_tensor_coord_off.get_offset();
}
return get_load_offset(offset_t{});
}();
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {

View File

@@ -46,8 +46,8 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
static constexpr index_t flatKPerWarp = get_warp_size() * ContinuousKPerThread;
};
template <typename Problem, typename PipelinePolicy = MXF4FlatmmPipelineAgBgCrPolicy>
struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>
template <typename Problem, typename PipelinePolicy = MXFlatmmPipelineAgBgCrPolicy>
struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>
{
using Underlying = FlatmmPipelineAGmemBGmemCRegV1<Problem, PipelinePolicy>;
@@ -470,17 +470,39 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
return PipelinePolicy::template MakeADramTileDistribution<Problem>();
}
template <typename... Args>
CK_TILE_DEVICE auto operator()(Args&&... args) const
{
auto c_warp_tensors = Run_(std::forward<Args>(args)...);
// Block GEMM Acc register tile
using CWarpDstr = typename WG::CWarpDstr;
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
auto c_block_tile = BlockFlatmm{}.MakeCBlockTile();
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensors(mIter)(nIter).get_thread_buffer());
});
});
return c_block_tile;
}
template <typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename ScaleADramBlockWindowTmp,
typename ScaleBDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_copy_dram_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const ScaleADramBlockWindowTmp& scale_a_window,
const ScaleBDramBlockWindowTmp& scale_b_window,
index_t num_loop,
void* __restrict__ p_smem_ping,
void* __restrict__ p_smem_pong) const
CK_TILE_DEVICE auto Run_(const ADramBlockWindowTmp& a_copy_dram_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const ScaleADramBlockWindowTmp& scale_a_window,
const ScaleBDramBlockWindowTmp& scale_b_window,
index_t num_loop,
void* __restrict__ p_smem_ping,
void* __restrict__ p_smem_pong) const
{
#ifndef __gfx950__
static_assert(false, "Only gfx950 is supported for MXFP4 flatmm pipeline now.");
@@ -497,19 +519,14 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
// constexpr auto MIter_2nd_last = max(0, MIterPerWarp - 2);
static_assert(NWarp == 4);
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
auto a_dram_window =
make_tile_window(PipelinePolicy::template MakeMXFP4_AAsyncLoadDramDescriptor<Problem>(
make_tile_window(PipelinePolicy::template MakeMX_AAsyncLoadDramDescriptor<Problem>(
a_copy_dram_window_tmp.get_bottom_tensor_view()),
a_copy_dram_window_tmp.get_window_lengths(),
a_copy_dram_window_tmp.get_window_origin(),
PipelinePolicy::template MakeMXFP4_ADramTileDistribution<Problem>());
PipelinePolicy::template MakeMX_ADramTileDistribution<Problem>());
__builtin_amdgcn_sched_barrier(0);
@@ -518,7 +535,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeMXFP4_ALdsBlockDescriptor<Problem>();
PipelinePolicy::template MakeMX_ALdsBlockDescriptor<Problem>();
auto a_lds_block_ping =
make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
@@ -535,39 +552,34 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
make_tile_window(a_lds_block_ping,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{0, 0},
PipelinePolicy::template MakeMXF4_ALDS_TileDistribution<Problem>());
PipelinePolicy::template MakeMX_ALDS_TileDistribution<Problem>());
auto a_warp_window_pong =
make_tile_window(a_lds_block_pong,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{0, 0},
PipelinePolicy::template MakeMXF4_ALDS_TileDistribution<Problem>());
// Block GEMM
auto block_flatmm = BlockFlatmm();
// Acc register tile
auto c_block_tile = block_flatmm.MakeCBlockTile();
PipelinePolicy::template MakeMX_ALDS_TileDistribution<Problem>());
// B flat DRAM window for load
// pingpong buffer for B
auto b_flat_dram_windows = generate_tuple(
auto b_flat_dram_window =
make_tile_window(b_flat_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeMX_BFlatDramTileDistribution<Problem>());
auto b_flat_dram_offsets = generate_tuple(
[&](auto nIter) {
constexpr auto packed_n_idx = nIter / number<NXdlPack>{};
constexpr auto packed_n_rank = nIter % number<NXdlPack>{};
auto window_i = make_tile_window(
b_flat_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution<Problem>());
move_tile_window(
window_i,
{number<packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank>{},
number<0>{}});
return window_i;
return b_flat_dram_window.get_load_offset(
tuple<number<packed_n_idx * NXdlPack * NFlatPerBlockPerIter>,
number<0>>{}) +
b_flat_dram_window.get_load_offset(
tuple<number<packed_n_rank>, number<0>>{});
},
number<NIterPerWarp>{});
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_windows(I0))), KIterPerWarp>,
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
NIterPerWarp>
b_warp_tensor_ping, b_warp_tensor_pong;
@@ -576,41 +588,37 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
scale_a_window.get_bottom_tensor_view(),
make_tuple(number<MWarp * WG::kM>{}, number<64 / WG::kM>{}),
scale_a_window.get_window_origin(),
PipelinePolicy::template MakeMXFP4_ScaleA_FlatDramTileDistribution<Problem>());
PipelinePolicy::template MakeMX_ScaleA_FlatDramTileDistribution<Problem>());
const auto scale_a_dram_step_m = amd_wave_read_first_lane(
scale_a_dram_window.get_load_offset(tuple<number<MWarp * WG::kM>, number<0>>{}));
const auto scale_a_dram_step_k = amd_wave_read_first_lane(
scale_a_dram_window.get_load_offset(tuple<number<0>, number<64 / WG::kM>>{}));
auto scale_b_dram_window = make_tile_window(
scale_b_window.get_bottom_tensor_view(),
make_tuple(number<NWarp * WG::kN>{}, number<64 / WG::kN>{}),
scale_b_window.get_window_origin(),
PipelinePolicy::template MakeMXFP4_ScaleB_DramTileDistribution<Problem>());
PipelinePolicy::template MakeMX_ScaleB_DramTileDistribution<Problem>());
const auto scale_b_dram_step_n = amd_wave_read_first_lane(
scale_b_dram_window.get_load_offset(tuple<number<NWarp * WG::kN>, number<0>>{}));
const auto scale_b_dram_step_k = amd_wave_read_first_lane(
scale_b_dram_window.get_load_offset(tuple<number<0>, number<64 / WG::kN>>{}));
constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack;
constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack;
constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack;
// ping pong buffer for scale A
statically_indexed_array<
statically_indexed_array<decltype(scale_a_dram_window), KIterPerWarp / KXdlPack>,
MIterPerWarp / MXdlPack>
scale_a_dram_windows;
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_a_dram_window)),
KIterPerWarp / KXdlPack>,
MIterPerWarp / MXdlPack>
scale_a_tile_tensor_ping;
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_a_dram_window)),
KIterPerWarp / KXdlPack>,
MIterPerWarp / MXdlPack>
scale_a_tile_tensor_pong;
statically_indexed_array<decltype(load_tile(scale_a_dram_window)), KPackIterPerWarp>,
MPackIterPerWarp>
scale_a_tile_tensor_ping, scale_a_tile_tensor_pong;
// ping pong buffer for scale B
statically_indexed_array<
statically_indexed_array<decltype(scale_b_dram_window), KIterPerWarp / KXdlPack>,
NIterPerWarp / NXdlPack>
scale_b_dram_windows;
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_b_dram_window)),
KIterPerWarp / KXdlPack>,
NIterPerWarp / NXdlPack>
scale_b_tile_tensor_ping;
statically_indexed_array<statically_indexed_array<decltype(load_tile(scale_b_dram_window)),
KIterPerWarp / KXdlPack>,
NIterPerWarp / NXdlPack>
scale_b_tile_tensor_pong;
statically_indexed_array<decltype(load_tile(scale_b_dram_window)), KPackIterPerWarp>,
NPackIterPerWarp>
scale_b_tile_tensor_ping, scale_b_tile_tensor_pong;
auto async_load_tile_ = [](auto lds, auto dram) {
async_load_tile(lds, dram, number<-1>{}, true_type{}, false_type{});
@@ -625,35 +633,31 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
b_flat_dram_window, b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
});
// move B window to next flat K
move_tile_window(b_flat_dram_windows(nIter), {0, KIterPerWarp * KFlatPerBlockPerIter});
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
});
// prefetch Scale A
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = load_tile_with_offset(
scale_a_dram_window,
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) =
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k);
});
});
// move Scale A window to next K
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
// prefetch Scale B
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) =
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = load_tile_with_offset(
scale_b_dram_window,
nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k);
});
});
// move Scale B window to next K
@@ -667,7 +671,12 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
move_tile_window(a_dram_window, {0, kKPerBlock});
}
// initialize C
clear_tile(c_block_tile);
statically_indexed_array<statically_indexed_array<CWarpTensor, NIterPerWarp>, MIterPerWarp>
c_warp_tensors;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}(
[&](auto nIter) { clear_tile(c_warp_tensors(mIter)(nIter)); });
});
statically_indexed_array<decltype(load_tile(a_warp_window_pong)), m_preload> a_warp_tensor;
@@ -688,40 +697,37 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
b_flat_dram_window,
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
// move B window to next flat K
if constexpr(kIter == KIterPerWarp - 1)
move_tile_window(b_flat_dram_windows(nIter),
{0, BlockGemmShape::flatKPerBlock});
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
});
});
// prefetch Scale A and Scale B (2i+1)
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) =
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = load_tile_with_offset(
scale_a_dram_window,
mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k);
});
});
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) =
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = load_tile_with_offset(
scale_b_dram_window,
nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k);
});
});
// GEMM 2i
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
@@ -729,39 +735,22 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<m_iter, n_iter>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0],
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
.get_thread_buffer()[0]);
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<m_iter, n_iter>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
constexpr auto addr =
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
(nIter_pack == NPackIterPerWarp - 1))
{
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
@@ -802,81 +791,60 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
b_flat_dram_window,
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
// move B window to next flat K
if constexpr(kIter == KIterPerWarp - 1)
move_tile_window(b_flat_dram_windows(nIter),
{0, BlockGemmShape::flatKPerBlock});
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
tuple<number<0>, number<KIterPerWarp * KFlatPerBlockPerIter>>{});
});
});
// prefetch Scale A and Scale B (2i+2)
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) =
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = load_tile_with_offset(
scale_a_dram_window,
mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k);
});
});
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) =
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = load_tile_with_offset(
scale_b_dram_window,
nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k);
});
});
// GEMM 2i+1
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
constexpr auto m_iter = mIter_pack * MXdlPack + imxdl;
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{}),
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
.get_thread_buffer()[0]); // scale B
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
constexpr auto addr =
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
(nIter_pack == NPackIterPerWarp - 1))
{
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
@@ -928,78 +896,54 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_windows(nIter),
make_tuple(number<0>{}, number<kIter * KFlatPerBlockPerIter>{}));
b_flat_dram_window,
b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter);
});
});
// prefetch Scale A and Scale B (2i+1)
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window;
move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack),
{mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)});
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) =
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = load_tile_with_offset(
scale_a_dram_window,
mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k);
});
});
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack),
{nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)});
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) =
load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack));
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = load_tile_with_offset(
scale_b_dram_window,
nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k);
});
});
// GEMM loopK-1
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
constexpr auto m_iter = mIter_pack * MXdlPack + imxdl;
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
.get_thread_buffer()[0]); // scale B
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
constexpr auto addr =
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
(nIter_pack == NPackIterPerWarp - 1))
{
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
@@ -1028,50 +972,32 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
Last2ndHotLoopScheduler();
// GEMM loopK
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
constexpr auto m_iter = mIter_pack * MXdlPack + imxdl;
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
b_warp_tensor_pong(number<n_iter>{})(number<k_iter>{}),
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
.get_thread_buffer()[0]); // scale B
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
constexpr auto addr =
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
(nIter_pack == NPackIterPerWarp - 1))
{
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
@@ -1089,50 +1015,32 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
else if constexpr(TailNum == TailNumber::Odd)
{
// GEMM loopK
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) {
static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) {
static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) {
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
constexpr auto m_iter = mIter_pack * MXdlPack + imxdl;
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
c_warp_tensors(number<m_iter>{})(number<n_iter>{}),
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
b_warp_tensor_ping(number<n_iter>{})(number<k_iter>{}),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
.get_thread_buffer()[0]); // scale B
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
constexpr auto addr =
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
(nIter_pack == NPackIterPerWarp - 1))
{
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
@@ -1151,7 +1059,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
{
static_assert(false, "Wrong TailNum");
}
return c_block_tile;
return c_warp_tensors;
}
};

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
{
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
@@ -58,7 +58,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
template <typename Problem, typename TensorView>
CK_TILE_DEVICE static constexpr auto
MakeMXFP4_AAsyncLoadDramDescriptor(const TensorView& naive_view)
MakeMX_AAsyncLoadDramDescriptor(const TensorView& naive_view)
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
@@ -107,7 +107,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeMXFP4_ADramTileDistribution()
CK_TILE_DEVICE static constexpr auto MakeMX_ADramTileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
@@ -140,7 +140,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor()
CK_TILE_DEVICE static constexpr auto MakeMX_ALdsBlockDescriptor()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
@@ -218,7 +218,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXF4_ALDS_TileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDS_TileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
@@ -255,7 +255,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_BFlatDramTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
@@ -298,7 +298,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_DramTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
@@ -335,7 +335,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_DramTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_DramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
@@ -372,7 +372,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_FlatDramTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_FlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
@@ -394,7 +394,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_FlatDramTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_FlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
@@ -420,8 +420,8 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
return sizeof(ADataType) *
MakeMXFP4_ALdsBlockDescriptor<Problem>().get_element_space_size() / APackedSize;
return sizeof(ADataType) * MakeMX_ALdsBlockDescriptor<Problem>().get_element_space_size() /
APackedSize;
}
template <typename Problem>

View File

@@ -23,7 +23,8 @@ struct BaseGemmPipelineAgBgCrCompV4
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
constexpr index_t HotLoopGlobalReads = 2;
return num_loop >= (HotLoopGlobalReads + PrefetchStages);
}
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)

View File

@@ -373,8 +373,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
{
// Need to multiply aquant with accumulated C
//
// The accumulated C tile has the standard distribution. For example
// lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
// The accumulated C tile has the standard distribution. For example, a
// 32x32 C lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
// [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
// [26,0], [27,0].
//
@@ -388,35 +388,31 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
//
// These scales can be obtained using __builtin_amdgcn_ds_bpermute.
// MIters per warp
constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM;
// Reg block offset based on mIter
constexpr index_t reg_block_offset =
((mIter / mIters_per_warp) * Traits::AQPerBlock);
constexpr index_t lane_base_offset =
(mIter % mIters_per_warp) * WarpGemm::kM;
// Scale tensor offset along K
constexpr index_t src_reg_offset = reg_block_offset + kQScale;
// Directly index into thread buffer corresponding to
// desired row coefficient
// Each thread stores AQPerBlock scale values per M iteration.
constexpr index_t reg_block_offset = mIter * Traits::AQPerBlock;
constexpr index_t src_reg_offset = reg_block_offset + kQScale;
auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset];
constexpr uint32_t kTileRows = (get_warp_size() == 64) ? 4 : 8;
;
constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows;
constexpr uint32_t reg_offset_for_row_data = c_row * WarpGemm::kCMLane;
// Multiply by 4 because output is stored in tiles of 4
// x CNLane
constexpr uint32_t row_base =
((reg_offset_for_row_data / kTiledCMsPerWarp) * kTiledCMsPerWarp) +
((reg_offset_for_row_data % kTiledCMsPerWarp) / WarpGemm::kCMLane);
// Divide M dimension of C Warp tile into groups of
// (WarpGemm::kCMLane * WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane)
// m_base_offset_of_c_row indicates which group the current c_row belongs
// to.
constexpr index_t m_base_offset_of_c_row =
(c_row / WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane) *
(WarpGemm::kCMLane * WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane);
// M offset of each thread within its group (see comment above)
index_t m_base_offset_of_lane =
(get_lane_id() / WarpGemm::kN *
WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane);
// M offset wrt. c_row in the subgroup of kCM1PerLane
constexpr index_t m_offset_of_c_row =
c_row & (WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane - 1);
// Lane index to source scale from
uint32_t src_lane_idx =
lane_base_offset + row_base + (__lane_id() / WarpGemm::kN * kTileRows);
m_base_offset_of_c_row + m_base_offset_of_lane + m_offset_of_c_row;
return exchange_quant_value_across_lanes(scale_reg, src_lane_idx);
}

View File

@@ -94,21 +94,20 @@ struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding
// # of elements per thread
constexpr index_t X = XPerTile;
constexpr index_t Y0 = 1;
constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1;
constexpr index_t Y2 = MWarps;
constexpr index_t Y3 = WarpGemm::kM;
static_assert(Y3 >= WarpGemm::kM,
constexpr index_t YR = 1;
constexpr index_t Y0 = MIterPerWarp ? MIterPerWarp : 1;
constexpr index_t Y1 = MWarps;
constexpr index_t Y2 = WarpGemm::kM;
static_assert(Y2 >= WarpGemm::kM,
"Scales for all rows must be available within the warp.");
static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile,
"Y0, Y1, Y2, Y3 must cover the blocktile along Y.");
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y.");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<NWarps>,
tuple<sequence<Y0, Y1, Y2, Y3>, sequence<X>>,
tuple<sequence<1, 0>, sequence<1, 1>>,
tuple<sequence<2, 0>, sequence<0, 3>>,
tile_distribution_encoding<sequence<NWarps, YR>,
tuple<sequence<Y0, Y1, Y2>, sequence<X>>,
tuple<sequence<1, 0>, sequence<0, 1>>,
tuple<sequence<1, 0>, sequence<1, 2>>,
sequence<1, 2>,
sequence<1, 0>>{});
sequence<0, 0>>{});
}
}
};

View File

@@ -15,6 +15,142 @@ namespace tensor_operation {
namespace device {
namespace instance {
#if defined(CK_USE_WMMA)
#if defined(CK_ENABLE_FP16)
void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif // CK_ENABLE_FP16
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) && defined(__gfx12__)
void add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F16,
F8,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
F8,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#if defined(CK_ENABLE_BF16)
void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Col,
Empty_Tuple,
Row,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif // CK_ENABLE_BF16
#endif // CK_USE_WMMA
#if defined(CK_USE_XDL)
#if defined(CK_ENABLE_FP16)
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
@@ -409,6 +545,81 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#if defined(CK_USE_WMMA)
#if defined(CK_ENABLE_FP16)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
#endif // CK_ENABLE_FP16
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) && defined(__gfx12__)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances(op_ptrs);
}
}
#endif
#if defined(CK_ENABLE_BF16)
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
is_same_v<EDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs);
}
}
#endif // CK_ENABLE_BF16
#endif // CK_USE_WMMA
#if defined(CK_USE_XDL)
#if defined(CK_ENABLE_FP16)
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&

View File

@@ -0,0 +1,205 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/utility/loop_scheduler.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F8 = ck::f8_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Empty_Tuple = ck::Tuple<>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AccDataType = F32;
using DsDataType = Empty_Tuple;
using DsLayout = Empty_Tuple;
using ELayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;
static constexpr auto PipelineV1 = BlockGemmPipelineVersion::v1;
static constexpr auto PipelineV3 = BlockGemmPipelineVersion::v3;
static constexpr auto IntrawaveScheduler = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto InterwaveScheduler = BlockGemmPipelineScheduler::Interwave;
static constexpr auto GemmMNKPadding = device::GemmSpecialization::MNKPadding;
static constexpr auto GemmDefault = device::GemmSpecialization::Default;
// Instances for 2 byte datatypes in CRR layout with ADataType = BDataType = EDataType
template <typename T,
device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
enable_if_t<sizeof(T) == 2, bool> = false>
using device_grouped_gemm_wmma_universal_km_kn_mn_instances =
std::tuple<
// clang-format off
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
// clang`-format on
>;
// Instances for 2 byte datatypes in CCR layout with ADataType = BDataType = EDataType
template <typename T,
device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
enable_if_t<sizeof(T) == 2, bool> = false>
using device_grouped_gemm_wmma_universal_km_nk_mn_instances = std::tuple<
// clang-format off
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
DeviceGroupedGemm_Wmma_CShuffleV3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
// clang-format on
>;
// Instances for 2 byte datatypes in RRR layout with ADataType = BDataType = EDataType
template <typename T,
device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
enable_if_t<sizeof(T) == 2, bool> = false>
using device_grouped_gemm_wmma_universal_mk_kn_mn_instances =
std::tuple<
// clang-format off
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
// clang-format on
>;
// Instances for 2 byte datatypes in RCR layout with ADataType = BDataType = EDataType
template <typename T,
device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
enable_if_t<sizeof(T) == 2, bool> = false>
using device_grouped_gemm_wmma_universal_mk_nk_mn_instances =
std::tuple<
// clang-format off
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
// clang-format on
>;
// Helper function to add a list of layout instances with specific A/B/E datatypes for all supported
// padding/scheduler/pipeline version combinations
template <typename ALayout,
typename BLayout,
template <device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer>
typename LayoutInstances,
typename ADataType, // NOTE: type parameters as last so that they can be inferred from the
typename BDataType, // vector argument
typename EDataType>
void add_device_grouped_gemm_wmma_universal_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
{
add_device_operation_instances(instances,
LayoutInstances<GemmDefault, IntrawaveScheduler, PipelineV1>{});
add_device_operation_instances(instances,
LayoutInstances<GemmDefault, InterwaveScheduler, PipelineV1>{});
add_device_operation_instances(instances,
LayoutInstances<GemmDefault, IntrawaveScheduler, PipelineV3>{});
add_device_operation_instances(
instances, LayoutInstances<GemmMNKPadding, IntrawaveScheduler, PipelineV1>{});
add_device_operation_instances(
instances, LayoutInstances<GemmMNKPadding, InterwaveScheduler, PipelineV1>{});
add_device_operation_instances(
instances, LayoutInstances<GemmMNKPadding, IntrawaveScheduler, PipelineV3>{});
}
// Helper function to add a list of layout instances for instances with matching A/B/E data types
// for all supported padding/scheduler/pipeline version combinations
template <typename T,
typename ALayout,
typename BLayout,
template <typename T2,
device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer>
typename LayoutInstances>
void add_device_grouped_gemm_wmma_universal_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
ELayout,
T,
T,
DsDataType,
T,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
{
add_device_operation_instances(
instances, LayoutInstances<T, GemmDefault, IntrawaveScheduler, PipelineV1>{});
add_device_operation_instances(
instances, LayoutInstances<T, GemmDefault, InterwaveScheduler, PipelineV1>{});
add_device_operation_instances(
instances, LayoutInstances<T, GemmDefault, IntrawaveScheduler, PipelineV3>{});
add_device_operation_instances(
instances, LayoutInstances<T, GemmMNKPadding, IntrawaveScheduler, PipelineV1>{});
add_device_operation_instances(
instances, LayoutInstances<T, GemmMNKPadding, InterwaveScheduler, PipelineV1>{});
add_device_operation_instances(
instances, LayoutInstances<T, GemmMNKPadding, IntrawaveScheduler, PipelineV3>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -57,7 +57,7 @@ function(add_instance_library INSTANCE_NAME)
list(REMOVE_ITEM ARGN "${source}")
endif()
# Do not build XDL instances if gfx9 targets are not on the target list
if(NOT INST_TARGETS MATCHES "gfx9" AND NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "_xdl")
if(((NOT INST_TARGETS MATCHES "gfx9" AND NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12") OR FORCE_DISABLE_XDL) AND source_name MATCHES "_xdl")
message(DEBUG "removing xdl instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
@@ -67,7 +67,7 @@ function(add_instance_library INSTANCE_NAME)
list(REMOVE_ITEM ARGN "${source}")
endif()
# Do not build WMMA instances if gfx11 targets are not on the target list
if(NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "_wmma")
if(((NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12") OR FORCE_DISABLE_WMMA) AND source_name MATCHES "_wmma")
message(DEBUG "removing wmma instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
@@ -88,7 +88,7 @@ function(add_instance_library INSTANCE_NAME)
endif()
endif()
# Do not build WMMA gemm_universal_f8 for any targets except gfx12+
if(NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "gemm_wmma_universal" AND source_name MATCHES "_f8_")
if((NOT INST_TARGETS MATCHES "gfx12" OR FORCE_DISABLE_WMMA) AND source_name MATCHES "gemm_wmma_universal" AND source_name MATCHES "_f8_")
message(DEBUG "removing gemm_universal_f8 instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
@@ -274,7 +274,7 @@ FOREACH(subdir_path ${dir_list})
message(DEBUG "Found only dl instances, but DL_KERNELS is not set. Skipping.")
set(add_inst 0)
endif()
if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12"))
if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12" OR FORCE_DISABLE_XDL))
message(DEBUG "Found only xdl instances, but gfx9 is not on the targets list. Skipping.")
set(add_inst 0)
endif()
@@ -282,7 +282,7 @@ FOREACH(subdir_path ${dir_list})
message(DEBUG "Found only MX instances, but gfx950 is not on the targets list. Skipping.")
set(add_inst 0)
endif()
if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12"))
if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (((NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12")) OR FORCE_DISABLE_WMMA))
message(DEBUG "Found only wmma instances, but gfx11 is not on the targets list. Skipping.")
set(add_inst 0)
endif()
@@ -290,7 +290,7 @@ FOREACH(subdir_path ${dir_list})
message(DEBUG "Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.")
set(add_inst 0)
endif()
if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12"))
if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND ((NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12") OR (FORCE_DISABLE_XDL AND FORCE_DISABLE_WMMA)))
message(DEBUG "Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.")
set(add_inst 0)
endif()
@@ -333,20 +333,22 @@ FOREACH(subdir_path ${dir_list})
if((add_inst EQUAL 1))
get_filename_component(target_dir ${subdir_path} NAME)
add_subdirectory(${target_dir})
if("${cmake_instance}" MATCHES "gemm")
list(APPEND CK_DEVICE_GEMM_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
elseif("${cmake_instance}" MATCHES "conv")
list(APPEND CK_DEVICE_CONV_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
elseif("${cmake_instance}" MATCHES "mha")
list(APPEND CK_DEVICE_MHA_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
elseif("${cmake_instance}" MATCHES "contr")
list(APPEND CK_DEVICE_CONTRACTION_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
elseif("${cmake_instance}" MATCHES "reduce")
list(APPEND CK_DEVICE_REDUCTION_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
else()
list(APPEND CK_DEVICE_OTHER_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
endif()
message(DEBUG "add_instance_directory ${subdir_path}")
if (TARGET device_${target_dir}_instance)
if("${cmake_instance}" MATCHES "gemm")
list(APPEND CK_DEVICE_GEMM_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
elseif("${cmake_instance}" MATCHES "conv")
list(APPEND CK_DEVICE_CONV_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
elseif("${cmake_instance}" MATCHES "mha")
list(APPEND CK_DEVICE_MHA_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
elseif("${cmake_instance}" MATCHES "contr")
list(APPEND CK_DEVICE_CONTRACTION_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
elseif("${cmake_instance}" MATCHES "reduce")
list(APPEND CK_DEVICE_REDUCTION_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
else()
list(APPEND CK_DEVICE_OTHER_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
endif()
message(DEBUG "add_instance_directory ${subdir_path}")
endif()
else()
message(DEBUG "skip_instance_directory ${subdir_path}")
endif()

View File

@@ -1,7 +1,7 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_grouped_gemm_instance
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
@@ -36,4 +36,17 @@ add_instance_library(device_grouped_gemm_instance
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_bf16_bf16_mk_nk_mn_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_nk_mn_instance.cpp
device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instance.cpp
device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instance.cpp
device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instance.cpp
device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instance.cpp
device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instance.cpp
device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instance.cpp
)

View File

@@ -0,0 +1,37 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
DsLayout,
ELayout,
BF16,
BF16,
DsDataType,
BF16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
BF16,
Col,
Row,
device_grouped_gemm_wmma_universal_km_kn_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,37 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Col,
DsLayout,
ELayout,
BF16,
BF16,
DsDataType,
BF16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
BF16,
Col,
Col,
device_grouped_gemm_wmma_universal_km_nk_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,37 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
BF16,
Row,
Row,
device_grouped_gemm_wmma_universal_mk_kn_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,37 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
DsLayout,
Row,
BF16,
BF16,
DsDataType,
BF16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
BF16,
Row,
Col,
device_grouped_gemm_wmma_universal_mk_nk_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,37 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
F16,
Col,
Row,
device_grouped_gemm_wmma_universal_km_kn_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,37 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Col,
Col,
DsLayout,
ELayout,
F16,
F16,
DsDataType,
F16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
F16,
Col,
Col,
device_grouped_gemm_wmma_universal_km_nk_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,38 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
DsLayout,
Row,
F16,
F16,
DsDataType,
F16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
F16,
Row,
Row,
device_grouped_gemm_wmma_universal_mk_kn_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,38 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Col,
DsLayout,
ELayout,
F16,
F16,
DsDataType,
F16,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
F16,
Row,
Col,
device_grouped_gemm_wmma_universal_mk_nk_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,57 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using ADataType = F16;
using BDataType = F8;
using EDataType = F16;
template <device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer>
using device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances =
std::tuple<
// clang-format off
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
// clang-format on
>;
void add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
DsLayout,
Row,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
Row,
Row,
device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,57 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <cstdlib>
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using ADataType = F8;
using BDataType = F16;
using EDataType = F16;
template <device::GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer>
using device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances =
std::tuple<
// clang-format off
//##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
//##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>,
DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>
// clang-format on
>;
void add_device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedGemm<Row,
Row,
DsLayout,
Row,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp>>>& instances)
{
add_device_grouped_gemm_wmma_universal_instances<
Row,
Row,
device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -42,10 +42,11 @@ bool profile_grouped_gemm_impl(int do_verification,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
const std::vector<int>& kbatches = {},
int n_warmup = 1,
int n_iter = 10,
int instance_index = -1)
const std::vector<int>& kbatches = {},
int n_warmup = 1,
int n_iter = 10,
int instance_index = -1,
bool fail_if_no_supported_instance = false)
{
bool pass = true;
// TODO: Fixme - we do not pass compute data type here but need it
@@ -225,6 +226,7 @@ bool profile_grouped_gemm_impl(int do_verification,
}
}
// profile device GEMM instances
int instances_supporting_all_batch_sizes = 0;
for(auto& gemm_ptr : op_ptrs)
{
auto argument_ptr =
@@ -268,6 +270,7 @@ bool profile_grouped_gemm_impl(int do_verification,
kbatch_list = kbatches;
}
bool all_batch_sizes_supported = true;
for(std::size_t j = 0; j < kbatch_list.size(); j++)
{
auto kbatch_curr = kbatch_list[j];
@@ -367,10 +370,30 @@ bool profile_grouped_gemm_impl(int do_verification,
}
else
{
all_batch_sizes_supported = false;
std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem"
<< std::endl;
}
}
// If all batch sizes were supported by this instance, the instance can be marked as
// 'supported' for this problem
if(all_batch_sizes_supported)
{
++instances_supporting_all_batch_sizes;
}
}
// Warn if not a single instance was supported
if(instances_supporting_all_batch_sizes == 0)
{
std::cout << "Warning! No instance found that supported all of the batch sizes."
<< std::endl;
if(fail_if_no_supported_instance)
{
return false;
}
}
if(time_kernel)
@@ -384,6 +407,7 @@ bool profile_grouped_gemm_impl(int do_verification,
std::cout << "grouped_gemm_instance (" << instance_index << "/" << num_kernel << "): Passed"
<< std::endl;
}
return pass;
}

View File

@@ -53,6 +53,13 @@ struct GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<false>();
};
struct GemmConfigPrefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
};
struct GemmConfigPreshuffleQuant : public GemmConfigBase
{
static constexpr bool PreshuffleQuant = true;

View File

@@ -39,6 +39,12 @@ using AQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
// PreshuffleQuant = false && TransposeC = false && Prefill
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigPrefill, GroupSize>,
// PreshuffleQuant = false && TransposeC = true
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigTransposeC, GroupSize>,

View File

@@ -3,10 +3,15 @@
add_custom_target(test_grouped_gemm)
add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance)
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk)
# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary
# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link
# the instance library if there's no instances present for the current arch.
if (CK_USE_XDL OR CK_USE_WMMA)
add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance)
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk)
endif()
endif()
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp)

View File

@@ -9,6 +9,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "test_grouped_gemm_util.hpp"
#include "test_grouped_gemm_interface_xdl.hpp"
class TestGGemmSplitKInterface_MKNKMN : public ::testing::Test
{

View File

@@ -0,0 +1,205 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <array>
#include <string>
#include <sstream>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "ck/ck.hpp"
#include "ck/stream_config.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/number.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp"
namespace ck {
namespace test {
template <typename ALayout,
typename BLayout,
typename ELayout,
tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t KPerBlock,
ck::index_t K1,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferSrcScalarPerVector,
index_t CDEBlockTransferScalarPerVector_NPerBlock>
struct DeviceGroupedGemmSplitkInstanceWrapper
{
using F16 = half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = tensor_operation::element_wise::PassThrough;
using EmptyTuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
template <ck::index_t N>
using I = ck::Number<N>;
using ABlockTransferThreadClusterArrageOrder =
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
using ABlockTransferSrcAccessOrder =
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
using ABlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<ALayout, Row>, I<3>, I<2>>;
using ABlockTransferDstScalarPerVector_K1 =
std::conditional_t<std::is_same_v<ALayout, Row>, I<8>, I<2>>;
using ABlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<1>, I<0>>;
using BBlockTransferThreadClusterArrageOrder =
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
using BBlockTransferSrcAccessOrder =
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
using BBlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<BLayout, Row>, I<2>, I<3>>;
using BBlockTransferDstScalarPerVector_K1 =
std::conditional_t<std::is_same_v<ALayout, Row>, I<2>, I<8>>;
using BBlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<0>, I<1>>;
using DeviceGroupedGemmSplitKInstance =
tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle<
ALayout,
BLayout,
EmptyTuple,
ELayout,
F16,
F16,
F32,
F16,
EmptyTuple,
F16,
PassThrough,
PassThrough,
PassThrough,
GemmSpec,
1,
128,
128,
128,
KPerBlock,
K1,
K1,
16,
16,
8,
4,
S<1, 4, 16, 1>,
ABlockTransferThreadClusterArrageOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim::value,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1::value,
ABlockLdsAddExtraM::value,
S<1, 4, 16, 1>,
BBlockTransferThreadClusterArrageOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim::value,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1::value,
BBlockLdsAddExtraM::value,
1,
1,
S<1, 16, 1, 8>,
CDEBlockTransferScalarPerVector_NPerBlock>;
bool IsSupported(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1) const
{
std::size_t n_groups = Ms.size();
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
<< "The number of groups is not consistent!";
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
for(std::size_t i = 0; i < n_groups; ++i)
{
gemm_descs.push_back(tensor_operation::device::GemmDesc{
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
}
std::vector<const void*> p_As(n_groups, nullptr);
std::vector<const void*> p_Bs(n_groups, nullptr);
std::vector<void*> p_Cs(n_groups, nullptr);
auto p_Ds = std::vector<std::array<const void*, 0>>{};
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
auto argument = ggemm_instance.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1)
{
ggemm_instance.SetKBatchSize(&argument, kbatch);
}
return ggemm_instance.IsSupportedArgument(argument);
}
float Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1) const
{
std::size_t n_groups = Ms.size();
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
<< "The number of groups is not consistent!";
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
for(std::size_t i = 0; i < n_groups; ++i)
{
gemm_descs.push_back(tensor_operation::device::GemmDesc{
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
}
std::vector<const void*> p_As(n_groups, nullptr);
std::vector<const void*> p_Bs(n_groups, nullptr);
std::vector<void*> p_Cs(n_groups, nullptr);
auto p_Ds = std::vector<std::array<const void*, 0>>{};
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
auto argument = ggemm_instance.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1)
{
ggemm_instance.SetKBatchSize(&argument, kbatch);
}
if(kbatch > 1 && ck::is_gfx11_supported())
{
EXPECT_FALSE(ggemm_instance.IsSupportedArgument(argument));
return 0;
}
else
{
EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument));
auto invoker = ggemm_instance.MakeInvoker();
DeviceMem dev_gemm_kargs(ggemm_instance.GetDeviceKernelArgSize(&argument));
ggemm_instance.SetDeviceKernelArgs(&argument, dev_gemm_kargs.GetDeviceBuffer());
return invoker.Run(argument, StreamConfig{nullptr, false});
}
}
};
} // namespace test
} // namespace ck

View File

@@ -24,21 +24,48 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <typename Tuple>
class TestGroupedGemm : public ck::test::TestGroupedGemm<Tuple>
{
public:
void SetUp() override
{
ck::test::TestGroupedGemm<Tuple>::SetUp();
#if defined(CK_USE_WMMA)
// The old XDL tests didn't fail if instances were not supported, so we want to keep that
// behaviour When compiling WMMA instances and WMMA is supported, then we'll fail if a
// specific case is not supported
this->fail_if_no_supported_instances_ =
ck::is_gfx11_supported() || ck::is_gfx12_supported();
#endif
}
};
// clang-format off
using KernelTypes = ::testing::Types<
#if defined(CK_USE_WMMA)
// WWMA only. No reason to not have it for XDL, but the instance was not defined and it was not in the original test.
std::tuple< Col, Col, Row, BF16, BF16, BF16>,
#endif
#if defined(CK_USE_XDL) && defined(__gfx9__)
// XDL only at the moment, instances for WMMA not defined
std::tuple< Row, Row, Row, BF16, I8, BF16>,
std::tuple< Row, Col, Row, BF16, I8, BF16>,
#endif
#if (defined(CK_USE_XDL) && (defined(__gfx9__) || defined(__gfx12__))) || (defined(CK_USE_WMMA) && defined(__gfx12__))
std::tuple< Row, Row, Row, F8, F16, F16>,
std::tuple< Row, Row, Row, F16, F8, F16>,
#endif
std::tuple< Row, Row, Row, F16, F16, F16>,
std::tuple< Row, Col, Row, F16, F16, F16>,
std::tuple< Col, Row, Row, F16, F16, F16>,
std::tuple< Col, Col, Row, F16, F16, F16>,
std::tuple< Row, Row, Row, BF16, BF16, BF16>,
std::tuple< Row, Col, Row, BF16, BF16, BF16>,
std::tuple< Col, Row, Row, BF16, BF16, BF16>,
std::tuple< Row, Row, Row, BF16, I8, BF16>,
std::tuple< Row, Col, Row, BF16, I8, BF16>,
std::tuple< Row, Row, Row, F16, F8, F16>,
std::tuple< Row, Row, Row, F8, F16, F16>
std::tuple< Col, Row, Row, BF16, BF16, BF16>
>;
// clang-format on

View File

@@ -65,6 +65,13 @@ TYPED_TEST(TestGroupedGemm, MNKPadded)
TYPED_TEST(TestGroupedGemm, TestLargeKBatch)
{
// gfx11 does not support split-K due to missing atomic add for fp16/bf16
// Technically, we could still run the tests for fp32, but we currently don't have instances for
// it so we disable it entirely
if(ck::is_gfx11_supported())
GTEST_SKIP() << "Split-K not supported for FP16/BF16 on GFX11 due to missing atomic add "
"instructions";
const std::vector<int> Ms{188, 210};
constexpr int N = 768;
constexpr int K = 4096;

View File

@@ -11,16 +11,7 @@
#include <gtest/gtest.h>
#include "ck/ck.hpp"
#include "ck/stream_config.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/number.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp"
extern ck::index_t param_mask;
@@ -41,7 +32,7 @@ std::string serialize_range(const Range& range)
return std::string(str.begin(), str.end() - 2);
}
template <typename Tuple>
template <typename Tuple, bool FailIfNoSupportedInstances = false>
class TestGroupedGemm : public testing::Test
{
protected:
@@ -62,9 +53,26 @@ class TestGroupedGemm : public testing::Test
static constexpr bool bench_ = false; // measure kernel performance
static constexpr int n_warmup_ = 0;
static constexpr int n_iter_ = 1;
bool fail_if_no_supported_instances_ = FailIfNoSupportedInstances;
std::vector<int> k_batches_;
void SetUp() override { k_batches_ = {1, 2, 3, 5, 8}; }
void SetUp() override
{
constexpr bool require_16bit_atomic_add =
std::is_same_v<EDataType, ck::half_t> || std::is_same_v<EDataType, ck::bhalf_t>;
if(require_16bit_atomic_add && ck::is_gfx11_supported())
{
// gfx11 does not support split-K due to missing atomic add for fp16/bf16
// Technically, we could still use split-K for fp32, but we currently don't have
// instances for it so we disable it entirely
k_batches_ = {1};
}
else
{
k_batches_ = {1, 2, 3, 5, 8};
}
}
private:
template <typename Layout>
@@ -132,204 +140,31 @@ class TestGroupedGemm : public testing::Test
const std::vector<int>& StrideCs,
const std::vector<int>& kbatches)
{
bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType,
BDataType,
EDataType,
float,
ALayout,
BLayout,
ELayout>(verify_,
init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatches,
n_warmup_,
n_iter_,
instance_index);
bool pass =
ck::profiler::profile_grouped_gemm_impl<ADataType,
BDataType,
EDataType,
float,
ALayout,
BLayout,
ELayout>(verify_,
init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatches,
n_warmup_,
n_iter_,
instance_index,
fail_if_no_supported_instances_);
EXPECT_TRUE(pass);
}
};
template <typename ALayout,
typename BLayout,
typename ELayout,
tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t KPerBlock,
ck::index_t K1,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferSrcScalarPerVector,
index_t CDEBlockTransferScalarPerVector_NPerBlock>
struct DeviceGroupedGemmSplitkInstanceWrapper
{
using F16 = half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = tensor_operation::element_wise::PassThrough;
using EmptyTuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
template <ck::index_t N>
using I = ck::Number<N>;
using ABlockTransferThreadClusterArrageOrder =
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
using ABlockTransferSrcAccessOrder =
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
using ABlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<ALayout, Row>, I<3>, I<2>>;
using ABlockTransferDstScalarPerVector_K1 =
std::conditional_t<std::is_same_v<ALayout, Row>, I<8>, I<2>>;
using ABlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<1>, I<0>>;
using BBlockTransferThreadClusterArrageOrder =
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
using BBlockTransferSrcAccessOrder =
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
using BBlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<BLayout, Row>, I<2>, I<3>>;
using BBlockTransferDstScalarPerVector_K1 =
std::conditional_t<std::is_same_v<ALayout, Row>, I<2>, I<8>>;
using BBlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<0>, I<1>>;
using DeviceGroupedGemmSplitKInstance =
tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle<
ALayout,
BLayout,
EmptyTuple,
ELayout,
F16,
F16,
F32,
F16,
EmptyTuple,
F16,
PassThrough,
PassThrough,
PassThrough,
GemmSpec,
1,
128,
128,
128,
KPerBlock,
K1,
K1,
16,
16,
8,
4,
S<1, 4, 16, 1>,
ABlockTransferThreadClusterArrageOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim::value,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1::value,
ABlockLdsAddExtraM::value,
S<1, 4, 16, 1>,
BBlockTransferThreadClusterArrageOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim::value,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1::value,
BBlockLdsAddExtraM::value,
1,
1,
S<1, 16, 1, 8>,
CDEBlockTransferScalarPerVector_NPerBlock>;
bool IsSupported(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1) const
{
std::size_t n_groups = Ms.size();
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
<< "The number of groups is not consistent!";
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
for(std::size_t i = 0; i < n_groups; ++i)
{
gemm_descs.push_back(tensor_operation::device::GemmDesc{
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
}
std::vector<const void*> p_As(n_groups, nullptr);
std::vector<const void*> p_Bs(n_groups, nullptr);
std::vector<void*> p_Cs(n_groups, nullptr);
auto p_Ds = std::vector<std::array<const void*, 0>>{};
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
auto argument = ggemm_instance.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1)
{
ggemm_instance.SetKBatchSize(&argument, kbatch);
}
return ggemm_instance.IsSupportedArgument(argument);
}
float Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1) const
{
std::size_t n_groups = Ms.size();
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
<< "The number of groups is not consistent!";
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
for(std::size_t i = 0; i < n_groups; ++i)
{
gemm_descs.push_back(tensor_operation::device::GemmDesc{
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
}
std::vector<const void*> p_As(n_groups, nullptr);
std::vector<const void*> p_Bs(n_groups, nullptr);
std::vector<void*> p_Cs(n_groups, nullptr);
auto p_Ds = std::vector<std::array<const void*, 0>>{};
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
auto argument = ggemm_instance.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1)
{
ggemm_instance.SetKBatchSize(&argument, kbatch);
}
if(kbatch > 1 && ck::is_gfx11_supported())
{
EXPECT_FALSE(ggemm_instance.IsSupportedArgument(argument));
return 0;
}
else
{
EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument));
auto invoker = ggemm_instance.MakeInvoker();
DeviceMem dev_gemm_kargs(ggemm_instance.GetDeviceKernelArgSize(&argument));
ggemm_instance.SetDeviceKernelArgs(&argument, dev_gemm_kargs.GetDeviceBuffer());
return invoker.Run(argument, StreamConfig{nullptr, false});
}
}
};
} // namespace test
} // namespace ck