tempsave. compile pass, function wrong

This commit is contained in:
aska-0096
2025-05-20 10:57:26 +00:00
parent f3a296bad4
commit e1084fe7d6
12 changed files with 5418 additions and 620 deletions

View File

@@ -10,6 +10,7 @@ add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8)
add_example_executable(example_gemm_mx_fp4 gemm_mx_fp4.cpp)
add_example_executable(example_gemm_mx_fp4_bpreshuffle gemm_mx_fp4_bpreshuffle.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp4)
set(FP4_MXGEMM_OPTIONS)
@@ -19,4 +20,5 @@ set(FP8_MXGEMM_OPTIONS)
list(APPEND FP8_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
list(APPEND FP8_MXGEMM_OPTIONS -v --save-temps -Wno-gnu-line-marker -ftemplate-backtrace-limit=0)
target_compile_options(example_gemm_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
target_compile_options(example_gemm_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
target_compile_options(example_gemm_mx_fp8 PRIVATE ${FP8_MXGEMM_OPTIONS})

View File

@@ -0,0 +1,585 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ck::type_convert;
struct ExecutionConfig final
{
int do_verification = 1; // (0=no, 1=CPU)
int init_method = 2; // (0=constant values, 1=integer values, 2=decimal values)
bool time_kernel = false; // (0=no, 1=yes)
int verbosity = 0; // (0=no info, 1=verbose info)
};
struct ProblemSizeSplitK final
{
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = -1;
ck::index_t StrideB = -1;
ck::index_t StrideC = -1;
ck::index_t KBatch = 1;
};
bool parse_cmd_args(int argc,
char* argv[],
ProblemSizeSplitK& problem_size,
ExecutionConfig& config)
{
if(argc == 1)
{
// use default case
}
else if(argc == 5)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
config.verbosity = std::stoi(argv[4]);
}
else if(argc >= 11)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
config.verbosity = std::stoi(argv[4]);
problem_size.M = std::stoi(argv[5]);
problem_size.N = std::stoi(argv[6]);
problem_size.K = std::stoi(argv[7]);
problem_size.StrideA = std::stoi(argv[8]);
problem_size.StrideB = std::stoi(argv[9]);
problem_size.StrideC = std::stoi(argv[10]);
if(argc >= 12)
{
problem_size.KBatch = std::stoi(argv[11]);
}
}
else
{
std::cerr << "arg1: verification (0=no, 1=CPU)" << std::endl
<< "arg2: initialization (0=constant values, 1=integer values, 2=decimal values)"
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4: verbosity (0=no info, 1=verbose info)" << std::endl
<< "arg5 to 10: M(128x), N(128x), K(256x), StrideA, StrideB, StrideC" << std::endl
<< "arg11: KBatch" << std::endl;
return false;
}
return true;
}
#if 1
template <bool KLast>
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
{
int MNXdlPack = 2;
int KXdlPack = 2;
int XdlMNThread = 16;
int XdlKThread = 64 / XdlMNThread;
int K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
// Then, MNRepeat->KRepeat
for(int n = 0; n < MN; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
int tempn = n % (XdlMNThread * MNXdlPack);
int n1 = tempn % XdlMNThread; // i XdlMNThread
int n2 = tempn / XdlMNThread; // i MNXdlPack
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
int tempk = k % (XdlKThread * KXdlPack);
int k1 = tempk % XdlKThread; // i XdlKThread
int k2 = tempk / XdlKThread; // i KXdlPack
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
// k2 * MNXdlPack)));
if constexpr(KLast)
dst[outputIndex] = src[n * K + k];
else
dst[outputIndex] = src[k * MN + n];
}
}
}
void preShuffleBuffer(const ck::f4x2_pk_t* src, ck::f4x2_pk_t* dst, int N, int K, int NXdl)
{
int KPack = 16;
int NLane = NXdl;
int KLane = 64 / NLane;
int K0 = K / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / NLane;
int n1 = n % NLane;
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[(outputIndex+1)/2] = src[(n * K + k + 1)/2];
}
}
}
#endif
template <typename DeviceOpInstance,
typename ADataType,
typename BDataType,
typename XDataType,
typename XPackedDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementOp,
typename BElementOp,
typename CElementOp,
typename AccDataType,
typename CShuffleDataType,
ck::index_t ScaleBlockSize>
bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& config)
{
auto M = problem_size.M;
auto N = problem_size.N;
auto K = problem_size.K;
auto StrideA = problem_size.StrideA;
auto StrideB = problem_size.StrideB;
auto StrideC = problem_size.StrideC;
auto KBatch = problem_size.KBatch;
auto f_host_tensor_descriptor =
[](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1});
}
else
{
return HostTensorDescriptor({row, col}, {1, stride});
}
};
auto f_get_default_stride =
[](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) {
if(stride == -1)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<ck::index_t>(col);
}
else
{
return static_cast<ck::index_t>(row);
}
}
else
return static_cast<ck::index_t>(stride);
};
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
if(K % ScaleBlockSize != 0)
{
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};
// Hardcode scale layouts as per pipeline assumptions
// TODO: Allow user to specify scale layouts
using AScaleLayout = Row;
using BScaleLayout = Col;
auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{});
auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{});
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_preshuffled(
f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // use layout only for size
Tensor<XDataType> a_m_k_scale(f_host_tensor_descriptor(
M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A
Tensor<XDataType> b_k_n_scale(f_host_tensor_descriptor(
K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); // scales for B
Tensor<XDataType> a_shuffled_scale(f_host_tensor_descriptor(
M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A
Tensor<XDataType> b_shuffled_scale(f_host_tensor_descriptor(
K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); // scales for B
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // host verification
Tensor<CDataType> c_m_n_device_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // device result downloaded to host
if(config.verbosity >= 0)
{
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl;
std::cout << "c_m_n_device_result: " << c_m_n_device_result.mDesc << std::endl;
}
auto a_data_element = [](float x) {
if constexpr(ck::is_same_v<ADataType, ck::f4x2_pk_t>)
return ck::type_convert<ADataType>(ck::float2_t(x));
else
return ck::type_convert<ADataType>(x);
};
auto b_data_element = [](float x) {
if constexpr(ck::is_same_v<BDataType, ck::f4x2_pk_t>)
return ck::type_convert<BDataType>(ck::float2_t(x));
else
return ck::type_convert<BDataType>(x);
};
switch(config.init_method)
{
case 0: // Initializations for development and debugging
ck::utils::FillConstant<ADataType>{a_data_element(1.0f)}(a_m_k);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(a_m_k_scale);
ck::utils::FillConstant<BDataType>{b_data_element(2.0f)}(b_k_n);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(0.5f)}(b_k_n_scale);
if(config.verbosity > 0)
{
std::cout << "Init A = {1}" << std::endl;
std::cout << "Init A scale = {2.0}" << std::endl;
std::cout << "Init B = {0.5}" << std::endl;
std::cout << "Init B scale = {1.0}" << std::endl;
std::cout << "Expect C = {K}" << std::endl;
}
break;
case 1:
ck::utils::FillConstant<ADataType>{a_data_element(1.0f)}(a_m_k);
ck::utils::FillConstant<BDataType>{b_data_element(1.0f)}(b_k_n);
a_m_k_scale.GenerateTensorValue(
GeneratorTensor_2<XDataType>{120, 129}); // scales: {0.25, 0.5, 1, 2}
b_k_n_scale.GenerateTensorValue(
GeneratorTensor_2<XDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
ck::utils::FillConstant<BDataType>{b_data_element(1.0f)}(b_k_n);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(a_m_k_scale);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(b_k_n_scale);
break;
case 3:
ck::utils::FillConstant<ADataType>{a_data_element(1.0f)}(a_m_k);
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(a_m_k_scale);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(b_k_n_scale);
break;
case 4:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
a_m_k_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
b_k_n_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f});
break;
default:
if(config.verbosity > 0)
{
std::cout << "NOTE: No input data initialization." << std::endl;
}
}
#if 1
preShuffleScaleBuffer<ck::is_same_v<ALayout, Row>>(
a_m_k_scale.mData.data(), a_shuffled_scale.mData.data(), M, K / ScaleBlockSize);
preShuffleScaleBuffer<ck::is_same_v<BLayout, Col>>(
b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize);
int NPerXdl = 16; // Fixed 16
preShuffleBuffer(b_k_n.mData.data(), b_preshuffled.mData.data(), N, K, NPerXdl);
#endif
printf("b:\n");
for(ck::index_t i = 0; i < N; i++)
{
for(ck::index_t j = 0; j < K; j+=2)
{
printf("%02x ", *reinterpret_cast<uint8_t*>(&b_k_n(j, i)));
if ( j %32 == 31)
{
printf("\n");
}
}
printf("\n");
}
// printf("b_scale:\n");
// for(ck::index_t i = 0; i < N; i++)
// {
// for(ck::index_t j = 0; j < K / ScaleBlockSize; j++)
// {
// // // b_k_n_scale(j, i) =
// // // ck::type_convert<XDataType>(static_cast<float>(powf(2.0f, (j / 4) % 4)));
// // b_k_n_scale(j, i) =ck::type_convert<XDataType>(static_cast<float>(1.0f));
// // b_shuffled_scale(j, i) =ck::type_convert<XDataType>(static_cast<float>(1.0f));
// printf("%02x ", *reinterpret_cast<uint8_t*>(&b_k_n_scale(j, i)));
// }
// printf("\n");
// }
// printf("a_shuffled_scale:\n");
// for(ck::index_t i = 0; i < M * K / ScaleBlockSize; i++)
// {
// printf("%02x ", *reinterpret_cast<uint8_t*>(&(a_shuffled_scale.mData.data()[i])));
// if(i % 64 == 63)
// printf("\n");
// }
// printf("b_shuffled_scale:\n");
// for(ck::index_t i = 0; i < N * K / ScaleBlockSize; i++)
// {
// printf("%02x ", *reinterpret_cast<uint8_t*>(&(b_shuffled_scale.mData.data()[i])));
// if(i % 64 == 63)
// printf("\n");
// }
if(config.verbosity > 0)
std::cout << "Device memory allocation..." << std::endl;
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.GetElementSpaceSize());
DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.GetElementSpaceSize());
DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.GetElementSpaceSize());
if(config.verbosity > 0)
std::cout << "Upload data to device..." << std::endl;
a_device_buf.ToDevice(a_m_k.mData.data());
a_scale_device_buf.ToDevice(a_shuffled_scale.mData.data());
b_device_buf.ToDevice(b_preshuffled.mData.data());
b_scale_device_buf.ToDevice(b_shuffled_scale.mData.data());
if(config.verbosity > 0)
std::cout << "Done." << std::endl;
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// run GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<XPackedDataType*>(a_scale_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<XPackedDataType*>(b_scale_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
Scale_Stride_AM,
StrideB,
Scale_Stride_BN,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error("wrong!\n"
"Provided combination of compilation and runtime parameters is "
"not consistent with the supported device_gemm arguments.");
}
if(config.verbosity > 0)
{
std::cout << "Computing GEMM on device..." << std::endl << std::endl;
}
float ave_time =
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50});
bool res_verified = true;
if(config.do_verification > 0)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
if(config.verbosity > 0)
{
std::cout << "Done." << std::endl;
std::cout << "Computing GEMM on host..." << std::endl;
}
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm<ADataType,
BDataType,
CDataType,
AccDataType,
XDataType,
PassThrough,
PassThrough,
PassThrough,
float,
float>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_m_k,
a_m_k_scale,
b_k_n,
b_k_n_scale,
c_m_n_host_result,
PassThrough{},
PassThrough{},
PassThrough{});
ref_invoker.Run(ref_argument);
if(config.verbosity > 0)
{
std::cout << "Done." << std::endl;
std::cout << "Comparing results..." << std::endl;
}
// if(config.init_method == 0)
// {
// auto expected = static_cast<float>(K);
// auto computed = type_convert<float>(c_m_n_device_result(1, 12));
// res_verified = res_verified && std::abs(expected - computed) <= 0.0f;
// std::cout << "\nExpected vs Computed: " << expected << " vs " << computed
// << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl
// << std::endl;
// }
res_verified = res_verified && ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!");
if(config.verbosity > 0 && res_verified)
std::cout << "Verification Successful!" << std::endl;
}
else
{
if(config.verbosity > 0)
std::cout << "Done." << std::endl;
}
if(config.time_kernel)
{
// Output size(M*N) * [dot product(2K) + product of scales(K/ScaleBlockSize) + scaling of
// partial sums(K/ScaleBlockSize)]
// FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N +
sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << device_op.GetTypeString() << std::endl;
}
return res_verified;
}
template <typename DeviceOpInstance,
typename ADataType,
typename BDataType,
typename XDataType,
typename XPackedDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename AElementOp,
typename BElementOp,
typename CElementOp,
typename AccDataType,
typename CShuffleDataType,
ck::index_t MXVectorSize>
bool run_mx_gemm_example(int argc, char* argv[])
{
ProblemSizeSplitK problem_size;
ExecutionConfig config;
return parse_cmd_args(argc, argv, problem_size, config) &&
run_mx_gemm<DeviceOpInstance,
ADataType,
BDataType,
XDataType,
XPackedDataType,
CDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
AccDataType,
CShuffleDataType,
MXVectorSize>(problem_size, config);
}

View File

@@ -0,0 +1,105 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_mx_bpreshuffle_common.hpp"
using ADataType = ck::f4x2_pk_t;
using BDataType = ck::f4x2_pk_t;
// using ADataType = ck::f4_t;
// using BDataType = ck::f4_t;
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t;
using CDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = CDataType;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
// AB DataType: f4x2_pk_t
// Mathmatically, all numbers are represented as f4x2.
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle<
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
ADataType, // ADataType
XPackedDataType, // AScaleDataType
BDataType, // BDataType
XPackedDataType, // BScaleDataType
CDataType, // CDataType
AccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CElementOp, // CElementwiseOperation
GemmSpec, // GemmSpec
ScaleBlockSize, // ScaleBlockSize: Scaling block size
64, // BlockSize: Thread block size
32, // MPerBlock
32, // NPerBlock
KPerBlock, // KPerBlock
16, // AK1
16, // BK1
16, // MPerXDL
16, // NPerXDL
2, // MXdlPerWave
2, // NXdlPerWave
S<8, 8, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM
S<8, 8, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
16, // BBlockTransferSrcScalarPerVector
16, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN
2, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 16, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
BlkGemmPSched, // BlkGemmPipeSched
BlkGemmPVer, // BlkGemmPipelineVer
ADataType, // ComputeTypeA
BDataType // ComputeTypeB
>;
int main(int argc, char* argv[])
{
return run_mx_gemm_example<DeviceOpInstance,
ADataType,
BDataType,
XDataType,
XPackedDataType,
CDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
AccDataType,
CShuffleDataType,
ScaleBlockSize>(argc, argv)
? 0
: -1;
}

View File

@@ -123,6 +123,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
using Base::I0;
using Base::I1;
using Base::I2;
using Base::KGroup;
using Base::KRepeat;
using Base::xdlops_gemm;
using typename Base::HotLoopInstList;
@@ -139,10 +140,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::MWaves;
static constexpr index_t PrefetchStages = 2;
@@ -156,9 +153,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack;
constexpr index_t K2 = KPack / KGroup;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K0 = KRepeat;
constexpr index_t K0 = KRepeat * KGroup;
return transform_tensor_descriptor(
TileDesc_M0_M1_M2_K{},
@@ -184,298 +181,137 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
}
template <typename Stage>
__device__ static constexpr auto HotLoopScheduler(Stage stage)
__device__ static constexpr auto HotLoopScheduler()
{
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
? HotLoopInstList::A_LDS_Read_Inst_Num
: HotLoopInstList::A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num;
static_assert(num_buffer_load_inst_a == num_ds_write_inst_a);
constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
constexpr auto staged_num_mfma = num_mfma / MRepeat;
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle);
if constexpr(stage.value == 0)
{
constexpr auto staged_num_buffer_load_b_per_ds_read_a =
num_buffer_load_inst_b / staged_num_ds_read_inst_a;
constexpr auto staged_num_mfma_per_buffer_load_b =
staged_num_mfma / num_buffer_load_inst_b;
// B global
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
// constexpr auto num_dsread_a_mfma =
// (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
static_for<0, staged_num_buffer_load_b_per_ds_read_a - 1, 1>{}([&](auto ibuf_inst) {
ignore = ibuf_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
constexpr auto num_total_stages = MRepeat;
// Group num_mfma_perstage num_ds_read_a_perstage
// since we want to reuse a local register buffer
constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
constexpr auto num_ds_read_a_mfma_perstage =
math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
constexpr auto num_ds_read_a_prefetch_stages = 2;
constexpr auto buffer_load_perstage_more = math::integer_divide_ceil(
(num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
constexpr auto buffer_load_perstage_less = math::integer_divide_floor(
(num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
constexpr auto buffer_load_stages_more =
(num_buffer_load_inst_a + num_buffer_load_inst_b) -
math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b),
(num_total_stages - 2)) *
((num_total_stages - 2));
constexpr auto buffer_load_b_stages =
buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b
? num_buffer_load_inst_b / buffer_load_perstage_more
: (buffer_load_stages_more +
(num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) /
buffer_load_perstage_less);
constexpr auto buffer_load_a_stages =
num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages;
constexpr auto buffer_load_issue_point_b = 0;
constexpr auto buffer_load_issue_point_interval_more =
num_mfma_perstage / buffer_load_perstage_more;
constexpr auto buffer_load_issue_point_interval_less =
num_mfma_perstage / buffer_load_perstage_less;
constexpr auto ds_write_issue_point = 0;
constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
// B global read
static_for<0, buffer_load_b_stages, 1>{}([&](auto i) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(stage.value == 1)
{
constexpr auto staged_num_mfma_per_ds_write_a =
math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a);
constexpr auto stage_more_mfma =
staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a;
// A local write
static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) {
if constexpr(i_inst.value < stage_more_mfma)
if constexpr(((i < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
buffer_load_issue_point_b)) ||
((i >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
buffer_load_issue_point_b)))
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
}
}
else
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
}
}
});
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(stage.value == 2)
{
constexpr auto staged_num_mfma_per_buffer_load_a =
math::integer_divide_ceil(staged_num_mfma, num_buffer_load_inst_a);
constexpr auto stage_more_mfma =
staged_num_mfma - (staged_num_mfma_per_buffer_load_a - 1) * num_buffer_load_inst_a;
// A global
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i_inst) {
if constexpr(i_inst.value < stage_more_mfma)
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
}
else
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_a - 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
}
});
__builtin_amdgcn_sched_barrier(0);
}
else
{
// A local Read
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_barrier(0);
}
}
template <typename Stage>
__device__ static constexpr auto EpilogueScheduler_1(Stage stage)
{
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num;
constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
constexpr auto staged_num_mfma = num_mfma / MRepeat;
constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
if constexpr(stage.value == 0)
{
constexpr auto staged_num_buffer_load_b_per_ds_read_a =
num_buffer_load_inst_b / staged_num_ds_read_inst_a;
constexpr auto staged_num_mfma_per_buffer_load_b =
staged_num_mfma / num_buffer_load_inst_b;
// B global
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
static_for<0, staged_num_buffer_load_b_per_ds_read_a, 1>{}([&](auto ibuf_inst) {
ignore = ibuf_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(stage.value == 1)
{
#if 0
constexpr auto staged_num_ds_write_a_per_ds_read_a =
num_ds_write_inst_a / staged_num_ds_read_inst_a;
constexpr auto staged_num_mfma_per_ds_write_a = staged_num_mfma / num_ds_write_inst_a;
// A local write
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
static_for<0, staged_num_ds_write_a_per_ds_read_a, 1>{}([&](auto idswrite_inst) {
ignore = idswrite_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
});
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_ds_write_a_per_ds_read_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
#elif 1
constexpr auto staged_num_mfma_per_ds_write_a =
math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a);
constexpr auto stage_more_mfma =
staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a;
// A local write
static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) {
if constexpr(i_inst.value < stage_more_mfma)
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
}
}
else
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
}
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
#endif
__builtin_amdgcn_sched_barrier(0);
}
else
{
// A local Read
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
__builtin_amdgcn_sched_group_barrier(
0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_barrier(0);
}
}
__device__ static constexpr auto EpilogueScheduler_2()
{
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
constexpr auto staged_num_mfma = num_mfma / MRepeat;
constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
// A local Read
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
__builtin_amdgcn_sched_group_barrier(0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_barrier(0);
// A global read + A local write
static_for<0, buffer_load_a_stages, 1>{}([&](auto i) {
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
ds_write_issue_point)) ||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
ds_write_issue_point)))
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_more ==
buffer_load_issue_point_a)) ||
(((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
(imfma % buffer_load_issue_point_interval_less ==
buffer_load_issue_point_a)))
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
// lds synchronization, prefetch next loop local A
static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
});
});
#endif
}
template <bool HasMainLoop,
@@ -528,22 +364,26 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
__builtin_amdgcn_sched_barrier(0);
// // Local prefill A1
// Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
// // Global prefetch A2
// Global prefetch A2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// Local prefetch A1
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(I0, I0, I0, k0, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(I0, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, 2, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
});
// Initialize C
@@ -558,26 +398,18 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
do
{
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
if constexpr(m0.value == 0)
{
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
}
else if constexpr(m0.value == 1)
{
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf));
}
else if constexpr(m0.value == 2)
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
}
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf));
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
@@ -613,49 +445,88 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
});
});
if constexpr(m0.value == MRepeat - 1)
if constexpr(m0.value == (MRepeat - 2))
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
else if constexpr(m0.value == (MRepeat - 1))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
else
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple(
Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
HotLoopScheduler(m0);
});
HotLoopScheduler();
};
LoopFunc(I0, I1);
@@ -667,20 +538,14 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
// tail
if constexpr(TailNum == TailNumber::Even)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
if constexpr(m0.value == 0)
{
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1));
}
else if constexpr(m0.value == MRepeat - 1)
{
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
}
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1));
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
@@ -707,36 +572,72 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
});
});
if constexpr(m0.value == MRepeat - 1)
if constexpr(m0.value == (MRepeat - 2))
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
else if constexpr(m0.value == (MRepeat - 1))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
else
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
I0,
I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
}
EpilogueScheduler_1(m0);
});
HotLoopScheduler();
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
@@ -764,25 +665,31 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
});
});
if constexpr(m0.value != (MRepeat - 1))
if constexpr(m0.value < (MRepeat - 2))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<m0 + 1>{}, I0, I0, k0, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(
Number<(m0 + 1 + HotloopLocalBufSwitch) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
I0,
I0,
k0,
I0,
Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
EpilogueScheduler_2();
}
});
HotLoopScheduler();
// Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
// __builtin_amdgcn_sched_barrier(0);
}
else if constexpr(TailNum == TailNumber::Odd)
{
@@ -813,18 +720,21 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
});
});
if constexpr(m0.value != (MRepeat - 1))
if constexpr(m0.value < (MRepeat - 2))
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(Number<m0 + 1>{}, I0, I0, k0, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(
Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
a_thread_buf);
});
});
EpilogueScheduler_2();
}
});
}
@@ -841,7 +751,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
ComputeDataType,
decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
decltype(a_thread_desc_),
Sequence<1, 1, 1, 1, 1, KPack>,
Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
@@ -857,4 +767,4 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
using Base::c_thread_desc_;
};
} // namespace ck
} // namespace ck

View File

@@ -0,0 +1,100 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp"
namespace ck {
/**
* @brief Define matrix data types that have hardware support for MX GEMMs
*/
template <typename T>
static constexpr bool is_scale_mfma_data_type()
{
using U = element_type_t<T>;
return is_same_v<U, f8_ocp_t> || is_same_v<U, bf8_ocp_t> || is_same_v<U, f6_t> ||
is_same_v<U, bf6_t> || is_same_v<U, f4_t>;
}
/**
* @brief Define scale data types that have hardware support for MX GEMMs
*/
template <typename T>
static constexpr bool is_scale_mfma_scale_type()
{
return is_same_v<T, e8m0_bexp_t>;
}
/**
* @brief Combination of data types that have hardware support for MX GEMMs
*/
template <typename ADataType, typename BDataType, typename AScaleDataType, typename BScaleDataType>
static constexpr bool scale_mfma_hw_support()
{
return is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>() &&
is_scale_mfma_scale_type<AScaleDataType>() && is_scale_mfma_scale_type<BScaleDataType>();
}
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
BlockGemmPipelineScheduler BlkGemmPipeSche,
index_t ThreadBlockSize,
index_t ScaleBlockSize,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ComputeDataType, // TODO: remove this as in this pipeline ADataType and BDataType
// must be used for compute
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
{
// Hardware MX GEMM pipeline
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlkGemmPipeSche,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else
{
std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl;
}
}
} // namespace ck

View File

@@ -163,7 +163,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRun = (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
static constexpr auto ScalesPerXdlopsRun =
(APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRunPerThread =
@@ -202,9 +203,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
? HotLoopInstList::B_LDS_Read_Inst_Num
: HotLoopInstList::B_LDS_Read_Inst_Num / 2;
// constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
// constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
@@ -231,29 +229,28 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
// stage 1
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
constexpr auto num_buffer_load_total = num_buffer_load_inst_a+num_buffer_load_inst_b+num_buffer_load_a_scale+num_buffer_load_b_scale;
constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b +
num_buffer_load_a_scale + num_buffer_load_b_scale;
constexpr auto mfma_perstage_more = math::integer_divide_ceil(
num_mfma_stage1, num_buffer_load_total);
constexpr auto mfma_perstage_less = math::integer_divide_floor(
num_mfma_stage1, num_buffer_load_total);
constexpr auto mfma_perstage_more =
math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total);
constexpr auto mfma_perstage_less =
math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total);
constexpr auto mfma_stages_more =
num_mfma_stage1 -
mfma_perstage_less * num_buffer_load_total;
// constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
// constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total;
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
if constexpr(i< mfma_stages_more){
if constexpr(i < mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) {
ignore = imfma;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else{
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) {
ignore = imfma;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
@@ -263,14 +260,16 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
});
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
if constexpr((i+num_buffer_load_inst_a)< mfma_stages_more){
if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) {
ignore = imfma;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else{
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) {
ignore = imfma;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
@@ -280,14 +279,16 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
});
static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) {
if constexpr((i+num_buffer_load_inst_a+num_buffer_load_inst_b)< mfma_stages_more){
if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) < mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) {
ignore = imfma;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else{
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) {
ignore = imfma;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
@@ -297,14 +298,17 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
});
static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) {
if constexpr((i+num_buffer_load_inst_a+num_buffer_load_inst_b+num_buffer_load_a_scale)< mfma_stages_more){
if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b +
num_buffer_load_a_scale) < mfma_stages_more)
{
static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) {
ignore = imfma;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
else{
else
{
static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) {
ignore = imfma;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
@@ -463,47 +467,50 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops/APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_bufs(I0),
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_bufs(I0),
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read block data in chunks to assemble correct thread vectors
static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}([&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_bufs(I0),
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_bufs(I0),
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
});
@@ -529,8 +536,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
// __builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(scale_comp_buf));
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(scale_comp_buf));
a_blockwise_copy.Run(
a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(scale_comp_buf));
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(scale_comp_buf));
// Prefetch a_scales
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
@@ -613,10 +622,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
vector_type<ComputeTypeA, KPack>
a_thread_vec;
vector_type<ComputeTypeB, KPack>
b_thread_vec;
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(
@@ -682,52 +689,54 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
// __builtin_amdgcn_s_waitcnt(3952);
// block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops/APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk *
xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_bufs(scale_mem_buf),
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
static_for<0,
xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk),
1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_bufs(scale_mem_buf),
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read block data in chunks to assemble correct thread vectors
static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step + chunk * KThreadChunk *
xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_bufs(scale_mem_buf),
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
static_for<0,
xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk),
1>{}([&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_bufs(scale_mem_buf),
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
});
@@ -859,50 +868,54 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx<BlockGemmPipelineScheduler::Intrawave,
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops/APackedSize * (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
(APackedSize * KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_bufs(I1),
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_bufs(I1),
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read block data in chunks to assemble correct thread vectors
static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}([&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_bufs(I1),
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
[&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_bufs(I1),
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
});

View File

@@ -0,0 +1,637 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/host_utility/flush_cache.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_mx.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// clang-format off
/**
* \brief WIP: Implements XDL CShuffle V3 GEMM for microscale-compliant data types
*
* This class is a work-in-progress implementation of the XDL CShuffle V3 GEMM for
* microscale-compliant data types.
*
* Assumptions:
* - A and B data types are compliant with the OCP Microscaling Formats (MX) Specification
* - Each scale applies to ScaleBlockSize elements in K direction
* - A scale matrix is a row-major
* - B scale matrix is a column-major
* - Scale data types must have get_exponent_value() specialization, whereas lowest 8 bits of the
* exponent will be interpreted as conventional biased Float32 exponent (E8M0)
*
* Tunable parameters.
* The CK instance includes a series of tunable template parameters to control the parallel
* granularity of the workload to achieve load balancing on different hardware platforms. These
* parameters include Block Size, M/N/K Per Block, M/N per XDL, AK1, BK1, etc.
* - Block Size determines the number of threads in the thread block.
* - M/N/K Per Block determines the size of tile that each thread block is responsible for
* calculating.
* - M/N Per XDL refers to M/N size for Instinct accelerator Matrix Fused Multiply Add (MFMA)
* instructions operating on a per-wavefront basis.
* - A/B K1 is related to the data type. It can be any value ranging from 1 to K Per Block. To
* achieve the optimal load/store performance, 128bit per load is suggested. In addition, the A/B
* loading parameters must be changed accordingly to match the A/B K1 value; otherwise, it will
* result in compilation errors.
*
* Conditions for achieving computational load balancing on different hardware platforms can vary.
*
* Serialized version of the algorithm:
* \code
* // E = A * B + C
* // Loop over E[MPerBlock,NPerBlock] tiles
* for(int mb = 0; mb < M; mb += MPerBlock){
* for(int nb = 0; nb < N; nb += NPerBlock){
* // initialize E[MPerBlock,NPerBlock] tile
* for(int mt = mb; mt < mb + MPerBlock; mt++){
* for(int nt = nb; nt < nb + NPerBlock; nt++){
* E[mt,nt] = C[mt,nt];
* }
* }
*
* // multiply-accumulate per tile
* for(int kb = 0; kb < K; kb += KPerBlock){
* for(int m0 = mb; m0 < mb + MPerBlock; m0 += MWaves * MPerXDL){
* for(int n0 = nb; n0 < nb + NPerBlock; n0 += NWaves * NPerXDL){
* for(int mw = m0; mw < m0 + MWaves * MPerXDL; mw += MPerXDL){
* for(int nw = n0; nw < n0 + NWaves * NPerXDL; nw += NPerXDL){
* for(int k0 = kb; k0 < kb + KPerBlock; k0 += mfma.num_input_blks*KPack){
* // MFMA accumulation
* for(int k_pack = k0; k_pack < k0 + mfma.num_input_blks*KPack; k_pack += KPerXdlops){
* // MFMA instruction
* for(int k_mfma = k_pack; k_mfma < k_pack + KPerXdlops; k_mfma += mfma.k_per_blk){
* for(int m = mw; m < mw + MPerXDL; m++){
* for(int n = nw; n < nw + NPerXDL; n++){
* for(int k = k_mfma; k < k_mfma + mfma.k_per_blk; k++){
* E[m,n] += A[m,k] * B[k,n];
* }
* }
* }
* }
* }
* }
* }
* }
* }
* }
* }
* }
* }
* \endcode
*
*/
// clang-format on
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename CDataType,
typename GemmAccDataType, // TODO: always float
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t ScaleBlockSize, // Scaling block size
index_t BlockSize, // Thread block size
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA =
ADataType, // XXX: These should always be the same as ADataType and BDataType
typename ComputeTypeB =
BDataType // TODO: Hardcode them and remove from the list of template parameters
>
struct DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmMX<ALayout,
BLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
CDataType,
ScaleBlockSize,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle<
ALayout,
BLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
ScaleBlockSize,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
using Argument = typename GridwiseGemm::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
auto size_a_buffer =
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
auto size_b_buffer =
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg_);
}
else
{
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
};
// TODO: Check if this is the right algorithm for minimum_occupancy
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
? 2
: 1
: 2;
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
#if 0
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
#endif
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
else
{
throw std::runtime_error("wrong! BlkGemmPipelineVer");
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3){
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
static_assert(is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>(),
"Only microscaling formats are supported for ADataType and BDataType");
static_assert(ScaleBlockSize == 32, "Only ScaleBlockSize 32 is supported");
static_assert(is_same_v<ComputeTypeA, ADataType> && is_same_v<ComputeTypeB, BDataType>,
"ComputeTypeA and ComputeTypeB must be the same as ADataType and BDataType");
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if constexpr(!IsValidCompilationParameter())
{
return false;
}
if(!ck::is_xdl_supported())
{
return false;
}
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const AScaleDataType* p_a_scale,
const BDataType* p_b,
const BScaleDataType* p_b_scale,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideScaleA,
index_t StrideB,
index_t StrideScaleB,
index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a,
p_a_scale,
p_b,
p_b_scale,
p_c,
M,
N,
K,
StrideA,
StrideScaleA,
StrideB,
StrideScaleB,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_a_scale,
const void* p_b,
const void* p_b_scale,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideScaleA,
ck::index_t StrideB,
ck::index_t StrideScaleB,
ck::index_t StrideC,
ck::index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BDataType*>(p_b),
static_cast<const BScaleDataType*>(p_b_scale),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideScaleA,
StrideB,
StrideScaleB,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGemmMX_Xdl_CShuffleV3"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerXDL<<"x"<<NPerXDL << ", "
<< "WaveMap: "
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
<< "Kpack: "
<< GridwiseGemm::BlockwiseGemmPipe::AMmaKStride << ", "
<< "ScaleBlockSize: "
<< ScaleBlockSize;
// clang-format on
return str.str();
}
REGISTER_EXTRA_PRINTING_METHODS
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -200,7 +200,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
NPerXdl,
ComputeTypeB,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk/APackedSize);
is_scale_mfma>::selected_mfma.k_per_blk /
APackedSize);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
@@ -270,13 +271,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
constexpr auto permuted_desc = transform_tensor_descriptor(
constexpr auto permuted_desc = transform_tensor_descriptor(
TileDesc_K0_MN_K1{},
make_tuple(make_xor_with_modulo_transform(make_tuple(Number<MN>{}, Number<K0>{})),
make_pass_through_transform(Number<K1>{})),
make_pass_through_transform(Number<K1>{})),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
return transform_tensor_descriptor(
permuted_desc,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
@@ -361,24 +362,25 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// not pad M or K
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_mraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(K/KPerBlock, AK0Number, AK1Value)),
make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
const auto a_grid_desc_permuted = transform_tensor_descriptor(
a_grid_desc_ak0_m_ak1,
make_tuple(make_pass_through_transform(K/KPerBlock),
make_tuple(make_pass_through_transform(K / KPerBlock),
make_xor_with_modulo_transform(make_tuple(M, AK0Number)),
make_pass_through_transform(AK1Value)),
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}));
const auto a_grid_desc = transform_tensor_descriptor(
a_grid_desc_permuted,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K/KPerBlock, AK0Number)),
make_pass_through_transform(M),
make_pass_through_transform(AK1Value)),
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, AK0Number)),
make_pass_through_transform(M),
make_pass_through_transform(AK1Value)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
@@ -467,25 +469,27 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
{
// not pad N or K
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_nraw_kraw,
make_tuple(make_unmerge_transform(make_tuple(K/KPerBlock, BK0Number, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
b_grid_desc_nraw_kraw,
make_tuple(
make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
const auto b_grid_desc_permuted = transform_tensor_descriptor(
b_grid_desc_bk0_n_bk1,
make_tuple(make_pass_through_transform(K/KPerBlock),
make_tuple(make_pass_through_transform(K / KPerBlock),
make_xor_with_modulo_transform(make_tuple(N, BK0Number)),
make_pass_through_transform(BK1Value)),
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}));
const auto b_grid_desc = transform_tensor_descriptor(
b_grid_desc_permuted,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K/KPerBlock, BK0Number)),
make_pass_through_transform(N),
make_pass_through_transform(BK1Value)),
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, BK0Number)),
make_pass_through_transform(N),
make_pass_through_transform(BK1Value)),
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
@@ -690,10 +694,10 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
bool is_reduce_ = false)
: Problem{M_,
N_,
K_/APackedSize,
StrideA_/APackedSize,
K_ / APackedSize,
StrideA_ / APackedSize,
StrideScaleA_,
StrideB_/BPackedSize,
StrideB_ / BPackedSize,
StrideScaleB_,
StrideC_,
k_batch_},
@@ -765,21 +769,23 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// Calculate A scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize/APackedSize);
a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize);
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize/APackedSize) * karg.StrideScaleA;
a_scale_k_split_offset =
k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA;
}
// Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_scale_k_split_offset = k_id * (karg.KRead / (ScaleBlockSize/BPackedSize)) * karg.StrideScaleB;
b_scale_k_split_offset =
k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize/BPackedSize);
b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize);
}
if(k_id < (karg.KBatch - 1))
@@ -817,7 +823,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
return make_naive_tensor_descriptor(
make_tuple(Number<AK0Number>{}, Number<MPerBlock>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock>{}, I1));
}
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// in some cases.
@@ -1119,7 +1124,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
static_assert(KPerBlock % (ScaleBlockSize/BPackedSize) == 0,
static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
"KPerBlock should be multiple of ScaleBlockSize");
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
@@ -1444,7 +1449,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
@@ -1470,12 +1474,11 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// Cast after lds
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared),
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
sizeof(ADataType)),
reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
a_block_space_size_aligned * sizeof(ADataType)),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
@@ -1821,15 +1824,17 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// A/B shuffled scale for better 8-bit scale access pattern
// MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(make_tuple(
problem.M / (MXdlPack * MPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize/APackedSize)) / (KXdlPack * 64 / MPerXdl),
64 * KXdlPack * MXdlPack / scale_pack_size_a));
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
make_tuple(problem.M / (MXdlPack * MPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
(KXdlPack * 64 / MPerXdl),
64 * KXdlPack * MXdlPack / scale_pack_size_a));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(make_tuple(
problem.N / (NXdlPack * NPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize/BPackedSize)) / (KXdlPack * 64 / NPerXdl),
64 * KXdlPack * NXdlPack / scale_pack_size_b));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
make_tuple(problem.N / (NXdlPack * NPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
(KXdlPack * 64 / NPerXdl),
64 * KXdlPack * NXdlPack / scale_pack_size_b));
Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(a_scale_grid_desc_am_ak),
@@ -1945,7 +1950,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0));
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
@@ -2312,13 +2316,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run_2Lds(const ADataType* p_a_grid,
const AScaleDataType* p_a_scale_grid,
const BDataType* p_b_grid,
const BScaleDataType* p_b_scale_grid,
CDataType* p_c_grid,
void* p_shared_0,
void* p_shared_1,
const Problem& problem)
const AScaleDataType* p_a_scale_grid,
const BDataType* p_b_grid,
const BScaleDataType* p_b_scale_grid,
CDataType* p_c_grid,
void* p_shared_0,
void* p_shared_1,
const Problem& problem)
{
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
@@ -2332,36 +2336,38 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// A/B shuffled scale for better 8-bit scale access pattern
// MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(make_tuple(
problem.M / (MXdlPack * MPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize/APackedSize)) / (KXdlPack * 64 / MPerXdl),
64 * KXdlPack * MXdlPack / scale_pack_size_a));
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
make_tuple(problem.M / (MXdlPack * MPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
(KXdlPack * 64 / MPerXdl),
64 * KXdlPack * MXdlPack / scale_pack_size_a));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(make_tuple(
problem.N / (NXdlPack * NPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize/BPackedSize)) / (KXdlPack * 64 / NPerXdl),
64 * KXdlPack * NXdlPack / scale_pack_size_b));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
make_tuple(problem.N / (NXdlPack * NPerXdl),
math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
(KXdlPack * 64 / NPerXdl),
64 * KXdlPack * NXdlPack / scale_pack_size_b));
Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
decltype(a_scale_grid_desc_am_ak),
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_scale_grid_desc_bn_ak),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(p_a_grid,
p_a_scale_grid,
p_b_grid,
p_b_scale_grid,
p_c_grid,
p_shared_0,
p_shared_1,
problem,
a_grid_desc_ak0_m_ak1,
a_scale_grid_desc_am_ak,
b_grid_desc_bk0_n_bk1,
b_scale_grid_desc_bn_ak,
c_grid_desc_mblock_mperblock_nblock_nperblock);
decltype(a_scale_grid_desc_am_ak),
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_scale_grid_desc_bn_ak),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(p_a_grid,
p_a_scale_grid,
p_b_grid,
p_b_scale_grid,
p_c_grid,
p_shared_0,
p_shared_1,
problem,
a_grid_desc_ak0_m_ak1,
a_scale_grid_desc_am_ak,
b_grid_desc_bk0_n_bk1,
b_scale_grid_desc_bn_ak,
c_grid_desc_mblock_mperblock_nblock_nperblock);
}
};

View File

@@ -1260,7 +1260,7 @@ struct ThreadwiseTensorSliceTransfer_v4
}
});
}
#if 0
// Fuse scale
template <typename SrcRefToOriginDisplacement,
typename DstOriginIdx,
@@ -1460,7 +1460,7 @@ struct ThreadwiseTensorSliceTransfer_v4
}
});
}
#endif
template <typename SrcSliceMoveStepIdx>
__device__ void MoveSrcSliceWindow(const SrcDesc&,
const SrcSliceMoveStepIdx& src_slice_move_step_idx)

View File

@@ -1020,7 +1020,7 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t src_element_space_size)
{
// Direct loads require that each thread reads and writes exactly a single DWORD.
constexpr auto dword_bytes = 4;
// constexpr auto dword_bytes = 4;
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
// static_assert(bytes_per_thread == dword_bytes);