mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Ck tile gemm padding dim (#1516)
* Support the N dimension padding * Finished the padding feature for different dimension of K
This commit is contained in:
@@ -123,14 +123,26 @@ struct GemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
auto ABlockWindow = make_tile_window(
|
||||
auto a_pad_view = pad_tensor_view(
|
||||
a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
|
||||
sequence < 0,
|
||||
GemmPipeline::kPadA ? 1 : 0 > {});
|
||||
|
||||
auto ABlockWindow = make_tile_window(
|
||||
a_pad_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
|
||||
{i_m, 0});
|
||||
|
||||
auto BBlockWindow = make_tile_window(
|
||||
auto b_pad_view = pad_tensor_view(
|
||||
b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
|
||||
sequence < 0,
|
||||
GemmPipeline::kPadB ? 1 : 0 > {});
|
||||
|
||||
auto BBlockWindow = make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
|
||||
{i_n, 0});
|
||||
|
||||
// allocate LDS
|
||||
@@ -163,12 +175,16 @@ struct GemmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
auto CBlockWindow = make_tile_window(
|
||||
auto c_pad_view = pad_tensor_view(
|
||||
c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
|
||||
sequence < 0,
|
||||
GemmPipeline::kPadC ? 1 : 0 > {});
|
||||
auto CBlockWindow_pad = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
|
||||
{i_m, i_n});
|
||||
// epilogue.
|
||||
EpiloguePipeline{}(CBlockWindow, acc);
|
||||
EpiloguePipeline{}(CBlockWindow_pad, acc);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -29,6 +29,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
|
||||
static constexpr index_t AlignmentB = Problem::AlignmentB;
|
||||
static constexpr index_t AlignmentC = Problem::AlignmentC;
|
||||
|
||||
static constexpr bool kPadA = Problem::kPadA;
|
||||
static constexpr bool kPadB = Problem::kPadB;
|
||||
static constexpr bool kPadC = Problem::kPadC;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
|
||||
{
|
||||
return ck_tile::integer_divide_ceil(
|
||||
|
||||
@@ -28,9 +28,9 @@ struct BlockGemmPipelineProblem
|
||||
static constexpr bool kPadB = kPadB_;
|
||||
static constexpr bool kPadC = kPadC_;
|
||||
|
||||
static constexpr index_t AlignmentA = kPadA ? VectorLoadSize / sizeof(ADataType) : 1;
|
||||
static constexpr index_t AlignmentB = kPadB ? VectorLoadSize / sizeof(BDataType) : 1;
|
||||
static constexpr index_t AlignmentC = kPadC ? VectorLoadSize / sizeof(CDataType) : 1;
|
||||
static constexpr index_t AlignmentA = kPadA ? 1 : VectorLoadSize / sizeof(ADataType);
|
||||
static constexpr index_t AlignmentB = kPadB ? 1 : VectorLoadSize / sizeof(BDataType);
|
||||
static constexpr index_t AlignmentC = kPadC ? 1 : VectorLoadSize / sizeof(CDataType);
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user