support flatmm scaling

This commit is contained in:
Feng Shijie
2025-07-23 19:04:22 +00:00
parent 3f7d848dd3
commit 5a1183ebbd
7 changed files with 476 additions and 318 deletions

View File

@@ -23,9 +23,12 @@ template <typename FlatmmConfig,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ScaleM,
typename ScaleN,
bool persistent,
typename CDEElementWise>
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s)
float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s)
{
using CodegenFlatmmShape = ck_tile::TileGemmShape<
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
@@ -81,13 +84,13 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
constexpr auto memory_operation = memory_operation_.value;
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using CodegenFlatmmPipeline =
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
@@ -217,6 +220,7 @@ int run_flatmm_example(int argc, char* argv[])
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
int scale_opt = arg_parser.get_int("scale");
if(a_layout == "R" && b_layout == "C")
{
if(data_type == "fp16")
@@ -231,13 +235,29 @@ int run_flatmm_example(int argc, char* argv[])
}
else if(data_type == "fp8")
{
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
argc, argv, Row{}, Col{}, Row{});
if(scale_opt == 0)
{
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>, 1, 1>(
argc, argv, Row{}, Col{}, Row{});
}
}
else if(data_type == "bf8")
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
argc, argv, Row{}, Col{}, Row{});
if(scale_opt == 0)
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>, 1, 1>(
argc, argv, Row{}, Col{}, Row{});
}
}
else
{

View File

@@ -83,10 +83,10 @@ struct FlatmmConfig16
template <typename DataType>
struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
{
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType);
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType);
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
static constexpr int kBlockPerCu = 1;
static constexpr int kBlockPerCu = 1;
};
template <typename ADataType>
@@ -167,120 +167,6 @@ struct is_8bit_type
{
};
template <typename DataType>
struct GemmConfig
{
#if defined(USING_MFMA_16x16x128_F8) //MI350 FP8 16X16
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
#elif defined(USING_MFMA_32x32x64_F8) //MI350 FP8 32X32 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 64;
#elif defined(USING_MFMA_16x16x32_F16) //MI350 FP16 16X16 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
#elif defined(USING_MFMA_32x32x16_F16) //MI350 FP16 32X32 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
#elif defined(USING_MFMA_16x16x32_F8) //MI300 FP8 16X16
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 64;
#elif defined(USING_MFMA_32x32x16_F8) //MI300 FP8 32X32 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 8;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
#elif defined(USING_MFMA_16x16x16_F16) //MI300 FP16 16X16 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
#elif defined(USING_MFMA_32x32x8_F16) //MI300 FP16 32X32 (need tune)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
#else
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
#endif
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
@@ -301,6 +187,7 @@ auto create_args(int argc, char* argv[])
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8")
.insert("warp_tile",
"0",
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");

View File

@@ -18,7 +18,7 @@ constexpr const char* DataTypeToString()
{
return "bf8";
}
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
return "bf16";
}
@@ -83,9 +83,12 @@ template <typename FlatmmConfig,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ScaleM,
typename ScaleN,
bool persistent,
typename CDEElementWise>
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s);
float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s);
template <typename FlatmmConfig,
typename ADataType,
@@ -97,6 +100,8 @@ template <typename FlatmmConfig,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ScaleM,
typename ScaleN,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
ck_tile::DeviceMem& b_shuffle_dev_buf,
@@ -108,21 +113,25 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
ScaleM scale_m,
ScaleN scale_n,
int n_warmup,
int n_repeat)
{
ck_tile::FlatmmHostArgs<> args = {a_dev_buf.GetDeviceBuffer(),
b_shuffle_dev_buf.GetDeviceBuffer(),
{},
c_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C};
ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN> args = {a_dev_buf.GetDeviceBuffer(),
b_shuffle_dev_buf.GetDeviceBuffer(),
{},
c_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C,
scale_m,
scale_n};
float ave_time = flatmm_calc<FlatmmConfig,
ADataType,
@@ -134,6 +143,8 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
BLayout,
DsLayout,
CLayout,
ScaleM,
ScaleN,
false,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
@@ -154,6 +165,8 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
template <typename PrecType,
typename FlatmmConfig,
int ScaleGranularityM = -1,
int ScaleGranularityN = -1,
typename ALayout,
typename BLayout,
typename CLayout>
@@ -197,21 +210,30 @@ int run_flatmm_example_with_layouts(int argc,
ck_tile::HostTensor<CDataType> c_rslt_host(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::HostTensor<AccDataType> per_token_scale(ck_tile::HostTensorDescriptor({M}, {1}));
ck_tile::HostTensor<AccDataType> per_channel_scale(ck_tile::HostTensorDescriptor({N}, {1}));
// TODO: add different init types
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_channel_scale);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<ADataType>{}(a_host);
ck_tile::FillMonotonicSeq<BDataType>{}(b_origin_host);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_channel_scale);
}
else if(init_method == 2)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_channel_scale);
}
else
{
@@ -222,14 +244,25 @@ int run_flatmm_example_with_layouts(int argc,
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
ck_tile::DeviceMem per_channel_scale_dev_buf(
per_channel_scale.get_element_space_size_in_bytes());
a_dev_buf.ToDevice(a_host.data());
c_rslt_host.SetZero();
per_token_scale_dev_buf.ToDevice(per_token_scale.data());
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
// do pre-shuffle
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<FlatmmConfig>(b_origin_host);
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
invoke_flatmm<FlatmmConfig,
ADataType,
BDataType,
@@ -239,18 +272,22 @@ int run_flatmm_example_with_layouts(int argc,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout>(a_dev_buf,
b_shuffle_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
CLayout,
decltype(per_token_scale_dev_ptr),
decltype(per_channel_scale_dev_ptr)>(a_dev_buf,
b_shuffle_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
per_token_scale_dev_ptr,
per_channel_scale_dev_ptr,
n_warmup,
n_repeat);
c_dev_buf.FromDevice(c_rslt_host.data());
bool pass = true;
@@ -263,6 +300,8 @@ int run_flatmm_example_with_layouts(int argc,
if(arg_parser.get_int("v") == 1)
{
assert(ScaleGranularityM == -1 && ScaleGranularityN == -1 &&
"ScaleAB is not supported for CPU verification!");
ck_tile::HostTensor<CDataType> c_ref_host(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_ref_host.SetZero();
@@ -310,13 +349,41 @@ int run_flatmm_example_with_layouts(int argc,
N * K * sizeof(BDataType),
hipMemcpyHostToDevice));
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
if constexpr(ScaleGranularityM == -1 && ScaleGranularityN == -1)
{
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
}
else
{
ck_tile::reference_blockwise_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
d_A,
d_B,
d_C,
M,
N,
K,
stride_A,
stride_B,
stride_C,
ScaleGranularityM,
ScaleGranularityN,
K,
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer()));
}
ck_tile::hip_check_error(hipMemcpy(c_gpu_ref_dev_buf.GetDeviceBuffer(),
d_C,

View File

@@ -165,6 +165,9 @@ struct sequence
return sequence<Is..., Xs...>{};
}
CK_TILE_HOST_DEVICE static constexpr auto sum() { return (Is + ... + 0); }
CK_TILE_HOST_DEVICE static constexpr auto product() { return (Is * ... * 1); }
// pickup element at index <Ids...>
template <index_t... Ids>
CK_TILE_HOST_DEVICE static constexpr auto extract(number<Ids>...)
@@ -1236,9 +1239,8 @@ constexpr auto reverse_slice_sequence(Seq,
template <typename Seq,
index_t SliceSize,
typename Mask = typename uniform_sequence_gen<Seq::size(), 1>::type>
constexpr auto slice_sequence(Seq,
number<SliceSize>,
Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
constexpr auto
slice_sequence(Seq, number<SliceSize>, Mask = typename uniform_sequence_gen<Seq::size(), 1>::type{})
{
constexpr auto r =
reverse_slice_sequence(Seq{}.reverse(), number<SliceSize>{}, Mask{}.reverse());

View File

@@ -195,6 +195,104 @@ __global__ void naive_gemm_kernel(ADataType* A,
}
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
__global__ void blockwise_gemm_kernel(ADataType* A,
BDataType* B,
CDataType* C,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t strideA,
ck_tile::index_t strideB,
ck_tile::index_t strideC,
ck_tile::index_t scale_granularity_m,
ck_tile::index_t scale_granularity_n,
ck_tile::index_t scale_granularity_k,
float* scale_A_ptr,
float* scale_B_ptr)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int row = idx / N; // Compute row index
int col = idx % N; // Compute column index
if(row < M && col < N)
{
AccDataType acc = 0.0, acc_temp = 0.0;
index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
float scale_A = 0;
float scale_B = 0;
for(int k = 0; k < K; ++k)
{
if(k % scale_granularity_k == 0)
{
// update acc
acc += acc_temp * scale_A * scale_B;
acc_temp = 0.0;
// update scale factors
scale_A = scale_A_ptr[(row / scale_granularity_m) +
(k / scale_granularity_k) * scale_A_stride];
scale_B = scale_B_ptr[(col / scale_granularity_n) +
(k / scale_granularity_k) * scale_B_stride];
}
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
// Adjust indexing based on matrix layout
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? row * strideA + k
: k * strideA + row;
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? col * strideB + k
: k * strideB + col;
AccDataType v_a;
AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
}
else
{
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
}
acc_temp += v_a * v_b;
}
// final accumulation
acc += acc_temp * scale_A * scale_B;
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? row * strideC + col
: col * strideC + row;
C[c_index] = ck_tile::type_convert<CDataType>(acc);
}
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
@@ -223,6 +321,51 @@ void reference_gemm_gpu(ADataType* a_ptr,
return;
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_blockwise_gemm_gpu(ADataType* a_ptr,
BDataType* b_ptr,
CDataType* c_ptr,
index_t M,
index_t N,
index_t K,
index_t stride_a,
index_t stride_b,
index_t stride_c,
index_t scale_granularity_m,
index_t scale_granularity_n,
index_t scale_granularity_k,
float* scale_A_ptr,
float* scale_B_ptr)
{
int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
blockwise_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(a_ptr,
b_ptr,
c_ptr,
M,
N,
K,
stride_a,
stride_b,
stride_c,
scale_granularity_m,
scale_granularity_n,
scale_granularity_k,
scale_A_ptr,
scale_B_ptr);
return;
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
@@ -260,4 +403,5 @@ void reference_batched_gemm_gpu(ADataType* a_ptr,
return;
}
} // namespace ck_tile

View File

@@ -282,8 +282,8 @@ struct CShuffleEpilogue
{0, 0});
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
constexpr index_t num_access = SFC::get_num_of_access();
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
@@ -334,8 +334,8 @@ struct CShuffleEpilogue
const auto c_ds_tiles = concat_tuple_of_reference(
tie(c_out_tensor, c_out_tensor),
generate_tie(
[&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
number<NumDTensor>{}));
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
@@ -360,7 +360,12 @@ struct CShuffleEpilogue
}
});
}
template <typename ODramWindow, typename OAccTile, typename DsDramWindows, typename ScaleM, typename ScaleN>
template <typename ODramWindow,
typename OAccTile,
typename DsDramWindows,
typename ScaleM,
typename ScaleN>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
@@ -368,118 +373,133 @@ struct CShuffleEpilogue
ScaleM scale_m,
ScaleN scale_n)
{
// const index_t iMWarp = get_warp_id() / kNWave;
// const index_t iNWarp = get_warp_id() - iMWarp * kNWave;
// const index_t iMLane = get_lane_id() / NPerXdl;
// const index_t iNLane = get_lane_id() % NPerXdl;
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
// constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
// auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
static_cast<ODataType*>(p_smem), lds_block_desc);
// constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
// auto o_lds_block = make_tensor_view<address_space_enum::lds>(
// static_cast<ODataType*>(p_smem), lds_block_desc);
auto in_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0},
LdsTileDistr);
// auto in_lds_window = make_tile_window(
// o_lds_block,
// make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
// {0, 0},
// LdsTileDistr);
auto out_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0});
// auto out_lds_window = make_tile_window(
// o_lds_block,
// make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
// {0, 0});
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
constexpr index_t num_access = SFC::get_num_of_access();
// using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
// sequence<0, 1>,
// sequence<MPerIterationShuffle, NPerIterationShuffle>>;
// constexpr index_t num_access = SFC::get_num_of_access();
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
// static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
// "Currently, the CShuffle Epilogue only supports the Row Major Output layout");
using TileEncodingPattern =
TileDistributionEncodingPattern2D<kBlockSize,
MPerIterationShuffle,
NPerIterationShuffle,
GetVectorSizeC(),
tile_distribution_pattern::thread_raked,
Problem::kNumWaveGroups>;
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
// using TileEncodingPattern =
// TileDistributionEncodingPattern2D<kBlockSize,
// MPerIterationShuffle,
// NPerIterationShuffle,
// GetVectorSizeC(),
// tile_distribution_pattern::thread_raked,
// Problem::kNumWaveGroups>;
// constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
auto d_dram_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
},
number<NumDTensor>{});
// auto d_dram_windows = generate_tuple(
// [&](auto idx) {
// return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
// },
// number<NumDTensor>{});
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// constexpr auto c_warp_y_lengths =
// to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
// constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr int kM2 = 4; // Val
constexpr int kM1 = (64 / NPerXdl); // Thr
constexpr int kM0 = MPerXdl / kM1; // Val
// static_for<0, num_access, 1>{}([&](auto iAccess) {
// block_sync_lds();
// constexpr auto idx_y_start = SFC::get_index(iAccess);
const index_t iMWarp = get_warp_id() / NWave;
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
const index_t iMLane = get_lane_id() / NPerXdl;
const index_t iNLane = get_lane_id() % NPerXdl;
// constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
// constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
static_for<0, num_access, 1>{}([&](auto iAccess) {
block_sync_lds();
constexpr auto idx_y_start = SFC::get_index(iAccess);
// lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
// merge_sequences(
// sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
// c_warp_y_index_zeros),
// merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
// c_warp_y_lengths));
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
// const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
// store_tile(in_lds_window, c_warptile_in_tensor_casted);
// block_sync_lds();
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
float scale_B =
scale_n[nIter * NPerIterationShuffle +
iNWarp * NumNXdlPerWavePerShuffle * NPerXdl + n_xdl * NPerXdl + iNLane];
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl * NumMXdlPerWavePerShuffle + n_xdl) * c_warp_y_lengths.product();
// auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
// auto m1 = iMLane;
// float scale_B = scale_n[nIter * NPerIterationShuffle];
// static_for<0, kM0, 1>{}([&](auto m0) {
// static_for<0, kM2, 1>{}([&](auto m2) {
// float scale_A = scale_m[mIter * MPerIterationShuffle + iMWarp * MPerXdl +
// m0 * kM1 * kM2 + m1 * kM2 + m2];
// c_out_tensor.get_thread_buffer()[m0 * kM2 + m2] *= scale_A * scale_B;
// });
// });
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
float scale_A =
scale_m[mIter * MPerIterationShuffle +
iMWarp * NumMXdlPerWavePerShuffle * MPerXdl +
m_xdl * MPerXdl + m0 * kM1 * kM2 + iMLane * kM2 + m2];
lds_tile.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
scale_A * scale_B;
});
});
});
});
// const auto ds_tensor = generate_tuple(
// [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
// const auto c_ds_tiles = concat_tuple_of_reference(
// tie(c_out_tensor, c_out_tensor),
// generate_tie(
// [&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));
store_tile(in_lds_window, c_warptile_in_tensor_casted);
block_sync_lds();
// tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
// if constexpr(MemoryOperation == memory_operation_enum::set)
// {
// store_tile(out_dram_window, c_out_tensor);
// }
// else
// {
// update_tile(out_dram_window, c_out_tensor);
// }
// if constexpr(iAccess != num_access - 1)
// {
// constexpr auto step = SFC::get_forward_step(iAccess);
const auto ds_tensor = generate_tuple(
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
// move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
const auto c_ds_tiles = concat_tuple_of_reference(
tie(c_out_tensor, c_out_tensor),
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
number<NumDTensor>{}));
// static_for<0, NumDTensor, 1>{}([&](auto idx) {
// move_tile_window(d_dram_windows[idx],
// {step.at(number<0>{}), step.at(number<1>{})});
// });
// }
// });
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
else
{
update_tile(out_dram_window, c_out_tensor);
}
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(iAccess);
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
static_for<0, NumDTensor, 1>{}([&](auto idx) {
move_tile_window(d_dram_windows[idx],
{step.at(number<0>{}), step.at(number<1>{})});
});
}
});
}
};
} // namespace ck_tile

View File

@@ -102,17 +102,17 @@ struct BaseFlatmmHostArgs
{
CK_TILE_HOST BaseFlatmmHostArgs() = default;
CK_TILE_HOST BaseFlatmmHostArgs(const void* a_ptr_,
const void* b_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_E_)
const void* b_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_E_)
: a_ptr(a_ptr_),
b_ptr(b_ptr_),
ds_ptr(ds_ptr_),
@@ -151,35 +151,49 @@ struct BaseFlatmmHostArgs
index_t k_batch;
};
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>, index_t NumDTensor = 0>
template <class ScaleM = FlatmmScalePointer<-1>,
class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
struct ScaleFlatmmHostArgs : public BaseFlatmmHostArgs<>
{
CK_TILE_HOST ScaleFlatmmHostArgs() = default;
CK_TILE_HOST ScaleFlatmmHostArgs(const void* a_ptr_,
const void* b_shuffle_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* c_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_C_,
ScaleM scale_m_ = nullptr,
ScaleN scale_n_ = nullptr)
: BaseFlatmmHostArgs(a_ptr_, b_shuffle_ptr_, ds_ptr_, c_ptr_, k_batch_, M_, N_, K_, stride_A_, stride_B_, stride_Ds_, stride_C_),
scale_m(scale_m_),
scale_n(scale_n_)
const void* b_shuffle_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* c_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_C_,
ScaleM scale_m_ = nullptr,
ScaleN scale_n_ = nullptr)
: BaseFlatmmHostArgs(a_ptr_,
b_shuffle_ptr_,
ds_ptr_,
c_ptr_,
k_batch_,
M_,
N_,
K_,
stride_A_,
stride_B_,
stride_Ds_,
stride_C_),
scale_m(scale_m_),
scale_n(scale_n_)
{
}
ScaleM scale_m = nullptr;
ScaleN scale_n = nullptr;
};
template <int NumberTensor=0>
using FlatmmHostArgs = ScaleFlatmmHostArgs<FlatmmScalePointer<-1>, FlatmmScalePointer<-1>, NumberTensor>;
template <int NumberTensor = 0>
using FlatmmHostArgs =
ScaleFlatmmHostArgs<FlatmmScalePointer<-1>, FlatmmScalePointer<-1>, NumberTensor>;
template <class ScaleM, class ScaleN, index_t NumDTensor = 0>
struct FlatmmKernelArgs
@@ -278,7 +292,8 @@ struct FlatmmKernel
struct SplitKBatchOffset
{
template <class KernelArgs>
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) {
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
{
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = kargs.k_batch * K1;
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
@@ -681,16 +696,17 @@ struct FlatmmKernel
}
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
CK_TILE_DEVICE static void RunFlatmm(const ADataType* a_ptr,
const BDataType* b_flat_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* smem_ptr_ping,
void* smem_ptr_pong,
const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
CK_TILE_DEVICE static void
RunFlatmm(const ADataType* a_ptr,
const BDataType* b_flat_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* smem_ptr_ping,
void* smem_ptr_pong,
const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
@@ -712,19 +728,21 @@ struct FlatmmKernel
if constexpr(ScaleM::granularity != -1 || ScaleN::granularity != -1)
{
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window,
c_block_tile,
d_block_window,
smem_ptr_ping,
kargs.scale_m_ptr + block_idx_m,
kargs.scale_n_ptr + block_idx_n);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window,
c_block_tile,
d_block_window,
smem_ptr_ping,
kargs.scale_m_ptr + block_idx_m,
kargs.scale_n_ptr + block_idx_n);
}
else if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
}
}
@@ -755,15 +773,15 @@ struct FlatmmKernel
{
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
b_flat_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_ping,
smem_ptr_pong,
kargs,
splitk_batch_offset,
i_m,
i_n);
b_flat_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_ping,
smem_ptr_pong,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
}
};