Files
composable_kernel/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp
Brock Hargreaves 2a16d53cce [rocm-libraries] ROCm/rocm-libraries#5045 (commit 64a5502)
[CK] Address a bunch of errors associated with targeting
 gfx1200 on Windows (#5045)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

Still addressing errors that are blocking the merge of TheRock PR:
https://github.com/ROCm/TheRock/actions/runs/22545831304/job/65308264096?pr=3382

## Technical Details

1. There are multiple fmha python scripts that are writing native paths
which are confusing cmake. I addressed one of these in an earlier PR
https://github.com/ROCm/rocm-libraries/pull/4812 and now I'm addressing
more that are exposed with gfx1200 target:

```
[composable_kernel configure] CMake Error at example/ck_tile/50_sparse_attn/CMakeLists.txt:61 (add_library):
[composable_kernel configure]   Syntax error in cmake code when parsing string
[composable_kernel configure]
[composable_kernel configure]     B:\build\ml-libs\composable_kernel\build\example\ck_tile\50_sparse_attn\fmha_jenga_fwd_d128_fp16_batch_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_w32x32x16_qr_async_vr_psddv_nlogits_nbias_nmask_nskip_nsquant_ntrload.cpp
[composable_kernel configure]
[composable_kernel configure]   Invalid character escape '\b'.
```

2. In the following compiler error we see gemm_prec_str<ADataType,
BDataType> being passed as a function to concat(...), instead of being
evaluated with the parenthesis operator(), i.e.,
gemm_prec_str<ADataType, BDataType>(). There are multiples instances of
this, I wonder what non-msvc compilers do here:

```
[composable_kernel] FAILED: [code=1] example/ck_tile/38_block_scale_gemm/CMakeFiles/tile_example_gemm_quant.dir/gemm_bquant_quantgrouped_mx_bf16bf8.cpp.obj
[composable_kernel] In file included from E:/TheRock/rocm-libraries/projects/composablekernel/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf8.cpp:4:
[composable_kernel] In file included from E:/TheRock/rocm-libraries/projects/composablekernel/example/ck_tile/38_block_scale_gemm\run_gemm_quant_example.inc:17:
[composable_kernel] In file included from E:/TheRock/rocm-libraries/projects/composablekernel/include\ck_tile/host.hpp:7:
[composable_kernel] E:/TheRock/rocm-libraries/projects/composablekernel/include\ck_tile/host/concat.hpp:119:21: error: implicit conversion between pointer-to-function and pointer-to-object is a Microsoft extension [-Werror,-Wmicrosoft-cast]
[composable_kernel]   119 |     ((oss << sep << rest), ...);
[composable_kernel]       |                     ^~~~
[composable_kernel] E:/TheRock/rocm-libraries/projects/composablekernel/include\ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp:248:16: note: in instantiation of function template specialization 'ck_tile::concat<char, char[11], std::basic_string<char> (), std::basic_string<char>>' requested here
[composable_kernel]   248 |         return concat('_', "gemm_quant", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
[composable_kernel]       |                ^
```

There are plenty of other places where we use gemm_prec_str with the
operator(), so I'm pretty sure these were just typos...but I'd like some
eyes on it.

3. There are 2 tests that fail to build on Windows, which I've excluded
from the build but will open bug tickets for:
    1.  gemm_weight_preshuffle
    2.  grouped_gemm_preshuffle

Here's a sample of the compiler error for these tests:

```
[composable_kernel] [16/19] Building HIP object test/ck_tile/grouped_gemm_preshuffle/CMakeFiles/test_ck_tile_grouped_gemm_preshuffle.dir/test_grouped_gemm_preshuffle.cpp.obj
[composable_kernel] FAILED: [code=1] test/ck_tile/grouped_gemm_preshuffle/CMakeFiles/test_ck_tile_grouped_gemm_preshuffle.dir/test_grouped_gemm_preshuffle.cpp.obj
[composable_kernel] E:\TheRock\build\core\clr\dist\lib\llvm\bin\clang++.exe  -DCK_ENABLE_BF16 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_FP8 -DCK_ENABLE_INT8 -DCK_TILE_USE_WMMA=1 -DCK_TIME_KERNEL=1 -DCK_USE_OCP_FP8 -DCK_USE_WMMA -DCK_USE_WMMA_FP8 -DCK_USE_XDL -DDPP_KERNELS -DUSE_PROF_API=1 -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1 -D__HIP_ROCclr__=1 -IE:/TheRock/rocm-libraries/projects/composablekernel/profiler/include -IE:/TheRock/rocm-libraries/projects/composablekernel -IE:/TheRock/rocm-libraries/projects/composablekernel/library/include -IE:/TheRock/rocm-libraries/projects/composablekernel/include -IE:/TheRock/build/ml-libs/composable_kernel/build/include -IE:/TheRock/build/base/half/stage/include -isystem E:/TheRock/build/core/clr/dist/include -isystem E:/TheRock/build/ml-libs/composable_kernel/build/_deps/gtest-src/googletest/include -isystem E:/TheRock/build/ml-libs/composable_kernel/build/_deps/gtest-src/googletest -isystem E:/TheRock/build/ml-libs/composable_kernel/build/_deps/getopt-src/src -O3 -DNDEBUG -std=gnu++20 --offload-arch=gfx1200 -D_DLL -D_MT -Xclang --dependent-lib=msvcrt   -Wall -Wextra -Wcomment -Wendif-labels -Wformat -Winit-self -Wreturn-type -Wsequence-point -Wswitch -Wtrigraphs -Wundef -Wuninitialized -Wunreachable-code -Wunused -Wno-reserved-identifier -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt -Wno-unused-template -Wno-missing-field-initializers -Wno-error=deprecated-declarations -Wall -Wextra -Wcomment -Wendif-labels -Wformat -Winit-self -Wreturn-type -Wsequence-point -Wswitch -Wtrigraphs -Wundef -Wuninitialized -Wunreachable-code -Wunused -Wno-reserved-identifier -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt -Wno-unused-template -Weverything -Wno-c++98-compat -Wno-c++98-compat-pedantic -Wno-conversion -Wno-double-promotion -Wno-exit-time-destructors -Wno-extra-semi -Wno-float-conversion -Wno-gnu-anonymous-struct -Wno-gnu-zero-variadic-macro-arguments -Wno-missing-prototypes -Wno-nested-anon-types -Wno-padded -Wno-return-std-move-in-c++11 -Wno-shorten-64-to-32 -Wno-sign-conversion -Wno-unknown-warning-option -Wno-unused-command-line-argument -Wno-weak-vtables -Wno-covered-switch-default -Wno-unsafe-buffer-usage -Wno-unused-lambda-capture -Wno-nvcc-compat -Wno-c++20-compat -Wno-bit-int-extension -Wno-pass-failed -Wno-switch-default -Wno-unique-object-duplication -fbracket-depth=1024 -Wno-nrvo -Werror -Weverything -fcolor-diagnostics -Wno-c++20-extensions -Wno-global-constructors -Wno-undef -DCK_TILE_USE_OCP_FP8 -MD -MT test/ck_tile/grouped_gemm_preshuffle/CMakeFiles/test_ck_tile_grouped_gemm_preshuffle.dir/test_grouped_gemm_preshuffle.cpp.obj -MF test\ck_tile\grouped_gemm_preshuffle\CMakeFiles\test_ck_tile_grouped_gemm_preshuffle.dir\test_grouped_gemm_preshuffle.cpp.obj.d -o test/ck_tile/grouped_gemm_preshuffle/CMakeFiles/test_ck_tile_grouped_gemm_preshuffle.dir/test_grouped_gemm_preshuffle.cpp.obj -x hip -c E:/TheRock/rocm-libraries/projects/composablekernel/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp
[composable_kernel] In file included from E:/TheRock/rocm-libraries/projects/composablekernel/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp:8:
[composable_kernel] In file included from E:/TheRock/rocm-libraries/projects/composablekernel/include\ck_tile/host.hpp:6:
[composable_kernel] In file included from E:/TheRock/rocm-libraries/projects/composablekernel/include\ck_tile/host/check_err.hpp:16:
[composable_kernel] In file included from E:/TheRock/rocm-libraries/projects/composablekernel/include\ck_tile/core.hpp:89:
[composable_kernel] E:/TheRock/rocm-libraries/projects/composablekernel/include\ck_tile/core/utility/env.hpp:110:31: warning: 'getenv' is deprecated: This function or variable may be unsafe. Consider using _dupenv_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details. [-Wdeprecated-declarations]
[composable_kernel]   110 |         const char* vp = std::getenv(name);
[composable_kernel]       |                               ^
[composable_kernel] C:\Program Files (x86)\Windows Kits\10\include\10.0.22621.0\ucrt\stdlib.h:1183:20: note: 'getenv' has been explicitly marked deprecated here
[composable_kernel]  1183 |     _Check_return_ _CRT_INSECURE_DEPRECATE(_dupenv_s)
[composable_kernel]       |                    ^
[composable_kernel] C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Tools\MSVC\14.44.35207\include\vcruntime.h:368:55: note: expanded from macro '_CRT_INSECURE_DEPRECATE'
[composable_kernel]   368 |         #define _CRT_INSECURE_DEPRECATE(_Replacement) _CRT_DEPRECATE_TEXT(    \
[composable_kernel]       |                                                       ^
[composable_kernel] C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Tools\MSVC\14.44.35207\include\vcruntime.h:358:47: note: expanded from macro '_CRT_DEPRECATE_TEXT'
[composable_kernel]   358 | #define _CRT_DEPRECATE_TEXT(_Text) __declspec(deprecated(_Text))
[composable_kernel]       |                                               ^
[composable_kernel] clang++: error: clang frontend command failed due to signal (use -v to see invocation)
[composable_kernel] AMD clang version 22.0.0git (https://github.com/ROCm/llvm-project.git a2dc42b87c63e686377a69f09ea23aec7550babc+PATCHED:e4d5bf498b7b8626bb9716f1f5a5946d45025918)
[composable_kernel] Target: x86_64-pc-windows-msvc
[composable_kernel] Thread model: posix
[composable_kernel] InstalledDir: E:\TheRock\build\core\clr\dist\lib\llvm\bin
[composable_kernel] clang++: note: diagnostic msg: Error generating preprocessed source(s).
[composable_kernel] ninja: build stopped: subcommand failed.
[composable_kernel FAILED WITH CODE 1 in 238 seconds]
ninja: build stopped: subcommand failed.
```

## Test Plan

Wait for internal CI and make sure build compiles locally.

## Test Result

Waiting on CI

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-03 21:55:14 +00:00

973 lines
39 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile {
struct FlatmmProblem
{
CK_TILE_HOST FlatmmProblem() = default;
CK_TILE_HOST FlatmmProblem(
index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
: M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_)
{
}
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
};
template <int SharedGranularityMN, int SharedGranularityK = 0, typename ScaleType_ = float>
struct FlatmmScalePointer
{
using ScaleType = ScaleType_;
static constexpr int GranularityMN = SharedGranularityMN;
static constexpr int GranularityK = SharedGranularityK;
const ScaleType* ptr;
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_) : ptr(ptr_) {}
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_, [[maybe_unused]] index_t length_)
: ptr(ptr_)
{
}
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
{
FlatmmScalePointer ret;
if constexpr(GranularityMN == 0)
{
ret.ptr = ptr + offset / GranularityK;
}
else
{
ret.ptr = ptr + offset / GranularityMN / GranularityK;
}
return ret;
}
CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const = delete;
};
template <int SharedGranularityMN, typename ScaleType_>
struct FlatmmScalePointer<SharedGranularityMN, 0, ScaleType_>
{
using ScaleType = ScaleType_;
static constexpr int GranularityMN = SharedGranularityMN;
static constexpr int GranularityK = 0;
static_assert(GranularityMN != 0);
const ScaleType* ptr;
index_t length;
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_) : ptr(ptr_), length(1) {}
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_, index_t length_)
: ptr(ptr_), length(length_)
{
}
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
{
FlatmmScalePointer ret;
if constexpr(GranularityMN == 1)
{
ret.ptr = ptr + offset;
ret.length = length - offset;
}
else
{
ret.ptr = ptr + offset / GranularityMN;
ret.length = length - offset / GranularityMN;
}
return ret;
}
CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const
{
// with additional oob check
if constexpr(GranularityMN == 1)
return i < length ? ptr[i] : 0;
else
return i / GranularityMN < length ? ptr[i / GranularityMN] : 0;
}
};
// shared granularityMN = -1 means no scale
template <typename ScaleType_>
struct FlatmmScalePointer<-1, 0, ScaleType_>
{
using ScaleType = ScaleType_;
static constexpr int GranularityMN = -1;
static constexpr int GranularityK = 0;
const ScaleType* ptr = nullptr;
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default;
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const ScaleType*) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const ScaleType*, index_t) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer operator+(index_t) const
{
return FlatmmScalePointer{};
}
CK_TILE_HOST_DEVICE constexpr ScaleType operator[](index_t) const
{
return 1; // alway return 1, it doesn't change the result
}
};
template <index_t NumDTensor = 0>
struct BaseFlatmmHostArgs
{
CK_TILE_HOST BaseFlatmmHostArgs() = default;
CK_TILE_HOST BaseFlatmmHostArgs(const void* a_ptr_,
const void* b_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_E_)
: a_ptr(a_ptr_),
b_ptr(b_ptr_),
ds_ptr(ds_ptr_),
e_ptr(e_ptr_),
M(M_),
N(N_),
K(K_),
stride_A(stride_A_),
stride_B(stride_B_),
stride_Ds(stride_Ds_),
stride_E(stride_E_),
k_batch(k_batch_)
{
}
const void* a_ptr;
const void* b_ptr;
const std::array<const void*, NumDTensor> ds_ptr;
union
{
void* e_ptr;
void* c_ptr;
};
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
const std::array<index_t, NumDTensor> stride_Ds;
union
{
index_t stride_E;
index_t stride_C;
};
index_t k_batch;
};
template <class ScaleM = FlatmmScalePointer<-1>,
class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
struct ScaleFlatmmHostArgs : public BaseFlatmmHostArgs<>
{
CK_TILE_HOST ScaleFlatmmHostArgs() = default;
CK_TILE_HOST ScaleFlatmmHostArgs(const void* a_ptr_,
const void* b_shuffle_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* c_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_C_,
ScaleM scale_m_ = nullptr,
ScaleN scale_n_ = nullptr)
: BaseFlatmmHostArgs(a_ptr_,
b_shuffle_ptr_,
ds_ptr_,
c_ptr_,
k_batch_,
M_,
N_,
K_,
stride_A_,
stride_B_,
stride_Ds_,
stride_C_),
scale_m(scale_m_),
scale_n(scale_n_)
{
}
ScaleM scale_m = nullptr;
ScaleN scale_n = nullptr;
};
template <int NumberTensor = 0>
using FlatmmHostArgs =
ScaleFlatmmHostArgs<FlatmmScalePointer<-1>, FlatmmScalePointer<-1>, NumberTensor>;
template <class ScaleM, class ScaleN, index_t NumDTensor = 0>
struct FlatmmKernelArgs
{
const void* a_ptr;
// const void* b_shuffle_ptr;
const void* b_ptr;
const std::array<const void*, NumDTensor> ds_ptr;
void* e_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
std::array<index_t, NumDTensor> stride_Ds;
index_t stride_E;
index_t k_batch;
ScaleM scale_m_ptr = nullptr;
ScaleN scale_n_ptr = nullptr;
};
template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
struct FlatmmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using FlatmmPipeline = remove_cvref_t<FlatmmPipeline_>;
using BlockGemmShape =
remove_cvref_t<typename FlatmmPipeline::BlockGemmShape>; // TileFlatmmShape
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize;
static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
// Below type is actually accumulation data type - the output of block GEMM.
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr index_t NumDTensor = DsDataType::size();
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto I3 = number<3>();
static_assert(DsLayout::size() == DsDataType::size(),
"The size of DsLayout and DsDataType should be the same");
// using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>(), FlatmmPipeline::GetName());
// clang-format on
}
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{
assert(!UsePersistentKernel);
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
template <class ScaleM, class ScaleN>
CK_TILE_HOST static constexpr auto
GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
{
if constexpr(UsePersistentKernel)
{
hipDeviceProp_t prop;
int deviceId = 0; // default device
constexpr int block_size = FlatmmKernel::BlockSize().x;
int dync_smem_size = 0;
int maxActiveBlocksPerCU = 0;
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&maxActiveBlocksPerCU,
reinterpret_cast<void*>(
kentry<1, FlatmmKernel, FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
block_size,
dync_smem_size);
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
// std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
// << ", persistent_block_size: " << persistent_block_size
// << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
assert(kargs.k_batch == 1);
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
}
else
{
return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
}
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
template <class ScaleM, class ScaleN>
CK_TILE_HOST static constexpr FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>
MakeKernelArgs(const ScaleFlatmmHostArgs<ScaleM, ScaleN, DsDataType::size()>& hostArgs)
{
return {hostArgs.a_ptr,
hostArgs.b_ptr,
hostArgs.ds_ptr,
hostArgs.e_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_Ds,
hostArgs.stride_E,
hostArgs.k_batch,
hostArgs.scale_m,
hostArgs.scale_n};
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPingSize()
{
return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPongSize()
{
return FlatmmPipeline::GetSmemSize();
}
struct SplitKBatchOffset
{
template <class KernelArgs>
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
{
constexpr auto N1 = BlockGemmShape::WarpTile::at(number<1>{});
constexpr auto K1 = BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = kargs.k_batch * K1;
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = k_id * KRead;
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = k_id * KRead * kargs.stride_A;
}
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = k_id * KRead * kargs.stride_B * N1;
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = k_id * KRead * N1;
}
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
{
splitted_k = KRead;
}
else
{
splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t splitted_k;
};
template <class KernelArgs>
CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
{
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value)
{
if(kargs.k_batch != 1)
{
std::cerr << "Conditions not met for Kbatch >1 !" << std::endl;
return false;
}
}
if constexpr(UsePersistentKernel)
{
if(kargs.k_batch != 1)
{
std::cerr << "Persistent mode doesn't support Kbatch >1 !" << std::endl;
return false;
}
}
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
{
std::cerr << "Can't support K that is not a multiple of KPerBlock"
" without padding!"
<< std::endl;
return false;
}
if(kargs.K % FlatmmPipeline::GetVectorSizeA() != 0)
{
std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl;
return false;
}
}
else
{
if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
{
std::cerr << "Can't support M that is not a multiple of MPerBlock"
" without padding!"
<< std::endl;
return false;
}
if(kargs.M % FlatmmPipeline::GetVectorSizeA() != 0)
{
std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl;
return false;
}
}
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
{
std::cerr << "Can't support N that is not a multiple of NPerBlock"
" without padding!"
<< std::endl;
return false;
}
if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0)
{
std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;
return false;
}
}
else
{
if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
{
std::cerr << "Can't support K that is not a multiple of KPerBlock"
" without padding!"
<< std::endl;
return false;
}
if(kargs.K % FlatmmPipeline::GetVectorSizeB() != 0)
{
std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl;
return false;
}
}
bool DTesnorIsValid = {true};
static_for<0, NumDTensor, 1>{}([&](auto index) {
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
if(std::is_same_v<DiLayout, ELayout> == false)
{
DTesnorIsValid = false;
}
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
{
CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
"NPerBlock without padding!");
DTesnorIsValid = false;
}
if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
{
CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
DTesnorIsValid = false;
}
}
else
{
if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
{
CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
"MPerBlock without padding!");
DTesnorIsValid = false;
}
if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
{
CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
DTesnorIsValid = false;
}
}
});
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
{
std::cerr << "Can't support N that is not a multiple of NPerBlock"
" without padding!"
<< std::endl;
return false;
}
if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
{
std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
return false;
}
}
else
{
if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
{
std::cerr << "Can't support M that is not a multiple of MPerBlock"
" without padding!"
<< std::endl;
return false;
}
if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
{
std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
return false;
}
}
return DTesnorIsValid;
}
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr,
const KernelArgs& kargs,
const index_t k_size,
const index_t block_idx_m)
{
// Step 1: Create tensor view
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M, k_size),
make_tuple(kargs.stride_A, 1),
number<FlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(k_size, kargs.M),
make_tuple(kargs.stride_A, 1),
number<FlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
}
}();
// Step 2: Create padded view
const auto& a_pad_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, FlatmmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
}();
// Step 3: Create tile window
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{block_idx_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, block_idx_m});
}
}
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr,
const KernelArgs& kargs,
const index_t block_idx_n)
{
// Step 1: Create tensor view
index_t kFlatK =
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
// Step 2: No padding needed for b_flat
// Step 3: Create tile window
return make_tile_window(
b_flat_tensor_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0});
}
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array<const void*, NumDTensor>& ds_ptr,
const KernelArgs& kargs,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Step 1: Create tensor views
const auto& ds_tensor_view = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
static_cast<const DDataType_*>(ds_ptr[i]),
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_Ds[i], 1),
number<EpiloguePipeline::GetVectorSizeD(i)>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
static_cast<const DDataType_*>(ds_ptr[i]),
make_tuple(kargs.N, kargs.M),
make_tuple(kargs.stride_Ds[i], 1),
number<EpiloguePipeline::GetVectorSizeD(i)>{},
number<1>{});
}
},
number<NumDTensor>{});
// Step 2: Create padded views
const auto& ds_pad_view = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(ds_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(ds_tensor_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
},
number<NumDTensor>{});
// Step 3: Create tile windows
return generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{block_idx_m, block_idx_n});
}
else
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{block_idx_n, block_idx_m});
}
},
number<NumDTensor>{});
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, typename KernelArgs>
CK_TILE_DEVICE static auto MakeEBlockWindow(EDataType* e_ptr,
const KernelArgs& kargs,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Step 1: Create tensor view
const auto& e_tensor_view = [&]() {
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_E, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(kargs.N, kargs.M),
make_tuple(kargs.stride_E, 1),
number<1>{},
number<1>{});
}
}();
// Step 2: Create padded view
const auto& e_pad_view = [&]() {
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<FlatmmPipeline::kPadM, false>{});
}
}();
// Step 3: Create tile window
return make_tile_window(
e_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{block_idx_m, block_idx_n});
}
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeScaleMWindow(const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m)
{
constexpr int GM = decltype(kargs.scale_m_ptr)::GranularityMN;
constexpr int GK = decltype(kargs.scale_m_ptr)::GranularityK;
static_assert(GM != -1,
"MakeScaleMWindow should only be instantiated when scale is enabled");
// per-tensor (GM==0) -> Mdim = 1, stride 0
const index_t m_dim = (GM == 0) ? 1 : (kargs.M / GM);
const index_t m_stride = (GM == 0) ? 0 : 1;
const index_t k_dim = (GK == 0) ? 1 : (splitk_batch_offset.splitted_k / GK);
const index_t k_stride = 0; // your original code keeps K stride 0
const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_m_ptr.ptr,
make_tuple(m_dim, k_dim),
make_tuple(m_stride, k_stride),
number < (GM == 1) ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
number<1>{});
// Window extents: if GM==0, we still just broadcast from [0,*]
return make_tile_window(scale_m_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number < (GK == 0) ? TilePartitioner::NPerBlock
: TilePartitioner::KPerBlock > {}),
{block_idx_m, 0});
}
template <typename KernelArgs>
CK_TILE_DEVICE static auto MakeScaleNWindow(const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_n)
{
constexpr int GN = decltype(kargs.scale_n_ptr)::GranularityMN;
constexpr int GK = decltype(kargs.scale_n_ptr)::GranularityK;
static_assert(GN != -1,
"MakeScaleNWindow should only be instantiated when scale is enabled");
// per-tensor (GN==0) -> Ndim = 1, stride 0
const index_t n_dim = (GN == 0) ? 1 : (kargs.N / GN);
const index_t n_stride = (GN == 0) ? 0 : 1;
const index_t k_dim = (GK == 0) ? 1 : (splitk_batch_offset.splitted_k / GK);
const index_t k_stride = 0;
const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_n_ptr.ptr,
make_tuple(k_dim, n_dim),
make_tuple(k_stride, n_stride),
number < (GN == 1) ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
number<1>{});
return make_tile_window(scale_n_view,
make_tuple(number < (GK == 0) ? TilePartitioner::MPerBlock
: TilePartitioner::KPerBlock > {},
number<TilePartitioner::NPerBlock>{}),
{0, block_idx_n});
}
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
CK_TILE_DEVICE static void
RunFlatmm(const ADataType* a_ptr,
const BDataType* b_flat_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* smem_ptr_ping,
void* smem_ptr_pong,
const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create block windows using specialized methods
const auto& a_block_window =
MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m);
const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n);
const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
// Run GEMM cooperatively by whole workgroup.
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
// Run Epilogue Pipeline with k_batch dispatching
if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
{
const auto& scale_m_window = MakeScaleMWindow(kargs, splitk_batch_offset, block_idx_m);
const auto& scale_n_window = MakeScaleNWindow(kargs, splitk_batch_offset, block_idx_n);
if(kargs.k_batch == 1)
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(e_block_window),
decltype(c_block_tile),
decltype(ds_block_window)>(e_block_window,
c_block_tile,
ds_block_window,
smem_ptr_ping,
scale_m_window,
scale_n_window);
}
else
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(e_block_window),
decltype(c_block_tile),
decltype(ds_block_window)>(e_block_window,
c_block_tile,
ds_block_window,
smem_ptr_ping,
scale_m_window,
scale_n_window);
}
}
else if(UseDefaultScheduler || (get_warp_id() == 0))
{
if(kargs.k_batch == 1)
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::set>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(e_block_window),
decltype(c_block_tile),
decltype(ds_block_window)>(
e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
}
else
{
auto e_block_window = MakeEBlockWindow<memory_operation_enum::atomic_add>(
e_ptr, kargs, block_idx_m, block_idx_n);
EpiloguePipeline{}
.template operator()<decltype(e_block_window),
decltype(c_block_tile),
decltype(ds_block_window)>(
e_block_window, c_block_tile, ds_block_window, smem_ptr_ping);
}
}
}
template <class ScaleM, class ScaleN>
CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
int partition_idx = blockIdx.x) const
{
int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
do
{
const auto [iM, iN] =
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
const SplitKBatchOffset splitk_batch_offset(kargs);
// options
const ADataType* a_ptr =
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_flat_ptr =
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
// allocate LDS
__shared__ char smem_ptr_ping[GetSmemPingSize()];
__shared__ char smem_ptr_pong[GetSmemPongSize()];
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
{
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
b_flat_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_ping,
smem_ptr_pong,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
partition_idx += gridDim.x;
} while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
}
};
} // namespace ck_tile