diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp index 84e47483de..bafff4447d 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v1.hpp @@ -72,7 +72,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; - static constexpr index_t kLdsAlignmentInBytes = Problem::VectorLoadSize / sizeof(ADataType); + static constexpr index_t kLdsAlignmentInBytes = 16; static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; static constexpr auto I0 = number<0>(); diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index 2cb1e22ea0..692e4b4218 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -81,7 +81,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadK = Problem::kPadK; - static constexpr index_t kLdsAlignmentInBytes = Problem::VectorLoadSize / sizeof(ADataType); + static constexpr index_t kLdsAlignmentInBytes = 16; static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; static constexpr auto I0 = number<0>(); @@ -107,7 +107,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; - static constexpr index_t K1 = 16 / sizeof(ADataType); + static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1; static constexpr auto TailNum = Problem::TailNum; diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v3.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v3.hpp index 9ac45b9a9a..b8df644b34 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v3.hpp @@ -107,7 +107,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV3 static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; - static constexpr index_t K1 = VectorLoadSize / sizeof(ADataType); + static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1; static constexpr index_t ACopyLoadNumPerK = ACopyLoadNum / KIterPerWarp; static constexpr index_t ACopyPerLoadM = kMPerBlock / ACopyLoadNum;