diff --git a/python/mscclpp/language/program.py b/python/mscclpp/language/program.py index 1baf0de2..681a2fe3 100644 --- a/python/mscclpp/language/program.py +++ b/python/mscclpp/language/program.py @@ -431,6 +431,12 @@ class Ref(ChunkRef): def chunk(rank, buffer, index, size=1) -> Ref: + if buffer is Buffer.scratch: + if buffer not in _curr().buffers[rank]: + _curr().buffers[rank][buffer] = BufferSlice(Buffer.scratch, buffer) + if index >= len(_curr().buffers[rank][buffer]): + _curr().buffers[rank][buffer][index] = ChunkRef(rank, buffer, index, size) + if _curr().buffers[rank][buffer][index] is None: return None return _curr().get_ref(rank, buffer, index, size)