HIP: Heterogenous-computing Interface for Portability
Loading...
Searching...
No Matches
amd_hip_cooperative_groups.h
1/*
2Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved.
3
4Permission is hereby granted, free of charge, to any person obtaining a copy
5of this software and associated documentation files (the "Software"), to deal
6in the Software without restriction, including without limitation the rights
7to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8copies of the Software, and to permit persons to whom the Software is
9furnished to do so, subject to the following conditions:
10
11The above copyright notice and this permission notice shall be included in
12all copies or substantial portions of the Software.
13
14THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20THE SOFTWARE.
21*/
22
32#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
33#define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
34
35#if __cplusplus
36#if !defined(__HIPCC_RTC__)
38#endif
39
40namespace cooperative_groups {
41
50class thread_group {
51 protected:
52 uint32_t _type; // thread_group type
53 uint32_t _size; // total number of threads in the tread_group
54 uint64_t _mask; // Lanemask for coalesced and tiled partitioned group types,
55 // LSB represents lane 0, and MSB represents lane 63
56
57 // Construct a thread group, and set thread group type and other essential
58 // thread group properties. This generic thread group is directly constructed
59 // only when the group is supposed to contain only the calling the thread
60 // (throurh the API - `this_thread()`), and in all other cases, this thread
61 // group object is a sub-object of some other derived thread group object
62 __CG_QUALIFIER__ thread_group(internal::group_type type, uint32_t size = static_cast<uint64_t>(0),
63 uint64_t mask = static_cast<uint64_t>(0)) {
64 _type = type;
65 _size = size;
66 _mask = mask;
67 }
68
69 struct _tiled_info {
70 bool is_tiled;
71 unsigned int size;
72 unsigned int meta_group_rank;
73 unsigned int meta_group_size;
74 };
75
76 struct _coalesced_info {
77 lane_mask member_mask;
78 unsigned int size;
79 struct _tiled_info tiled_info;
80 } coalesced_info;
81
82 friend __CG_QUALIFIER__ thread_group tiled_partition(const thread_group& parent,
83 unsigned int tile_size);
84 friend class thread_block;
85
86 public:
87 // Total number of threads in the thread group, and this serves the purpose
88 // for all derived cooperative group types since their `size` is directly
89 // saved during the construction
90 __CG_QUALIFIER__ uint32_t size() const { return _size; }
91 __CG_QUALIFIER__ unsigned int cg_type() const { return _type; }
92 // Rank of the calling thread within [0, size())
93 __CG_QUALIFIER__ uint32_t thread_rank() const;
94 // Is this cooperative group type valid?
95 __CG_QUALIFIER__ bool is_valid() const;
96 // synchronize the threads in the thread group
97 __CG_QUALIFIER__ void sync() const;
98};
122class multi_grid_group : public thread_group {
123 // Only these friend functions are allowed to construct an object of this class
124 // and access its resources
125 friend __CG_QUALIFIER__ multi_grid_group this_multi_grid();
126
127 protected:
128 // Construct mutli-grid thread group (through the API this_multi_grid())
129 explicit __CG_QUALIFIER__ multi_grid_group(uint32_t size)
130 : thread_group(internal::cg_multi_grid, size) {}
131
132 public:
133 // Number of invocations participating in this multi-grid group. In other
134 // words, the number of GPUs
135 __CG_QUALIFIER__ uint32_t num_grids() { return internal::multi_grid::num_grids(); }
136 // Rank of this invocation. In other words, an ID number within the range
137 // [0, num_grids()) of the GPU, this kernel is running on
138 __CG_QUALIFIER__ uint32_t grid_rank() { return internal::multi_grid::grid_rank(); }
139 __CG_QUALIFIER__ uint32_t thread_rank() const { return internal::multi_grid::thread_rank(); }
140 __CG_QUALIFIER__ bool is_valid() const { return internal::multi_grid::is_valid(); }
141 __CG_QUALIFIER__ void sync() const { internal::multi_grid::sync(); }
142};
143
153__CG_QUALIFIER__ multi_grid_group this_multi_grid() {
154 return multi_grid_group(internal::multi_grid::size());
155}
156
165class grid_group : public thread_group {
166 // Only these friend functions are allowed to construct an object of this class
167 // and access its resources
168 friend __CG_QUALIFIER__ grid_group this_grid();
169
170 protected:
171 // Construct grid thread group (through the API this_grid())
172 explicit __CG_QUALIFIER__ grid_group(uint32_t size) : thread_group(internal::cg_grid, size) {}
173
174 public:
175 __CG_QUALIFIER__ uint32_t thread_rank() const { return internal::grid::thread_rank(); }
176 __CG_QUALIFIER__ bool is_valid() const { return internal::grid::is_valid(); }
177 __CG_QUALIFIER__ void sync() const { internal::grid::sync(); }
178};
179
189__CG_QUALIFIER__ grid_group this_grid() { return grid_group(internal::grid::size()); }
190
200class thread_block : public thread_group {
201 // Only these friend functions are allowed to construct an object of thi
202 // class and access its resources
203 friend __CG_QUALIFIER__ thread_block this_thread_block();
204 friend __CG_QUALIFIER__ thread_group tiled_partition(const thread_group& parent,
205 unsigned int tile_size);
206 friend __CG_QUALIFIER__ thread_group tiled_partition(const thread_block& parent,
207 unsigned int tile_size);
208 protected:
209 // Construct a workgroup thread group (through the API this_thread_block())
210 explicit __CG_QUALIFIER__ thread_block(uint32_t size)
211 : thread_group(internal::cg_workgroup, size) {}
212
213 __CG_QUALIFIER__ thread_group new_tiled_group(unsigned int tile_size) const {
214 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
215 // Invalid tile size, assert
216 if (!tile_size || (tile_size > __AMDGCN_WAVEFRONT_SIZE) || !pow2) {
217 __hip_assert(false && "invalid tile size");
218 }
219
220 thread_group tiledGroup = thread_group(internal::cg_tiled_group, tile_size);
221 tiledGroup.coalesced_info.tiled_info.size = tile_size;
222 tiledGroup.coalesced_info.tiled_info.is_tiled = true;
223 tiledGroup.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
224 tiledGroup.coalesced_info.tiled_info.meta_group_size = (size() + tile_size - 1) / tile_size;
225 return tiledGroup;
226 }
227
228 public:
229 // 3-dimensional block index within the grid
230 __CG_STATIC_QUALIFIER__ dim3 group_index() { return internal::workgroup::group_index(); }
231 // 3-dimensional thread index within the block
232 __CG_STATIC_QUALIFIER__ dim3 thread_index() { return internal::workgroup::thread_index(); }
233 __CG_STATIC_QUALIFIER__ uint32_t thread_rank() { return internal::workgroup::thread_rank(); }
234 __CG_STATIC_QUALIFIER__ uint32_t size() { return internal::workgroup::size(); }
235 __CG_STATIC_QUALIFIER__ bool is_valid() { return internal::workgroup::is_valid(); }
236 __CG_STATIC_QUALIFIER__ void sync() { internal::workgroup::sync(); }
237 __CG_QUALIFIER__ dim3 group_dim() { return internal::workgroup::block_dim(); }
238};
239
249__CG_QUALIFIER__ thread_block this_thread_block() {
250 return thread_block(internal::workgroup::size());
251}
252
261class tiled_group : public thread_group {
262 private:
263 friend __CG_QUALIFIER__ thread_group tiled_partition(const thread_group& parent,
264 unsigned int tile_size);
265 friend __CG_QUALIFIER__ tiled_group tiled_partition(const tiled_group& parent,
266 unsigned int tile_size);
267
268 __CG_QUALIFIER__ tiled_group new_tiled_group(unsigned int tile_size) const {
269 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
270
271 if (!tile_size || (tile_size > __AMDGCN_WAVEFRONT_SIZE) || !pow2) {
272 __hip_assert(false && "invalid tile size");
273 }
274
275 if (size() <= tile_size) {
276 return *this;
277 }
278
279 tiled_group tiledGroup = tiled_group(tile_size);
280 tiledGroup.coalesced_info.tiled_info.is_tiled = true;
281 return tiledGroup;
282 }
283
284 protected:
285 explicit __CG_QUALIFIER__ tiled_group(unsigned int tileSize)
286 : thread_group(internal::cg_tiled_group, tileSize) {
287 coalesced_info.tiled_info.size = tileSize;
288 coalesced_info.tiled_info.is_tiled = true;
289 }
290
291 public:
292 __CG_QUALIFIER__ unsigned int size() const { return (coalesced_info.tiled_info.size); }
293
294 __CG_QUALIFIER__ unsigned int thread_rank() const {
295 return (internal::workgroup::thread_rank() & (coalesced_info.tiled_info.size - 1));
296 }
297
298 __CG_QUALIFIER__ void sync() const {
299 internal::tiled_group::sync();
300 }
301};
302
310class coalesced_group : public thread_group {
311 private:
312 friend __CG_QUALIFIER__ coalesced_group coalesced_threads();
313 friend __CG_QUALIFIER__ thread_group tiled_partition(const thread_group& parent, unsigned int tile_size);
314 friend __CG_QUALIFIER__ coalesced_group tiled_partition(const coalesced_group& parent, unsigned int tile_size);
315
316 __CG_QUALIFIER__ coalesced_group new_tiled_group(unsigned int tile_size) const {
317 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
318
319 if (!tile_size || (tile_size > size()) || !pow2) {
320 return coalesced_group(0);
321 }
322
323 // If a tiled group is passed to be partitioned further into a coalesced_group.
324 // prepare a mask for further partitioning it so that it stays coalesced.
325 if (coalesced_info.tiled_info.is_tiled) {
326 unsigned int base_offset = (thread_rank() & (~(tile_size - 1)));
327 unsigned int masklength = min(static_cast<unsigned int>(size()) - base_offset, tile_size);
328 lane_mask member_mask = static_cast<lane_mask>(-1) >> (__AMDGCN_WAVEFRONT_SIZE - masklength);
329
330 member_mask <<= (__lane_id() & ~(tile_size - 1));
331 coalesced_group coalesced_tile = coalesced_group(member_mask);
332 coalesced_tile.coalesced_info.tiled_info.is_tiled = true;
333 coalesced_tile.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
334 coalesced_tile.coalesced_info.tiled_info.meta_group_size = size() / tile_size;
335 return coalesced_tile;
336 }
337 // Here the parent coalesced_group is not partitioned.
338 else {
339 lane_mask member_mask = 0;
340 unsigned int tile_rank = 0;
341 int lanes_to_skip = ((thread_rank()) / tile_size) * tile_size;
342
343 for (unsigned int i = 0; i < __AMDGCN_WAVEFRONT_SIZE; i++) {
344 lane_mask active = coalesced_info.member_mask & (1 << i);
345 // Make sure the lane is active
346 if (active) {
347 if (lanes_to_skip <= 0 && tile_rank < tile_size) {
348 // Prepare a member_mask that is appropriate for a tile
349 member_mask |= active;
350 tile_rank++;
351 }
352 lanes_to_skip--;
353 }
354 }
355 coalesced_group coalesced_tile = coalesced_group(member_mask);
356 coalesced_tile.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
357 coalesced_tile.coalesced_info.tiled_info.meta_group_size =
358 (size() + tile_size - 1) / tile_size;
359 return coalesced_tile;
360 }
361 return coalesced_group(0);
362 }
363
364 protected:
365 // Constructor
366 explicit __CG_QUALIFIER__ coalesced_group(lane_mask member_mask)
367 : thread_group(internal::cg_coalesced_group) {
368 coalesced_info.member_mask = member_mask; // Which threads are active
369 coalesced_info.size = __popcll(coalesced_info.member_mask); // How many threads are active
370 coalesced_info.tiled_info.is_tiled = false; // Not a partitioned group
371 coalesced_info.tiled_info.meta_group_rank = 0;
372 coalesced_info.tiled_info.meta_group_size = 1;
373 }
374
375 public:
376 __CG_QUALIFIER__ unsigned int size() const {
377 return coalesced_info.size;
378 }
379
380 __CG_QUALIFIER__ unsigned int thread_rank() const {
381 return internal::coalesced_group::masked_bit_count(coalesced_info.member_mask);
382 }
383
384 __CG_QUALIFIER__ void sync() const {
385 internal::coalesced_group::sync();
386 }
387
388 __CG_QUALIFIER__ unsigned int meta_group_rank() const {
389 return coalesced_info.tiled_info.meta_group_rank;
390 }
391
392 __CG_QUALIFIER__ unsigned int meta_group_size() const {
393 return coalesced_info.tiled_info.meta_group_size;
394 }
395
396 template <class T>
397 __CG_QUALIFIER__ T shfl(T var, int srcRank) const {
398 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
399
400 srcRank = srcRank % static_cast<int>(size());
401
402 int lane = (size() == __AMDGCN_WAVEFRONT_SIZE) ? srcRank
403 : (__AMDGCN_WAVEFRONT_SIZE == 64) ? __fns64(coalesced_info.member_mask, 0, (srcRank + 1))
404 : __fns32(coalesced_info.member_mask, 0, (srcRank + 1));
405
406 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
407 }
408
409 template <class T>
410 __CG_QUALIFIER__ T shfl_down(T var, unsigned int lane_delta) const {
411 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
412
413 // Note: The cuda implementation appears to use the remainder of lane_delta
414 // and WARP_SIZE as the shift value rather than lane_delta itself.
415 // This is not described in the documentation and is not done here.
416
417 if (size() == __AMDGCN_WAVEFRONT_SIZE) {
418 return __shfl_down(var, lane_delta, __AMDGCN_WAVEFRONT_SIZE);
419 }
420
421 int lane;
422 if (__AMDGCN_WAVEFRONT_SIZE == 64) {
423 lane = __fns64(coalesced_info.member_mask, __lane_id(), lane_delta + 1);
424 }
425 else {
426 lane = __fns32(coalesced_info.member_mask, __lane_id(), lane_delta + 1);
427 }
428
429 if (lane == -1) {
430 lane = __lane_id();
431 }
432
433 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
434 }
435
436 template <class T>
437 __CG_QUALIFIER__ T shfl_up(T var, unsigned int lane_delta) const {
438 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
439
440 // Note: The cuda implementation appears to use the remainder of lane_delta
441 // and WARP_SIZE as the shift value rather than lane_delta itself.
442 // This is not described in the documentation and is not done here.
443
444 if (size() == __AMDGCN_WAVEFRONT_SIZE) {
445 return __shfl_up(var, lane_delta, __AMDGCN_WAVEFRONT_SIZE);
446 }
447
448 int lane;
449 if (__AMDGCN_WAVEFRONT_SIZE == 64) {
450 lane = __fns64(coalesced_info.member_mask, __lane_id(), -(lane_delta + 1));
451 }
452 else if (__AMDGCN_WAVEFRONT_SIZE == 32) {
453 lane = __fns32(coalesced_info.member_mask, __lane_id(), -(lane_delta + 1));
454 }
455
456 if (lane == -1) {
457 lane = __lane_id();
458 }
459
460 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
461 }
462};
463
471__CG_QUALIFIER__ coalesced_group coalesced_threads() {
472 return cooperative_groups::coalesced_group(__builtin_amdgcn_read_exec());
473}
474
480__CG_QUALIFIER__ uint32_t thread_group::thread_rank() const {
481 switch (this->_type) {
482 case internal::cg_multi_grid: {
483 return (static_cast<const multi_grid_group*>(this)->thread_rank());
484 }
485 case internal::cg_grid: {
486 return (static_cast<const grid_group*>(this)->thread_rank());
487 }
488 case internal::cg_workgroup: {
489 return (static_cast<const thread_block*>(this)->thread_rank());
490 }
491 case internal::cg_tiled_group: {
492 return (static_cast<const tiled_group*>(this)->thread_rank());
493 }
494 case internal::cg_coalesced_group: {
495 return (static_cast<const coalesced_group*>(this)->thread_rank());
496 }
497 default: {
498 __hip_assert(false && "invalid cooperative group type");
499 return -1;
500 }
501 }
502}
508__CG_QUALIFIER__ bool thread_group::is_valid() const {
509 switch (this->_type) {
510 case internal::cg_multi_grid: {
511 return (static_cast<const multi_grid_group*>(this)->is_valid());
512 }
513 case internal::cg_grid: {
514 return (static_cast<const grid_group*>(this)->is_valid());
515 }
516 case internal::cg_workgroup: {
517 return (static_cast<const thread_block*>(this)->is_valid());
518 }
519 case internal::cg_tiled_group: {
520 return (static_cast<const tiled_group*>(this)->is_valid());
521 }
522 case internal::cg_coalesced_group: {
523 return (static_cast<const coalesced_group*>(this)->is_valid());
524 }
525 default: {
526 __hip_assert(false && "invalid cooperative group type");
527 return false;
528 }
529 }
530}
536__CG_QUALIFIER__ void thread_group::sync() const {
537 switch (this->_type) {
538 case internal::cg_multi_grid: {
539 static_cast<const multi_grid_group*>(this)->sync();
540 break;
541 }
542 case internal::cg_grid: {
543 static_cast<const grid_group*>(this)->sync();
544 break;
545 }
546 case internal::cg_workgroup: {
547 static_cast<const thread_block*>(this)->sync();
548 break;
549 }
550 case internal::cg_tiled_group: {
551 static_cast<const tiled_group*>(this)->sync();
552 break;
553 }
554 case internal::cg_coalesced_group: {
555 static_cast<const coalesced_group*>(this)->sync();
556 break;
557 }
558 default: {
559 __hip_assert(false && "invalid cooperative group type");
560 }
561 }
562}
563
570template <class CGTy> __CG_QUALIFIER__ uint32_t group_size(CGTy const& g) { return g.size(); }
577template <class CGTy> __CG_QUALIFIER__ uint32_t thread_rank(CGTy const& g) {
578 return g.thread_rank();
579}
586template <class CGTy> __CG_QUALIFIER__ bool is_valid(CGTy const& g) { return g.is_valid(); }
593template <class CGTy> __CG_QUALIFIER__ void sync(CGTy const& g) { g.sync(); }
599template <unsigned int tileSize> class tile_base {
600 protected:
601 _CG_STATIC_CONST_DECL_ unsigned int numThreads = tileSize;
602
603 public:
604 // Rank of the thread within this tile
605 _CG_STATIC_CONST_DECL_ unsigned int thread_rank() {
606 return (internal::workgroup::thread_rank() & (numThreads - 1));
607 }
608
609 // Number of threads within this tile
610 __CG_STATIC_QUALIFIER__ unsigned int size() { return numThreads; }
611};
617template <unsigned int size> class thread_block_tile_base : public tile_base<size> {
618 static_assert(is_valid_tile_size<size>::value,
619 "Tile size is either not a power of 2 or greater than the wavefront size");
620 using tile_base<size>::numThreads;
621
622 public:
623 __CG_STATIC_QUALIFIER__ void sync() {
624 internal::tiled_group::sync();
625 }
626
627 template <class T> __CG_QUALIFIER__ T shfl(T var, int srcRank) const {
628 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
629 return (__shfl(var, srcRank, numThreads));
630 }
631
632 template <class T> __CG_QUALIFIER__ T shfl_down(T var, unsigned int lane_delta) const {
633 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
634 return (__shfl_down(var, lane_delta, numThreads));
635 }
636
637 template <class T> __CG_QUALIFIER__ T shfl_up(T var, unsigned int lane_delta) const {
638 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
639 return (__shfl_up(var, lane_delta, numThreads));
640 }
641
642 template <class T> __CG_QUALIFIER__ T shfl_xor(T var, unsigned int laneMask) const {
643 static_assert(is_valid_type<T>::value, "Neither an integer or float type.");
644 return (__shfl_xor(var, laneMask, numThreads));
645 }
646};
649template <unsigned int tileSize, typename ParentCGTy>
650class parent_group_info {
651public:
652 // Returns the linear rank of the group within the set of tiles partitioned
653 // from a parent group (bounded by meta_group_size)
654 __CG_STATIC_QUALIFIER__ unsigned int meta_group_rank() {
655 return ParentCGTy::thread_rank() / tileSize;
656 }
657
658 // Returns the number of groups created when the parent group was partitioned.
659 __CG_STATIC_QUALIFIER__ unsigned int meta_group_size() {
660 return (ParentCGTy::size() + tileSize - 1) / tileSize;
661 }
662};
663
670template <unsigned int tileSize, class ParentCGTy>
671class thread_block_tile_type : public thread_block_tile_base<tileSize>,
672 public tiled_group,
673 public parent_group_info<tileSize, ParentCGTy> {
674 _CG_STATIC_CONST_DECL_ unsigned int numThreads = tileSize;
675 typedef thread_block_tile_base<numThreads> tbtBase;
676 protected:
677 __CG_QUALIFIER__ thread_block_tile_type() : tiled_group(numThreads) {
678 coalesced_info.tiled_info.size = numThreads;
679 coalesced_info.tiled_info.is_tiled = true;
680 }
681 public:
682 using tbtBase::size;
683 using tbtBase::sync;
684 using tbtBase::thread_rank;
685};
686
687// Partial template specialization
688template <unsigned int tileSize>
689class thread_block_tile_type<tileSize, void> : public thread_block_tile_base<tileSize>,
690 public tiled_group
691 {
692 _CG_STATIC_CONST_DECL_ unsigned int numThreads = tileSize;
693
694 typedef thread_block_tile_base<numThreads> tbtBase;
695
696 protected:
697
698 __CG_QUALIFIER__ thread_block_tile_type(unsigned int meta_group_rank, unsigned int meta_group_size)
699 : tiled_group(numThreads) {
700 coalesced_info.tiled_info.size = numThreads;
701 coalesced_info.tiled_info.is_tiled = true;
702 coalesced_info.tiled_info.meta_group_rank = meta_group_rank;
703 coalesced_info.tiled_info.meta_group_size = meta_group_size;
704 }
705
706 public:
707 using tbtBase::size;
708 using tbtBase::sync;
709 using tbtBase::thread_rank;
710
711 __CG_QUALIFIER__ unsigned int meta_group_rank() const {
712 return coalesced_info.tiled_info.meta_group_rank;
713 }
714
715 __CG_QUALIFIER__ unsigned int meta_group_size() const {
716 return coalesced_info.tiled_info.meta_group_size;
717 }
718// end of operative group
722};
723
724
731__CG_QUALIFIER__ thread_group tiled_partition(const thread_group& parent, unsigned int tile_size) {
732 if (parent.cg_type() == internal::cg_tiled_group) {
733 const tiled_group* cg = static_cast<const tiled_group*>(&parent);
734 return cg->new_tiled_group(tile_size);
735 }
736 else if(parent.cg_type() == internal::cg_coalesced_group) {
737 const coalesced_group* cg = static_cast<const coalesced_group*>(&parent);
738 return cg->new_tiled_group(tile_size);
739 }
740 else {
741 const thread_block* tb = static_cast<const thread_block*>(&parent);
742 return tb->new_tiled_group(tile_size);
743 }
744}
745
746// Thread block type overload
747__CG_QUALIFIER__ thread_group tiled_partition(const thread_block& parent, unsigned int tile_size) {
748 return (parent.new_tiled_group(tile_size));
749}
750
751__CG_QUALIFIER__ tiled_group tiled_partition(const tiled_group& parent, unsigned int tile_size) {
752 return (parent.new_tiled_group(tile_size));
753}
754
755// If a coalesced group is passed to be partitioned, it should remain coalesced
756__CG_QUALIFIER__ coalesced_group tiled_partition(const coalesced_group& parent, unsigned int tile_size) {
757 return (parent.new_tiled_group(tile_size));
758}
759
760template <unsigned int size, class ParentCGTy> class thread_block_tile;
761
762namespace impl {
763template <unsigned int size, class ParentCGTy> class thread_block_tile_internal;
764
765template <unsigned int size, class ParentCGTy>
766class thread_block_tile_internal : public thread_block_tile_type<size, ParentCGTy> {
767 protected:
768 template <unsigned int tbtSize, class tbtParentT>
769 __CG_QUALIFIER__ thread_block_tile_internal(
770 const thread_block_tile_internal<tbtSize, tbtParentT>& g)
771 : thread_block_tile_type<size, ParentCGTy>(g.meta_group_rank(), g.meta_group_size()) {}
772
773 __CG_QUALIFIER__ thread_block_tile_internal(const thread_block& g)
774 : thread_block_tile_type<size, ParentCGTy>() {}
775};
776} // namespace impl
777
778template <unsigned int size, class ParentCGTy>
779class thread_block_tile : public impl::thread_block_tile_internal<size, ParentCGTy> {
780 protected:
781 __CG_QUALIFIER__ thread_block_tile(const ParentCGTy& g)
782 : impl::thread_block_tile_internal<size, ParentCGTy>(g) {}
783
784 public:
785 __CG_QUALIFIER__ operator thread_block_tile<size, void>() const {
786 return thread_block_tile<size, void>(*this);
787 }
788};
789
790
791template <unsigned int size>
792class thread_block_tile<size, void> : public impl::thread_block_tile_internal<size, void> {
793 template <unsigned int, class ParentCGTy> friend class thread_block_tile;
794
795 protected:
796 public:
797 template <class ParentCGTy>
798 __CG_QUALIFIER__ thread_block_tile(const thread_block_tile<size, ParentCGTy>& g)
799 : impl::thread_block_tile_internal<size, void>(g) {}
800};
801
802template <unsigned int size, class ParentCGTy = void> class thread_block_tile;
803
804namespace impl {
805template <unsigned int size, class ParentCGTy> struct tiled_partition_internal;
806
807template <unsigned int size>
808struct tiled_partition_internal<size, thread_block> : public thread_block_tile<size, thread_block> {
809 __CG_QUALIFIER__ tiled_partition_internal(const thread_block& g)
810 : thread_block_tile<size, thread_block>(g) {}
811};
812
813} // namespace impl
814
820template <unsigned int size, class ParentCGTy>
821__CG_QUALIFIER__ thread_block_tile<size, ParentCGTy> tiled_partition(const ParentCGTy& g) {
822 static_assert(is_valid_tile_size<size>::value,
823 "Tiled partition with size > wavefront size. Currently not supported ");
824 return impl::tiled_partition_internal<size, ParentCGTy>(g);
825}
826} // namespace cooperative_groups
827
828#endif // __cplusplus
829#endif // HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
Device side implementation of cooperative group feature.