mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
Ck tile gemm padding dim (#1516)
* Support the N dimension padding
* Finished the padding feature for different dimension of K
[ROCm/composable_kernel commit: 694c300145]
This commit is contained in:
@@ -179,9 +179,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf,
|
||||
|
||||
std::cout << "The overall perfomance of the GEMM with "
|
||||
<< "[" << data_type << "]"
|
||||
<< "batch size: " << batch_size << ". m:" << M << ",n:" << N << ", k:" << K
|
||||
<< "is: \n";
|
||||
std::cout << "Running time :" << ave_time << "ms, Throughput" << gb_per_sec << "GB/s \n"
|
||||
<< "batch size: " << batch_size << ". m:" << M << ", n:" << N << ", k:" << K
|
||||
<< " is: \n";
|
||||
std::cout << "Running time: " << ave_time << "ms, Throughput " << gb_per_sec << "GB/s \n"
|
||||
<< std::flush;
|
||||
|
||||
return ave_time;
|
||||
@@ -235,7 +235,7 @@ int main(int argc, char* argv[])
|
||||
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadA = true;
|
||||
constexpr bool kPadB = true;
|
||||
constexpr bool kPadC = false;
|
||||
constexpr bool kPadC = true;
|
||||
|
||||
// This part comes from the Codegen
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
@@ -348,7 +348,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
pass_gpu = ck_tile::check_err(c_host_dev, c_host_gpu_ref);
|
||||
|
||||
std::cout << "The GPU veification result is:" << (pass_gpu ? "correct" : "fail")
|
||||
std::cout << "The GPU veification result is: " << (pass_gpu ? "correct" : "fail")
|
||||
<< std::flush;
|
||||
}
|
||||
|
||||
|
||||
@@ -123,14 +123,26 @@ struct GemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
auto ABlockWindow = make_tile_window(
|
||||
auto a_pad_view = pad_tensor_view(
|
||||
a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
|
||||
sequence < 0,
|
||||
GemmPipeline::kPadA ? 1 : 0 > {});
|
||||
|
||||
auto ABlockWindow = make_tile_window(
|
||||
a_pad_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
|
||||
{i_m, 0});
|
||||
|
||||
auto BBlockWindow = make_tile_window(
|
||||
auto b_pad_view = pad_tensor_view(
|
||||
b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
|
||||
sequence < 0,
|
||||
GemmPipeline::kPadB ? 1 : 0 > {});
|
||||
|
||||
auto BBlockWindow = make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
|
||||
{i_n, 0});
|
||||
|
||||
// allocate LDS
|
||||
@@ -163,12 +175,16 @@ struct GemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
auto CBlockWindow = make_tile_window(
|
||||
auto c_pad_view = pad_tensor_view(
|
||||
c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
|
||||
sequence < 0,
|
||||
GemmPipeline::kPadC ? 1 : 0 > {});
|
||||
auto CBlockWindow_pad = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
|
||||
{i_m, i_n});
|
||||
// epilogue.
|
||||
EpiloguePipeline{}(CBlockWindow, acc);
|
||||
EpiloguePipeline{}(CBlockWindow_pad, acc);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -29,6 +29,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
|
||||
static constexpr index_t AlignmentB = Problem::AlignmentB;
|
||||
static constexpr index_t AlignmentC = Problem::AlignmentC;
|
||||
|
||||
static constexpr bool kPadA = Problem::kPadA;
|
||||
static constexpr bool kPadB = Problem::kPadB;
|
||||
static constexpr bool kPadC = Problem::kPadC;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
|
||||
{
|
||||
return ck_tile::integer_divide_ceil(
|
||||
|
||||
@@ -28,9 +28,9 @@ struct BlockGemmPipelineProblem
|
||||
static constexpr bool kPadB = kPadB_;
|
||||
static constexpr bool kPadC = kPadC_;
|
||||
|
||||
static constexpr index_t AlignmentA = kPadA ? VectorLoadSize / sizeof(ADataType) : 1;
|
||||
static constexpr index_t AlignmentB = kPadB ? VectorLoadSize / sizeof(BDataType) : 1;
|
||||
static constexpr index_t AlignmentC = kPadC ? VectorLoadSize / sizeof(CDataType) : 1;
|
||||
static constexpr index_t AlignmentA = kPadA ? 1 : VectorLoadSize / sizeof(ADataType);
|
||||
static constexpr index_t AlignmentB = kPadB ? 1 : VectorLoadSize / sizeof(BDataType);
|
||||
static constexpr index_t AlignmentC = kPadC ? 1 : VectorLoadSize / sizeof(CDataType);
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user