mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Post-merge fix of PR 1300 (#1313)
* add f8 gemm with multiD for both row/col wise * change compute_type to fp8 * changed tuning parameters in the example * add rcr example * post-merge fix * fix * reduce init range
This commit is contained in:
@@ -146,7 +146,7 @@ template <typename ALayout,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
typename LDSTypeA = ADataType,
|
||||
typename LDSTypeB = BDataType>
|
||||
struct GridwiseGemm_xdl_cshuffle_v3
|
||||
struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -690,8 +690,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(Number<MPerBlock / MLdsLayer>{},
|
||||
Number<AK0Number * MLdsLayer>{})),
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(
|
||||
Number<MPerBlock / MLdsLayer>{}, Number<AK0Number * MLdsLayer>{})),
|
||||
make_pass_through_transform(AK1Number)),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
|
||||
@@ -756,7 +756,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(Number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(
|
||||
make_xor_with_modulo_transform(
|
||||
make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
|
||||
make_pass_through_transform(Number<mpair>{}),
|
||||
make_pass_through_transform(AK1Number)),
|
||||
@@ -827,8 +827,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc,
|
||||
make_tuple(make_xor_transform(make_tuple(Number<NPerBlock / NLdsLayer>{},
|
||||
Number<BK0Number * NLdsLayer>{})),
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(
|
||||
Number<NPerBlock / NLdsLayer>{}, Number<BK0Number * NLdsLayer>{})),
|
||||
make_pass_through_transform(BK1Number)),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
|
||||
@@ -890,7 +890,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(Number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(
|
||||
make_xor_with_modulo_transform(
|
||||
make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
|
||||
make_pass_through_transform(Number<npair>{}),
|
||||
make_pass_through_transform(BK1Number)),
|
||||
|
||||
Reference in New Issue
Block a user