Fix xor transform dim.

This commit is contained in:
mhYang
2025-03-21 15:00:05 +00:00
parent 1f604e9b0a
commit 8f3b534d29

View File

@@ -27,7 +27,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
constexpr index_t kKPack = 8;
#if BANK_CONFLICT_K_FIRST
#pragma message ("BANK_CONFLICT: K_FIRST")
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
@@ -42,7 +41,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1>{}));
#elif PADDING_K_FIRST
#pragma message ("BANK_CONFLICT: PADDING_K_FIRST")
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
@@ -57,7 +55,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1>{}));
#elif PADDING_MN_FIRST
#pragma message ("BANK_CONFLICT: PADDING_MN_FIRST")
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
@@ -72,7 +69,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1>{}));
#elif XOR
#pragma message ("BANK_CONFLICT: XOR")
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr auto DataTypeSize = sizeof(ADataType);
@@ -98,9 +94,9 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(
make_tuple(number<kKPerBlock / kKPack>{}, number<MLdsLayer>{})),
make_pass_through_transform(number<kMPerBlock / MLdsLayer>{}),
make_pass_through_transform(number<kKPack>{})),
make_tuple(number<MLdsLayer>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kMPerBlock / MLdsLayer>{}),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
@@ -110,7 +106,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(number<kMPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#endif
@@ -126,7 +122,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
constexpr index_t kKPack = 8;
#if BANK_CONFLICT_K_FIRST
#pragma message ("BANK_CONFLICT: K_FIRST")
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
@@ -141,7 +136,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1>{}));
#elif PADDING_K_FIRST
#pragma message ("BANK_CONFLICT: PADDING_K_FIRST")
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
@@ -156,7 +150,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1>{}));
#elif PADDING_MN_FIRST
#pragma message ("BANK_CONFLICT: PADDING_MN_FIRST")
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kNPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
@@ -171,7 +164,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(sequence<0>{}, sequence<1>{}));
#elif XOR
#pragma message ("BANK_CONFLICT: XOR")
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr auto DataTypeSize = sizeof(BDataType);
@@ -197,9 +189,9 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(
make_tuple(number<kKPerBlock / kKPack>{}, number<NLdsLayer>{})),
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
make_pass_through_transform(number<kKPack>{})),
make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
@@ -209,7 +201,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#endif