Merge commit 'a7da3c68b979bd46c315da09208271d26f5e2900' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-01 23:11:22 +00:00
parent 24ac4febf4
commit ba74f76cbc
17 changed files with 952 additions and 76 deletions

View File

@@ -76,6 +76,7 @@ struct GemmConfigMemory : public GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile = 8;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool Persistent = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
@@ -116,6 +117,7 @@ struct GemmConfigV4 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool Persistent = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;

View File

@@ -182,9 +182,9 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc,
<< std::endl;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 /* + 256 * i */);
Ns.push_back(256 /* + 512 * i */);
Ks.push_back(64 /* + 384 * i */);
Ms.push_back(256 + 256 * i);
Ns.push_back(256 + 512 * i);
Ks.push_back(512 + 384 * i);
stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);
@@ -256,8 +256,8 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc,
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
ck_tile::FillUniformDistribution<D0DataType>{2.f, -2.f}(d0_m_n_tensors[i]);
ck_tile::FillUniformDistribution<D1DataType>{2.f, -2.f}(d1_m_n_tensors[i]);
ck_tile::FillUniformDistribution<D0DataType>{-1.f, 1.f}(d0_m_n_tensors[i]);
ck_tile::FillUniformDistribution<D1DataType>{-1.f, 1.f}(d1_m_n_tensors[i]);
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(a_m_k_tensors[i]));