mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
fix tail
This commit is contained in:
@@ -66,32 +66,36 @@ struct MultiplyMultiply
|
||||
void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) {
|
||||
const int NRepeat = 1;
|
||||
const int KRepeat = 4;
|
||||
const int NWave = 4;
|
||||
const int KLane = 2;
|
||||
const int NLane = 128;
|
||||
const int NLane = 32;
|
||||
const int KPack = 16;
|
||||
int N0 = N / (NRepeat * NLane);
|
||||
int N0 = N / (NRepeat * NLane * NWave);
|
||||
int K0 = K / (KRepeat * KLane * KPack);
|
||||
|
||||
int tempn, tempk;
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int k = 0; k < K; ++k) {
|
||||
int n0 = n / (NRepeat * NLane);
|
||||
int n0 = n / (NRepeat * NLane * NWave);
|
||||
int k0 = k / (KRepeat * KLane * KPack);
|
||||
tempn = n % (NRepeat * NLane);
|
||||
tempn = n % (NRepeat * NLane * NWave);
|
||||
tempk = k % (KRepeat * KLane * KPack);
|
||||
int n1 = tempn / NLane;
|
||||
int n1 = tempn / (NLane * NWave);
|
||||
int k1 = tempk / (KLane * KPack);
|
||||
int n2 = n1 % NLane;
|
||||
tempn = tempn % (NLane * NWave);
|
||||
tempk = tempk % (KLane * KPack);
|
||||
int n2 = tempn / NLane;
|
||||
int k2 = tempk / KPack;
|
||||
int n3 = tempn % NLane;
|
||||
int k3 = tempk % KPack;
|
||||
|
||||
int outputIndex = n0 * KPack * NLane * KLane * KRepeat * NRepeat * K0
|
||||
+ k0 * KPack * NLane * KLane * KRepeat * NRepeat
|
||||
+ n1 * KPack * NLane * KLane * KRepeat
|
||||
+ k1 * KPack * NLane * KLane
|
||||
int outputIndex = n0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat * K0
|
||||
+ k0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat
|
||||
+ n1 * KPack * NLane * KLane * NWave * KRepeat
|
||||
+ k1 * KPack * NLane * KLane * NWave
|
||||
+ n2 * KPack * NLane * KLane
|
||||
+ k2 * KPack * NLane
|
||||
+ n2 * KPack
|
||||
+ n3 * KPack
|
||||
+ k3;
|
||||
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
@@ -269,7 +273,7 @@ int main(int argc, char* argv[])
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50});
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 1});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
|
||||
@@ -357,7 +357,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
|
||||
// if(threadIdx.x==0) {
|
||||
// printf("%f, %f; ", type_convert<float>(a_thread_vec.template AsType<ComputeDataType>()(ik)), ype_convert<float>(b_thread_vec.template AsType<ComputeDataType>()(ik)));
|
||||
// printf("%f, %f; ", type_convert<float>(a_thread_vec.template AsType<ComputeDataType>()(ik)), type_convert<float>(b_thread_vec.template AsType<ComputeDataType>()(ik)));
|
||||
// }
|
||||
});
|
||||
|
||||
@@ -451,6 +451,11 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
{
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf1);
|
||||
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<1>{});
|
||||
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
@@ -462,6 +467,48 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
xdlops_gemm.Run(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
|
||||
a_block_buf1,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec =
|
||||
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<1>{}>();
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
Reference in New Issue
Block a user