[CK_TILE] support group from cmdline (#1295)

* support cmdline seqlen decode

* silent print

* update readme

* update kernel launch 3d

* update tile partitioner

* fix spill for bf16

* modify based on comment

* modify payload_t

* fix bug for alibi mode

* fix alibi test err

* refactor kernel launch, support select timer

* add missing file

* remove useless code

* add some comments
This commit is contained in:
carlushuang
2024-05-28 11:13:21 +08:00
committed by GitHub
parent 02fa2c298b
commit 5055b3bdcb
16 changed files with 534 additions and 261 deletions

View File

@@ -29,6 +29,25 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
return __builtin_bit_cast(int32x4_t, res);
}
namespace impl {
// below type indicate the data type used for buffer load inline asm
// clang-format off
template<index_t N, typename T> struct buffer_load_trait;
template<typename T> struct buffer_load_trait<16, T> { using payload_t = fp32x4_t; };
template<typename T> struct buffer_load_trait<8 , T> { using payload_t = fp32x2_t; };
template<typename T> struct buffer_load_trait<4 , T> { using payload_t = float; };
template<typename T> struct buffer_load_trait<2 , T> { using payload_t = float; };
template<typename T> struct buffer_load_trait<1 , T> { using payload_t = float; };
#if CK_TILE_BUFFER_LOAD_RAW_BF16_WA
template<> struct buffer_load_trait<16, thread_buffer<bf16_t, 8>> { using payload_t = bf16x8_t; };
template<> struct buffer_load_trait<8 , thread_buffer<bf16_t, 4>> { using payload_t = bf16x4_t; };
template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payload_t = bf16x2_t; };
#endif
// clang-format on
} // namespace impl
// TODO: glc/slc/...
template <index_t bytes>
struct buffer_load;
@@ -48,7 +67,7 @@ struct buffer_load<16>
index_t /*flag*/ = 0)
{
static_assert(sizeof(T) == 16);
using mbuf_t = fp32x4_t;
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
asm volatile("buffer_load_dwordx4 %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
@@ -68,7 +87,7 @@ struct buffer_load<8>
index_t /*flag*/ = 0)
{
static_assert(sizeof(T) == 8);
using mbuf_t = fp32x2_t;
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
asm volatile("buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
@@ -88,7 +107,7 @@ struct buffer_load<4>
index_t /*flag*/ = 0)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
asm volatile("buffer_load_dword %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
@@ -108,7 +127,7 @@ struct buffer_load<2>
index_t /*flag*/ = 0)
{
static_assert(sizeof(T) == 4); // subdword is buggy, use dword buf and convert manually
using mbuf_t = float;
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
asm volatile("buffer_load_ushort %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
@@ -128,7 +147,7 @@ struct buffer_load<1>
index_t /*flag*/ = 0)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
asm volatile("buffer_load_ubyte %0, %1, %2, %3 offen offset:%4"
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "s"(s_offset), "n"(i_offset)
@@ -152,7 +171,7 @@ struct buffer_load_if<16>
{
static_assert(sizeof(T) == 16);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x4_t;
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
static_assert(sizeof(mbuf_t) == sizeof(T));
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
@@ -177,7 +196,7 @@ struct buffer_load_if<8>
{
static_assert(sizeof(T) == 8);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = fp32x2_t;
using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_dwordx2 %0, %1, %2, %3 offen offset:%4\n"
@@ -201,7 +220,7 @@ struct buffer_load_if<4>
{
static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_dword %0, %1, %2, %3 offen offset:%4\n"
@@ -225,7 +244,7 @@ struct buffer_load_if<2>
{
static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_ushort %0, %1, %2, %3 offen offset:%4\n"
@@ -249,7 +268,7 @@ struct buffer_load_if<1>
{
static_assert(sizeof(T) == 4);
auto saved_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t;
asm volatile(
"v_cmpx_le_u32 exec, 1, %5\n"
"buffer_load_ubyte %0, %1, %2, %3 offen offset:%4\n"

View File

@@ -171,3 +171,7 @@
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2
#define CK_TILE_FMHA_FWD_FAST_EXP2 0
#endif
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
#endif

View File

@@ -20,3 +20,4 @@
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp"

View File

@@ -27,7 +27,14 @@ struct DeviceMem
DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {}
DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
{
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
if(mMemSize != 0)
{
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
else
{
mpDeviceBuf = nullptr;
}
}
void Realloc(std::size_t mem_size)
{
@@ -36,7 +43,14 @@ struct DeviceMem
HIP_CHECK_ERROR(hipFree(mpDeviceBuf));
}
mMemSize = mem_size;
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
if(mMemSize != 0)
{
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
else
{
mpDeviceBuf = nullptr;
}
}
void* GetDeviceBuffer() const { return mpDeviceBuf; }
std::size_t GetBufferSize() const { return mMemSize; }
@@ -47,15 +61,18 @@ struct DeviceMem
HIP_CHECK_ERROR(
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
}
else
{
throw std::runtime_error("ToDevice with an empty pointer");
}
// else
// {
// throw std::runtime_error("ToDevice with an empty pointer");
// }
}
void ToDevice(const void* p, const std::size_t cpySize) const
{
HIP_CHECK_ERROR(
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), cpySize, hipMemcpyHostToDevice));
if(mpDeviceBuf)
{
HIP_CHECK_ERROR(
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), cpySize, hipMemcpyHostToDevice));
}
}
void FromDevice(void* p) const
{
@@ -63,14 +80,17 @@ struct DeviceMem
{
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
}
else
{
throw std::runtime_error("FromDevice with an empty pointer");
}
// else
// {
// throw std::runtime_error("FromDevice with an empty pointer");
// }
}
void FromDevice(void* p, const std::size_t cpySize) const
{
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
if(mpDeviceBuf)
{
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
}
}
void SetZero() const
{
@@ -82,13 +102,16 @@ struct DeviceMem
template <typename T>
void SetValue(T x) const
{
if(mMemSize % sizeof(T) != 0)
if(mpDeviceBuf)
{
throw std::runtime_error("wrong! not entire DeviceMem will be set");
}
if(mMemSize % sizeof(T) != 0)
{
throw std::runtime_error("wrong! not entire DeviceMem will be set");
}
// TODO: call a gpu kernel to set the value (?)
set_buffer_value<T><<<1, 1024>>>(static_cast<T*>(mpDeviceBuf), x, mMemSize / sizeof(T));
// TODO: call a gpu kernel to set the value (?)
set_buffer_value<T><<<1, 1024>>>(static_cast<T*>(mpDeviceBuf), x, mMemSize / sizeof(T));
}
}
~DeviceMem()
{

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/timer.hpp"
#include <hip/hip_runtime.h>
#include <cstddef>
@@ -14,153 +15,92 @@ template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename...
#if CK_TILE_USE_LAUNCH_BOUNDS
__launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
#endif
__global__ void kentry(Kernel f, Args... args)
__global__ void kentry(Args... args)
{
f(args...);
}
template <typename... Args, typename F>
CK_TILE_HOST float launch_and_time_kernel(const stream_config& s,
F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args... args)
{
#if CK_TILE_TIME_KERNEL
if(s.time_kernel_)
{
// warm up
for(int i = 0; i < s.cold_niters_; ++i)
{
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
}
const int nrepeat = s.nrepeat_;
hipEvent_t start, stop;
HIP_CHECK_ERROR(hipEventCreate(&start));
HIP_CHECK_ERROR(hipEventCreate(&stop));
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipEventRecord(start, s.stream_id_));
for(int i = 0; i < nrepeat; ++i)
{
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
}
HIP_CHECK_ERROR(hipEventRecord(stop, s.stream_id_));
HIP_CHECK_ERROR(hipEventSynchronize(stop));
float total_time = 0;
HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop));
return total_time / nrepeat;
}
else
{
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
return 0;
}
#else
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
return 0;
#endif
}
template <typename... Args, typename F, typename PreProcessFunc>
CK_TILE_HOST float launch_and_time_kernel_with_preprocess(const stream_config& s,
PreProcessFunc preprocess,
F kernel,
dim3 grid_dim,
dim3 block_dim,
std::size_t lds_byte,
Args... args)
{
#if CK_TILE_TIME_KERNEL
if(s.time_kernel_)
{
#if CK_TILE_DEBUG_LOG
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
__func__,
grid_dim.x,
grid_dim.y,
grid_dim.z,
block_dim.x,
block_dim.y,
block_dim.z);
printf("Warm up 1 time\n");
#endif
// warm up
preprocess();
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
const int nrepeat = 10;
#if CK_TILE_DEBUG_LOG
printf("Start running %d times...\n", nrepeat);
#endif
hipEvent_t start, stop;
HIP_CHECK_ERROR(hipEventCreate(&start));
HIP_CHECK_ERROR(hipEventCreate(&stop));
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipEventRecord(start, s.stream_id_));
for(int i = 0; i < nrepeat; ++i)
{
preprocess();
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
}
HIP_CHECK_ERROR(hipEventRecord(stop, s.stream_id_));
HIP_CHECK_ERROR(hipEventSynchronize(stop));
float total_time = 0;
HIP_CHECK_ERROR(hipEventElapsedTime(&total_time, start, stop));
return total_time / nrepeat;
}
else
{
preprocess();
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
return 0;
}
#else
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
hip_check_error(hipGetLastError());
return 0;
#endif
Kernel{}(args...);
}
//
// return a anonymous functor(lambda) to be called later
// the KernelImpl should be a class without non-static data member, or let's say
// can be instantiate with "KernelImpl{}"
//
// the "static __device__ operator()(some_arg)" is the entry point of KernelImpl
//
template <int MaxThreadPerBlock = CK_TILE_MAX_THREAD_PER_BLOCK,
int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
typename KernelImpl,
typename... Args>
CK_TILE_HOST float launch_kernel(const stream_config& s,
KernelImpl kernel_impl,
dim3 grid_dim,
dim3 block_dim,
std::size_t dynamic_smem_byte,
Args... args)
CK_TILE_HOST auto
make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
{
const auto kernel = kentry<MaxThreadPerBlock, MinBlockPerCu, KernelImpl, Args...>;
return launch_and_time_kernel(
s, kernel, grid_dim, block_dim, dynamic_smem_byte, kernel_impl, args...);
return [=](const stream_config& s) {
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
};
}
// clang-format off
/*
* launch_kernel()
*
* this is the function to launch arbitrary number of kernels with optional timer(selected by stream_config)
* the callables should have signature as "operator()(const stream_config& s){ ... }" to call
*
* the simplest way is pass in a lambda function, with "[=](const stream_config& s){ call_your_kernel_here() }"
* as signature, for the callable (pay attention to the capture list)
*
* e.g.
* ck_tile::launch_kernel(s,
* [=](const stream_config& s){ hipMemset(ptr, 0, size) },
* [=](const stream_config& s){ some_kernel<<<grids, blocks>>>(arg); }
* );
*
* if you use ck_tile kernel, or similiar to this style (structure with "static __device__ operator()(...){}")
* you can pass your kernel to ck_tile::make_kernel(), which will create a anonymous functor for you,
* then pass it to ck_tile::launch_kernel()
*
* e.g.
* ck_tile::launch_kernel(s,
* ck_tile::make_kernel<T0, B0>(kernel_0{}, grids0, blocks0, 0, kargs0),
* ck_tile::make_kernel<T0, B1>(kernel_1{}, grids1, blocks1, 0, kargs1),
* ...);
**/
// clang-format on
template <typename... Callables>
CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
{
// clang-format off
if(!s.time_kernel_) {
(callables(s),...); hip_check_error(hipGetLastError());
return 0;
}
if(s.is_gpu_timer_) {
gpu_timer timer {};
// warmup
for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError());
timer.start(s.stream_id_);
for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError());
timer.stop(s.stream_id_);
return timer.duration() / s.nrepeat_;
}
else {
cpu_timer timer {};
// warmup
for(int i = 0; i < s.cold_niters_; i++) { (callables(s),...); } hip_check_error(hipGetLastError());
timer.start(s.stream_id_);
for(int i = 0; i < s.nrepeat_; i++) { (callables(s),...); } hip_check_error(hipGetLastError());
timer.stop(s.stream_id_);
return timer.duration() / s.nrepeat_;
}
// clang-format on
}
} // namespace ck_tile

View File

@@ -6,6 +6,22 @@
#include <hip/hip_runtime.h>
namespace ck_tile {
/*
* construct this structure with behavior as:
*
* // create stream config with default stream(NULL), and not timing the kernel
* stream_config s = stream_config{};
*
* // create stream config with _some_stream_id_, and not timing the kernel
* stream_config s = stream_config{_some_stream_id_};
*
* // create stream config with _some_stream_id_, and benchmark with warmup/repeat as default
* stream_config s = stream_config{_some_stream_id_, true};
*
* // create stream config with _some_stream_id_, and benchmark using cpu timer
* stream_config s = stream_config{_some_stream_id_, true, 0, 3, 10, false};
**/
struct stream_config
{
hipStream_t stream_id_ = nullptr;
@@ -13,5 +29,6 @@ struct stream_config
int log_level_ = 0;
int cold_niters_ = 3;
int nrepeat_ = 10;
bool is_gpu_timer_ = true; // keep compatible
};
} // namespace ck_tile

View File

@@ -0,0 +1,79 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <hip/hip_runtime.h>
#include <cstddef>
#include <chrono>
namespace ck_tile {
struct gpu_timer
{
CK_TILE_HOST gpu_timer()
{
HIP_CHECK_ERROR(hipEventCreate(&start_evt));
HIP_CHECK_ERROR(hipEventCreate(&stop_evt));
}
CK_TILE_HOST ~gpu_timer() noexcept(false)
{
HIP_CHECK_ERROR(hipEventDestroy(start_evt));
HIP_CHECK_ERROR(hipEventDestroy(stop_evt));
}
CK_TILE_HOST void start(const hipStream_t& s)
{
HIP_CHECK_ERROR(hipDeviceSynchronize());
HIP_CHECK_ERROR(hipEventRecord(start_evt, s));
}
CK_TILE_HOST void stop(const hipStream_t& s)
{
HIP_CHECK_ERROR(hipEventRecord(stop_evt, s));
HIP_CHECK_ERROR(hipEventSynchronize(stop_evt));
}
// return in ms
CK_TILE_HOST float duration() const
{
float ms = 0;
HIP_CHECK_ERROR(hipEventElapsedTime(&ms, start_evt, stop_evt));
return ms;
}
private:
hipEvent_t start_evt, stop_evt;
};
struct cpu_timer
{
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST void start(const hipStream_t&)
{
HIP_CHECK_ERROR(hipDeviceSynchronize());
start_tick = std::chrono::high_resolution_clock::now();
}
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST void stop(const hipStream_t&)
{
HIP_CHECK_ERROR(hipDeviceSynchronize());
stop_tick = std::chrono::high_resolution_clock::now();
}
// return in ms
CK_TILE_HOST float duration() const
{
double sec =
std::chrono::duration_cast<std::chrono::duration<double>>(stop_tick - start_tick)
.count();
return static_cast<float>(sec * 1e3);
}
private:
std::chrono::time_point<std::chrono::high_resolution_clock> start_tick;
std::chrono::time_point<std::chrono::high_resolution_clock> stop_tick;
};
} // namespace ck_tile

View File

@@ -23,13 +23,13 @@ VERTICAL:
[0] 1 2 3 4 5
[0] 1 2 3 4 5
TOP_LEFT:
TOP_LEFT(but negative):
[0] 1 2 3 4 5
1 [0] 1 2 3 4
2 1 [0] 1 2 3
3 2 1 [0] 1 2
FROM_BOTTOM_RIGHT:
FROM_BOTTOM_RIGHT(but negative):
2 1 [0] 1 2 3
3 2 1 [0] 1 2
4 3 2 1 [0] 1
@@ -54,7 +54,7 @@ struct Alibi
index_t x_total_,
AlibiMode mode_ = AlibiMode::VERTICAL)
{
slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope;
slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope_;
shift_left_up = [&]() {
if(RowMajor)

View File

@@ -76,7 +76,7 @@ struct FmhaFwdKernel
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<QDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_" +
"_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_"
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
@@ -702,7 +702,7 @@ struct FmhaFwdKernel
else
{
return Alibi<SaccDataType, true>{
slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::VERTICAL};
slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
}
}
else

View File

@@ -18,10 +18,12 @@ struct FmhaFwdTilePartitioner
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
static constexpr const char* name = "shb";
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
@@ -51,4 +53,53 @@ struct FmhaFwdTilePartitioner
}
};
template <typename BlockFmhaShape_>
using FmhaFwdTilePartitioner_SHB = FmhaFwdTilePartitioner<BlockFmhaShape_>;
template <typename BlockFmhaShape_>
struct FmhaFwdTilePartitioner_HBS
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
static constexpr const char* name = "hbs";
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
{
// TODO: this may need tuning
return dim3(nhead_,
batch_size_,
ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
ck_tile::integer_divide_ceil(hdim_v_, kN1));
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const index_t i_block = blockIdx.z;
const index_t i_nhead = blockIdx.x;
const index_t i_batch = blockIdx.y;
const auto f = [](index_t dividend, index_t divisor) {
index_t quotient = dividend / divisor;
index_t modulus = dividend - quotient * divisor;
return ck_tile::make_tuple(quotient, modulus);
};
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
}
};
} // namespace ck_tile