Grouped conv bwd weight with grouped gemm (#2304)

* Grouped conv bwd weight with grouped gemm

* fixes

* fix

* Fixes

* test comments

* restore atol

* fix
This commit is contained in:
Bartłomiej Kocot
2025-06-12 10:15:07 +02:00
committed by GitHub
parent 8aff45a8af
commit bb4f471b09
5 changed files with 242 additions and 153 deletions

View File

@@ -271,6 +271,7 @@ struct BlockToCTileMap_Grouped_M00_N0_M01Adapt
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt() = default;
__host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M,
index_t N,
index_t M01 = 8)
@@ -870,6 +871,7 @@ struct OffsettedBlockToCTileMap
{
using underlying_type = UnderlyingBlockToCTileMap;
__host__ __device__ OffsettedBlockToCTileMap() = default;
__host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map,
index_t block_start)
{