mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Padding support for wave transfer (#3537)
* Add padding support with transpose Also move check before writing storing is_src_valid during reading * Add/modify instances to use wave transfer for gemm universal Condition is changed so now the vectorsize of vmem reading and lds writing must be equal to 8 in order to use the wave transfer * Fix clang format * Modify example * Fix bwd data * Add restriction for wave transfer with padding and transpose Add test case which shows this limitation * Fix validity checks 8 bit types * Add validity check gemm_bias_add_reduce * Add validity check grouped gemm tile loop * Fix validity checks new flavours * Minor fixes * Fix clang format
This commit is contained in:
@@ -19,22 +19,22 @@ using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
PassThrough, PassThrough, PassThrough, GemmDefault,
|
||||
PassThrough, PassThrough, PassThrough, GemmSpec,
|
||||
256,
|
||||
128, 256, 64,
|
||||
8, 8,
|
||||
16, 16,
|
||||
2, 8,
|
||||
S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 8, 8, 1,
|
||||
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 1, 8, 1,
|
||||
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>,
|
||||
1, 1, 8, 1,
|
||||
1, 8, 8, 1,
|
||||
1, 1,
|
||||
S<1, 64, 1, 4>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>;
|
||||
|
||||
Reference in New Issue
Block a user