slot 0.0.1
A real time UI render framework
载入中...
搜索中...
未找到
slang-cuda-prelude.h
浏览该文件的文档.
1#ifndef SLANG_CUDA_PRELUDE_H
2#define SLANG_CUDA_PRELUDE_H
3
4#define SLANG_PRELUDE_EXPORT
5
6#ifdef __CUDACC_RTC__
7#define SLANG_CUDA_RTC 1
8#else
9#define SLANG_CUDA_RTC 0
10#endif
11
12#if SLANG_CUDA_RTC
13
14#else
15
16#include <cstdint>
17#include <stdio.h>
18
19#endif
20
21// Define SLANG_CUDA_ENABLE_HALF to use the cuda_fp16 include to add half support.
22// For this to work NVRTC needs to have the path to the CUDA SDK.
23//
24// As it stands the includes paths defined for Slang are passed down to NVRTC. Similarly defines
25// defined for the Slang compile are passed down.
26
27#ifdef SLANG_CUDA_ENABLE_HALF
28// We don't want half2 operators from cuda_fp16.h (comparison returns bool). Arithmetic for
29// __half2 is defined in the macro SLANG_CUDA_VECTOR_FLOAT_OP_HALF2 below (CUDA intrinsics).
30#define __CUDA_NO_HALF2_OPERATORS__
31#include <cuda_fp16.h>
32#endif
33
34#ifdef SLANG_CUDA_ENABLE_FP8
35#include <cuda_fp8.h>
36#endif
37
38#ifdef SLANG_CUDA_ENABLE_BF16
39#include <cuda_bf16.h>
40#endif
41
42#ifdef SLANG_CUDA_ENABLE_OPTIX
43#include <optix.h>
44#endif
45
46// Define slang offsetof implementation
47#ifndef SLANG_OFFSET_OF
48#define SLANG_OFFSET_OF(type, member) (size_t)((char*)&(((type*)0)->member) - (char*)0)
49#endif
50
51// Must be large enough to cause overflow and therefore infinity
52#ifndef SLANG_INFINITY
53#define SLANG_INFINITY ((float)(1e+300 * 1e+300))
54#endif
55
56// For now we'll disable any asserts in this prelude
57#define SLANG_PRELUDE_ASSERT(x)
58
59#ifndef SLANG_CUDA_WARP_SIZE
60#define SLANG_CUDA_WARP_SIZE 32
61#endif
62
63#define SLANG_CUDA_WARP_MASK \
64 (SLANG_CUDA_WARP_SIZE - 1) // Used for masking threadIdx.x to the warp lane index
65#define SLANG_CUDA_WARP_BITMASK (~int(0))
66
67//
68#define SLANG_FORCE_INLINE inline
69
70#define SLANG_CUDA_CALL __device__
71
72#define SLANG_FORCE_INLINE inline
73#define SLANG_INLINE inline
74
75
76// Since we are using unsigned arithmatic care is need in this comparison.
77// It is *assumed* that sizeInBytes >= elemSize. Which means (sizeInBytes >= elemSize) >= 0
78// Which means only a single test is needed
79
80// Asserts for bounds checking.
81// It is assumed index/count are unsigned types.
82#define SLANG_BOUND_ASSERT(index, count) SLANG_PRELUDE_ASSERT(index < count);
83#define SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) \
84 SLANG_PRELUDE_ASSERT(index <= (sizeInBytes - elemSize) && (index & 3) == 0);
85
86// Macros to zero index if an access is out of range
87#define SLANG_BOUND_ZERO_INDEX(index, count) index = (index < count) ? index : 0;
88#define SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes) \
89 index = (index <= (sizeInBytes - elemSize)) ? index : 0;
90
91// The 'FIX' macro define how the index is fixed. The default is to do nothing. If
92// SLANG_ENABLE_BOUND_ZERO_INDEX the fix macro will zero the index, if out of range
93#ifdef SLANG_ENABLE_BOUND_ZERO_INDEX
94#define SLANG_BOUND_FIX(index, count) SLANG_BOUND_ZERO_INDEX(index, count)
95#define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) \
96 SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes)
97#define SLANG_BOUND_FIX_FIXED_ARRAY(index, count) \
98 SLANG_BOUND_ZERO_INDEX(index, count) SLANG_BOUND_ZERO_INDEX(index, count)
99#else
100#define SLANG_BOUND_FIX(index, count)
101#define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes)
102#define SLANG_BOUND_FIX_FIXED_ARRAY(index, count)
103#endif
104
105#ifndef SLANG_BOUND_CHECK
106#define SLANG_BOUND_CHECK(index, count) \
107 SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX(index, count)
108#endif
109
110#ifndef SLANG_BOUND_CHECK_BYTE_ADDRESS
111#define SLANG_BOUND_CHECK_BYTE_ADDRESS(index, elemSize, sizeInBytes) \
112 SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) \
113 SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes)
114#endif
115
116#ifndef SLANG_BOUND_CHECK_FIXED_ARRAY
117#define SLANG_BOUND_CHECK_FIXED_ARRAY(index, count) \
118 SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX_FIXED_ARRAY(index, count)
119#endif
120
121// This macro handles how out-of-range surface coordinates are handled;
122// I can equal
123// cudaBoundaryModeClamp, in which case out-of-range coordinates are clamped to the valid range
124// cudaBoundaryModeZero, in which case out-of-range reads return zero and out-of-range writes are
125// ignored cudaBoundaryModeTrap, in which case out-of-range accesses cause the kernel execution to
126// fail.
127
128#ifndef SLANG_CUDA_BOUNDARY_MODE
129#define SLANG_CUDA_BOUNDARY_MODE cudaBoundaryModeZero
130
131// Can be one of SLANG_CUDA_PTX_BOUNDARY_MODE. Only applies *PTX* emitted CUDA operations
132// which currently is just RWTextureRW format writes
133//
134// .trap causes an execution trap on out-of-bounds addresses
135// .clamp stores data at the nearest surface location (sized appropriately)
136// .zero drops stores to out-of-bounds addresses
137
138#define SLANG_PTX_BOUNDARY_MODE "zero"
139#endif
140
141struct TypeInfo
142{
143 size_t typeSize;
144};
145
146template<typename T, size_t SIZE>
147struct FixedArray
148{
149 SLANG_CUDA_CALL const T& operator[](size_t index) const
150 {
152 return m_data[index];
153 }
155 {
157 return m_data[index];
158 }
159
160 T m_data[SIZE];
161};
162
163// An array that has no specified size, becomes a 'Array'. This stores the size so it can
164// potentially do bounds checking.
165template<typename T>
166struct Array
167{
168 SLANG_CUDA_CALL const T& operator[](size_t index) const
169 {
170 SLANG_BOUND_CHECK(index, count);
171 return data[index];
172 }
174 {
175 SLANG_BOUND_CHECK(index, count);
176 return data[index];
177 }
178
179 T* data;
180 size_t count;
181};
182
183// Typically defined in cuda.h, but we can't ship/rely on that, so just define here
184typedef unsigned long long CUtexObject;
185typedef unsigned long long CUsurfObject;
186
187// On CUDA sampler state is actually bound up with the texture object. We have a SamplerState type,
188// backed as a pointer, to simplify code generation, with the downside that such a binding will take
189// up uniform space, even though it will have no effect.
190// TODO(JS): Consider ways to strip use of variables of this type so have no binding,
191struct SamplerStateUnused;
192typedef SamplerStateUnused* SamplerState;
193
194
195// TODO(JS): Not clear yet if this can be handled on CUDA, by just ignoring.
196// For now, just map to the index type.
198
199// Code generator will generate the specific type
200template<typename T, int ROWS, int COLS>
201struct Matrix;
202
203// Boolean vector types should follow CUDA's builtin vector alignment rules
204// Align boolX the same as charX according to CUDA spec:
205// char1/uchar1: 1-byte aligned, char2/uchar2: 2-byte aligned
206// char3/uchar3: 1-byte aligned, char4/uchar4: 4-byte aligned
207struct __align__(1) bool1
208{
209 bool x;
210
211 SLANG_FORCE_INLINE SLANG_CUDA_CALL bool& operator[](int idx)
212 {
213 return (&x)[idx];
214 }
215 SLANG_FORCE_INLINE SLANG_CUDA_CALL const bool& operator[](int idx) const
216 {
217 return (&x)[idx];
218 }
219};
220
221struct __align__(2) bool2
222{
223 bool x, y;
224
225 SLANG_FORCE_INLINE SLANG_CUDA_CALL bool& operator[](int idx)
226 {
227 return (&x)[idx];
228 }
229 SLANG_FORCE_INLINE SLANG_CUDA_CALL const bool& operator[](int idx) const
230 {
231 return (&x)[idx];
232 }
233};
234
235struct __align__(1) bool3
236{
237 bool x, y, z;
238
239 SLANG_FORCE_INLINE SLANG_CUDA_CALL bool& operator[](int idx)
240 {
241 return (&x)[idx];
242 }
243 SLANG_FORCE_INLINE SLANG_CUDA_CALL const bool& operator[](int idx) const
244 {
245 return (&x)[idx];
246 }
247};
248
249struct __align__(4) bool4
250{
251 bool x, y, z, w;
252
253 SLANG_FORCE_INLINE SLANG_CUDA_CALL bool& operator[](int idx)
254 {
255 return (&x)[idx];
256 }
257 SLANG_FORCE_INLINE SLANG_CUDA_CALL const bool& operator[](int idx) const
258 {
259 return (&x)[idx];
260 }
261};
262
264{
265 return (bool)(__ldg((const char*)ptr));
266}
267
269{
270 auto val = __ldg((const char2*)ptr);
271 return {val.x != 0, val.y != 0};
272}
273
275{
276 auto val = __ldg((const char4*)ptr);
277 return {val.x != 0, val.y != 0, val.z != 0, val.w != 0};
278}
279
280#if SLANG_CUDA_RTC
281
282typedef signed char int8_t;
283typedef short int16_t;
284typedef int int32_t;
285typedef long long int64_t;
286typedef ptrdiff_t intptr_t;
287
288typedef unsigned char uint8_t;
289typedef unsigned short uint16_t;
290typedef unsigned int uint32_t;
291typedef unsigned long long uint64_t;
292typedef size_t uintptr_t;
293
294typedef long long longlong;
295typedef unsigned long long ulonglong;
296
297#else
298
299// When not using NVRTC, match the platform's int64_t definition for signed type
300// On Linux: int64_t is 'long', on Windows: int64_t is 'long long'
301typedef int64_t longlong;
302// ulonglong must remain 'unsigned long long' to match CUDA's atomic operations
303typedef unsigned long long ulonglong;
304
305#endif
306
307typedef unsigned char uchar;
308typedef unsigned short ushort;
309typedef unsigned int uint;
310
311#if SLANG_CUDA_ENABLE_HALF
312typedef __half half;
313#endif
314
315union Union32
316{
317 uint32_t u;
318 int32_t i;
319 float f;
320};
321
322union Union64
323{
324 uint64_t u;
325 int64_t i;
326 double d;
327};
328
329template<typename T>
331{
332 return (float)val;
333}
334
336{
337 return ::fmodf(x, y);
338}
340{
341 return ::fmod(x, y);
342}
343
344#if SLANG_CUDA_ENABLE_HALF
345
346// Add the other vector half types
347struct __half1
348{
349 __half x;
350};
351struct __align__(4) __half3
352{
353 __half x, y, z;
354};
355struct __align__(4) __half4
356{
357 __half x, y, z, w;
358};
359#endif
360
361#if SLANG_CUDA_ENABLE_BF16
362
363// Add the other vector bfloat16 types
364struct __nv_bfloat161
365{
366 __nv_bfloat16 x;
367};
368struct __nv_bfloat163
369{
370 __nv_bfloat16 x, y, z;
371};
372struct __nv_bfloat164
373{
374 __nv_bfloat16 x, y, z, w;
375};
376#endif
377
378#if SLANG_CUDA_ENABLE_FP8
379
380// Add the other vector fp8 types
381struct __nv_fp8_e4m31
382{
383 __nv_fp8_e4m3 x;
384};
385struct __nv_fp8_e4m32
386{
387 __nv_fp8_e4m3 x, y;
388};
389struct __nv_fp8_e4m33
390{
391 __nv_fp8_e4m3 x, y, z;
392};
393struct __nv_fp8_e4m34
394{
395 __nv_fp8_e4m3 x, y, z, w;
396};
397struct __nv_fp8_e5m21
398{
399 __nv_fp8_e5m2 x;
400};
401struct __nv_fp8_e5m22
402{
403 __nv_fp8_e5m2 x, y;
404};
405struct __nv_fp8_e5m23
406{
407 __nv_fp8_e5m2 x, y, z;
408};
409struct __nv_fp8_e5m24
410{
411 __nv_fp8_e5m2 x, y, z, w;
412};
413#endif
414
415#define SLANG_VECTOR_GET_ELEMENT(T) \
416 SLANG_FORCE_INLINE SLANG_CUDA_CALL T _slang_vector_get_element(T##1 x, int index) \
417 { \
418 return ((T*)(&x))[index]; \
419 } \
420 SLANG_FORCE_INLINE SLANG_CUDA_CALL T _slang_vector_get_element(T##2 x, int index) \
421 { \
422 return ((T*)(&x))[index]; \
423 } \
424 SLANG_FORCE_INLINE SLANG_CUDA_CALL T _slang_vector_get_element(T##3 x, int index) \
425 { \
426 return ((T*)(&x))[index]; \
427 } \
428 SLANG_FORCE_INLINE SLANG_CUDA_CALL T _slang_vector_get_element(T##4 x, int index) \
429 { \
430 return ((T*)(&x))[index]; \
431 }
443
444#define SLANG_VECTOR_GET_ELEMENT_PTR(T) \
445 SLANG_FORCE_INLINE SLANG_CUDA_CALL T* _slang_vector_get_element_ptr(const T##1 * x, int index) \
446 { \
447 return ((T*)(x)) + index; \
448 } \
449 SLANG_FORCE_INLINE SLANG_CUDA_CALL T* _slang_vector_get_element_ptr(const T##2 * x, int index) \
450 { \
451 return ((T*)(x)) + index; \
452 } \
453 SLANG_FORCE_INLINE SLANG_CUDA_CALL T* _slang_vector_get_element_ptr(const T##3 * x, int index) \
454 { \
455 return ((T*)(x)) + index; \
456 } \
457 SLANG_FORCE_INLINE SLANG_CUDA_CALL T* _slang_vector_get_element_ptr(const T##4 * x, int index) \
458 { \
459 return ((T*)(x)) + index; \
460 }
472
473#if SLANG_CUDA_ENABLE_HALF
476#endif
477
478#if SLANG_CUDA_ENABLE_BF16
479SLANG_VECTOR_GET_ELEMENT(__nv_bfloat16)
480SLANG_VECTOR_GET_ELEMENT_PTR(__nv_bfloat16)
481
483_slang_vector_dot(__nv_bfloat162 v0, __nv_bfloat162 v1)
484{
485 __nv_bfloat16 result = __nv_bfloat16(0.0f);
486 for (int i = 0; i < 2; i++)
487 {
489 }
490 return result;
491}
493_slang_vector_dot(__nv_bfloat163 v0, __nv_bfloat163 v1)
494{
495 __nv_bfloat16 result = __nv_bfloat16(0.0f);
496 for (int i = 0; i < 3; i++)
497 {
499 }
500 return result;
501}
503_slang_vector_dot(__nv_bfloat164 v0, __nv_bfloat164 v1)
504{
505 __nv_bfloat16 result = __nv_bfloat16(0.0f);
506 for (int i = 0; i < 4; i++)
507 {
509 }
510 return result;
511}
512#endif
513
514#if SLANG_CUDA_ENABLE_FP8
515SLANG_VECTOR_GET_ELEMENT(__nv_fp8_e4m3)
516SLANG_VECTOR_GET_ELEMENT_PTR(__nv_fp8_e4m3)
517SLANG_VECTOR_GET_ELEMENT(__nv_fp8_e5m2)
518SLANG_VECTOR_GET_ELEMENT_PTR(__nv_fp8_e5m2)
519#endif
520
521#define SLANG_CUDA_VECTOR_BINARY_OP(T, n, op) \
522 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##n operator op(T##n thisVal, T##n other) \
523 { \
524 T##n result; \
525 for (int i = 0; i < n; i++) \
526 *_slang_vector_get_element_ptr(&result, i) = \
527 _slang_vector_get_element(thisVal, i) op _slang_vector_get_element(other, i); \
528 return result; \
529 }
530#define SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, op) \
531 SLANG_FORCE_INLINE SLANG_CUDA_CALL bool##n operator op(T##n thisVal, T##n other) \
532 { \
533 bool##n result; \
534 for (int i = 0; i < n; i++) \
535 *_slang_vector_get_element_ptr(&result, i) = \
536 (_slang_vector_get_element(thisVal, i) op _slang_vector_get_element(other, i)); \
537 return result; \
538 }
539#define SLANG_CUDA_VECTOR_UNARY_OP(T, n, op) \
540 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##n operator op(T##n thisVal) \
541 { \
542 T##n result; \
543 for (int i = 0; i < n; i++) \
544 *_slang_vector_get_element_ptr(&result, i) = op _slang_vector_get_element(thisVal, i); \
545 return result; \
546 }
547
548#define SLANG_CUDA_VECTOR_INT_OP(T, n) \
549 SLANG_CUDA_VECTOR_BINARY_OP(T, n, +) \
550 SLANG_CUDA_VECTOR_BINARY_OP(T, n, -) \
551 SLANG_CUDA_VECTOR_BINARY_OP(T, n, *) \
552 SLANG_CUDA_VECTOR_BINARY_OP(T, n, /) \
553 SLANG_CUDA_VECTOR_BINARY_OP(T, n, %) \
554 SLANG_CUDA_VECTOR_BINARY_OP(T, n, ^) \
555 SLANG_CUDA_VECTOR_BINARY_OP(T, n, &) \
556 SLANG_CUDA_VECTOR_BINARY_OP(T, n, |) \
557 SLANG_CUDA_VECTOR_BINARY_OP(T, n, &&) \
558 SLANG_CUDA_VECTOR_BINARY_OP(T, n, ||) \
559 SLANG_CUDA_VECTOR_BINARY_OP(T, n, >>) \
560 SLANG_CUDA_VECTOR_BINARY_OP(T, n, <<) \
561 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, >) \
562 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, <) \
563 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, >=) \
564 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, <=) \
565 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, ==) \
566 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, !=) \
567 SLANG_CUDA_VECTOR_UNARY_OP(T, n, !) \
568 SLANG_CUDA_VECTOR_UNARY_OP(T, n, -) \
569 SLANG_CUDA_VECTOR_UNARY_OP(T, n, ~)
570
571#define SLANG_CUDA_VECTOR_INT_OPS(T) \
572 SLANG_CUDA_VECTOR_INT_OP(T, 2) \
573 SLANG_CUDA_VECTOR_INT_OP(T, 3) \
574 SLANG_CUDA_VECTOR_INT_OP(T, 4)
575
585
586#define SLANG_CUDA_VECTOR_FLOAT_OP(T, n) \
587 SLANG_CUDA_VECTOR_BINARY_OP(T, n, +) \
588 SLANG_CUDA_VECTOR_BINARY_OP(T, n, -) \
589 SLANG_CUDA_VECTOR_BINARY_OP(T, n, *) \
590 SLANG_CUDA_VECTOR_BINARY_OP(T, n, /) \
591 SLANG_CUDA_VECTOR_BINARY_OP(T, n, &&) \
592 SLANG_CUDA_VECTOR_BINARY_OP(T, n, ||) \
593 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, >) \
594 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, <) \
595 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, >=) \
596 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, <=) \
597 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, ==) \
598 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, !=) \
599 SLANG_CUDA_VECTOR_UNARY_OP(T, n, -)
600/* Special case __half2: use CUDA intrinsics (__hadd2, __hsub2, etc.) so we get one add.f16x2
601 per op; generic macro would give two add.f16. Compare/logical stay element-wise for bool2. */
602#define SLANG_CUDA_VECTOR_FLOAT_OP_HALF2 \
603 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator+(const __half2& lh, const __half2& rh) \
604 { \
605 return __hadd2(lh, rh); \
606 } \
607 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator-(const __half2& lh, const __half2& rh) \
608 { \
609 return __hsub2(lh, rh); \
610 } \
611 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator*(const __half2& lh, const __half2& rh) \
612 { \
613 return __hmul2(lh, rh); \
614 } \
615 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator/(const __half2& lh, const __half2& rh) \
616 { \
617 return __h2div(lh, rh); \
618 } \
619 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator-(const __half2& h) \
620 { \
621 return __hneg2(h); \
622 } \
623 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator+=(__half2& lh, const __half2& rh) \
624 { \
625 lh = __hadd2(lh, rh); \
626 return lh; \
627 } \
628 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator-=(__half2& lh, const __half2& rh) \
629 { \
630 lh = __hsub2(lh, rh); \
631 return lh; \
632 } \
633 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator*=(__half2& lh, const __half2& rh) \
634 { \
635 lh = __hmul2(lh, rh); \
636 return lh; \
637 } \
638 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator/=(__half2& lh, const __half2& rh) \
639 { \
640 lh = __h2div(lh, rh); \
641 return lh; \
642 } \
643 SLANG_CUDA_VECTOR_BINARY_OP(__half, 2, &&) \
644 SLANG_CUDA_VECTOR_BINARY_OP(__half, 2, ||) \
645 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(__half, 2, >) \
646 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(__half, 2, <) \
647 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(__half, 2, >=) \
648 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(__half, 2, <=) \
649 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(__half, 2, ==) \
650 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(__half, 2, !=)
651/* Explicit per-type expansion (no dispatch, no token-paste with __half) so NVRTC and all compilers
652 * behave. */
653#define SLANG_CUDA_VECTOR_FLOAT_OPS_float \
654 SLANG_CUDA_VECTOR_FLOAT_OP(float, 2) \
655 SLANG_CUDA_VECTOR_FLOAT_OP(float, 3) \
656 SLANG_CUDA_VECTOR_FLOAT_OP(float, 4)
657#define SLANG_CUDA_VECTOR_FLOAT_OPS_double \
658 SLANG_CUDA_VECTOR_FLOAT_OP(double, 2) \
659 SLANG_CUDA_VECTOR_FLOAT_OP(double, 3) \
660 SLANG_CUDA_VECTOR_FLOAT_OP(double, 4)
661#define SLANG_CUDA_VECTOR_FLOAT_OPS___half \
662 SLANG_CUDA_VECTOR_FLOAT_OP_HALF2 \
663 SLANG_CUDA_VECTOR_FLOAT_OP(__half, 3) \
664 SLANG_CUDA_VECTOR_FLOAT_OP(__half, 4)
665#define SLANG_CUDA_VECTOR_FLOAT_OPS(T) SLANG_CUDA_VECTOR_FLOAT_OPS_##T
666
669#if SLANG_CUDA_ENABLE_HALF
671#endif
672#define SLANG_CUDA_FLOAT_VECTOR_MOD_IMPL(T, n) \
673 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##n operator%(const T##n& left, const T##n& right) \
674 { \
675 T##n result; \
676 for (int i = 0; i < n; i++) \
677 *_slang_vector_get_element_ptr(&result, i) = _slang_fmod( \
678 _slang_vector_get_element(left, i), \
679 _slang_vector_get_element(right, i)); \
680 return result; \
681 }
682#define SLANG_CUDA_FLOAT_VECTOR_MOD(T) \
683 SLANG_CUDA_FLOAT_VECTOR_MOD_IMPL(T, 2) \
684 SLANG_CUDA_FLOAT_VECTOR_MOD_IMPL(T, 3) \
685 SLANG_CUDA_FLOAT_VECTOR_MOD_IMPL(T, 4)
686
689
690#if SLANG_CUDA_RTC || SLANG_CUDA_ENABLE_HALF
691#define SLANG_MAKE_VECTOR(T) \
692 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##2 make_##T##2(T x, T y) \
693 { \
694 return T##2 {x, y}; \
695 } \
696 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##3 make_##T##3(T x, T y, T z) \
697 { \
698 return T##3 {x, y, z}; \
699 } \
700 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##4 make_##T##4(T x, T y, T z, T w) \
701 { \
702 return T##4 {x, y, z, w}; \
703 }
704#endif
705
706#if SLANG_CUDA_RTC
707SLANG_MAKE_VECTOR(int)
708SLANG_MAKE_VECTOR(uint)
709SLANG_MAKE_VECTOR(short)
710SLANG_MAKE_VECTOR(ushort)
711SLANG_MAKE_VECTOR(char)
712SLANG_MAKE_VECTOR(uchar)
713SLANG_MAKE_VECTOR(float)
714SLANG_MAKE_VECTOR(double)
715SLANG_MAKE_VECTOR(longlong)
716SLANG_MAKE_VECTOR(ulonglong)
717#endif
718
719#if SLANG_CUDA_ENABLE_HALF
720SLANG_MAKE_VECTOR(__half)
721#endif
722
723#if SLANG_CUDA_ENABLE_BF16
724SLANG_MAKE_VECTOR(__nv_bfloat16)
725#endif
726
727#if SLANG_CUDA_ENABLE_FP8
728SLANG_MAKE_VECTOR(__nv_fp8_e4m3)
729SLANG_MAKE_VECTOR(__nv_fp8_e5m2)
730#endif
731
733{
734 return bool1{x};
735}
737{
738 return bool2{x, y};
739}
740SLANG_FORCE_INLINE SLANG_CUDA_CALL bool3 make_bool3(bool x, bool y, bool z)
741{
742 return bool3{x, y, z};
743}
744SLANG_FORCE_INLINE SLANG_CUDA_CALL bool4 make_bool4(bool x, bool y, bool z, bool w)
745{
746 return bool4{x, y, z, w};
747}
749{
750 return bool2{x, x};
751}
753{
754 return bool3{x, x, x};
755}
757{
758 return bool4{x, x, x, x};
759}
760
761#if SLANG_CUDA_RTC
762#define SLANG_MAKE_VECTOR_FROM_SCALAR(T) \
763 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##1 make_##T##1(T x) \
764 { \
765 return T##1 {x}; \
766 } \
767 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##2 make_##T##2(T x) \
768 { \
769 return make_##T##2(x, x); \
770 } \
771 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##3 make_##T##3(T x) \
772 { \
773 return make_##T##3(x, x, x); \
774 } \
775 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##4 make_##T##4(T x) \
776 { \
777 return make_##T##4(x, x, x, x); \
778 }
779#else
780#define SLANG_MAKE_VECTOR_FROM_SCALAR(T) \
781 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##2 make_##T##2(T x) \
782 { \
783 return make_##T##2(x, x); \
784 } \
785 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##3 make_##T##3(T x) \
786 { \
787 return make_##T##3(x, x, x); \
788 } \
789 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##4 make_##T##4(T x) \
790 { \
791 return make_##T##4(x, x, x, x); \
792 }
793#endif
804#if SLANG_CUDA_ENABLE_HALF
806#if !SLANG_CUDA_RTC
807SLANG_FORCE_INLINE SLANG_CUDA_CALL __half1 make___half1(__half x)
808{
809 return __half1{x};
810}
811#endif
812#endif
813#if SLANG_CUDA_ENABLE_BF16
815#if !SLANG_CUDA_RTC
816SLANG_FORCE_INLINE SLANG_CUDA_CALL __nv_bfloat16 make___nv_bfloat161(__nv_bfloat16 x)
817{
818 return __nv_bfloat16{x};
819}
820#endif
821#endif
822
823#if SLANG_CUDA_ENABLE_FP8
826#if !SLANG_CUDA_RTC
827SLANG_FORCE_INLINE SLANG_CUDA_CALL __nv_fp8_e4m3 make___nv_fp8_e4m31(__nv_fp8_e4m3 x)
828{
829 return __nv_fp8_e4m3{x};
830}
831SLANG_FORCE_INLINE SLANG_CUDA_CALL __nv_fp8_e5m2 make___nv_fp8_e5m21(__nv_fp8_e5m2 x)
832{
833 return __nv_fp8_e5m2{x};
834}
835#endif
836#endif
837
838#define SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(Fn, T, N) \
839 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##N Fn(T##N* address, T##N val) \
840 { \
841 T##N result; \
842 for (int i = 0; i < N; i++) \
843 *_slang_vector_get_element_ptr(&result, i) = \
844 Fn(_slang_vector_get_element_ptr(address, i), _slang_vector_get_element(val, i)); \
845 return result; \
846 }
847
848#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 900
851#endif
852#if defined(SLANG_CUDA_ENABLE_HALF) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)
855#endif
866
867template<typename T, int n>
869{
870};
871
872#define GET_VECTOR_TYPE_IMPL(T, n) \
873 template<> \
874 struct GetVectorTypeImpl<T, n> \
875 { \
876 typedef T##n type; \
877 static SLANG_FORCE_INLINE SLANG_CUDA_CALL T##n fromScalar(T v) \
878 { \
879 return make_##T##n(v); \
880 } \
881 };
882#define GET_VECTOR_TYPE_IMPL_N(T) \
883 GET_VECTOR_TYPE_IMPL(T, 1) \
884 GET_VECTOR_TYPE_IMPL(T, 2) \
885 GET_VECTOR_TYPE_IMPL(T, 3) \
886 GET_VECTOR_TYPE_IMPL(T, 4)
887
899#if SLANG_CUDA_ENABLE_HALF
901#endif
902#if SLANG_CUDA_ENABLE_BF16
903GET_VECTOR_TYPE_IMPL_N(__nv_bfloat16)
904#endif
905#if SLANG_CUDA_ENABLE_FP8
906GET_VECTOR_TYPE_IMPL_N(__nv_fp8_e4m3)
907GET_VECTOR_TYPE_IMPL_N(__nv_fp8_e5m2)
908#endif
909
910template<typename T, int n>
912
913template<typename T, int n, typename OtherT, int m>
915{
916 Vector<T, n> result;
917 for (int i = 0; i < n; i++)
918 {
919 OtherT otherElement = T(0);
920 if (i < m)
921 otherElement = _slang_vector_get_element(other, i);
922 *_slang_vector_get_element_ptr(&result, i) = (T)otherElement;
923 }
924 return result;
925}
926
927template<typename T, int ROWS, int COLS>
928struct Matrix
929{
930 Vector<T, COLS> rows[ROWS];
932 {
933 return rows[index];
934 }
935
937 {
938 return rows[index];
939 }
940};
941
942
943template<typename T, int ROWS, int COLS>
945{
947 for (int i = 0; i < ROWS; i++)
949 return result;
950}
951
952template<typename T, int ROWS, int COLS>
954{
956 result.rows[0] = row0;
957 return result;
958}
959
960template<typename T, int ROWS, int COLS>
962 const Vector<T, COLS>& row0,
963 const Vector<T, COLS>& row1)
964{
966 result.rows[0] = row0;
967 result.rows[1] = row1;
968 return result;
969}
970
971template<typename T, int ROWS, int COLS>
973 const Vector<T, COLS>& row0,
974 const Vector<T, COLS>& row1,
975 const Vector<T, COLS>& row2)
976{
978 result.rows[0] = row0;
979 result.rows[1] = row1;
980 result.rows[2] = row2;
981 return result;
982}
983
984template<typename T, int ROWS, int COLS>
986 const Vector<T, COLS>& row0,
987 const Vector<T, COLS>& row1,
988 const Vector<T, COLS>& row2,
989 const Vector<T, COLS>& row3)
990{
992 result.rows[0] = row0;
993 result.rows[1] = row1;
994 result.rows[2] = row2;
995 result.rows[3] = row3;
996 return result;
997}
998
999template<typename T, int ROWS, int COLS, typename U, int otherRow, int otherCol>
1001 const Matrix<U, otherRow, otherCol>& other)
1002{
1003 Matrix<T, ROWS, COLS> result;
1004 int minRow = ROWS;
1005 int minCol = COLS;
1006 if (minRow > otherRow)
1007 minRow = otherRow;
1008 if (minCol > otherCol)
1009 minCol = otherCol;
1010 for (int i = 0; i < minRow; i++)
1011 for (int j = 0; j < minCol; j++)
1012 *_slang_vector_get_element_ptr(result.rows + i, j) =
1013 (T)_slang_vector_get_element(other.rows[i], j);
1014 return result;
1015}
1016
1017template<typename T, int ROWS, int COLS>
1019{
1021 rs.rows[0].x = v0;
1022 rs.rows[0].y = v1;
1023 rs.rows[1].x = v2;
1024 rs.rows[1].y = v3;
1025 return rs;
1026}
1027
1028template<typename T, int ROWS, int COLS>
1030 T v0,
1031 T v1,
1032 T v2,
1033 T v3,
1034 T v4,
1035 T v5)
1036{
1038 if (COLS == 3)
1039 {
1040 *_slang_vector_get_element_ptr(&rs.rows[0], 0) = v0;
1041 *_slang_vector_get_element_ptr(&rs.rows[0], 1) = v1;
1042 *_slang_vector_get_element_ptr(&rs.rows[0], 2) = v2;
1043 *_slang_vector_get_element_ptr(&rs.rows[1], 0) = v3;
1044 *_slang_vector_get_element_ptr(&rs.rows[1], 1) = v4;
1045 *_slang_vector_get_element_ptr(&rs.rows[1], 2) = v5;
1046 }
1047 else
1048 {
1049 rs.rows[0].x = v0;
1050 rs.rows[0].y = v1;
1051 rs.rows[1].x = v2;
1052 rs.rows[1].y = v3;
1053 rs.rows[2].x = v4;
1054 rs.rows[2].y = v5;
1055 }
1056 return rs;
1057}
1058
1059template<typename T, int ROWS, int COLS>
1061 T v0,
1062 T v1,
1063 T v2,
1064 T v3,
1065 T v4,
1066 T v5,
1067 T v6,
1068 T v7)
1069{
1071 if (COLS == 4)
1072 {
1073 *_slang_vector_get_element_ptr(&rs.rows[0], 0) = v0;
1074 *_slang_vector_get_element_ptr(&rs.rows[0], 1) = v1;
1075 *_slang_vector_get_element_ptr(&rs.rows[0], 2) = v2;
1076 *_slang_vector_get_element_ptr(&rs.rows[0], 3) = v3;
1077 *_slang_vector_get_element_ptr(&rs.rows[1], 0) = v4;
1078 *_slang_vector_get_element_ptr(&rs.rows[1], 1) = v5;
1079 *_slang_vector_get_element_ptr(&rs.rows[1], 2) = v6;
1080 *_slang_vector_get_element_ptr(&rs.rows[1], 3) = v7;
1081 }
1082 else
1083 {
1084 rs.rows[0].x = v0;
1085 rs.rows[0].y = v1;
1086 rs.rows[1].x = v2;
1087 rs.rows[1].y = v3;
1088 rs.rows[2].x = v4;
1089 rs.rows[2].y = v5;
1090 rs.rows[3].x = v6;
1091 rs.rows[3].y = v7;
1092 }
1093 return rs;
1094}
1095
1096template<typename T, int ROWS, int COLS>
1098 T v0,
1099 T v1,
1100 T v2,
1101 T v3,
1102 T v4,
1103 T v5,
1104 T v6,
1105 T v7,
1106 T v8)
1107{
1109 rs.rows[0].x = v0;
1110 rs.rows[0].y = v1;
1111 rs.rows[0].z = v2;
1112 rs.rows[1].x = v3;
1113 rs.rows[1].y = v4;
1114 rs.rows[1].z = v5;
1115 rs.rows[2].x = v6;
1116 rs.rows[2].y = v7;
1117 rs.rows[2].z = v8;
1118 return rs;
1119}
1120
1121template<typename T, int ROWS, int COLS>
1123 T v0,
1124 T v1,
1125 T v2,
1126 T v3,
1127 T v4,
1128 T v5,
1129 T v6,
1130 T v7,
1131 T v8,
1132 T v9,
1133 T v10,
1134 T v11)
1135{
1137 if (COLS == 4)
1138 {
1139 *_slang_vector_get_element_ptr(&rs.rows[0], 0) = v0;
1140 *_slang_vector_get_element_ptr(&rs.rows[0], 1) = v1;
1141 *_slang_vector_get_element_ptr(&rs.rows[0], 2) = v2;
1142 *_slang_vector_get_element_ptr(&rs.rows[0], 3) = v3;
1143 *_slang_vector_get_element_ptr(&rs.rows[1], 0) = v4;
1144 *_slang_vector_get_element_ptr(&rs.rows[1], 1) = v5;
1145 *_slang_vector_get_element_ptr(&rs.rows[1], 2) = v6;
1146 *_slang_vector_get_element_ptr(&rs.rows[1], 3) = v7;
1147 *_slang_vector_get_element_ptr(&rs.rows[2], 0) = v8;
1148 *_slang_vector_get_element_ptr(&rs.rows[2], 1) = v9;
1149 *_slang_vector_get_element_ptr(&rs.rows[2], 2) = v10;
1150 *_slang_vector_get_element_ptr(&rs.rows[2], 3) = v11;
1151 }
1152 else
1153 {
1154 rs.rows[0].x = v0;
1155 rs.rows[0].y = v1;
1156 rs.rows[0].z = v2;
1157 rs.rows[1].x = v3;
1158 rs.rows[1].y = v4;
1159 rs.rows[1].z = v5;
1160 rs.rows[2].x = v6;
1161 rs.rows[2].y = v7;
1162 rs.rows[2].z = v8;
1163 rs.rows[3].x = v9;
1164 rs.rows[3].y = v10;
1165 rs.rows[3].z = v11;
1166 }
1167 return rs;
1168}
1169
1170template<typename T, int ROWS, int COLS>
1172 T v0,
1173 T v1,
1174 T v2,
1175 T v3,
1176 T v4,
1177 T v5,
1178 T v6,
1179 T v7,
1180 T v8,
1181 T v9,
1182 T v10,
1183 T v11,
1184 T v12,
1185 T v13,
1186 T v14,
1187 T v15)
1188{
1190 rs.rows[0].x = v0;
1191 rs.rows[0].y = v1;
1192 rs.rows[0].z = v2;
1193 rs.rows[0].w = v3;
1194 rs.rows[1].x = v4;
1195 rs.rows[1].y = v5;
1196 rs.rows[1].z = v6;
1197 rs.rows[1].w = v7;
1198 rs.rows[2].x = v8;
1199 rs.rows[2].y = v9;
1200 rs.rows[2].z = v10;
1201 rs.rows[2].w = v11;
1202 rs.rows[3].x = v12;
1203 rs.rows[3].y = v13;
1204 rs.rows[3].z = v14;
1205 rs.rows[3].w = v15;
1206 return rs;
1207}
1208
1209#define SLANG_MATRIX_BINARY_OP(T, op) \
1210 template<int R, int C> \
1211 SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, R, C> operator op( \
1212 const Matrix<T, R, C>& thisVal, \
1213 const Matrix<T, R, C>& other) \
1214 { \
1215 Matrix<T, R, C> result; \
1216 for (int i = 0; i < R; i++) \
1217 for (int j = 0; j < C; j++) \
1218 *_slang_vector_get_element_ptr(result.rows + i, j) = \
1219 _slang_vector_get_element(thisVal.rows[i], j) \
1220 op _slang_vector_get_element(other.rows[i], j); \
1221 return result; \
1222 }
1223
1224#define SLANG_MATRIX_UNARY_OP(T, op) \
1225 template<int R, int C> \
1226 SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, R, C> operator op(const Matrix<T, R, C>& thisVal) \
1227 { \
1228 Matrix<T, R, C> result; \
1229 for (int i = 0; i < R; i++) \
1230 for (int j = 0; j < C; j++) \
1231 *_slang_vector_get_element_ptr(result.rows + i, j) = \
1232 op _slang_vector_get_element(thisVal.rows[i], j); \
1233 return result; \
1234 }
1235#define SLANG_INT_MATRIX_OPS(T) \
1236 SLANG_MATRIX_BINARY_OP(T, +) \
1237 SLANG_MATRIX_BINARY_OP(T, -) \
1238 SLANG_MATRIX_BINARY_OP(T, *) \
1239 SLANG_MATRIX_BINARY_OP(T, /) \
1240 SLANG_MATRIX_BINARY_OP(T, &) \
1241 SLANG_MATRIX_BINARY_OP(T, |) \
1242 SLANG_MATRIX_BINARY_OP(T, &&) \
1243 SLANG_MATRIX_BINARY_OP(T, ||) \
1244 SLANG_MATRIX_BINARY_OP(T, ^) \
1245 SLANG_MATRIX_BINARY_OP(T, %) \
1246 SLANG_MATRIX_UNARY_OP(T, !) \
1247 SLANG_MATRIX_UNARY_OP(T, ~)
1248#define SLANG_FLOAT_MATRIX_OPS(T) \
1249 SLANG_MATRIX_BINARY_OP(T, +) \
1250 SLANG_MATRIX_BINARY_OP(T, -) \
1251 SLANG_MATRIX_BINARY_OP(T, *) \
1252 SLANG_MATRIX_BINARY_OP(T, /) \
1253 SLANG_MATRIX_UNARY_OP(T, -)
1264#if SLANG_CUDA_ENABLE_HALF
1266#endif
1267#define SLANG_MATRIX_INT_NEG_OP(T) \
1268 template<int R, int C> \
1269 SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, R, C> operator-(Matrix<T, R, C> thisVal) \
1270 { \
1271 Matrix<T, R, C> result; \
1272 for (int i = 0; i < R; i++) \
1273 for (int j = 0; j < C; j++) \
1274 *_slang_vector_get_element_ptr(result.rows + i, j) = \
1275 0 - _slang_vector_get_element(thisVal.rows[i], j); \
1276 return result; \
1277 }
1286
1287#define SLANG_FLOAT_MATRIX_MOD(T) \
1288 template<int R, int C> \
1289 SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, R, C> operator%( \
1290 Matrix<T, R, C> left, \
1291 Matrix<T, R, C> right) \
1292 { \
1293 Matrix<T, R, C> result; \
1294 for (int i = 0; i < R; i++) \
1295 for (int j = 0; j < C; j++) \
1296 *_slang_vector_get_element_ptr(result.rows + i, j) = _slang_fmod( \
1297 _slang_vector_get_element(left.rows[i], j), \
1298 _slang_vector_get_element(right.rows[i], j)); \
1299 return result; \
1300 }
1301
1304#if SLANG_CUDA_ENABLE_HALF
1305template<int R, int C>
1309{
1310 Matrix<__half, R, C> result;
1311 for (int i = 0; i < R; i++)
1312 for (int j = 0; j < C; j++)
1313 *_slang_vector_get_element_ptr(result.rows + i, j) = __float2half(_slang_fmod(
1314 __half2float(_slang_vector_get_element(left.rows[i], j)),
1315 __half2float(_slang_vector_get_element(right.rows[i], j))));
1316 return result;
1317}
1318#endif
1319#undef SLANG_FLOAT_MATRIX_MOD
1320#undef SLANG_MATRIX_BINARY_OP
1321#undef SLANG_MATRIX_UNARY_OP
1322#undef SLANG_INT_MATRIX_OPS
1323#undef SLANG_FLOAT_MATRIX_OPS
1324#undef SLANG_MATRIX_INT_NEG_OP
1325#undef SLANG_FLOAT_MATRIX_MOD
1326
1327#define SLANG_SELECT_IMPL(T, N) \
1328 SLANG_FORCE_INLINE SLANG_CUDA_CALL Vector<T, N> _slang_select( \
1329 bool##N condition, \
1330 Vector<T, N> v0, \
1331 Vector<T, N> v1) \
1332 { \
1333 Vector<T, N> result; \
1334 for (int i = 0; i < N; i++) \
1335 { \
1336 *_slang_vector_get_element_ptr(&result, i) = _slang_vector_get_element(condition, i) \
1337 ? _slang_vector_get_element(v0, i) \
1338 : _slang_vector_get_element(v1, i); \
1339 } \
1340 return result; \
1341 }
1342#define SLANG_SELECT_T(T) \
1343 SLANG_SELECT_IMPL(T, 2) \
1344 SLANG_SELECT_IMPL(T, 3) \
1345 SLANG_SELECT_IMPL(T, 4)
1346
1347SLANG_SELECT_T(int)
1348SLANG_SELECT_T(bool)
1350SLANG_SELECT_T(short)
1352SLANG_SELECT_T(char)
1354SLANG_SELECT_T(float)
1355SLANG_SELECT_T(double)
1356
1357template<typename T>
1359{
1360 return condition ? v0 : v1;
1361}
1362
1363//
1364// Half support
1365//
1366
1367#if SLANG_CUDA_ENABLE_HALF
1368SLANG_SELECT_T(__half)
1369
1370// Convenience functions ushort -> half
1371
1372SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 __ushort_as_half(const ushort2& i)
1373{
1374 return __halves2half2(__ushort_as_half(i.x), __ushort_as_half(i.y));
1375}
1376SLANG_FORCE_INLINE SLANG_CUDA_CALL __half3 __ushort_as_half(const ushort3& i)
1377{
1378 return __half3{__ushort_as_half(i.x), __ushort_as_half(i.y), __ushort_as_half(i.z)};
1379}
1380SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 __ushort_as_half(const ushort4& i)
1381{
1382 return __half4{
1383 __ushort_as_half(i.x),
1384 __ushort_as_half(i.y),
1385 __ushort_as_half(i.z),
1386 __ushort_as_half(i.w)};
1387}
1388
1389// Convenience functions half -> ushort
1390
1391SLANG_FORCE_INLINE SLANG_CUDA_CALL ushort2 __half_as_ushort(const __half2& i)
1392{
1393 return make_ushort2(__half_as_ushort(i.x), __half_as_ushort(i.y));
1394}
1395SLANG_FORCE_INLINE SLANG_CUDA_CALL ushort3 __half_as_ushort(const __half3& i)
1396{
1397 return make_ushort3(__half_as_ushort(i.x), __half_as_ushort(i.y), __half_as_ushort(i.z));
1398}
1399SLANG_FORCE_INLINE SLANG_CUDA_CALL ushort4 __half_as_ushort(const __half4& i)
1400{
1401 return make_ushort4(
1402 __half_as_ushort(i.x),
1403 __half_as_ushort(i.y),
1404 __half_as_ushort(i.z),
1405 __half_as_ushort(i.w));
1406}
1407
1408// This is a little bit of a hack. Fortunately CUDA has the definitions of the templated types in
1409// include/surface_indirect_functions.h
1410// Here we find the template definition requires a specialization of __nv_isurf_trait to allow
1411// a specialization of the surface write functions.
1412// This *isn't* a problem on the read functions as they don't have a return type that uses this
1413// mechanism
1414
1415template<>
1416struct __nv_isurf_trait<__half>
1417{
1418 typedef void type;
1419};
1420template<>
1421struct __nv_isurf_trait<__half2>
1422{
1423 typedef void type;
1424};
1425template<>
1426struct __nv_isurf_trait<__half4>
1427{
1428 typedef void type;
1429};
1430
1431#define SLANG_DROP_PARENS(...) __VA_ARGS__
1432
1433#define SLANG_SURFACE_READ(FUNC_NAME, TYPE_ARGS, ARGS) \
1434 template<> \
1435 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half FUNC_NAME<__half>( \
1436 cudaSurfaceObject_t surfObj, \
1437 SLANG_DROP_PARENS TYPE_ARGS, \
1438 cudaSurfaceBoundaryMode boundaryMode) \
1439 { \
1440 return __ushort_as_half(FUNC_NAME<ushort>(surfObj, SLANG_DROP_PARENS ARGS, boundaryMode)); \
1441 } \
1442 \
1443 template<> \
1444 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 FUNC_NAME<__half2>( \
1445 cudaSurfaceObject_t surfObj, \
1446 SLANG_DROP_PARENS TYPE_ARGS, \
1447 cudaSurfaceBoundaryMode boundaryMode) \
1448 { \
1449 return __ushort_as_half( \
1450 FUNC_NAME<ushort2>(surfObj, SLANG_DROP_PARENS ARGS, boundaryMode)); \
1451 } \
1452 \
1453 template<> \
1454 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 FUNC_NAME<__half4>( \
1455 cudaSurfaceObject_t surfObj, \
1456 SLANG_DROP_PARENS TYPE_ARGS, \
1457 cudaSurfaceBoundaryMode boundaryMode) \
1458 { \
1459 return __ushort_as_half( \
1460 FUNC_NAME<ushort4>(surfObj, SLANG_DROP_PARENS ARGS, boundaryMode)); \
1461 }
1462
1463SLANG_SURFACE_READ(surf1Dread, (int x), (x))
1464SLANG_SURFACE_READ(surf2Dread, (int x, int y), (x, y))
1465SLANG_SURFACE_READ(surf3Dread, (int x, int y, int z), (x, y, z))
1466SLANG_SURFACE_READ(surf1DLayeredread, (int x, int layer), (x, layer))
1467SLANG_SURFACE_READ(surf2DLayeredread, (int x, int y, int layer), (x, y, layer))
1468SLANG_SURFACE_READ(surfCubemapread, (int x, int y, int face), (x, y, face))
1469SLANG_SURFACE_READ(surfCubemapLayeredread, (int x, int y, int layerFace), (x, y, layerFace))
1470
1471#define SLANG_SURFACE_WRITE(FUNC_NAME, TYPE_ARGS, ARGS) \
1472 template<> \
1473 SLANG_FORCE_INLINE SLANG_CUDA_CALL void FUNC_NAME<__half>( \
1474 __half data, \
1475 cudaSurfaceObject_t surfObj, \
1476 SLANG_DROP_PARENS TYPE_ARGS, \
1477 cudaSurfaceBoundaryMode boundaryMode) \
1478 { \
1479 FUNC_NAME<ushort>(__half_as_ushort(data), surfObj, SLANG_DROP_PARENS ARGS, boundaryMode); \
1480 } \
1481 \
1482 template<> \
1483 SLANG_FORCE_INLINE SLANG_CUDA_CALL void FUNC_NAME<__half2>( \
1484 __half2 data, \
1485 cudaSurfaceObject_t surfObj, \
1486 SLANG_DROP_PARENS TYPE_ARGS, \
1487 cudaSurfaceBoundaryMode boundaryMode) \
1488 { \
1489 FUNC_NAME<ushort2>(__half_as_ushort(data), surfObj, SLANG_DROP_PARENS ARGS, boundaryMode); \
1490 } \
1491 \
1492 template<> \
1493 SLANG_FORCE_INLINE SLANG_CUDA_CALL void FUNC_NAME<__half4>( \
1494 __half4 data, \
1495 cudaSurfaceObject_t surfObj, \
1496 SLANG_DROP_PARENS TYPE_ARGS, \
1497 cudaSurfaceBoundaryMode boundaryMode) \
1498 { \
1499 FUNC_NAME<ushort4>(__half_as_ushort(data), surfObj, SLANG_DROP_PARENS ARGS, boundaryMode); \
1500 }
1501
1502SLANG_SURFACE_WRITE(surf1Dwrite, (int x), (x))
1503SLANG_SURFACE_WRITE(surf2Dwrite, (int x, int y), (x, y))
1504SLANG_SURFACE_WRITE(surf3Dwrite, (int x, int y, int z), (x, y, z))
1505SLANG_SURFACE_WRITE(surf1DLayeredwrite, (int x, int layer), (x, layer))
1506SLANG_SURFACE_WRITE(surf2DLayeredwrite, (int x, int y, int layer), (x, y, layer))
1507SLANG_SURFACE_WRITE(surfCubemapwrite, (int x, int y, int face), (x, y, face))
1508SLANG_SURFACE_WRITE(surfCubemapLayeredwrite, (int x, int y, int layerFace), (x, y, layerFace))
1509
1510// ! Hack to test out reading !!!
1511// Only works converting *from* half
1512
1513// template <typename T>
1514// SLANG_FORCE_INLINE SLANG_CUDA_CALL T surf2Dread_convert(cudaSurfaceObject_t surfObj, int x, int
1515// y, cudaSurfaceBoundaryMode boundaryMode);
1516
1517#define SLANG_SURFACE_READ_HALF_CONVERT(FUNC_NAME, TYPE_ARGS, ARGS) \
1518 \
1519 template<typename T> \
1520 SLANG_FORCE_INLINE SLANG_CUDA_CALL T FUNC_NAME##_convert( \
1521 cudaSurfaceObject_t surfObj, \
1522 SLANG_DROP_PARENS TYPE_ARGS, \
1523 cudaSurfaceBoundaryMode boundaryMode); \
1524 \
1525 template<> \
1526 SLANG_FORCE_INLINE SLANG_CUDA_CALL float FUNC_NAME##_convert<float>( \
1527 cudaSurfaceObject_t surfObj, \
1528 SLANG_DROP_PARENS TYPE_ARGS, \
1529 cudaSurfaceBoundaryMode boundaryMode) \
1530 { \
1531 return __ushort_as_half( \
1532 FUNC_NAME<uint16_t>(surfObj, SLANG_DROP_PARENS ARGS, boundaryMode)); \
1533 } \
1534 \
1535 template<> \
1536 SLANG_FORCE_INLINE SLANG_CUDA_CALL float2 FUNC_NAME##_convert<float2>( \
1537 cudaSurfaceObject_t surfObj, \
1538 SLANG_DROP_PARENS TYPE_ARGS, \
1539 cudaSurfaceBoundaryMode boundaryMode) \
1540 { \
1541 const __half2 v = \
1542 __ushort_as_half(FUNC_NAME<ushort2>(surfObj, SLANG_DROP_PARENS ARGS, boundaryMode)); \
1543 return float2{v.x, v.y}; \
1544 } \
1545 \
1546 template<> \
1547 SLANG_FORCE_INLINE SLANG_CUDA_CALL float4 FUNC_NAME##_convert<float4>( \
1548 cudaSurfaceObject_t surfObj, \
1549 SLANG_DROP_PARENS TYPE_ARGS, \
1550 cudaSurfaceBoundaryMode boundaryMode) \
1551 { \
1552 const __half4 v = \
1553 __ushort_as_half(FUNC_NAME<ushort4>(surfObj, SLANG_DROP_PARENS ARGS, boundaryMode)); \
1554 return float4{v.x, v.y, v.z, v.w}; \
1555 }
1556
1557SLANG_SURFACE_READ_HALF_CONVERT(surf1Dread, (int x), (x))
1558SLANG_SURFACE_READ_HALF_CONVERT(surf2Dread, (int x, int y), (x, y))
1559SLANG_SURFACE_READ_HALF_CONVERT(surf3Dread, (int x, int y, int z), (x, y, z))
1560
1561#endif
1562
1563// Support for doing format conversion when writing to a surface/RWTexture
1564
1565// NOTE! For normal surface access x values are *byte* addressed.
1566// For the _convert versions they are *not*. They don't need to be because sust.p does not require
1567// it.
1568
1569// https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html
1570// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#surface-instructions-sust
1571
1572
1573// surf1Dwrite_convert
1574
1575template<typename T>
1577 T v,
1578 cudaSurfaceObject_t surfObj,
1579 int x,
1580 cudaSurfaceBoundaryMode boundaryMode);
1581
1582#define SLANG_SURF1DWRITE_CONVERT_IMPL(T, c) \
1583 template<> \
1584 SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf1Dwrite_convert<T>( \
1585 T v, \
1586 cudaSurfaceObject_t surfObj, \
1587 int x, \
1588 cudaSurfaceBoundaryMode boundaryMode) \
1589 { \
1590 asm volatile( \
1591 "sust.p.1d.b32." SLANG_PTX_BOUNDARY_MODE " [%0, {%1}], {%2};" ::"l"(surfObj), \
1592 "r"(x), \
1593 c(v)); \
1594 } \
1595 template<> \
1596 SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf1Dwrite_convert<T##2>( \
1597 T##2 v, \
1598 cudaSurfaceObject_t surfObj, \
1599 int x, \
1600 cudaSurfaceBoundaryMode boundaryMode) \
1601 { \
1602 const T vx = v.x, vy = v.y; \
1603 asm volatile( \
1604 "sust.p.1d.v2.b32." SLANG_PTX_BOUNDARY_MODE " [%0, {%1}], {%2, %3};" ::"l"(surfObj), \
1605 "r"(x), \
1606 c(vx), \
1607 c(vy)); \
1608 } \
1609 template<> \
1610 SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf1Dwrite_convert<T##4>( \
1611 T##4 v, \
1612 cudaSurfaceObject_t surfObj, \
1613 int x, \
1614 cudaSurfaceBoundaryMode boundaryMode) \
1615 { \
1616 const T vx = v.x, vy = v.y, vz = v.z, vw = v.w; \
1617 asm volatile( \
1618 "sust.p.1d.v4.b32." SLANG_PTX_BOUNDARY_MODE \
1619 " [%0, {%1}], {%2, %3, %4, %5};" ::"l"(surfObj), \
1620 "r"(x), \
1621 c(vx), \
1622 c(vy), \
1623 c(vz), \
1624 c(vw)); \
1625 }
1626
1630
1631// surf1DLayeredwrite_convert (not supported)
1632
1633template<typename T>
1635 T v,
1636 cudaSurfaceObject_t surfObj,
1637 int x,
1638 int layer,
1639 cudaSurfaceBoundaryMode boundaryMode)
1640{
1641 // TODO: static_assert(false) can fail on some compilers, even if template is not instantiated.
1642 // We should check for this in hlsl.meta.slang instead.
1643 // static_assert(false, "CUDA doesn't support formatted surface writes on 1D array surfaces");
1644}
1645
1646// surf2Dwrite_convert
1647
1648template<typename T>
1650 T v,
1651 cudaSurfaceObject_t surfObj,
1652 int x,
1653 int y,
1654 cudaSurfaceBoundaryMode boundaryMode);
1655
1656#define SLANG_SURF2DWRITE_CONVERT_IMPL(T, c) \
1657 template<> \
1658 SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf2Dwrite_convert<T>( \
1659 T v, \
1660 cudaSurfaceObject_t surfObj, \
1661 int x, \
1662 int y, \
1663 cudaSurfaceBoundaryMode boundaryMode) \
1664 { \
1665 asm volatile( \
1666 "sust.p.2d.b32." SLANG_PTX_BOUNDARY_MODE " [%0, {%1, %2}], {%3};" ::"l"(surfObj), \
1667 "r"(x), \
1668 "r"(y), \
1669 c(v)); \
1670 } \
1671 template<> \
1672 SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf2Dwrite_convert<T##2>( \
1673 T##2 v, \
1674 cudaSurfaceObject_t surfObj, \
1675 int x, \
1676 int y, \
1677 cudaSurfaceBoundaryMode boundaryMode) \
1678 { \
1679 const T vx = v.x, vy = v.y; \
1680 asm volatile( \
1681 "sust.p.2d.v2.b32." SLANG_PTX_BOUNDARY_MODE \
1682 " [%0, {%1, %2}], {%3, %4};" ::"l"(surfObj), \
1683 "r"(x), \
1684 "r"(y), \
1685 c(vx), \
1686 c(vy)); \
1687 } \
1688 template<> \
1689 SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf2Dwrite_convert<T##4>( \
1690 T##4 v, \
1691 cudaSurfaceObject_t surfObj, \
1692 int x, \
1693 int y, \
1694 cudaSurfaceBoundaryMode boundaryMode) \
1695 { \
1696 const T vx = v.x, vy = v.y, vz = v.z, vw = v.w; \
1697 asm volatile( \
1698 "sust.p.2d.v4.b32." SLANG_PTX_BOUNDARY_MODE \
1699 " [%0, {%1, %2}], {%3, %4, %5, %6};" ::"l"(surfObj), \
1700 "r"(x), \
1701 "r"(y), \
1702 c(vx), \
1703 c(vy), \
1704 c(vz), \
1705 c(vw)); \
1706 }
1707
1711
1712// surf2DLayeredwrite_convert (not supported)
1713
1714template<typename T>
1716 T v,
1717 cudaSurfaceObject_t surfObj,
1718 int x,
1719 int y,
1720 int layer,
1721 cudaSurfaceBoundaryMode boundaryMode)
1722{
1723 // TODO: static_assert(false) can fail on some compilers, even if template is not instantiated.
1724 // We should check for this in hlsl.meta.slang instead.
1725 // static_assert(false, "CUDA doesn't support formatted surface writes on 2D array surfaces");
1726}
1727
1728// surf3Dwrite_convert
1729
1730template<typename T>
1732 T v,
1733 cudaSurfaceObject_t surfObj,
1734 int x,
1735 int y,
1736 int z,
1737 cudaSurfaceBoundaryMode boundaryMode);
1738
1739#define SLANG_SURF3DWRITE_CONVERT_IMPL(T, c) \
1740 template<> \
1741 SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf3Dwrite_convert<T>( \
1742 T v, \
1743 cudaSurfaceObject_t surfObj, \
1744 int x, \
1745 int y, \
1746 int z, \
1747 cudaSurfaceBoundaryMode boundaryMode) \
1748 { \
1749 asm volatile( \
1750 "sust.p.3d.b32." SLANG_PTX_BOUNDARY_MODE \
1751 " [%0, {%1, %2, %3, %4}], {%5};" ::"l"(surfObj), \
1752 "r"(x), \
1753 "r"(y), \
1754 "r"(z), \
1755 "r"(0), \
1756 c(v)); \
1757 } \
1758 template<> \
1759 SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf3Dwrite_convert<T##2>( \
1760 T##2 v, \
1761 cudaSurfaceObject_t surfObj, \
1762 int x, \
1763 int y, \
1764 int z, \
1765 cudaSurfaceBoundaryMode boundaryMode) \
1766 { \
1767 const T vx = v.x, vy = v.y; \
1768 asm volatile( \
1769 "sust.p.3d.v2.b32." SLANG_PTX_BOUNDARY_MODE \
1770 " [%0, {%1, %2, %3, %4}], {%5, %6};" ::"l"(surfObj), \
1771 "r"(x), \
1772 "r"(y), \
1773 "r"(z), \
1774 "r"(0), \
1775 c(vx), \
1776 c(vy)); \
1777 } \
1778 template<> \
1779 SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf3Dwrite_convert<T##4>( \
1780 T##4 v, \
1781 cudaSurfaceObject_t surfObj, \
1782 int x, \
1783 int y, \
1784 int z, \
1785 cudaSurfaceBoundaryMode boundaryMode) \
1786 { \
1787 const T vx = v.x, vy = v.y, vz = v.z, vw = v.w; \
1788 asm volatile( \
1789 "sust.p.3d.v4.b32." SLANG_PTX_BOUNDARY_MODE \
1790 " [%0, {%1, %2, %3, %4}], {%5, %6, %7, %8};" ::"l"(surfObj), \
1791 "r"(x), \
1792 "r"(y), \
1793 "r"(z), \
1794 "r"(0), \
1795 c(vx), \
1796 c(vy), \
1797 c(vz), \
1798 c(vw)); \
1799 }
1800
1804
1805// ----------------------------- F16 -----------------------------------------
1806#if SLANG_CUDA_ENABLE_HALF
1807// Unary
1809{
1810 return ::hceil(f);
1811}
1812
1814{
1815 return ::hfloor(f);
1816}
1817
1819{
1820 return ::hrint(f);
1821}
1822
1824{
1825 return ::hsin(f);
1826}
1827
1829{
1830 return ::hcos(f);
1831}
1832
1833SLANG_FORCE_INLINE SLANG_CUDA_CALL void F16_sincos(__half f, __half* s, __half* c)
1834{
1835 *s = ::hsin(f);
1836 *c = ::hcos(f);
1837}
1838
1840{
1841 return __float2half(::tanf(__half2float(f)));
1842}
1844{
1845 return __float2half(::asinf(__half2float(f)));
1846}
1848{
1849 return __float2half(::acosf(__half2float(f)));
1850}
1852{
1853 return __float2half(::atanf(__half2float(f)));
1854}
1856{
1857 return __float2half(::sinhf(__half2float(f)));
1858}
1860{
1861 return __float2half(::coshf(__half2float(f)));
1862}
1864{
1865 return __float2half(::tanhf(__half2float(f)));
1866}
1868{
1869 return __float2half(::asinhf(__half2float(f)));
1870}
1872{
1873 return __float2half(::acoshf(__half2float(f)));
1874}
1876{
1877 return __float2half(::atanhf(__half2float(f)));
1878}
1880{
1881 return ::hlog2(f);
1882}
1884{
1885 return ::hlog(f);
1886}
1888{
1889 return ::hlog10(f);
1890}
1892{
1893 return ::hexp2(f);
1894}
1896{
1897 return ::hexp(f);
1898}
1900{
1901 return __habs(f);
1902}
1904{
1905 return ::htrunc(f);
1906}
1908{
1909 return ::hsqrt(f);
1910}
1912{
1913 return ::hrsqrt(f);
1914}
1916{
1917 return (f == __half(0.0f)) ? 0 : ((f < __half(0.0f)) ? -1 : 1);
1918}
1919
1921{
1922 return f - F16_floor(f);
1923}
1924
1926{
1927 return __hisnan(f);
1928}
1930{
1931 return !__hisinf(f) && !__hisnan(f);
1932}
1934{
1935 return __hisinf(f);
1936}
1937
1938// Binary
1939SLANG_FORCE_INLINE SLANG_CUDA_CALL __half F16_min(__half a, __half b)
1940{
1941 return __hmin(a, b);
1942}
1943SLANG_FORCE_INLINE SLANG_CUDA_CALL __half F16_max(__half a, __half b)
1944{
1945 return __hmax(a, b);
1946}
1947SLANG_FORCE_INLINE SLANG_CUDA_CALL __half F16_pow(__half a, __half b)
1948{
1949 return __float2half(::powf(__half2float(a), __half2float(b)));
1950}
1951SLANG_FORCE_INLINE SLANG_CUDA_CALL __half F16_fmod(__half a, __half b)
1952{
1953 return __float2half(::fmodf(__half2float(a), __half2float(b)));
1954}
1955SLANG_FORCE_INLINE SLANG_CUDA_CALL __half F16_remainder(__half a, __half b)
1956{
1957 return __float2half(::remainderf(__half2float(a), __half2float(b)));
1958}
1959SLANG_FORCE_INLINE SLANG_CUDA_CALL __half F16_atan2(__half a, __half b)
1960{
1961 return __float2half(::atan2(__half2float(a), __half2float(b)));
1962}
1963
1964SLANG_FORCE_INLINE SLANG_CUDA_CALL __half F16_frexp(__half x, int* e)
1965{
1966 return __float2half(frexpf(__half2float(x), e));
1967}
1968
1969SLANG_FORCE_INLINE SLANG_CUDA_CALL __half F16_modf(__half x, __half* ip)
1970{
1971 float ipf;
1972 float res = ::modff(__half2float(x), &ipf);
1973 *ip = __float2half(ipf);
1974 return __float2half(res);
1975}
1976
1978{
1979 return __half_as_ushort(h);
1980}
1982{
1983 return __half_as_short(h);
1984}
1985
1986// Ternary
1987SLANG_FORCE_INLINE SLANG_CUDA_CALL __half F16_fma(__half a, __half b, __half c)
1988{
1989 return __hfma(a, b, c);
1990}
1991
1992#endif
1993
1994// ----------------------------- F32 -----------------------------------------
1995
1996// Unary
1998{
1999 return ::ceilf(f);
2000}
2002{
2003 return ::floorf(f);
2004}
2006{
2007 return ::roundf(f);
2008}
2010{
2011 return ::sinf(f);
2012}
2014{
2015 return ::cosf(f);
2016}
2017SLANG_FORCE_INLINE SLANG_CUDA_CALL void F32_sincos(float f, float* s, float* c)
2018{
2019 ::sincosf(f, s, c);
2020}
2022{
2023 return ::tanf(f);
2024}
2026{
2027 return ::asinf(f);
2028}
2030{
2031 return ::acosf(f);
2032}
2034{
2035 return ::atanf(f);
2036}
2038{
2039 return ::sinhf(f);
2040}
2042{
2043 return ::coshf(f);
2044}
2046{
2047 return ::tanhf(f);
2048}
2050{
2051 return ::asinhf(f);
2052}
2054{
2055 return ::acoshf(f);
2056}
2058{
2059 return ::atanhf(f);
2060}
2062{
2063 return ::log2f(f);
2064}
2066{
2067 return ::logf(f);
2068}
2070{
2071 return ::log10f(f);
2072}
2074{
2075 return ::exp2f(f);
2076}
2078{
2079 return ::expf(f);
2080}
2082{
2083 return ::fabsf(f);
2084}
2086{
2087 return ::truncf(f);
2088}
2090{
2091 return ::sqrtf(f);
2092}
2094{
2095 return ::rsqrtf(f);
2096}
2098{
2099 return (f == 0.0f) ? 0 : ((f < 0.0f) ? -1 : 1);
2100}
2102{
2103 return f - F32_floor(f);
2104}
2105
2107{
2108 return isnan(f);
2109}
2111{
2112 return isfinite(f);
2113}
2115{
2116 return isinf(f);
2117}
2118
2119// Binary
2121{
2122 return ::fminf(a, b);
2123}
2125{
2126 return ::fmaxf(a, b);
2127}
2129{
2130 return ::powf(a, b);
2131}
2133{
2134 return ::fmodf(a, b);
2135}
2137{
2138 return ::remainderf(a, b);
2139}
2141{
2142 return float(::atan2(a, b));
2143}
2144
2146{
2147 return frexpf(x, e);
2148}
2149
2151{
2152 return ::modff(x, ip);
2153}
2154
2156{
2157 Union32 u;
2158 u.f = f;
2159 return u.u;
2160}
2162{
2163 Union32 u;
2164 u.f = f;
2165 return u.i;
2166}
2167
2168// Ternary
2169SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_fma(float a, float b, float c)
2170{
2171 return ::fmaf(a, b, c);
2172}
2173
2174
2175// ----------------------------- F64 -----------------------------------------
2176
2177// Unary
2179{
2180 return ::ceil(f);
2181}
2183{
2184 return ::floor(f);
2185}
2187{
2188 return ::round(f);
2189}
2191{
2192 return ::sin(f);
2193}
2195{
2196 return ::cos(f);
2197}
2198SLANG_FORCE_INLINE SLANG_CUDA_CALL void F64_sincos(double f, double* s, double* c)
2199{
2200 ::sincos(f, s, c);
2201}
2203{
2204 return ::tan(f);
2205}
2207{
2208 return ::asin(f);
2209}
2211{
2212 return ::acos(f);
2213}
2215{
2216 return ::atan(f);
2217}
2219{
2220 return ::sinh(f);
2221}
2223{
2224 return ::cosh(f);
2225}
2227{
2228 return ::tanh(f);
2229}
2231{
2232 return ::log2(f);
2233}
2235{
2236 return ::log(f);
2237}
2239{
2240 return ::log10(f);
2241}
2243{
2244 return ::exp2(f);
2245}
2247{
2248 return ::exp(f);
2249}
2251{
2252 return ::fabs(f);
2253}
2255{
2256 return ::trunc(f);
2257}
2259{
2260 return ::sqrt(f);
2261}
2263{
2264 return ::rsqrt(f);
2265}
2267{
2268 return (f == 0.0) ? 0 : ((f < 0.0) ? -1 : 1);
2269}
2271{
2272 return f - F64_floor(f);
2273}
2274
2276{
2277 return isnan(f);
2278}
2280{
2281 return isfinite(f);
2282}
2284{
2285 return isinf(f);
2286}
2287
2288// Binary
2290{
2291 return ::fmin(a, b);
2292}
2294{
2295 return ::fmax(a, b);
2296}
2298{
2299 return ::pow(a, b);
2300}
2302{
2303 return ::fmod(a, b);
2304}
2306{
2307 return ::remainder(a, b);
2308}
2310{
2311 return ::atan2(a, b);
2312}
2313
2315{
2316 return ::frexp(x, e);
2317}
2318
2319SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_modf(double x, double* ip)
2320{
2321 return ::modf(x, ip);
2322}
2323
2324SLANG_FORCE_INLINE SLANG_CUDA_CALL void F64_asuint(double d, uint32_t* low, uint32_t* hi)
2325{
2326 Union64 u;
2327 u.d = d;
2328 *low = uint32_t(u.u);
2329 *hi = uint32_t(u.u >> 32);
2330}
2331
2332SLANG_FORCE_INLINE SLANG_CUDA_CALL void F64_asint(double d, int32_t* low, int32_t* hi)
2333{
2334 Union64 u;
2335 u.d = d;
2336 *low = int32_t(u.u);
2337 *hi = int32_t(u.u >> 32);
2338}
2339
2340// Ternary
2341SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_fma(double a, double b, double c)
2342{
2343 return ::fma(a, b, c);
2344}
2345
2346// ----------------------------- U8 -----------------------------------------
2347
2349{
2350 // No native 8bit popc yet, just cast and use 32bit variant
2351 return __popc(uint32_t(v));
2352}
2353
2354// ----------------------------- I8 -----------------------------------------
2355
2357{
2358 return U8_countbits(uint8_t(v));
2359}
2360
2361// ----------------------------- U16 -----------------------------------------
2362
2364{
2365 // No native 16bit popc yet, just cast and use 32bit variant
2366 return __popc(uint32_t(v));
2367}
2368
2369// ----------------------------- I16 -----------------------------------------
2370
2372{
2373 return U16_countbits(uint16_t(v));
2374}
2375
2376// ----------------------------- U32 -----------------------------------------
2377
2378// Unary
2380{
2381 return f;
2382}
2383
2384// Binary
2385SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U32_min(uint32_t a, uint32_t b)
2386{
2387 return a < b ? a : b;
2388}
2389SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U32_max(uint32_t a, uint32_t b)
2390{
2391 return a > b ? a : b;
2392}
2393
2395{
2396 Union32 u;
2397 u.u = x;
2398 return u.f;
2399}
2401{
2402 return uint32_t(x);
2403}
2404
2405SLANG_FORCE_INLINE SLANG_CUDA_CALL double U32_asdouble(uint32_t low, uint32_t hi)
2406{
2407 Union64 u;
2408 u.u = (uint64_t(hi) << 32) | low;
2409 return u.d;
2410}
2411
2413{
2414 return __popc(v);
2415}
2416
2418{
2419 // __ffs returns 1-based bit position or 0 if no bits set
2420 // firstbitlow should return 0-based bit position or ~0u if no bits set
2421 return v == 0 ? ~0u : (__ffs(v) - 1);
2422}
2423
2425{
2426 // maps to hlsl firstbithigh
2427 if ((int32_t)v < 0)
2428 v = ~v;
2429 if (v == 0)
2430 return ~0u;
2431 return 31 - __clz(v);
2432}
2433
2435{
2436 return __brev(v);
2437}
2438
2439// ----------------------------- I32 -----------------------------------------
2440
2441// Unary
2443{
2444 return (f < 0) ? -f : f;
2445}
2446
2447// Binary
2448SLANG_FORCE_INLINE SLANG_CUDA_CALL int32_t I32_min(int32_t a, int32_t b)
2449{
2450 return a < b ? a : b;
2451}
2452SLANG_FORCE_INLINE SLANG_CUDA_CALL int32_t I32_max(int32_t a, int32_t b)
2453{
2454 return a > b ? a : b;
2455}
2456
2458{
2459 Union32 u;
2460 u.i = x;
2461 return u.f;
2462}
2464{
2465 return uint32_t(x);
2466}
2467SLANG_FORCE_INLINE SLANG_CUDA_CALL double I32_asdouble(int32_t low, int32_t hi)
2468{
2469 Union64 u;
2470 u.u = (uint64_t(hi) << 32) | uint32_t(low);
2471 return u.d;
2472}
2473
2475{
2476 return U32_countbits(uint32_t(v));
2477}
2478
2480{
2481 return U32_firstbitlow(uint32_t(v));
2482}
2483
2485{
2486 return U32_firstbithigh(uint32_t(v));
2487}
2488
2490{
2491 return int32_t(U32_reversebits(uint32_t(v)));
2492}
2493
2494// ----------------------------- U64 -----------------------------------------
2495
2497{
2498 return f;
2499}
2500
2501SLANG_FORCE_INLINE SLANG_CUDA_CALL int64_t U64_min(uint64_t a, uint64_t b)
2502{
2503 return a < b ? a : b;
2504}
2505SLANG_FORCE_INLINE SLANG_CUDA_CALL int64_t U64_max(uint64_t a, uint64_t b)
2506{
2507 return a > b ? a : b;
2508}
2509
2511{
2512 return __popcll(v);
2513}
2514
2516{
2517 // __ffs returns 1-based bit position or 0 if no bits set
2518 // firstbitlow should return 0-based bit position or ~0u if no bits set
2519 return v == 0 ? ~uint32_t(0) : (__ffsll(v) - 1u);
2520}
2521
2523{
2524 if (v == 0)
2525 return ~uint32_t(0);
2526 return 63 - __clzll(v);
2527}
2528
2530{
2531 return __brevll(v);
2532}
2533
2534// ----------------------------- I64 -----------------------------------------
2535
2537{
2538 return (f < 0) ? -f : f;
2539}
2540
2541SLANG_FORCE_INLINE SLANG_CUDA_CALL int64_t I64_min(int64_t a, int64_t b)
2542{
2543 return a < b ? a : b;
2544}
2545SLANG_FORCE_INLINE SLANG_CUDA_CALL int64_t I64_max(int64_t a, int64_t b)
2546{
2547 return a > b ? a : b;
2548}
2549
2551{
2552 return U64_countbits(uint64_t(v));
2553}
2554
2556{
2557 return U64_firstbitlow(uint64_t(v));
2558}
2559
2561{
2562 if (v < 0)
2563 v = ~v;
2564 return U64_firstbithigh(uint64_t(v));
2565}
2566
2568{
2569 return int64_t(U64_reversebits(uint64_t(v)));
2570}
2571
2572// ----------------------------- IPTR -----------------------------------------
2573
2575{
2576 return (f < 0) ? -f : f;
2577}
2578
2580{
2581 return a < b ? a : b;
2582}
2583
2585{
2586 return a > b ? a : b;
2587}
2588
2589// ----------------------------- UPTR -----------------------------------------
2590
2595
2597{
2598 return a < b ? a : b;
2599}
2600
2602{
2603 return a > b ? a : b;
2604}
2605
2606// ----------------------------- ResourceType -----------------------------------------
2607
2608
2609// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-structuredbuffer-getdimensions
2610// Missing Load(_In_ int Location, _Out_ uint Status);
2611
2612template<typename T>
2613struct StructuredBuffer
2614{
2615 SLANG_CUDA_CALL T& operator[](size_t index) const
2616 {
2617#ifndef SLANG_CUDA_STRUCTURED_BUFFER_NO_COUNT
2618 SLANG_BOUND_CHECK(index, count);
2619#endif
2620 return data[index];
2621 }
2622
2623 SLANG_CUDA_CALL T& Load(size_t index) const
2624 {
2625#ifndef SLANG_CUDA_STRUCTURED_BUFFER_NO_COUNT
2626 SLANG_BOUND_CHECK(index, count);
2627#endif
2628 return data[index];
2629 }
2630
2631#ifndef SLANG_CUDA_STRUCTURED_BUFFER_NO_COUNT
2632 SLANG_CUDA_CALL void GetDimensions(uint32_t* outNumStructs, uint32_t* outStride) const
2633 {
2634 *outNumStructs = uint32_t(count);
2635 *outStride = uint32_t(sizeof(T));
2636 }
2637#endif
2638
2639 T* data;
2640#ifndef SLANG_CUDA_STRUCTURED_BUFFER_NO_COUNT
2641 size_t count;
2642#endif
2643};
2644
2645template<typename T>
2647{
2648 SLANG_CUDA_CALL T& operator[](size_t index) const
2649 {
2650#ifndef SLANG_CUDA_STRUCTURED_BUFFER_NO_COUNT
2651 SLANG_BOUND_CHECK(index, this->count);
2652#endif
2653 return this->data[index];
2654 }
2655};
2656
2657// Missing Load(_In_ int Location, _Out_ uint Status);
2658struct ByteAddressBuffer
2659{
2660 SLANG_CUDA_CALL void GetDimensions(uint32_t* outDim) const { *outDim = uint32_t(sizeInBytes); }
2661 SLANG_CUDA_CALL uint32_t Load(size_t index) const
2662 {
2664 return data[index >> 2];
2665 }
2666 SLANG_CUDA_CALL uint2 Load2(size_t index) const
2667 {
2669 const size_t dataIdx = index >> 2;
2670 return uint2{data[dataIdx], data[dataIdx + 1]};
2671 }
2672 SLANG_CUDA_CALL uint3 Load3(size_t index) const
2673 {
2675 const size_t dataIdx = index >> 2;
2676 return uint3{data[dataIdx], data[dataIdx + 1], data[dataIdx + 2]};
2677 }
2678 SLANG_CUDA_CALL uint4 Load4(size_t index) const
2679 {
2681 const size_t dataIdx = index >> 2;
2682 return uint4{data[dataIdx], data[dataIdx + 1], data[dataIdx + 2], data[dataIdx + 3]};
2683 }
2684 template<typename T>
2685 SLANG_CUDA_CALL T Load(size_t index) const
2686 {
2688 T data;
2689 memcpy(&data, ((const char*)this->data) + index, sizeof(T));
2690 return data;
2691 }
2692 template<typename T>
2694 {
2696 rs.data = (T*)data;
2697 rs.count = sizeInBytes / sizeof(T);
2698 return rs;
2699 }
2700 const uint32_t* data;
2701 size_t sizeInBytes; //< Must be multiple of 4
2702};
2703
2704// https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sm5-object-rwbyteaddressbuffer
2705// Atomic operations support
2706
2707// Signed 64-bit atomic wrappers
2708// CUDA only supports unsigned long long atomics, so we cast signed to unsigned
2709// Use longlong type with explicit unsigned long long casts for platform portability
2710__device__ __forceinline__ longlong atomicExch(longlong* address, longlong val)
2711{
2712 return (longlong)atomicExch((unsigned long long*)address, (unsigned long long)val);
2713}
2714
2715__device__ __forceinline__ longlong atomicCAS(longlong* address, longlong compare, longlong val)
2716{
2717 return (longlong)atomicCAS(
2718 (unsigned long long*)address,
2719 (unsigned long long)compare,
2720 (unsigned long long)val);
2721}
2722
2723__device__ __forceinline__ longlong atomicAdd(longlong* address, longlong val)
2724{
2725 return (longlong)atomicAdd((unsigned long long*)address, (unsigned long long)val);
2726}
2727
2728// Float bitwise atomic compare-and-swap
2729// Uses integer atomics to preserve exact float bit patterns
2730__device__ __forceinline__ float atomicCAS(float* address, float compare, float val)
2731{
2732 int* addr_as_int = (int*)address;
2733 int old = atomicCAS(addr_as_int, __float_as_int(compare), __float_as_int(val));
2734 return __int_as_float(old);
2735}
2736
2737// =====================================================================
2738// Atomic Reduction Operations (PTX `red` instruction)
2739// These are in-place atomic operations that don't return the old value.
2740// They are faster than the corresponding atomic operations that return values
2741// because they use the PTX `red` instruction with relaxed memory ordering.
2742//
2743// Supported operations based on PTX ISA:
2744// - add: .s32, .u32, .u64, .s64, .f16, .f16x2, .bf16, .bf16x2, .f32, .f64
2745// - min/max: .s32, .u32, .s64, .u64, .f32, .f64, .f16, .f16x2
2746// - and/or/xor: .b32, .b64
2747// - inc/dec: .u32
2748// =====================================================================
2749
2750// Atomic reduction ADD operations
2751__device__ __forceinline__ void __slang_atomic_reduce_add(int32_t* addr, int32_t val, int order)
2752{
2753 asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;" : : "l"(addr), "r"(val) : "memory");
2754}
2755
2756__device__ __forceinline__ void __slang_atomic_reduce_add(uint32_t* addr, uint32_t val, int order)
2757{
2758 asm volatile("red.relaxed.gpu.global.add.u32 [%0], %1;" : : "l"(addr), "r"(val) : "memory");
2759}
2760
2761__device__ __forceinline__ void __slang_atomic_reduce_add(int64_t* addr, int64_t val, int order)
2762{
2763 asm volatile("red.relaxed.gpu.global.add.s64 [%0], %1;" : : "l"(addr), "l"(val) : "memory");
2764}
2765
2766__device__ __forceinline__ void __slang_atomic_reduce_add(uint64_t* addr, uint64_t val, int order)
2767{
2768 asm volatile("red.relaxed.gpu.global.add.u64 [%0], %1;" : : "l"(addr), "l"(val) : "memory");
2769}
2770
2771__device__ __forceinline__ void __slang_atomic_reduce_add(float* addr, float val, int order)
2772{
2773 asm volatile("red.relaxed.gpu.global.add.f32 [%0], %1;" : : "l"(addr), "f"(val) : "memory");
2774}
2775
2776__device__ __forceinline__ void __slang_atomic_reduce_add(double* addr, double val, int order)
2777{
2778 asm volatile("red.relaxed.gpu.global.add.f64 [%0], %1;" : : "l"(addr), "d"(val) : "memory");
2779}
2780
2781#if SLANG_CUDA_ENABLE_HALF
2782__device__ __forceinline__ void __slang_atomic_reduce_add(__half* addr, __half val, int order)
2783{
2784 unsigned short val_as_ushort = *reinterpret_cast<unsigned short*>(&val);
2785 asm volatile("red.relaxed.gpu.global.add.noftz.f16 [%0], %1;"
2786 :
2787 : "l"(addr), "h"(val_as_ushort)
2788 : "memory");
2789}
2790
2791__device__ __forceinline__ void __slang_atomic_reduce_add(__half2* addr, __half2 val, int order)
2792{
2793 unsigned int val_as_uint = *reinterpret_cast<unsigned int*>(&val);
2794 asm volatile("red.relaxed.gpu.global.add.noftz.f16x2 [%0], %1;"
2795 :
2796 : "l"(addr), "r"(val_as_uint)
2797 : "memory");
2798}
2799#endif
2800
2801#if SLANG_CUDA_ENABLE_BF16
2802__device__ __forceinline__ void __slang_atomic_reduce_add(
2803 __nv_bfloat16* addr,
2804 __nv_bfloat16 val,
2805 int order)
2806{
2807 unsigned short val_as_ushort = *reinterpret_cast<unsigned short*>(&val);
2808 asm volatile("red.relaxed.gpu.global.add.noftz.bf16 [%0], %1;"
2809 :
2810 : "l"(addr), "h"(val_as_ushort)
2811 : "memory");
2812}
2813
2814__device__ __forceinline__ void __slang_atomic_reduce_add(
2815 __nv_bfloat162* addr,
2816 __nv_bfloat162 val,
2817 int order)
2818{
2819 unsigned int val_as_uint = *reinterpret_cast<unsigned int*>(&val);
2820 asm volatile("red.relaxed.gpu.global.add.noftz.bf16x2 [%0], %1;"
2821 :
2822 : "l"(addr), "r"(val_as_uint)
2823 : "memory");
2824}
2825#endif
2826
2827// Atomic reduction MIN operations
2828__device__ __forceinline__ void __slang_atomic_reduce_min(int32_t* addr, int32_t val, int order)
2829{
2830 asm volatile("red.relaxed.gpu.global.min.s32 [%0], %1;" : : "l"(addr), "r"(val) : "memory");
2831}
2832
2833__device__ __forceinline__ void __slang_atomic_reduce_min(uint32_t* addr, uint32_t val, int order)
2834{
2835 asm volatile("red.relaxed.gpu.global.min.u32 [%0], %1;" : : "l"(addr), "r"(val) : "memory");
2836}
2837
2838__device__ __forceinline__ void __slang_atomic_reduce_min(int64_t* addr, int64_t val, int order)
2839{
2840 asm volatile("red.relaxed.gpu.global.min.s64 [%0], %1;" : : "l"(addr), "l"(val) : "memory");
2841}
2842
2843__device__ __forceinline__ void __slang_atomic_reduce_min(uint64_t* addr, uint64_t val, int order)
2844{
2845 asm volatile("red.relaxed.gpu.global.min.u64 [%0], %1;" : : "l"(addr), "l"(val) : "memory");
2846}
2847
2848// NOTE: PTX `red` instruction does NOT support min/max for floating-point types.
2849// Only integer types (.u32, .u64, .s32, .s64) are supported for min/max.
2850// For floating-point min/max atomics, use the regular `atom` instruction via
2851// __atomic_min/__atomic_max.
2852
2853// Atomic reduction MAX operations
2854__device__ __forceinline__ void __slang_atomic_reduce_max(int32_t* addr, int32_t val, int order)
2855{
2856 asm volatile("red.relaxed.gpu.global.max.s32 [%0], %1;" : : "l"(addr), "r"(val) : "memory");
2857}
2858
2859__device__ __forceinline__ void __slang_atomic_reduce_max(uint32_t* addr, uint32_t val, int order)
2860{
2861 asm volatile("red.relaxed.gpu.global.max.u32 [%0], %1;" : : "l"(addr), "r"(val) : "memory");
2862}
2863
2864__device__ __forceinline__ void __slang_atomic_reduce_max(int64_t* addr, int64_t val, int order)
2865{
2866 asm volatile("red.relaxed.gpu.global.max.s64 [%0], %1;" : : "l"(addr), "l"(val) : "memory");
2867}
2868
2869__device__ __forceinline__ void __slang_atomic_reduce_max(uint64_t* addr, uint64_t val, int order)
2870{
2871 asm volatile("red.relaxed.gpu.global.max.u64 [%0], %1;" : : "l"(addr), "l"(val) : "memory");
2872}
2873
2874// NOTE: PTX `red` instruction does NOT support min/max for floating-point types.
2875// Only integer types (.u32, .u64, .s32, .s64) are supported for min/max.
2876// For floating-point min/max atomics, use the regular `atom` instruction via
2877// __atomic_min/__atomic_max.
2878
2879// Atomic reduction AND operations (bitwise, integers only)
2880__device__ __forceinline__ void __slang_atomic_reduce_and(int32_t* addr, int32_t val, int order)
2881{
2882 asm volatile("red.relaxed.gpu.global.and.b32 [%0], %1;" : : "l"(addr), "r"(val) : "memory");
2883}
2884
2885__device__ __forceinline__ void __slang_atomic_reduce_and(uint32_t* addr, uint32_t val, int order)
2886{
2887 asm volatile("red.relaxed.gpu.global.and.b32 [%0], %1;" : : "l"(addr), "r"(val) : "memory");
2888}
2889
2890__device__ __forceinline__ void __slang_atomic_reduce_and(int64_t* addr, int64_t val, int order)
2891{
2892 asm volatile("red.relaxed.gpu.global.and.b64 [%0], %1;" : : "l"(addr), "l"(val) : "memory");
2893}
2894
2895__device__ __forceinline__ void __slang_atomic_reduce_and(uint64_t* addr, uint64_t val, int order)
2896{
2897 asm volatile("red.relaxed.gpu.global.and.b64 [%0], %1;" : : "l"(addr), "l"(val) : "memory");
2898}
2899
2900// Atomic reduction OR operations (bitwise, integers only)
2901__device__ __forceinline__ void __slang_atomic_reduce_or(int32_t* addr, int32_t val, int order)
2902{
2903 asm volatile("red.relaxed.gpu.global.or.b32 [%0], %1;" : : "l"(addr), "r"(val) : "memory");
2904}
2905
2906__device__ __forceinline__ void __slang_atomic_reduce_or(uint32_t* addr, uint32_t val, int order)
2907{
2908 asm volatile("red.relaxed.gpu.global.or.b32 [%0], %1;" : : "l"(addr), "r"(val) : "memory");
2909}
2910
2911__device__ __forceinline__ void __slang_atomic_reduce_or(int64_t* addr, int64_t val, int order)
2912{
2913 asm volatile("red.relaxed.gpu.global.or.b64 [%0], %1;" : : "l"(addr), "l"(val) : "memory");
2914}
2915
2916__device__ __forceinline__ void __slang_atomic_reduce_or(uint64_t* addr, uint64_t val, int order)
2917{
2918 asm volatile("red.relaxed.gpu.global.or.b64 [%0], %1;" : : "l"(addr), "l"(val) : "memory");
2919}
2920
2921// Atomic reduction XOR operations (bitwise, integers only)
2922__device__ __forceinline__ void __slang_atomic_reduce_xor(int32_t* addr, int32_t val, int order)
2923{
2924 asm volatile("red.relaxed.gpu.global.xor.b32 [%0], %1;" : : "l"(addr), "r"(val) : "memory");
2925}
2926
2927__device__ __forceinline__ void __slang_atomic_reduce_xor(uint32_t* addr, uint32_t val, int order)
2928{
2929 asm volatile("red.relaxed.gpu.global.xor.b32 [%0], %1;" : : "l"(addr), "r"(val) : "memory");
2930}
2931
2932__device__ __forceinline__ void __slang_atomic_reduce_xor(int64_t* addr, int64_t val, int order)
2933{
2934 asm volatile("red.relaxed.gpu.global.xor.b64 [%0], %1;" : : "l"(addr), "l"(val) : "memory");
2935}
2936
2937__device__ __forceinline__ void __slang_atomic_reduce_xor(uint64_t* addr, uint64_t val, int order)
2938{
2939 asm volatile("red.relaxed.gpu.global.xor.b64 [%0], %1;" : : "l"(addr), "l"(val) : "memory");
2940}
2941
2942// Atomic reduction INC/DEC operations (unsigned 32-bit only in PTX)
2943// Note: PTX inc/dec have specific semantics:
2944// inc: d = (old >= b) ? 0 : old + 1
2945// dec: d = ((old == 0) || (old > b)) ? b : old - 1
2946// For simple increment by 1, we use add instead
2947__device__ __forceinline__ void __slang_atomic_reduce_inc(uint32_t* addr, int order)
2948{
2949 asm volatile("red.relaxed.gpu.global.add.u32 [%0], 1;" : : "l"(addr) : "memory");
2950}
2951
2952__device__ __forceinline__ void __slang_atomic_reduce_inc(int32_t* addr, int order)
2953{
2954 asm volatile("red.relaxed.gpu.global.add.s32 [%0], 1;" : : "l"(addr) : "memory");
2955}
2956
2957__device__ __forceinline__ void __slang_atomic_reduce_dec(uint32_t* addr, int order)
2958{
2959 asm volatile("red.relaxed.gpu.global.add.u32 [%0], -1;" : : "l"(addr) : "memory");
2960}
2961
2962__device__ __forceinline__ void __slang_atomic_reduce_dec(int32_t* addr, int order)
2963{
2964 asm volatile("red.relaxed.gpu.global.add.s32 [%0], -1;" : : "l"(addr) : "memory");
2965}
2966
2967// =====================================================================
2968// End of Atomic Reduction Operations
2969// =====================================================================
2970
2971// Missing support for Load with status
2973{
2974 SLANG_CUDA_CALL void GetDimensions(uint32_t* outDim) const { *outDim = uint32_t(sizeInBytes); }
2975
2976 SLANG_CUDA_CALL uint32_t Load(size_t index) const
2977 {
2979 return data[index >> 2];
2980 }
2981 SLANG_CUDA_CALL uint2 Load2(size_t index) const
2982 {
2984 const size_t dataIdx = index >> 2;
2985 return uint2{data[dataIdx], data[dataIdx + 1]};
2986 }
2987 SLANG_CUDA_CALL uint3 Load3(size_t index) const
2988 {
2990 const size_t dataIdx = index >> 2;
2991 return uint3{data[dataIdx], data[dataIdx + 1], data[dataIdx + 2]};
2992 }
2993 SLANG_CUDA_CALL uint4 Load4(size_t index) const
2994 {
2996 const size_t dataIdx = index >> 2;
2997 return uint4{data[dataIdx], data[dataIdx + 1], data[dataIdx + 2], data[dataIdx + 3]};
2998 }
2999 template<typename T>
3000 SLANG_CUDA_CALL T Load(size_t index) const
3001 {
3003 T data;
3004 memcpy(&data, ((const char*)this->data) + index, sizeof(T));
3005 return data;
3006 }
3007
3008 SLANG_CUDA_CALL void Store(size_t index, uint32_t v) const
3009 {
3011 data[index >> 2] = v;
3012 }
3013 SLANG_CUDA_CALL void Store2(size_t index, uint2 v) const
3014 {
3016 const size_t dataIdx = index >> 2;
3017 data[dataIdx + 0] = v.x;
3018 data[dataIdx + 1] = v.y;
3019 }
3020 SLANG_CUDA_CALL void Store3(size_t index, uint3 v) const
3021 {
3023 const size_t dataIdx = index >> 2;
3024 data[dataIdx + 0] = v.x;
3025 data[dataIdx + 1] = v.y;
3026 data[dataIdx + 2] = v.z;
3027 }
3028 SLANG_CUDA_CALL void Store4(size_t index, uint4 v) const
3029 {
3031 const size_t dataIdx = index >> 2;
3032 data[dataIdx + 0] = v.x;
3033 data[dataIdx + 1] = v.y;
3034 data[dataIdx + 2] = v.z;
3035 data[dataIdx + 3] = v.w;
3036 }
3037 template<typename T>
3038 SLANG_CUDA_CALL void Store(size_t index, T const& value) const
3039 {
3041 memcpy((char*)data + index, &value, sizeof(T));
3042 }
3043
3045 template<typename T>
3047 {
3049 return (T*)(((char*)data) + index);
3050 }
3051 template<typename T>
3053 {
3055 rs.data = (T*)data;
3056 rs.count = sizeInBytes / sizeof(T);
3057 return rs;
3058 }
3059 uint32_t* data;
3060 size_t sizeInBytes; //< Must be multiple of 4
3061};
3062
3063
3064// ---------------------- Wave --------------------------------------
3065
3066// TODO(JS): It appears that cuda does not have a simple way to get a lane index.
3067//
3068// Another approach could be...
3069// laneId = ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) &
3070// SLANG_CUDA_WARP_MASK If that is really true another way to do this, would be for code generator
3071// to add this function with the [numthreads] baked in.
3072//
3073// For now I'll just assume you have a launch that makes the following correct if the kernel uses
3074// WaveGetLaneIndex()
3075#ifndef SLANG_USE_ASM_LANE_ID
3076__forceinline__ __device__ uint32_t _getLaneId()
3077{
3078 return ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) &
3080}
3081#else
3082__forceinline__ __device__ uint32_t _getLaneId()
3083{
3084 // https://stackoverflow.com/questions/44337309/whats-the-most-efficient-way-to-calculate-the-warp-id-lane-id-in-a-1-d-grid#
3085 // This mechanism is not the fastest way to do it, and that is why the other mechanism
3086 // is the default. But the other mechanism relies on a launch that makes the assumption
3087 // true.
3088 unsigned ret;
3089 asm volatile("mov.u32 %0, %laneid;" : "=r"(ret));
3090 return ret;
3091}
3092#endif
3093
3094typedef int WarpMask;
3095
3096// It appears that the __activemask() cannot always be used because
3097// threads need to be converged.
3098//
3099// For CUDA the article claims mask has to be used carefully
3100// https://devblogs.nvidia.com/using-cuda-warp-level-primitives/
3101// With the Warp intrinsics there is no mask, and it's just the 'active lanes'.
3102// __activemask() though does not require there is convergence, so that doesn't work.
3103//
3104// '__ballot_sync' produces a convergance.
3105//
3106// From the CUDA docs:
3107// ```For __all_sync, __any_sync, and __ballot_sync, a mask must be passed that specifies the
3108// threads participating in the call. A bit, representing the thread's lane ID, must be set for each
3109// participating thread to ensure they are properly converged before the intrinsic is executed by
3110// the hardware. All active threads named in mask must execute the same intrinsic with the same
3111// mask, or the result is undefined.```
3112//
3113// Currently there isn't a mechanism to correctly get the mask without it being passed through.
3114// Doing so will most likely require some changes to slang code generation to track masks, for now
3115// then we use _getActiveMask.
3116
3117// Return mask of all the lanes less than the current lane
3118__forceinline__ __device__ WarpMask _getLaneLtMask()
3119{
3120 return (int(1) << _getLaneId()) - 1;
3121}
3122
3123// TODO(JS):
3124// THIS IS NOT CORRECT! That determining the appropriate active mask requires appropriate
3125// mask tracking.
3126__forceinline__ __device__ WarpMask _getActiveMask()
3127{
3128 return __ballot_sync(__activemask(), true);
3129}
3130
3131// Return a mask suitable for the 'MultiPrefix' style functions
3132__forceinline__ __device__ WarpMask _getMultiPrefixMask(int mask)
3133{
3134 return mask;
3135}
3136
3137// Note! Note will return true if mask is 0, but thats okay, because there must be one
3138// lane active to execute anything
3139__inline__ __device__ bool _waveIsSingleLane(WarpMask mask)
3140{
3141 return (mask & (mask - 1)) == 0;
3142}
3143
3144// Returns the power of 2 size of run of set bits. Returns 0 if not a suitable run.
3145// Examples:
3146// 0b00000000'00000000'00000000'11111111 -> 8
3147// 0b11111111'11111111'11111111'11111111 -> 32
3148// 0b00000000'00000000'00000000'00011111 -> 0 (since 5 is not a power of 2)
3149// 0b00000000'00000000'00000000'11110000 -> 0 (since the run of bits does not start at the LSB)
3150// 0b00000000'00000000'00000000'00100111 -> 0 (since it is not a single contiguous run)
3151__inline__ __device__ int _waveCalcPow2Offset(WarpMask mask)
3152{
3153 // This should be the most common case, so fast path it
3154 if (mask == SLANG_CUDA_WARP_BITMASK)
3155 {
3156 return SLANG_CUDA_WARP_SIZE;
3157 }
3158 // Is it a contiguous run of bits?
3159 if ((mask & (mask + 1)) == 0)
3160 {
3161 // const int offsetSize = __ffs(mask + 1) - 1;
3162 const int offset = 32 - __clz(mask);
3163 // Is it a power of 2 size
3164 if ((offset & (offset - 1)) == 0)
3165 {
3166 return offset;
3167 }
3168 }
3169 return 0;
3170}
3171
3172__inline__ __device__ bool _waveIsFirstLane()
3173{
3174 const WarpMask mask = __activemask();
3175 // We special case bit 0, as that most warps are expected to be fully active.
3176
3177 // mask & -mask, isolates the lowest set bit.
3178 // return (mask & 1 ) || ((mask & -mask) == (1 << _getLaneId()));
3179
3180 // This mechanism is most similar to what was in an nVidia post, so assume it is prefered.
3181 return (mask & 1) || ((__ffs(mask) - 1) == _getLaneId());
3182}
3183
3184template<typename T>
3186{
3187 __inline__ __device__ static T getInitial(T a) { return 0; }
3188 __inline__ __device__ static T doOp(T a, T b) { return a | b; }
3189};
3190
3191template<typename T>
3193{
3194 __inline__ __device__ static T getInitial(T a) { return ~T(0); }
3195 __inline__ __device__ static T doOp(T a, T b) { return a & b; }
3196};
3197
3198template<typename T>
3200{
3201 __inline__ __device__ static T getInitial(T a) { return 0; }
3202 __inline__ __device__ static T doOp(T a, T b) { return a ^ b; }
3203 __inline__ __device__ static T doInverse(T a, T b) { return a ^ b; }
3204};
3205
3206template<typename T>
3208{
3209 __inline__ __device__ static T getInitial(T a) { return 0; }
3210 __inline__ __device__ static T doOp(T a, T b) { return a + b; }
3211 __inline__ __device__ static T doInverse(T a, T b) { return a - b; }
3212};
3213
3214template<typename T>
3216{
3217 __inline__ __device__ static T getInitial(T a) { return T(1); }
3218 __inline__ __device__ static T doOp(T a, T b) { return a * b; }
3219 // Using this inverse for int is probably undesirable - because in general it requires T to have
3220 // more precision There is also a performance aspect to it, where divides are generally
3221 // significantly slower
3222 __inline__ __device__ static T doInverse(T a, T b) { return a / b; }
3223};
3224
3225template<typename T>
3227{
3228 __inline__ __device__ static T getInitial(T a, bool exclusive = false);
3229 __inline__ __device__ static T doOp(T a, T b) { return a > b ? a : b; }
3230};
3231
3232template<typename T>
3234{
3235 __inline__ __device__ static T getInitial(T a, bool exclusive = false);
3236 __inline__ __device__ static T doOp(T a, T b) { return a < b ? a : b; }
3237};
3238
3239// Compact specializations using macro for getInitial
3240#define SLANG_WAVE_MIN_SPEC(T, EXCL_VAL) \
3241 template<> \
3242 __inline__ __device__ T WaveOpMin<T>::getInitial(T a, bool exclusive) \
3243 { \
3244 return exclusive ? (EXCL_VAL) : a; \
3245 }
3246
3247#define SLANG_WAVE_MAX_SPEC(T, EXCL_VAL) \
3248 template<> \
3249 __inline__ __device__ T WaveOpMax<T>::getInitial(T a, bool exclusive) \
3250 { \
3251 return exclusive ? (EXCL_VAL) : a; \
3252 }
3253
3254// Min specializations (exclusive identity = max value)
3257SLANG_WAVE_MIN_SPEC(int, 0x7FFFFFFF)
3258SLANG_WAVE_MIN_SPEC(uint, 0xFFFFFFFF)
3259SLANG_WAVE_MIN_SPEC(char, (char)0x7F)
3260SLANG_WAVE_MIN_SPEC(int8_t, (int8_t)0x7F)
3261SLANG_WAVE_MIN_SPEC(uint8_t, (uint8_t)0xFF)
3262SLANG_WAVE_MIN_SPEC(int16_t, (int16_t)0x7FFF)
3263SLANG_WAVE_MIN_SPEC(uint16_t, (uint16_t)0xFFFF)
3264SLANG_WAVE_MIN_SPEC(int64_t, 0x7FFFFFFFFFFFFFFFLL)
3265SLANG_WAVE_MIN_SPEC(uint64_t, 0xFFFFFFFFFFFFFFFFULL)
3266#if SLANG_CUDA_ENABLE_HALF
3267SLANG_WAVE_MIN_SPEC(__half, __ushort_as_half(0x7BFF))
3268#endif
3269
3270// Max specializations (exclusive identity = min value)
3273SLANG_WAVE_MAX_SPEC(int, (int)0x80000000)
3275SLANG_WAVE_MAX_SPEC(char, (char)0x80)
3276SLANG_WAVE_MAX_SPEC(int8_t, (int8_t)0x80)
3277SLANG_WAVE_MAX_SPEC(uint8_t, 0)
3278SLANG_WAVE_MAX_SPEC(int16_t, (int16_t)0x8000)
3279SLANG_WAVE_MAX_SPEC(uint16_t, 0)
3280SLANG_WAVE_MAX_SPEC(int64_t, (int64_t)0x8000000000000000LL)
3281SLANG_WAVE_MAX_SPEC(uint64_t, 0)
3282#if SLANG_CUDA_ENABLE_HALF
3283SLANG_WAVE_MAX_SPEC(__half, __ushort_as_half(0xFBFF))
3284#endif
3285
3286#undef SLANG_WAVE_MIN_SPEC
3287#undef SLANG_WAVE_MAX_SPEC
3288
3289template<typename T>
3291
3292// Scalar
3293template<>
3295{
3296 typedef int Type;
3297};
3298template<>
3300{
3301 typedef uint Type;
3302};
3303template<>
3304struct ElementTypeTrait<float>
3305{
3306 typedef float Type;
3307};
3308template<>
3309struct ElementTypeTrait<double>
3310{
3311 typedef double Type;
3312};
3313template<>
3314struct ElementTypeTrait<uint64_t>
3315{
3316 typedef uint64_t Type;
3317};
3318template<>
3319struct ElementTypeTrait<int64_t>
3320{
3321 typedef int64_t Type;
3322};
3323template<>
3325{
3326 typedef char Type;
3327};
3328template<>
3330{
3331 typedef uchar Type;
3332};
3333template<>
3334struct ElementTypeTrait<short>
3335{
3336 typedef short Type;
3337};
3338template<>
3340{
3341 typedef ushort Type;
3342};
3343#if SLANG_CUDA_ENABLE_HALF
3344template<>
3345struct ElementTypeTrait<__half>
3346{
3347 typedef __half Type;
3348};
3349#endif
3350
3351// Vector
3352template<>
3354{
3355 typedef int Type;
3356};
3357template<>
3359{
3360 typedef int Type;
3361};
3362template<>
3364{
3365 typedef int Type;
3366};
3367template<>
3369{
3370 typedef int Type;
3371};
3372
3373template<>
3374struct ElementTypeTrait<uint1>
3375{
3376 typedef uint Type;
3377};
3378template<>
3380{
3381 typedef uint Type;
3382};
3383template<>
3385{
3386 typedef uint Type;
3387};
3388template<>
3390{
3391 typedef uint Type;
3392};
3393
3394template<>
3395struct ElementTypeTrait<float1>
3396{
3397 typedef float Type;
3398};
3399template<>
3401{
3402 typedef float Type;
3403};
3404template<>
3406{
3407 typedef float Type;
3408};
3409template<>
3411{
3412 typedef float Type;
3413};
3414
3415template<>
3416struct ElementTypeTrait<double1>
3417{
3418 typedef double Type;
3419};
3420template<>
3421struct ElementTypeTrait<double2>
3422{
3423 typedef double Type;
3424};
3425template<>
3426struct ElementTypeTrait<double3>
3427{
3428 typedef double Type;
3429};
3430template<>
3431struct ElementTypeTrait<double4>
3432{
3433 typedef double Type;
3434};
3435
3436// Additional vector types
3437template<>
3438struct ElementTypeTrait<char2>
3439{
3440 typedef char Type;
3441};
3442template<>
3443struct ElementTypeTrait<char3>
3444{
3445 typedef char Type;
3446};
3447template<>
3448struct ElementTypeTrait<char4>
3449{
3450 typedef char Type;
3451};
3452template<>
3453struct ElementTypeTrait<uchar2>
3454{
3455 typedef uchar Type;
3456};
3457template<>
3458struct ElementTypeTrait<uchar3>
3459{
3460 typedef uchar Type;
3461};
3462template<>
3463struct ElementTypeTrait<uchar4>
3464{
3465 typedef uchar Type;
3466};
3467template<>
3468struct ElementTypeTrait<short2>
3469{
3470 typedef short Type;
3471};
3472template<>
3473struct ElementTypeTrait<short3>
3474{
3475 typedef short Type;
3476};
3477template<>
3478struct ElementTypeTrait<short4>
3479{
3480 typedef short Type;
3481};
3482template<>
3483struct ElementTypeTrait<ushort2>
3484{
3485 typedef ushort Type;
3486};
3487template<>
3488struct ElementTypeTrait<ushort3>
3489{
3490 typedef ushort Type;
3491};
3492template<>
3493struct ElementTypeTrait<ushort4>
3494{
3495 typedef ushort Type;
3496};
3497template<>
3498struct ElementTypeTrait<longlong2>
3499{
3500 typedef int64_t Type;
3501};
3502template<>
3503struct ElementTypeTrait<longlong3>
3504{
3505 typedef int64_t Type;
3506};
3507template<>
3508struct ElementTypeTrait<longlong4>
3509{
3510 typedef int64_t Type;
3511};
3512template<>
3513struct ElementTypeTrait<ulonglong2>
3514{
3515 typedef uint64_t Type;
3516};
3517template<>
3518struct ElementTypeTrait<ulonglong3>
3519{
3520 typedef uint64_t Type;
3521};
3522template<>
3523struct ElementTypeTrait<ulonglong4>
3524{
3525 typedef uint64_t Type;
3526};
3527#if SLANG_CUDA_ENABLE_HALF
3528template<>
3529struct ElementTypeTrait<__half2>
3530{
3531 typedef __half Type;
3532};
3533template<>
3534struct ElementTypeTrait<__half3>
3535{
3536 typedef __half Type;
3537};
3538template<>
3539struct ElementTypeTrait<__half4>
3540{
3541 typedef __half Type;
3542};
3543#endif
3544
3545// Matrix
3546template<typename T, int ROWS, int COLS>
3547struct ElementTypeTrait<Matrix<T, ROWS, COLS>>
3548{
3549 typedef T Type;
3550};
3551
3552// Scalar
3553template<typename INTF, typename T>
3554__device__ T _waveReduceScalar(WarpMask mask, T val)
3555{
3556 const int offsetSize = _waveCalcPow2Offset(mask);
3557 if (offsetSize > 0)
3558 {
3559 // Fast path O(log2(activeLanes))
3560 for (int offset = offsetSize >> 1; offset > 0; offset >>= 1)
3561 {
3562 val = INTF::doOp(val, __shfl_xor_sync(mask, val, offset));
3563 }
3564 }
3565 else if (!_waveIsSingleLane(mask))
3566 {
3567 T result = INTF::getInitial(val);
3568 int remaining = mask;
3569 while (remaining)
3570 {
3571 const int laneBit = remaining & -remaining;
3572 // Get the sourceLane
3573 const int srcLane = __ffs(laneBit) - 1;
3574 // Broadcast (can also broadcast to self)
3575 result = INTF::doOp(result, __shfl_sync(mask, val, srcLane));
3576 remaining &= ~laneBit;
3577 }
3578 return result;
3579 }
3580 return val;
3581}
3582
3583
3584// Multiple values
3585template<typename INTF, typename T, size_t COUNT>
3586__device__ void _waveReduceMultiple(WarpMask mask, T* val)
3587{
3588 const int offsetSize = _waveCalcPow2Offset(mask);
3589 if (offsetSize > 0)
3590 {
3591 // Fast path O(log2(activeLanes))
3592 for (int offset = offsetSize >> 1; offset > 0; offset >>= 1)
3593 {
3594 for (size_t i = 0; i < COUNT; ++i)
3595 {
3596 val[i] = INTF::doOp(val[i], __shfl_xor_sync(mask, val[i], offset));
3597 }
3598 }
3599 }
3600 else if (!_waveIsSingleLane(mask))
3601 {
3602 // Copy the original
3603 T originalVal[COUNT];
3604 for (size_t i = 0; i < COUNT; ++i)
3605 {
3606 const T v = val[i];
3607 originalVal[i] = v;
3608 val[i] = INTF::getInitial(v);
3609 }
3610
3611 int remaining = mask;
3612 while (remaining)
3613 {
3614 const int laneBit = remaining & -remaining;
3615 // Get the sourceLane
3616 const int srcLane = __ffs(laneBit) - 1;
3617 // Broadcast (can also broadcast to self)
3618 for (size_t i = 0; i < COUNT; ++i)
3619 {
3620 val[i] = INTF::doOp(val[i], __shfl_sync(mask, originalVal[i], srcLane));
3621 }
3622 remaining &= ~laneBit;
3623 }
3624 }
3625}
3626
3627template<typename INTF, typename T>
3628__device__ void _waveReduceMultiple(WarpMask mask, T* val)
3629{
3630 typedef typename ElementTypeTrait<T>::Type ElemType;
3631 _waveReduceMultiple<INTF, ElemType, sizeof(T) / sizeof(ElemType)>(mask, (ElemType*)val);
3632}
3633
3634template<typename T>
3635__inline__ __device__ T _waveOr(WarpMask mask, T val)
3636{
3637 return _waveReduceScalar<WaveOpOr<T>, T>(mask, val);
3638}
3639
3640template<typename T>
3641__inline__ __device__ T _waveAnd(WarpMask mask, T val)
3642{
3643 return _waveReduceScalar<WaveOpAnd<T>, T>(mask, val);
3644}
3645
3646template<typename T>
3647__inline__ __device__ T _waveXor(WarpMask mask, T val)
3648{
3649 return _waveReduceScalar<WaveOpXor<T>, T>(mask, val);
3650}
3651
3652template<typename T>
3653__inline__ __device__ T _waveProduct(WarpMask mask, T val)
3654{
3655 return _waveReduceScalar<WaveOpMul<T>, T>(mask, val);
3656}
3657
3658template<typename T>
3659__inline__ __device__ T _waveSum(WarpMask mask, T val)
3660{
3661 return _waveReduceScalar<WaveOpAdd<T>, T>(mask, val);
3662}
3663
3664template<typename T>
3665__inline__ __device__ T _waveMin(WarpMask mask, T val)
3666{
3667 return _waveReduceScalar<WaveOpMin<T>, T>(mask, val);
3668}
3669
3670template<typename T>
3671__inline__ __device__ T _waveMax(WarpMask mask, T val)
3672{
3673 return _waveReduceScalar<WaveOpMax<T>, T>(mask, val);
3674}
3675
3676// Fast-path specializations when CUDA warp reduce operators are available
3677#if __CUDA_ARCH__ >= 800 // 8.x or higher
3678template<>
3679__inline__ __device__ unsigned _waveOr<unsigned>(WarpMask mask, unsigned val)
3680{
3681 return __reduce_or_sync(mask, val);
3682}
3683
3684template<>
3685__inline__ __device__ unsigned _waveAnd<unsigned>(WarpMask mask, unsigned val)
3686{
3687 return __reduce_and_sync(mask, val);
3688}
3689
3690template<>
3691__inline__ __device__ unsigned _waveXor<unsigned>(WarpMask mask, unsigned val)
3692{
3693 return __reduce_xor_sync(mask, val);
3694}
3695
3696template<>
3697__inline__ __device__ unsigned _waveSum<unsigned>(WarpMask mask, unsigned val)
3698{
3699 return __reduce_add_sync(mask, val);
3700}
3701
3702template<>
3703__inline__ __device__ int _waveSum<int>(WarpMask mask, int val)
3704{
3705 return __reduce_add_sync(mask, val);
3706}
3707
3708template<>
3709__inline__ __device__ unsigned _waveMin<unsigned>(WarpMask mask, unsigned val)
3710{
3711 return __reduce_min_sync(mask, val);
3712}
3713
3714template<>
3715__inline__ __device__ int _waveMin<int>(WarpMask mask, int val)
3716{
3717 return __reduce_min_sync(mask, val);
3718}
3719
3720template<>
3721__inline__ __device__ unsigned _waveMax<unsigned>(WarpMask mask, unsigned val)
3722{
3723 return __reduce_max_sync(mask, val);
3724}
3725
3726template<>
3727__inline__ __device__ int _waveMax<int>(WarpMask mask, int val)
3728{
3729 return __reduce_max_sync(mask, val);
3730}
3731#endif
3732
3733// Multiple
3734
3735template<typename T>
3736__inline__ __device__ T _waveOrMultiple(WarpMask mask, T val)
3737{
3738 typedef typename ElementTypeTrait<T>::Type ElemType;
3739 _waveReduceMultiple<WaveOpOr<ElemType>>(mask, &val);
3740 return val;
3741}
3742
3743template<typename T>
3744__inline__ __device__ T _waveAndMultiple(WarpMask mask, T val)
3745{
3746 typedef typename ElementTypeTrait<T>::Type ElemType;
3747 _waveReduceMultiple<WaveOpAnd<ElemType>>(mask, &val);
3748 return val;
3749}
3750
3751template<typename T>
3752__inline__ __device__ T _waveXorMultiple(WarpMask mask, T val)
3753{
3754 typedef typename ElementTypeTrait<T>::Type ElemType;
3755 _waveReduceMultiple<WaveOpXor<ElemType>>(mask, &val);
3756 return val;
3757}
3758
3759template<typename T>
3760__inline__ __device__ T _waveProductMultiple(WarpMask mask, T val)
3761{
3762 typedef typename ElementTypeTrait<T>::Type ElemType;
3763 _waveReduceMultiple<WaveOpMul<ElemType>>(mask, &val);
3764 return val;
3765}
3766
3767template<typename T>
3768__inline__ __device__ T _waveSumMultiple(WarpMask mask, T val)
3769{
3770 typedef typename ElementTypeTrait<T>::Type ElemType;
3771 _waveReduceMultiple<WaveOpAdd<ElemType>>(mask, &val);
3772 return val;
3773}
3774
3775template<typename T>
3776__inline__ __device__ T _waveMinMultiple(WarpMask mask, T val)
3777{
3778 typedef typename ElementTypeTrait<T>::Type ElemType;
3779 _waveReduceMultiple<WaveOpMin<ElemType>>(mask, &val);
3780 return val;
3781}
3782
3783template<typename T>
3784__inline__ __device__ T _waveMaxMultiple(WarpMask mask, T val)
3785{
3786 typedef typename ElementTypeTrait<T>::Type ElemType;
3787 _waveReduceMultiple<WaveOpMax<ElemType>>(mask, &val);
3788 return val;
3789}
3790
3791
3792template<typename T>
3793__inline__ __device__ bool _waveAllEqual(WarpMask mask, T val)
3794{
3795 int pred;
3796 __match_all_sync(mask, val, &pred);
3797 return pred != 0;
3798}
3799
3800template<typename T>
3801__inline__ __device__ bool _waveAllEqualMultiple(WarpMask mask, T inVal)
3802{
3803 typedef typename ElementTypeTrait<T>::Type ElemType;
3804 const size_t count = sizeof(T) / sizeof(ElemType);
3805 int pred;
3806 const ElemType* src = (const ElemType*)&inVal;
3807 for (size_t i = 0; i < count; ++i)
3808 {
3809 __match_all_sync(mask, src[i], &pred);
3810 if (pred == 0)
3811 {
3812 return false;
3813 }
3814 }
3815 return true;
3816}
3817
3818template<typename T>
3819__inline__ __device__ T _waveReadFirst(WarpMask mask, T val)
3820{
3821 const int lowestLaneId = __ffs(mask) - 1;
3822 return __shfl_sync(mask, val, lowestLaneId);
3823}
3824
3825template<typename T>
3826__inline__ __device__ T _waveReadFirstMultiple(WarpMask mask, T inVal)
3827{
3828 typedef typename ElementTypeTrait<T>::Type ElemType;
3829 const size_t count = sizeof(T) / sizeof(ElemType);
3830 T outVal;
3831 const ElemType* src = (const ElemType*)&inVal;
3832 ElemType* dst = (ElemType*)&outVal;
3833 const int lowestLaneId = __ffs(mask) - 1;
3834 for (size_t i = 0; i < count; ++i)
3835 {
3836 dst[i] = __shfl_sync(mask, src[i], lowestLaneId);
3837 }
3838 return outVal;
3839}
3840
3841template<typename T>
3842__inline__ __device__ T _waveShuffleMultiple(WarpMask mask, T inVal, int lane)
3843{
3844 typedef typename ElementTypeTrait<T>::Type ElemType;
3845 const size_t count = sizeof(T) / sizeof(ElemType);
3846 T outVal;
3847 const ElemType* src = (const ElemType*)&inVal;
3848 ElemType* dst = (ElemType*)&outVal;
3849 for (size_t i = 0; i < count; ++i)
3850 {
3851 dst[i] = __shfl_sync(mask, src[i], lane);
3852 }
3853 return outVal;
3854}
3855
3856// Scalar
3857
3858// Invertable means that when we get to the end of the reduce, we can remove val (to make
3859// exclusive), using the inverse of the op.
3860template<typename INTF, typename T>
3861__device__ T _wavePrefixInvertableScalar(WarpMask mask, T val)
3862{
3863 const int offsetSize = _waveCalcPow2Offset(mask);
3864
3865 const int laneId = _getLaneId();
3866 T result;
3867 if (offsetSize > 0)
3868 {
3869 // Sum is calculated inclusive of this lanes value
3870 result = val;
3871 for (int i = 1; i < offsetSize; i += i)
3872 {
3873 const T readVal = __shfl_up_sync(mask, result, i, offsetSize);
3874 if (laneId >= i)
3875 {
3876 result = INTF::doOp(result, readVal);
3877 }
3878 }
3879 // Remove val from the result, by applyin inverse
3880 result = INTF::doInverse(result, val);
3881 }
3882 else
3883 {
3884 result = INTF::getInitial(val);
3885 if (!_waveIsSingleLane(mask))
3886 {
3887 int remaining = mask;
3888 while (remaining)
3889 {
3890 const int laneBit = remaining & -remaining;
3891 // Get the sourceLane
3892 const int srcLane = __ffs(laneBit) - 1;
3893 // Broadcast (can also broadcast to self)
3894 const T readValue = __shfl_sync(mask, val, srcLane);
3895 // Only accumulate if srcLane is less than this lane
3896 if (srcLane < laneId)
3897 {
3898 result = INTF::doOp(result, readValue);
3899 }
3900 remaining &= ~laneBit;
3901 }
3902 }
3903 }
3904 return result;
3905}
3906
3907
3908// This implementation separately tracks the value to be propogated, and the value
3909// that is the final result
3910template<typename INTF, typename T>
3911__device__ T _wavePrefixScalar(WarpMask mask, T val)
3912{
3913 const int offsetSize = _waveCalcPow2Offset(mask);
3914
3915 const int laneId = _getLaneId();
3916 T result = INTF::getInitial(val);
3917 if (offsetSize > 0)
3918 {
3919 // For transmitted value we will do it inclusively with this lanes value
3920 // For the result we do not include the lanes value. This means an extra multiply for each
3921 // iteration but means we don't need to have a divide at the end and also removes overflow
3922 // issues in that scenario.
3923 for (int i = 1; i < offsetSize; i += i)
3924 {
3925 const T readVal = __shfl_up_sync(mask, val, i, offsetSize);
3926 if (laneId >= i)
3927 {
3928 result = INTF::doOp(result, readVal);
3929 val = INTF::doOp(val, readVal);
3930 }
3931 }
3932 }
3933 else
3934 {
3935 if (!_waveIsSingleLane(mask))
3936 {
3937 int remaining = mask;
3938 while (remaining)
3939 {
3940 const int laneBit = remaining & -remaining;
3941 // Get the sourceLane
3942 const int srcLane = __ffs(laneBit) - 1;
3943 // Broadcast (can also broadcast to self)
3944 const T readValue = __shfl_sync(mask, val, srcLane);
3945 // Only accumulate if srcLane is less than this lane
3946 if (srcLane < laneId)
3947 {
3948 result = INTF::doOp(result, readValue);
3949 }
3950 remaining &= ~laneBit;
3951 }
3952 }
3953 }
3954 return result;
3955}
3956
3957
3958template<typename INTF, typename T, size_t COUNT>
3959__device__ T _waveOpCopy(T* dst, const T* src)
3960{
3961 for (size_t j = 0; j < COUNT; ++j)
3962 {
3963 dst[j] = src[j];
3964 }
3965}
3966
3967
3968template<typename INTF, typename T, size_t COUNT>
3969__device__ T _waveOpDoInverse(T* inOut, const T* val)
3970{
3971 for (size_t j = 0; j < COUNT; ++j)
3972 {
3973 inOut[j] = INTF::doInverse(inOut[j], val[j]);
3974 }
3975}
3976
3977template<typename INTF, typename T, size_t COUNT>
3978__device__ T _waveOpSetInitial(T* out, const T* val)
3979{
3980 for (size_t j = 0; j < COUNT; ++j)
3981 {
3982 out[j] = INTF::getInitial(val[j]);
3983 }
3984}
3985
3986template<typename INTF, typename T, size_t COUNT>
3987__device__ T _wavePrefixInvertableMultiple(WarpMask mask, T* val)
3988{
3989 const int offsetSize = _waveCalcPow2Offset(mask);
3990
3991 const int laneId = _getLaneId();
3992 T originalVal[COUNT];
3993 _waveOpCopy<INTF, T, COUNT>(originalVal, val);
3994
3995 if (offsetSize > 0)
3996 {
3997 // Sum is calculated inclusive of this lanes value
3998 for (int i = 1; i < offsetSize; i += i)
3999 {
4000 // TODO(JS): Note that here I don't split the laneId outside so it's only tested once.
4001 // This may be better but it would also mean that there would be shfl between lanes
4002 // that are on different (albeit identical) instructions. So this seems more likely to
4003 // work as expected with everything in lock step.
4004 for (size_t j = 0; j < COUNT; ++j)
4005 {
4006 const T readVal = __shfl_up_sync(mask, val[j], i, offsetSize);
4007 if (laneId >= i)
4008 {
4009 val[j] = INTF::doOp(val[j], readVal);
4010 }
4011 }
4012 }
4013 // Remove originalVal from the result, by applyin inverse
4014 _waveOpDoInverse<INTF, T, COUNT>(val, originalVal);
4015 }
4016 else
4017 {
4018 _waveOpSetInitial<INTF, T, COUNT>(val, val);
4019 if (!_waveIsSingleLane(mask))
4020 {
4021 int remaining = mask;
4022 while (remaining)
4023 {
4024 const int laneBit = remaining & -remaining;
4025 // Get the sourceLane
4026 const int srcLane = __ffs(laneBit) - 1;
4027
4028 for (size_t j = 0; j < COUNT; ++j)
4029 {
4030 // Broadcast (can also broadcast to self)
4031 const T readValue = __shfl_sync(mask, originalVal[j], srcLane);
4032 // Only accumulate if srcLane is less than this lane
4033 if (srcLane < laneId)
4034 {
4035 val[j] = INTF::doOp(val[j], readValue);
4036 }
4037 remaining &= ~laneBit;
4038 }
4039 }
4040 }
4041 }
4042}
4043
4044template<typename INTF, typename T, size_t COUNT>
4045__device__ T _wavePrefixMultiple(WarpMask mask, T* val)
4046{
4047 const int offsetSize = _waveCalcPow2Offset(mask);
4048
4049 const int laneId = _getLaneId();
4050
4051 T work[COUNT];
4052 _waveOpCopy<INTF, T, COUNT>(work, val);
4053 _waveOpSetInitial<INTF, T, COUNT>(val, val);
4054
4055 if (offsetSize > 0)
4056 {
4057 // For transmitted value we will do it inclusively with this lanes value
4058 // For the result we do not include the lanes value. This means an extra op for each
4059 // iteration but means we don't need to have a divide at the end and also removes overflow
4060 // issues in that scenario.
4061 for (int i = 1; i < offsetSize; i += i)
4062 {
4063 for (size_t j = 0; j < COUNT; ++j)
4064 {
4065 const T readVal = __shfl_up_sync(mask, work[j], i, offsetSize);
4066 if (laneId >= i)
4067 {
4068 work[j] = INTF::doOp(work[j], readVal);
4069 val[j] = INTF::doOp(val[j], readVal);
4070 }
4071 }
4072 }
4073 }
4074 else
4075 {
4076 if (!_waveIsSingleLane(mask))
4077 {
4078 int remaining = mask;
4079 while (remaining)
4080 {
4081 const int laneBit = remaining & -remaining;
4082 // Get the sourceLane
4083 const int srcLane = __ffs(laneBit) - 1;
4084
4085 for (size_t j = 0; j < COUNT; ++j)
4086 {
4087 // Broadcast (can also broadcast to self)
4088 const T readValue = __shfl_sync(mask, work[j], srcLane);
4089 // Only accumulate if srcLane is less than this lane
4090 if (srcLane < laneId)
4091 {
4092 val[j] = INTF::doOp(val[j], readValue);
4093 }
4094 }
4095 remaining &= ~laneBit;
4096 }
4097 }
4098 }
4099}
4100
4101template<typename T>
4102__inline__ __device__ T _wavePrefixProduct(WarpMask mask, T val)
4103{
4104 return _wavePrefixScalar<WaveOpMul<T>, T>(mask, val);
4105}
4106
4107template<typename T>
4108__inline__ __device__ T _wavePrefixSum(WarpMask mask, T val)
4109{
4110 return _wavePrefixInvertableScalar<WaveOpAdd<T>, T>(mask, val);
4111}
4112
4113template<typename T>
4114__inline__ __device__ T _wavePrefixXor(WarpMask mask, T val)
4115{
4116 return _wavePrefixInvertableScalar<WaveOpXor<T>, T>(mask, val);
4117}
4118
4119template<typename T>
4120__inline__ __device__ T _wavePrefixOr(WarpMask mask, T val)
4121{
4122 return _wavePrefixScalar<WaveOpOr<T>, T>(mask, val);
4123}
4124
4125template<typename T>
4126__inline__ __device__ T _wavePrefixAnd(WarpMask mask, T val)
4127{
4128 return _wavePrefixScalar<WaveOpAnd<T>, T>(mask, val);
4129}
4130
4131
4132template<typename T>
4133__inline__ __device__ T _wavePrefixProductMultiple(WarpMask mask, T val)
4134{
4135 typedef typename ElementTypeTrait<T>::Type ElemType;
4136 _wavePrefixInvertableMultiple<WaveOpMul<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>(
4137 mask,
4138 (ElemType*)&val);
4139 return val;
4140}
4141
4142template<typename T>
4143__inline__ __device__ T _wavePrefixSumMultiple(WarpMask mask, T val)
4144{
4145 typedef typename ElementTypeTrait<T>::Type ElemType;
4146 _wavePrefixInvertableMultiple<WaveOpAdd<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>(
4147 mask,
4148 (ElemType*)&val);
4149 return val;
4150}
4151
4152template<typename T>
4153__inline__ __device__ T _wavePrefixXorMultiple(WarpMask mask, T val)
4154{
4155 typedef typename ElementTypeTrait<T>::Type ElemType;
4156 _wavePrefixInvertableMultiple<WaveOpXor<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>(
4157 mask,
4158 (ElemType*)&val);
4159 return val;
4160}
4161
4162template<typename T>
4163__inline__ __device__ T _wavePrefixOrMultiple(WarpMask mask, T val)
4164{
4165 typedef typename ElementTypeTrait<T>::Type ElemType;
4166 _wavePrefixMultiple<WaveOpOr<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>(
4167 mask,
4168 (ElemType*)&val);
4169 return val;
4170}
4171
4172template<typename T>
4173__inline__ __device__ T _wavePrefixAndMultiple(WarpMask mask, T val)
4174{
4175 typedef typename ElementTypeTrait<T>::Type ElemType;
4176 _wavePrefixMultiple<WaveOpAnd<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>(
4177 mask,
4178 (ElemType*)&val);
4179 return val;
4180}
4181
4182template<typename T>
4183__inline__ __device__ T _wavePrefixMin(WarpMask mask, T val)
4184{
4185 return _wavePrefixScalar<WaveOpMin<T>, T>(mask, val);
4186}
4187
4188template<typename T>
4189__inline__ __device__ T _wavePrefixMax(WarpMask mask, T val)
4190{
4191 return _wavePrefixScalar<WaveOpMax<T>, T>(mask, val);
4192}
4193
4194template<typename T>
4195__inline__ __device__ T _wavePrefixMinMultiple(WarpMask mask, T val)
4196{
4197 typedef typename ElementTypeTrait<T>::Type ElemType;
4198 _wavePrefixMultiple<WaveOpMin<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>(
4199 mask,
4200 (ElemType*)&val);
4201 return val;
4202}
4203
4204template<typename T>
4205__inline__ __device__ T _wavePrefixMaxMultiple(WarpMask mask, T val)
4206{
4207 typedef typename ElementTypeTrait<T>::Type ElemType;
4208 _wavePrefixMultiple<WaveOpMax<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>(
4209 mask,
4210 (ElemType*)&val);
4211 return val;
4212}
4213
4214// Wrapper structures for exclusive operations that use the overloaded getInitial method
4215template<typename T>
4217{
4218 __inline__ __device__ static T getInitial(T a) { return WaveOpMin<T>::getInitial(a, true); }
4219 __inline__ __device__ static T doOp(T a, T b) { return WaveOpMin<T>::doOp(a, b); }
4220};
4221
4222template<typename T>
4224{
4225 __inline__ __device__ static T getInitial(T a) { return WaveOpMax<T>::getInitial(a, true); }
4226 __inline__ __device__ static T doOp(T a, T b) { return WaveOpMax<T>::doOp(a, b); }
4227};
4228
4229// Inclusive prefix min/max functions (for WaveMultiPrefixInclusive*)
4230template<typename T>
4231__inline__ __device__ T _wavePrefixInclusiveMin(WarpMask mask, T val)
4232{
4233 return _wavePrefixMin(mask, val);
4234}
4235
4236template<typename T>
4237__inline__ __device__ T _wavePrefixInclusiveMax(WarpMask mask, T val)
4238{
4239 return _wavePrefixMax(mask, val);
4240}
4241
4242template<typename T>
4243__inline__ __device__ T _wavePrefixInclusiveMinMultiple(WarpMask mask, T val)
4244{
4245 return _wavePrefixMinMultiple(mask, val);
4246}
4247
4248template<typename T>
4249__inline__ __device__ T _wavePrefixInclusiveMaxMultiple(WarpMask mask, T val)
4250{
4251 return _wavePrefixMaxMultiple(mask, val);
4252}
4253
4254// Explicit exclusive prefix min/max functions (for WaveMultiPrefixExclusive*)
4255template<typename T>
4256__inline__ __device__ T _wavePrefixExclusiveMin(WarpMask mask, T val)
4257{
4258 return _wavePrefixScalar<WaveOpExclusiveMin<T>, T>(mask, val);
4259}
4260
4261template<typename T>
4262__inline__ __device__ T _wavePrefixExclusiveMax(WarpMask mask, T val)
4263{
4264 return _wavePrefixScalar<WaveOpExclusiveMax<T>, T>(mask, val);
4265}
4266
4267template<typename T>
4268__inline__ __device__ T _wavePrefixExclusiveMinMultiple(WarpMask mask, T val)
4269{
4270 typedef typename ElementTypeTrait<T>::Type ElemType;
4271 _wavePrefixMultiple<WaveOpExclusiveMin<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>(
4272 mask,
4273 (ElemType*)&val);
4274 return val;
4275}
4276
4277template<typename T>
4278__inline__ __device__ T _wavePrefixExclusiveMaxMultiple(WarpMask mask, T val)
4279{
4280 typedef typename ElementTypeTrait<T>::Type ElemType;
4281 _wavePrefixMultiple<WaveOpExclusiveMax<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>(
4282 mask,
4283 (ElemType*)&val);
4284 return val;
4285}
4286
4287template<typename T>
4288__inline__ __device__ uint4 _waveMatchScalar(WarpMask mask, T val)
4289{
4290 int pred;
4291 return make_uint4(__match_all_sync(mask, val, &pred), 0, 0, 0);
4292}
4293
4294template<typename T>
4295__inline__ __device__ uint4 _waveMatchMultiple(WarpMask mask, const T& inVal)
4296{
4297 typedef typename ElementTypeTrait<T>::Type ElemType;
4298 const size_t count = sizeof(T) / sizeof(ElemType);
4299 int pred;
4300 const ElemType* src = (const ElemType*)&inVal;
4301 uint matchBits = 0xffffffff;
4302 for (size_t i = 0; i < count && matchBits; ++i)
4303 {
4304 matchBits = matchBits & __match_all_sync(mask, src[i], &pred);
4305 }
4306 return make_uint4(matchBits, 0, 0, 0);
4307}
4308
4309__inline__ __device__ uint getAt(dim3 a, int b)
4310{
4311 SLANG_PRELUDE_ASSERT(b >= 0 && b < 3);
4312 return (&a.x)[b];
4313}
4314__inline__ __device__ uint3 operator*(uint3 a, dim3 b)
4315{
4316 uint3 r;
4317 r.x = a.x * b.x;
4318 r.y = a.y * b.y;
4319 r.z = a.z * b.z;
4320 return r;
4321}
4322
4323template<typename TResult, typename TInput>
4324__inline__ __device__ TResult slang_bit_cast(TInput val)
4325{
4326 return *(TResult*)(&val);
4327}
4328
4329/* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */
4330
4331
4332/* Type that defines the uniform entry point params. The actual content of this type is dependent on
4333the entry point parameters, and can be found via reflection or defined such that it matches the
4334shader appropriately.
4335*/
4336struct UniformEntryPointParams;
4337struct UniformState;
4338
4339// ---------------------- OptiX Ray Payload --------------------------------------
4340#ifdef SLANG_CUDA_ENABLE_OPTIX
4341
4342struct RayDesc
4343{
4344 float3 Origin;
4345 float TMin;
4347 float TMax;
4348};
4349
4350static __forceinline__ __device__ void* unpackOptiXRayPayloadPointer(uint32_t i0, uint32_t i1)
4351{
4352 const uint64_t uptr = static_cast<uint64_t>(i0) << 32 | i1;
4353 void* ptr = reinterpret_cast<void*>(uptr);
4354 return ptr;
4355}
4356
4357static __forceinline__ __device__ void packOptiXRayPayloadPointer(
4358 void* ptr,
4359 uint32_t& i0,
4360 uint32_t& i1)
4361{
4362 const uint64_t uptr = reinterpret_cast<uint64_t>(ptr);
4363 i0 = uptr >> 32;
4364 i1 = uptr & 0x00000000ffffffff;
4365}
4366
4367static __forceinline__ __device__ void* getOptiXRayPayloadPtr()
4368{
4369 const uint32_t u0 = optixGetPayload_0();
4370 const uint32_t u1 = optixGetPayload_1();
4371 return unpackOptiXRayPayloadPointer(u0, u1);
4372}
4373
4374// Maximum number of 32-bit registers for OptiX payload (32 registers = 128 bytes)
4375static constexpr size_t kMaxOptiXPayloadRegisters = 32;
4376
4377// Helper to pack/unpack payload to/from registers for small payloads (<= 128 bytes)
4378template<typename T, size_t N = (sizeof(T) + 3) / 4>
4379struct PayloadRegisters
4380{
4381 uint32_t regs[N > 0 ? N : 1];
4382
4383 __forceinline__ __device__ void pack(const T& payload) { memcpy(regs, &payload, sizeof(T)); }
4384
4385 __forceinline__ __device__ void unpack(T& payload) { memcpy(&payload, regs, sizeof(T)); }
4386};
4387
4388// Internal helper to call optixTrace with the right number of register arguments
4389template<typename T, size_t N = (sizeof(T) + 3) / 4>
4390__forceinline__ __device__ void optixTraceWithRegs(
4391 OptixTraversableHandle AccelerationStructure,
4392 float3 Origin,
4393 float3 Direction,
4394 float TMin,
4395 float TMax,
4396 float Time,
4397 uint32_t InstanceInclusionMask,
4398 uint32_t RayFlags,
4399 uint32_t RayContributionToHitGroupIndex,
4400 uint32_t MultiplierForGeometryContributionToHitGroupIndex,
4401 uint32_t MissShaderIndex,
4402 PayloadRegisters<T, N>& pr)
4403{
4404 // Call optixTrace with the appropriate number of payload registers
4405 if constexpr (N == 0)
4406 {
4407 optixTrace(
4408 AccelerationStructure,
4409 Origin,
4410 Direction,
4411 TMin,
4412 TMax,
4413 Time,
4414 InstanceInclusionMask,
4415 RayFlags,
4416 RayContributionToHitGroupIndex,
4417 MultiplierForGeometryContributionToHitGroupIndex,
4418 MissShaderIndex);
4419 }
4420 else if constexpr (N == 1)
4421 {
4422 optixTrace(
4423 AccelerationStructure,
4424 Origin,
4425 Direction,
4426 TMin,
4427 TMax,
4428 Time,
4429 InstanceInclusionMask,
4430 RayFlags,
4431 RayContributionToHitGroupIndex,
4432 MultiplierForGeometryContributionToHitGroupIndex,
4433 MissShaderIndex,
4434 pr.regs[0]);
4435 }
4436 else if constexpr (N == 2)
4437 {
4438 optixTrace(
4439 AccelerationStructure,
4440 Origin,
4441 Direction,
4442 TMin,
4443 TMax,
4444 Time,
4445 InstanceInclusionMask,
4446 RayFlags,
4447 RayContributionToHitGroupIndex,
4448 MultiplierForGeometryContributionToHitGroupIndex,
4449 MissShaderIndex,
4450 pr.regs[0],
4451 pr.regs[1]);
4452 }
4453 else if constexpr (N == 3)
4454 {
4455 optixTrace(
4456 AccelerationStructure,
4457 Origin,
4458 Direction,
4459 TMin,
4460 TMax,
4461 Time,
4462 InstanceInclusionMask,
4463 RayFlags,
4464 RayContributionToHitGroupIndex,
4465 MultiplierForGeometryContributionToHitGroupIndex,
4466 MissShaderIndex,
4467 pr.regs[0],
4468 pr.regs[1],
4469 pr.regs[2]);
4470 }
4471 else if constexpr (N == 4)
4472 {
4473 optixTrace(
4474 AccelerationStructure,
4475 Origin,
4476 Direction,
4477 TMin,
4478 TMax,
4479 Time,
4480 InstanceInclusionMask,
4481 RayFlags,
4482 RayContributionToHitGroupIndex,
4483 MultiplierForGeometryContributionToHitGroupIndex,
4484 MissShaderIndex,
4485 pr.regs[0],
4486 pr.regs[1],
4487 pr.regs[2],
4488 pr.regs[3]);
4489 }
4490 else if constexpr (N == 5)
4491 {
4492 optixTrace(
4493 AccelerationStructure,
4494 Origin,
4495 Direction,
4496 TMin,
4497 TMax,
4498 Time,
4499 InstanceInclusionMask,
4500 RayFlags,
4501 RayContributionToHitGroupIndex,
4502 MultiplierForGeometryContributionToHitGroupIndex,
4503 MissShaderIndex,
4504 pr.regs[0],
4505 pr.regs[1],
4506 pr.regs[2],
4507 pr.regs[3],
4508 pr.regs[4]);
4509 }
4510 else if constexpr (N == 6)
4511 {
4512 optixTrace(
4513 AccelerationStructure,
4514 Origin,
4515 Direction,
4516 TMin,
4517 TMax,
4518 Time,
4519 InstanceInclusionMask,
4520 RayFlags,
4521 RayContributionToHitGroupIndex,
4522 MultiplierForGeometryContributionToHitGroupIndex,
4523 MissShaderIndex,
4524 pr.regs[0],
4525 pr.regs[1],
4526 pr.regs[2],
4527 pr.regs[3],
4528 pr.regs[4],
4529 pr.regs[5]);
4530 }
4531 else if constexpr (N == 7)
4532 {
4533 optixTrace(
4534 AccelerationStructure,
4535 Origin,
4536 Direction,
4537 TMin,
4538 TMax,
4539 Time,
4540 InstanceInclusionMask,
4541 RayFlags,
4542 RayContributionToHitGroupIndex,
4543 MultiplierForGeometryContributionToHitGroupIndex,
4544 MissShaderIndex,
4545 pr.regs[0],
4546 pr.regs[1],
4547 pr.regs[2],
4548 pr.regs[3],
4549 pr.regs[4],
4550 pr.regs[5],
4551 pr.regs[6]);
4552 }
4553 else if constexpr (N == 8)
4554 {
4555 optixTrace(
4556 AccelerationStructure,
4557 Origin,
4558 Direction,
4559 TMin,
4560 TMax,
4561 Time,
4562 InstanceInclusionMask,
4563 RayFlags,
4564 RayContributionToHitGroupIndex,
4565 MultiplierForGeometryContributionToHitGroupIndex,
4566 MissShaderIndex,
4567 pr.regs[0],
4568 pr.regs[1],
4569 pr.regs[2],
4570 pr.regs[3],
4571 pr.regs[4],
4572 pr.regs[5],
4573 pr.regs[6],
4574 pr.regs[7]);
4575 }
4576 else if constexpr (N <= 16)
4577 {
4578 optixTrace(
4579 AccelerationStructure,
4580 Origin,
4581 Direction,
4582 TMin,
4583 TMax,
4584 Time,
4585 InstanceInclusionMask,
4586 RayFlags,
4587 RayContributionToHitGroupIndex,
4588 MultiplierForGeometryContributionToHitGroupIndex,
4589 MissShaderIndex,
4590 pr.regs[0],
4591 pr.regs[1],
4592 pr.regs[2],
4593 pr.regs[3],
4594 pr.regs[4],
4595 pr.regs[5],
4596 pr.regs[6],
4597 pr.regs[7],
4598 pr.regs[8],
4599 pr.regs[9],
4600 pr.regs[10],
4601 pr.regs[11],
4602 pr.regs[12],
4603 pr.regs[13],
4604 pr.regs[14],
4605 pr.regs[15]);
4606 }
4607 else if constexpr (N <= kMaxOptiXPayloadRegisters)
4608 {
4609 optixTrace(
4610 AccelerationStructure,
4611 Origin,
4612 Direction,
4613 TMin,
4614 TMax,
4615 Time,
4616 InstanceInclusionMask,
4617 RayFlags,
4618 RayContributionToHitGroupIndex,
4619 MultiplierForGeometryContributionToHitGroupIndex,
4620 MissShaderIndex,
4621 pr.regs[0],
4622 pr.regs[1],
4623 pr.regs[2],
4624 pr.regs[3],
4625 pr.regs[4],
4626 pr.regs[5],
4627 pr.regs[6],
4628 pr.regs[7],
4629 pr.regs[8],
4630 pr.regs[9],
4631 pr.regs[10],
4632 pr.regs[11],
4633 pr.regs[12],
4634 pr.regs[13],
4635 pr.regs[14],
4636 pr.regs[15],
4637 pr.regs[16],
4638 pr.regs[17],
4639 pr.regs[18],
4640 pr.regs[19],
4641 pr.regs[20],
4642 pr.regs[21],
4643 pr.regs[22],
4644 pr.regs[23],
4645 pr.regs[24],
4646 pr.regs[25],
4647 pr.regs[26],
4648 pr.regs[27],
4649 pr.regs[28],
4650 pr.regs[29],
4651 pr.regs[30],
4652 pr.regs[31]);
4653 }
4654}
4655
4656template<typename T>
4657__forceinline__ __device__ void optixTrace(
4658 OptixTraversableHandle AccelerationStructure,
4659 uint32_t RayFlags,
4660 uint32_t InstanceInclusionMask,
4661 uint32_t RayContributionToHitGroupIndex,
4662 uint32_t MultiplierForGeometryContributionToHitGroupIndex,
4663 uint32_t MissShaderIndex,
4664 RayDesc Ray,
4665 T* Payload)
4666{
4667 constexpr size_t numRegs = (sizeof(T) + 3) / 4;
4668
4669 if constexpr (numRegs <= kMaxOptiXPayloadRegisters)
4670 {
4671 // Register-based approach for small payloads
4672 PayloadRegisters<T> pr;
4673 pr.pack(*Payload);
4674
4675 optixTraceWithRegs<T>(
4676 AccelerationStructure,
4677 Ray.Origin,
4678 Ray.Direction,
4679 Ray.TMin,
4680 Ray.TMax,
4681 0.f, /* Time for motion blur */
4682 InstanceInclusionMask,
4683 RayFlags,
4684 RayContributionToHitGroupIndex,
4685 MultiplierForGeometryContributionToHitGroupIndex,
4686 MissShaderIndex,
4687 pr);
4688
4689 // Read back updated payload registers
4690 // Native optixTrace updates regs in place
4691 pr.unpack(*Payload);
4692 }
4693 else
4694 {
4695 // Pointer-based fallback for large payloads
4696 uint32_t r0, r1;
4697 packOptiXRayPayloadPointer((void*)Payload, r0, r1);
4698 optixTrace(
4699 AccelerationStructure,
4700 Ray.Origin,
4701 Ray.Direction,
4702 Ray.TMin,
4703 Ray.TMax,
4704 0.f,
4705 InstanceInclusionMask,
4706 RayFlags,
4707 RayContributionToHitGroupIndex,
4708 MultiplierForGeometryContributionToHitGroupIndex,
4709 MissShaderIndex,
4710 r0,
4711 r1);
4712 }
4713}
4714
4715// Non-template overload for empty payload case.
4716// When Slang's type legalization eliminates an empty payload struct,
4717// the generated code calls optixTrace without a payload argument.
4718__forceinline__ __device__ void optixTrace(
4719 OptixTraversableHandle AccelerationStructure,
4720 uint32_t RayFlags,
4721 uint32_t InstanceInclusionMask,
4722 uint32_t RayContributionToHitGroupIndex,
4723 uint32_t MultiplierForGeometryContributionToHitGroupIndex,
4724 uint32_t MissShaderIndex,
4725 RayDesc Ray)
4726{
4727 optixTrace(
4728 AccelerationStructure,
4729 Ray.Origin,
4730 Ray.Direction,
4731 Ray.TMin,
4732 Ray.TMax,
4733 0.f,
4734 InstanceInclusionMask,
4735 RayFlags,
4736 RayContributionToHitGroupIndex,
4737 MultiplierForGeometryContributionToHitGroupIndex,
4738 MissShaderIndex);
4739}
4740
4741#if (OPTIX_VERSION >= 90000)
4742__forceinline__ __device__ float4 optixGetSpherePositionAndRadius()
4743{
4744 float4 data[1];
4745 optixGetSphereData(data);
4746 return data[0];
4747}
4748#endif
4749
4750#if (OPTIX_VERSION >= 90000)
4751__forceinline__ __device__ float4
4752optixHitObjectGetSpherePositionAndRadius(OptixTraversableHandle* Obj)
4753{
4754 float4 data[1];
4755 optixHitObjectGetSphereData(data);
4756 return data[0];
4757}
4758#endif
4759
4760#if (OPTIX_VERSION >= 90000)
4761__forceinline__ __device__ Matrix<float, 2, 4> optixGetLssPositionsAndRadii()
4762{
4763 float4 data[2];
4764 optixGetLinearCurveVertexData(data);
4765 return makeMatrix<float, 2, 4>(data[0], data[1]);
4766}
4767#endif
4768
4769#if (OPTIX_VERSION >= 90000)
4770__forceinline__ __device__ Matrix<float, 2, 4> optixHitObjectGetLssPositionsAndRadii(
4772{
4773 float4 data[2];
4774 optixHitObjectGetLinearCurveVertexData(data);
4775 return makeMatrix<float, 2, 4>(data[0], data[1]);
4776}
4777#endif
4778
4779#if (OPTIX_VERSION >= 90000)
4780__forceinline__ __device__ bool optixIsSphereHit()
4781{
4782 return optixGetPrimitiveType() == OPTIX_PRIMITIVE_TYPE_SPHERE;
4783}
4784#endif
4785
4786#if (OPTIX_VERSION >= 90000)
4787__forceinline__ __device__ bool optixHitObjectIsSphereHit(OptixTraversableHandle* Obj)
4788{
4789 return optixGetPrimitiveType(optixHitObjectGetHitKind()) == OPTIX_PRIMITIVE_TYPE_SPHERE;
4790}
4791#endif
4792
4793#if (OPTIX_VERSION >= 90000)
4794__forceinline__ __device__ bool optixIsLSSHit()
4795{
4796 return optixGetPrimitiveType() == OPTIX_PRIMITIVE_TYPE_ROUND_LINEAR;
4797}
4798#endif
4799
4800#if (OPTIX_VERSION >= 90000)
4801__forceinline__ __device__ bool optixHitObjectIsLSSHit(OptixTraversableHandle* Obj)
4802{
4803 return optixGetPrimitiveType(optixHitObjectGetHitKind()) == OPTIX_PRIMITIVE_TYPE_ROUND_LINEAR;
4804}
4805#endif
4806
4807// Internal helper to call optixTraverse with the right number of register arguments
4808template<typename T, size_t N = (sizeof(T) + 3) / 4>
4809__forceinline__ __device__ void optixTraverseWithRegs(
4810 OptixTraversableHandle AccelerationStructure,
4811 float3 Origin,
4812 float3 Direction,
4813 float TMin,
4814 float TMax,
4815 float Time,
4816 uint32_t InstanceInclusionMask,
4817 uint32_t RayFlags,
4818 uint32_t RayContributionToHitGroupIndex,
4819 uint32_t MultiplierForGeometryContributionToHitGroupIndex,
4820 uint32_t MissShaderIndex,
4821 PayloadRegisters<T, N>& pr)
4822{
4823 // Call optixTraverse with the appropriate number of payload registers
4824 if constexpr (N == 0)
4825 {
4826 optixTraverse(
4827 AccelerationStructure,
4828 Origin,
4829 Direction,
4830 TMin,
4831 TMax,
4832 Time,
4833 InstanceInclusionMask,
4834 RayFlags,
4835 RayContributionToHitGroupIndex,
4836 MultiplierForGeometryContributionToHitGroupIndex,
4837 MissShaderIndex);
4838 }
4839 else if constexpr (N == 1)
4840 {
4841 optixTraverse(
4842 AccelerationStructure,
4843 Origin,
4844 Direction,
4845 TMin,
4846 TMax,
4847 Time,
4848 InstanceInclusionMask,
4849 RayFlags,
4850 RayContributionToHitGroupIndex,
4851 MultiplierForGeometryContributionToHitGroupIndex,
4852 MissShaderIndex,
4853 pr.regs[0]);
4854 }
4855 else if constexpr (N == 2)
4856 {
4857 optixTraverse(
4858 AccelerationStructure,
4859 Origin,
4860 Direction,
4861 TMin,
4862 TMax,
4863 Time,
4864 InstanceInclusionMask,
4865 RayFlags,
4866 RayContributionToHitGroupIndex,
4867 MultiplierForGeometryContributionToHitGroupIndex,
4868 MissShaderIndex,
4869 pr.regs[0],
4870 pr.regs[1]);
4871 }
4872 else if constexpr (N == 3)
4873 {
4874 optixTraverse(
4875 AccelerationStructure,
4876 Origin,
4877 Direction,
4878 TMin,
4879 TMax,
4880 Time,
4881 InstanceInclusionMask,
4882 RayFlags,
4883 RayContributionToHitGroupIndex,
4884 MultiplierForGeometryContributionToHitGroupIndex,
4885 MissShaderIndex,
4886 pr.regs[0],
4887 pr.regs[1],
4888 pr.regs[2]);
4889 }
4890 else if constexpr (N == 4)
4891 {
4892 optixTraverse(
4893 AccelerationStructure,
4894 Origin,
4895 Direction,
4896 TMin,
4897 TMax,
4898 Time,
4899 InstanceInclusionMask,
4900 RayFlags,
4901 RayContributionToHitGroupIndex,
4902 MultiplierForGeometryContributionToHitGroupIndex,
4903 MissShaderIndex,
4904 pr.regs[0],
4905 pr.regs[1],
4906 pr.regs[2],
4907 pr.regs[3]);
4908 }
4909 else if constexpr (N == 5)
4910 {
4911 optixTraverse(
4912 AccelerationStructure,
4913 Origin,
4914 Direction,
4915 TMin,
4916 TMax,
4917 Time,
4918 InstanceInclusionMask,
4919 RayFlags,
4920 RayContributionToHitGroupIndex,
4921 MultiplierForGeometryContributionToHitGroupIndex,
4922 MissShaderIndex,
4923 pr.regs[0],
4924 pr.regs[1],
4925 pr.regs[2],
4926 pr.regs[3],
4927 pr.regs[4]);
4928 }
4929 else if constexpr (N == 6)
4930 {
4931 optixTraverse(
4932 AccelerationStructure,
4933 Origin,
4934 Direction,
4935 TMin,
4936 TMax,
4937 Time,
4938 InstanceInclusionMask,
4939 RayFlags,
4940 RayContributionToHitGroupIndex,
4941 MultiplierForGeometryContributionToHitGroupIndex,
4942 MissShaderIndex,
4943 pr.regs[0],
4944 pr.regs[1],
4945 pr.regs[2],
4946 pr.regs[3],
4947 pr.regs[4],
4948 pr.regs[5]);
4949 }
4950 else if constexpr (N == 7)
4951 {
4952 optixTraverse(
4953 AccelerationStructure,
4954 Origin,
4955 Direction,
4956 TMin,
4957 TMax,
4958 Time,
4959 InstanceInclusionMask,
4960 RayFlags,
4961 RayContributionToHitGroupIndex,
4962 MultiplierForGeometryContributionToHitGroupIndex,
4963 MissShaderIndex,
4964 pr.regs[0],
4965 pr.regs[1],
4966 pr.regs[2],
4967 pr.regs[3],
4968 pr.regs[4],
4969 pr.regs[5],
4970 pr.regs[6]);
4971 }
4972 else if constexpr (N == 8)
4973 {
4974 optixTraverse(
4975 AccelerationStructure,
4976 Origin,
4977 Direction,
4978 TMin,
4979 TMax,
4980 Time,
4981 InstanceInclusionMask,
4982 RayFlags,
4983 RayContributionToHitGroupIndex,
4984 MultiplierForGeometryContributionToHitGroupIndex,
4985 MissShaderIndex,
4986 pr.regs[0],
4987 pr.regs[1],
4988 pr.regs[2],
4989 pr.regs[3],
4990 pr.regs[4],
4991 pr.regs[5],
4992 pr.regs[6],
4993 pr.regs[7]);
4994 }
4995 else if constexpr (N <= 16)
4996 {
4997 optixTraverse(
4998 AccelerationStructure,
4999 Origin,
5000 Direction,
5001 TMin,
5002 TMax,
5003 Time,
5004 InstanceInclusionMask,
5005 RayFlags,
5006 RayContributionToHitGroupIndex,
5007 MultiplierForGeometryContributionToHitGroupIndex,
5008 MissShaderIndex,
5009 pr.regs[0],
5010 pr.regs[1],
5011 pr.regs[2],
5012 pr.regs[3],
5013 pr.regs[4],
5014 pr.regs[5],
5015 pr.regs[6],
5016 pr.regs[7],
5017 pr.regs[8],
5018 pr.regs[9],
5019 pr.regs[10],
5020 pr.regs[11],
5021 pr.regs[12],
5022 pr.regs[13],
5023 pr.regs[14],
5024 pr.regs[15]);
5025 }
5026 else if constexpr (N <= kMaxOptiXPayloadRegisters)
5027 {
5028 optixTraverse(
5029 AccelerationStructure,
5030 Origin,
5031 Direction,
5032 TMin,
5033 TMax,
5034 Time,
5035 InstanceInclusionMask,
5036 RayFlags,
5037 RayContributionToHitGroupIndex,
5038 MultiplierForGeometryContributionToHitGroupIndex,
5039 MissShaderIndex,
5040 pr.regs[0],
5041 pr.regs[1],
5042 pr.regs[2],
5043 pr.regs[3],
5044 pr.regs[4],
5045 pr.regs[5],
5046 pr.regs[6],
5047 pr.regs[7],
5048 pr.regs[8],
5049 pr.regs[9],
5050 pr.regs[10],
5051 pr.regs[11],
5052 pr.regs[12],
5053 pr.regs[13],
5054 pr.regs[14],
5055 pr.regs[15],
5056 pr.regs[16],
5057 pr.regs[17],
5058 pr.regs[18],
5059 pr.regs[19],
5060 pr.regs[20],
5061 pr.regs[21],
5062 pr.regs[22],
5063 pr.regs[23],
5064 pr.regs[24],
5065 pr.regs[25],
5066 pr.regs[26],
5067 pr.regs[27],
5068 pr.regs[28],
5069 pr.regs[29],
5070 pr.regs[30],
5071 pr.regs[31]);
5072 }
5073}
5074
5075template<typename T>
5076__forceinline__ __device__ void optixTraverse(
5077 OptixTraversableHandle AccelerationStructure,
5078 uint32_t RayFlags,
5079 uint32_t InstanceInclusionMask,
5080 uint32_t RayContributionToHitGroupIndex,
5081 uint32_t MultiplierForGeometryContributionToHitGroupIndex,
5082 uint32_t MissShaderIndex,
5083 RayDesc Ray,
5084 T* Payload,
5085 OptixTraversableHandle* hitObj)
5086{
5087 constexpr size_t numRegs = (sizeof(T) + 3) / 4;
5088
5089 if constexpr (numRegs <= kMaxOptiXPayloadRegisters)
5090 {
5091 // Register-based approach for small payloads
5092 PayloadRegisters<T> pr;
5093 pr.pack(*Payload);
5094
5095 optixTraverseWithRegs<T>(
5096 AccelerationStructure,
5097 Ray.Origin,
5098 Ray.Direction,
5099 Ray.TMin,
5100 Ray.TMax,
5101 0.f, /* Time for motion blur */
5102 InstanceInclusionMask,
5103 RayFlags,
5104 RayContributionToHitGroupIndex,
5105 MultiplierForGeometryContributionToHitGroupIndex,
5106 MissShaderIndex,
5107 pr);
5108
5109 // Read back updated payload registers
5110 // Native optixTrace updates regs in place
5111 pr.unpack(*Payload);
5112 }
5113 else
5114 {
5115 // Pointer-based fallback for large payloads
5116 uint32_t r0, r1;
5117 packOptiXRayPayloadPointer((void*)Payload, r0, r1);
5118 optixTraverse(
5119 AccelerationStructure,
5120 Ray.Origin,
5121 Ray.Direction,
5122 Ray.TMin,
5123 Ray.TMax,
5124 0.f,
5125 InstanceInclusionMask,
5126 RayFlags,
5127 RayContributionToHitGroupIndex,
5128 MultiplierForGeometryContributionToHitGroupIndex,
5129 MissShaderIndex,
5130 r0,
5131 r1);
5132 }
5133}
5134
5135template<typename T>
5136__forceinline__ __device__ void optixTraverse(
5137 OptixTraversableHandle AccelerationStructure,
5138 uint32_t RayFlags,
5139 uint32_t InstanceInclusionMask,
5140 uint32_t RayContributionToHitGroupIndex,
5141 uint32_t MultiplierForGeometryContributionToHitGroupIndex,
5142 uint32_t MissShaderIndex,
5143 RayDesc Ray,
5144 float RayTime,
5145 T* Payload,
5146 OptixTraversableHandle* hitObj)
5147{
5148 constexpr size_t numRegs = (sizeof(T) + 3) / 4;
5149
5150 if constexpr (numRegs <= kMaxOptiXPayloadRegisters)
5151 {
5152 // Register-based approach for small payloads
5153 PayloadRegisters<T> pr;
5154 pr.pack(*Payload);
5155
5156 optixTraverseWithRegs<T>(
5157 AccelerationStructure,
5158 Ray.Origin,
5159 Ray.Direction,
5160 Ray.TMin,
5161 Ray.TMax,
5162 RayTime,
5163 InstanceInclusionMask,
5164 RayFlags,
5165 RayContributionToHitGroupIndex,
5166 MultiplierForGeometryContributionToHitGroupIndex,
5167 MissShaderIndex,
5168 pr);
5169
5170 // Read back updated payload registers
5171 // Native optixTrace updates regs in place
5172 pr.unpack(*Payload);
5173 }
5174 else
5175 {
5176 // Pointer-based fallback for large payloads
5177 uint32_t r0, r1;
5178 packOptiXRayPayloadPointer((void*)Payload, r0, r1);
5179 optixTraverse(
5180 AccelerationStructure,
5181 Ray.Origin,
5182 Ray.Direction,
5183 Ray.TMin,
5184 Ray.TMax,
5185 RayTime,
5186 InstanceInclusionMask,
5187 RayFlags,
5188 RayContributionToHitGroupIndex,
5189 MultiplierForGeometryContributionToHitGroupIndex,
5190 MissShaderIndex,
5191 r0,
5192 r1);
5193 }
5194}
5195
5196// Non-template overload for empty payload case.
5197// When Slang's type legalization eliminates an empty payload struct,
5198// the generated code calls optixTraverse without a payload argument.
5199__forceinline__ __device__ void optixTraverse(
5200 OptixTraversableHandle AccelerationStructure,
5201 uint32_t RayFlags,
5202 uint32_t InstanceInclusionMask,
5203 uint32_t RayContributionToHitGroupIndex,
5204 uint32_t MultiplierForGeometryContributionToHitGroupIndex,
5205 uint32_t MissShaderIndex,
5206 RayDesc Ray,
5207 OptixTraversableHandle* hitObj)
5208{
5209 optixTraverse(
5210 AccelerationStructure,
5211 Ray.Origin,
5212 Ray.Direction,
5213 Ray.TMin,
5214 Ray.TMax,
5215 0.f,
5216 InstanceInclusionMask,
5217 RayFlags,
5218 RayContributionToHitGroupIndex,
5219 MultiplierForGeometryContributionToHitGroupIndex,
5220 MissShaderIndex);
5221}
5222
5223#if (OPTIX_VERSION >= 80100)
5224static __forceinline__ __device__ bool slangOptixHitObjectIsHit(OptixTraversableHandle* hitObj)
5225{
5226 return optixHitObjectIsHit();
5227}
5228#endif
5229
5230#if (OPTIX_VERSION >= 80100)
5231static __forceinline__ __device__ bool slangOptixHitObjectIsMiss(OptixTraversableHandle* hitObj)
5232{
5233 return optixHitObjectIsMiss();
5234}
5235#endif
5236
5237#if (OPTIX_VERSION >= 80100)
5238static __forceinline__ __device__ bool slangOptixHitObjectIsNop(OptixTraversableHandle* hitObj)
5239{
5240 return optixHitObjectIsNop();
5241}
5242#endif
5243
5244#if (OPTIX_VERSION >= 90000)
5245static __forceinline__ __device__ uint
5246slangOptixHitObjectGetClusterId(OptixTraversableHandle* hitObj)
5247{
5248 return optixHitObjectGetClusterId();
5249}
5250#endif
5251
5252#if (OPTIX_VERSION >= 80100)
5253static __forceinline__ __device__ void optixMakeMissHitObject(
5254 uint MissShaderIndex,
5255 RayDesc Ray,
5256 OptixTraversableHandle* missObj)
5257{
5258 optixMakeMissHitObject(
5259 MissShaderIndex,
5260 Ray.Origin,
5261 Ray.Direction,
5262 Ray.TMin,
5263 Ray.TMax,
5264 0.f /* rayTime */
5265#if (OPTIX_VERSION >= 90000)
5266 ,
5267 OPTIX_RAY_FLAG_NONE /* rayFlags*/
5268#endif
5269 );
5270}
5271#endif
5272
5273#if (OPTIX_VERSION >= 80100)
5274static __forceinline__ __device__ void optixMakeMissHitObject(
5275 uint MissShaderIndex,
5276 RayDesc Ray,
5277 float CurrentTime,
5278 OptixTraversableHandle* missObj)
5279{
5280 optixMakeMissHitObject(
5281 MissShaderIndex,
5282 Ray.Origin,
5283 Ray.Direction,
5284 Ray.TMin,
5285 Ray.TMax,
5286 CurrentTime
5287#if (OPTIX_VERSION >= 90000)
5288 ,
5289 OPTIX_RAY_FLAG_NONE /* rayFlags*/
5290#endif
5291 );
5292}
5293#endif
5294
5295#if (OPTIX_VERSION >= 90000)
5296template<typename T>
5297static __forceinline__ __device__ void optixMakeHitObject(
5298 OptixTraversableHandle AccelerationStructure,
5299 uint InstanceIndex,
5300 uint GeometryIndex,
5301 uint PrimitiveIndex,
5302 uint HitKind,
5303 uint RayContributionToHitGroupIndex,
5304 uint MultiplierForGeometryContributionToHitGroupIndex,
5305 RayDesc Ray,
5306 T attr,
5307 OptixTraversableHandle* handle)
5308{
5309 OptixTraverseData data{};
5310 optixHitObjectGetTraverseData(&data);
5311 optixMakeHitObject(
5312 AccelerationStructure,
5313 Ray.Origin,
5314 Ray.Direction,
5315 Ray.TMin,
5316 0.f,
5317 OPTIX_RAY_FLAG_NONE, /* rayFlags*/
5318 data,
5319 nullptr, /*OptixTraversableHandle* transforms*/
5320 0 /*numTransforms */);
5321}
5322#elif (OPTIX_VERSION >= 80100)
5323template<typename T>
5324static __forceinline__ __device__ void optixMakeHitObject(
5325 OptixTraversableHandle AccelerationStructure,
5326 uint InstanceIndex,
5327 uint GeometryIndex,
5328 uint PrimitiveIndex,
5329 uint HitKind,
5330 uint RayContributionToHitGroupIndex,
5331 uint MultiplierForGeometryContributionToHitGroupIndex,
5332 RayDesc Ray,
5333 T attr,
5334 OptixTraversableHandle* handle)
5335{
5336 // OptiX 8.1 version: call native optixMakeHitObject directly
5337 optixMakeHitObject(
5338 AccelerationStructure, // handle
5339 Ray.Origin, // rayOrigin
5340 Ray.Direction, // rayDirection
5341 Ray.TMin, // tmin
5342 Ray.TMax, // tmax
5343 0.f, // rayTime
5344 RayContributionToHitGroupIndex, // sbtOffset
5345 MultiplierForGeometryContributionToHitGroupIndex, // sbtStride
5346 InstanceIndex, // instIdx
5347 nullptr, // transforms
5348 0, // numTransforms
5349 GeometryIndex, // sbtGASIdx
5350 PrimitiveIndex, // primIdx
5351 HitKind // hitKind
5352 /* no attributes passed - empty variadic pack */
5353 );
5354}
5355#endif
5356
5357#if (OPTIX_VERSION >= 90000)
5358template<typename T>
5359static __forceinline__ __device__ void optixMakeHitObject(
5360 uint HitGroupRecordIndex,
5361 OptixTraversableHandle AccelerationStructure,
5362 uint InstanceIndex,
5363 uint GeometryIndex,
5364 uint PrimitiveIndex,
5365 uint HitKind,
5366 RayDesc Ray,
5367 T attr,
5368 OptixTraversableHandle* handle)
5369{
5370 OptixTraverseData data{};
5371 optixHitObjectGetTraverseData(&data);
5372 optixMakeHitObject(
5373 AccelerationStructure,
5374 Ray.Origin,
5375 Ray.Direction,
5376 Ray.TMin,
5377 0.f,
5378 OPTIX_RAY_FLAG_NONE, /* rayFlags*/
5379 data,
5380 nullptr, /*OptixTraversableHandle* transforms*/
5381 0 /*numTransforms */);
5382}
5383#elif (OPTIX_VERSION >= 80100)
5384template<typename T>
5385static __forceinline__ __device__ void optixMakeHitObject(
5386 uint HitGroupRecordIndex,
5387 OptixTraversableHandle AccelerationStructure,
5388 uint InstanceIndex,
5389 uint GeometryIndex,
5390 uint PrimitiveIndex,
5391 uint HitKind,
5392 RayDesc Ray,
5393 T attr,
5394 OptixTraversableHandle* handle)
5395{
5396 // OptiX 8.1 version: call optixMakeHitObjectWithRecord directly
5397 optixMakeHitObjectWithRecord(
5398 AccelerationStructure, // handle
5399 Ray.Origin, // rayOrigin
5400 Ray.Direction, // rayDirection
5401 Ray.TMin, // tmin
5402 Ray.TMax, // tmax
5403 0.f, // rayTime
5404 HitGroupRecordIndex, // sbtRecordIndex
5405 InstanceIndex, // instIdx
5406 nullptr, // transforms
5407 0, // numTransforms
5408 GeometryIndex, // sbtGASIdx
5409 PrimitiveIndex, // primIdx
5410 HitKind // hitKind
5411 /* no attributes passed - empty variadic pack */
5412 );
5413}
5414#endif
5415
5416#if (OPTIX_VERSION >= 90000)
5417template<typename T>
5418static __forceinline__ __device__ void optixMakeHitObject(
5419 OptixTraversableHandle AccelerationStructure,
5420 uint InstanceIndex,
5421 uint GeometryIndex,
5422 uint PrimitiveIndex,
5423 uint HitKind,
5424 uint RayContributionToHitGroupIndex,
5425 uint MultiplierForGeometryContributionToHitGroupIndex,
5426 RayDesc Ray,
5427 float CurrentTime,
5428 T attr,
5429 OptixTraversableHandle* handle)
5430{
5431 OptixTraverseData data{};
5432 optixHitObjectGetTraverseData(&data);
5433 optixMakeHitObject(
5434 AccelerationStructure,
5435 Ray.Origin,
5436 Ray.Direction,
5437 Ray.TMin,
5438 CurrentTime,
5439 OPTIX_RAY_FLAG_NONE, /* rayFlags*/
5440 data,
5441 nullptr, /*OptixTraversableHandle* transforms*/
5442 0 /*numTransforms */);
5443}
5444#elif (OPTIX_VERSION >= 80100)
5445template<typename T>
5446static __forceinline__ __device__ void optixMakeHitObject(
5447 OptixTraversableHandle AccelerationStructure,
5448 uint InstanceIndex,
5449 uint GeometryIndex,
5450 uint PrimitiveIndex,
5451 uint HitKind,
5452 uint RayContributionToHitGroupIndex,
5453 uint MultiplierForGeometryContributionToHitGroupIndex,
5454 RayDesc Ray,
5455 float CurrentTime,
5456 T attr,
5457 OptixTraversableHandle* handle)
5458{
5459 // OptiX 8.1 version: call native optixMakeHitObject directly
5460 optixMakeHitObject(
5461 AccelerationStructure, // handle
5462 Ray.Origin, // rayOrigin
5463 Ray.Direction, // rayDirection
5464 Ray.TMin, // tmin
5465 Ray.TMax, // tmax
5466 CurrentTime, // rayTime
5467 RayContributionToHitGroupIndex, // sbtOffset
5468 MultiplierForGeometryContributionToHitGroupIndex, // sbtStride
5469 InstanceIndex, // instIdx
5470 nullptr, // transforms
5471 0, // numTransforms
5472 GeometryIndex, // sbtGASIdx
5473 PrimitiveIndex, // primIdx
5474 HitKind // hitKind
5475 /* no attributes passed - empty variadic pack */
5476 );
5477}
5478#endif
5479
5480#if (OPTIX_VERSION >= 90000)
5481template<typename T>
5482static __forceinline__ __device__ void optixMakeHitObject(
5483 uint HitGroupRecordIndex,
5484 OptixTraversableHandle AccelerationStructure,
5485 uint InstanceIndex,
5486 uint GeometryIndex,
5487 uint PrimitiveIndex,
5488 uint HitKind,
5489 RayDesc Ray,
5490 float CurrentTime,
5491 T attr,
5492 OptixTraversableHandle* handle)
5493{
5494 OptixTraverseData data{};
5495 optixHitObjectGetTraverseData(&data);
5496 optixMakeHitObject(
5497 AccelerationStructure,
5498 Ray.Origin,
5499 Ray.Direction,
5500 Ray.TMin,
5501 CurrentTime,
5502 OPTIX_RAY_FLAG_NONE, /* rayFlags*/
5503 data,
5504 nullptr, /*OptixTraversableHandle* transforms*/
5505 0 /*numTransforms */);
5506}
5507#elif (OPTIX_VERSION >= 80100)
5508template<typename T>
5509static __forceinline__ __device__ void optixMakeHitObject(
5510 uint HitGroupRecordIndex,
5511 OptixTraversableHandle AccelerationStructure,
5512 uint InstanceIndex,
5513 uint GeometryIndex,
5514 uint PrimitiveIndex,
5515 uint HitKind,
5516 RayDesc Ray,
5517 float CurrentTime,
5518 T attr,
5519 OptixTraversableHandle* handle)
5520{
5521 // OptiX 8.1 version: call optixMakeHitObjectWithRecord directly
5522 optixMakeHitObjectWithRecord(
5523 AccelerationStructure, // handle
5524 Ray.Origin, // rayOrigin
5525 Ray.Direction, // rayDirection
5526 Ray.TMin, // tmin
5527 Ray.TMax, // tmax
5528 CurrentTime, // rayTime
5529 HitGroupRecordIndex, // sbtRecordIndex
5530 InstanceIndex, // instIdx
5531 nullptr, // transforms
5532 0, // numTransforms
5533 GeometryIndex, // sbtGASIdx
5534 PrimitiveIndex, // primIdx
5535 HitKind // hitKind
5536 /* no attributes passed - empty variadic pack */
5537 );
5538}
5539#endif
5540
5541#if (OPTIX_VERSION >= 80100)
5542static __forceinline__ __device__ void slangOptixMakeNopHitObject(OptixTraversableHandle* Obj)
5543{
5544 optixMakeNopHitObject();
5545}
5546#endif
5547
5548#if (OPTIX_VERSION >= 80100)
5549// Internal helper to call optixInvoke with the right number of register arguments
5550template<typename T, size_t N = (sizeof(T) + 3) / 4>
5551__forceinline__ __device__ void optixInvokeWithRegs(PayloadRegisters<T, N>& pr)
5552{
5553 if constexpr (N == 0)
5554 {
5555 optixInvoke();
5556 }
5557 else if constexpr (N == 1)
5558 {
5559 optixInvoke(pr.regs[0]);
5560 }
5561 else if constexpr (N == 2)
5562 {
5563 optixInvoke(pr.regs[0], pr.regs[1]);
5564 }
5565 else if constexpr (N == 3)
5566 {
5567 optixInvoke(pr.regs[0], pr.regs[1], pr.regs[2]);
5568 }
5569 else if constexpr (N == 4)
5570 {
5571 optixInvoke(pr.regs[0], pr.regs[1], pr.regs[2], pr.regs[3]);
5572 }
5573 else if constexpr (N == 5)
5574 {
5575 optixInvoke(pr.regs[0], pr.regs[1], pr.regs[2], pr.regs[3], pr.regs[4]);
5576 }
5577 else if constexpr (N == 6)
5578 {
5579 optixInvoke(pr.regs[0], pr.regs[1], pr.regs[2], pr.regs[3], pr.regs[4], pr.regs[5]);
5580 }
5581 else if constexpr (N == 7)
5582 {
5583 optixInvoke(
5584 pr.regs[0],
5585 pr.regs[1],
5586 pr.regs[2],
5587 pr.regs[3],
5588 pr.regs[4],
5589 pr.regs[5],
5590 pr.regs[6]);
5591 }
5592 else if constexpr (N == 8)
5593 {
5594 optixInvoke(
5595 pr.regs[0],
5596 pr.regs[1],
5597 pr.regs[2],
5598 pr.regs[3],
5599 pr.regs[4],
5600 pr.regs[5],
5601 pr.regs[6],
5602 pr.regs[7]);
5603 }
5604 else if constexpr (N <= 16)
5605 {
5606 optixInvoke(
5607 pr.regs[0],
5608 pr.regs[1],
5609 pr.regs[2],
5610 pr.regs[3],
5611 pr.regs[4],
5612 pr.regs[5],
5613 pr.regs[6],
5614 pr.regs[7],
5615 pr.regs[8],
5616 pr.regs[9],
5617 pr.regs[10],
5618 pr.regs[11],
5619 pr.regs[12],
5620 pr.regs[13],
5621 pr.regs[14],
5622 pr.regs[15]);
5623 }
5624 else if constexpr (N <= kMaxOptiXPayloadRegisters)
5625 {
5626 optixInvoke(
5627 pr.regs[0],
5628 pr.regs[1],
5629 pr.regs[2],
5630 pr.regs[3],
5631 pr.regs[4],
5632 pr.regs[5],
5633 pr.regs[6],
5634 pr.regs[7],
5635 pr.regs[8],
5636 pr.regs[9],
5637 pr.regs[10],
5638 pr.regs[11],
5639 pr.regs[12],
5640 pr.regs[13],
5641 pr.regs[14],
5642 pr.regs[15],
5643 pr.regs[16],
5644 pr.regs[17],
5645 pr.regs[18],
5646 pr.regs[19],
5647 pr.regs[20],
5648 pr.regs[21],
5649 pr.regs[22],
5650 pr.regs[23],
5651 pr.regs[24],
5652 pr.regs[25],
5653 pr.regs[26],
5654 pr.regs[27],
5655 pr.regs[28],
5656 pr.regs[29],
5657 pr.regs[30],
5658 pr.regs[31]);
5659 }
5660}
5661
5662template<typename T>
5663static __forceinline__ __device__ void optixInvoke(
5664 OptixTraversableHandle AccelerationStructure,
5665 OptixTraversableHandle* HitOrMiss,
5666 T* Payload)
5667{
5668 constexpr size_t numRegs = (sizeof(T) + 3) / 4;
5669
5670 if constexpr (numRegs <= kMaxOptiXPayloadRegisters)
5671 {
5672 // Register-based approach for small payloads
5673 PayloadRegisters<T> pr;
5674 pr.pack(*Payload);
5675 optixInvokeWithRegs<T>(pr);
5676 // Read back updated payload registers
5677 pr.unpack(*Payload);
5678 }
5679 else
5680 {
5681 // Pointer-based fallback for large payloads
5682 uint32_t r0, r1;
5683 packOptiXRayPayloadPointer((void*)Payload, r0, r1);
5684 optixInvoke(r0, r1);
5685 }
5686}
5687
5688// Overload for empty payloads (when payload is eliminated by type legalization)
5689static __forceinline__ __device__ void optixInvoke(
5690 OptixTraversableHandle AccelerationStructure,
5691 OptixTraversableHandle* HitOrMiss)
5692{
5693 // Call OptiX invoke with no payload for empty payload case
5694 optixInvoke();
5695}
5696#endif
5697
5698#if (OPTIX_VERSION >= 80100)
5699static __forceinline__ __device__ RayDesc optixHitObjectGetRayDesc(OptixTraversableHandle* obj)
5700{
5701 RayDesc ray = {
5702 optixHitObjectGetWorldRayOrigin(),
5703 optixHitObjectGetRayTmin(),
5704 optixHitObjectGetWorldRayDirection(),
5705 optixHitObjectGetRayTmax()};
5706 return ray;
5707}
5708#endif
5709
5710#if (OPTIX_VERSION >= 80100)
5711static __forceinline__ __device__ uint
5712slangOptixHitObjectGetInstanceIndex(OptixTraversableHandle* Obj)
5713{
5714 return optixHitObjectGetInstanceIndex();
5715}
5716#endif
5717
5718#if (OPTIX_VERSION >= 80100)
5719static __forceinline__ __device__ uint slangOptixHitObjectGetInstanceId(OptixTraversableHandle* Obj)
5720{
5721 return optixHitObjectGetInstanceId();
5722}
5723#endif
5724
5725#if (OPTIX_VERSION >= 80000)
5726static __forceinline__ __device__ float slangOptixHitObjectGetRayTime(OptixTraversableHandle* Obj)
5727{
5728 return optixHitObjectGetRayTime();
5729}
5730#endif
5731
5732#if (OPTIX_VERSION >= 80100)
5733static __forceinline__ __device__ float slangOptixHitObjectGetRayTmax(OptixTraversableHandle* Obj)
5734{
5735 return optixHitObjectGetRayTmax();
5736}
5737#endif
5738
5739#if (OPTIX_VERSION >= 80100)
5740static __forceinline__ __device__ uint
5741slangOptixHitObjectGetSbtGASIndex(OptixTraversableHandle* Obj)
5742{
5743 return optixHitObjectGetSbtGASIndex();
5744}
5745#endif
5746
5747#if (OPTIX_VERSION >= 80100)
5748static __forceinline__ __device__ uint
5749slangOptixHitObjectGetPrimitiveIndex(OptixTraversableHandle* Obj)
5750{
5751 return optixHitObjectGetPrimitiveIndex();
5752}
5753#endif
5754
5755#if (OPTIX_VERSION >= 80100)
5756template<typename T>
5757static __forceinline__ __device__ T optixHitObjectGetAttribute(OptixTraversableHandle* Obj)
5758{
5759 constexpr size_t numInts = (sizeof(T) + sizeof(uint32_t) - 1) /
5760 sizeof(uint32_t); // Number of 32-bit values, rounded up
5761 static_assert(numInts <= 8, "Attribute type is too large");
5762
5763 // Create an array to hold the attribute values
5764 uint32_t values[numInts == 0 ? 1 : numInts] = {0}; // Ensure we have at least one element
5765
5766 // Read the appropriate number of attribute registers
5767 if constexpr (numInts > 0)
5768 values[0] = optixHitObjectGetAttribute_0();
5769 if constexpr (numInts > 1)
5770 values[1] = optixHitObjectGetAttribute_1();
5771 if constexpr (numInts > 2)
5772 values[2] = optixHitObjectGetAttribute_2();
5773 if constexpr (numInts > 3)
5774 values[3] = optixHitObjectGetAttribute_3();
5775 if constexpr (numInts > 4)
5776 values[4] = optixHitObjectGetAttribute_4();
5777 if constexpr (numInts > 5)
5778 values[5] = optixHitObjectGetAttribute_5();
5779 if constexpr (numInts > 6)
5780 values[6] = optixHitObjectGetAttribute_6();
5781 if constexpr (numInts > 7)
5782 values[7] = optixHitObjectGetAttribute_7();
5783
5784 // Reinterpret the array as the desired type
5785 T result;
5786 memcpy(&result, values, sizeof(T));
5787 return result;
5788}
5789#endif
5790
5791#if (OPTIX_VERSION >= 80100)
5792static __forceinline__ __device__ uint
5793slangOptixHitObjectGetSbtRecordIndex(OptixTraversableHandle* Obj)
5794{
5795 return optixHitObjectGetSbtRecordIndex();
5796}
5797#endif
5798
5799#if (OPTIX_VERSION >= 90000)
5800static __forceinline__ __device__ void slangOptixHitObjectSetSbtRecordIndex(
5802 uint sbtRecordIndex)
5803{
5804 optixHitObjectSetSbtRecordIndex(sbtRecordIndex);
5805}
5806#endif
5807
5808// HitObject transform matrix wrappers for SER (Shader Execution Reordering)
5809// These wrappers convert OptiX's float[12] matrix format to Slang's Matrix type
5810// Available in RG, CH, MS, CC, DC stages per OptiX documentation
5811// Note: optixHitObjectGetWorldToObjectTransformMatrix/optixHitObjectGetObjectToWorldTransformMatrix
5812// were added in OptiX 9.0 (not available in 8.0 or 8.1)
5813#if (OPTIX_VERSION >= 90000)
5814static __forceinline__ __device__ Matrix<float, 4, 3> slangOptixHitObjectGetWorldToObject(
5815 OptixTraversableHandle* hitObj)
5816{
5817 float m[12];
5818 optixHitObjectGetWorldToObjectTransformMatrix(m);
5819 // OptiX stores matrix as 3 rows of float4, we need to transpose to 4 rows of float3
5820 return makeMatrix<float, 4, 3>(
5821 make_float3(m[0], m[4], m[8]),
5822 make_float3(m[1], m[5], m[9]),
5823 make_float3(m[2], m[6], m[10]),
5824 make_float3(m[3], m[7], m[11]));
5825}
5826#endif
5827
5828#if (OPTIX_VERSION >= 90000)
5829static __forceinline__ __device__ Matrix<float, 4, 3> slangOptixHitObjectGetObjectToWorld(
5830 OptixTraversableHandle* hitObj)
5831{
5832 float m[12];
5833 optixHitObjectGetObjectToWorldTransformMatrix(m);
5834 // OptiX stores matrix as 3 rows of float4, we need to transpose to 4 rows of float3
5835 return makeMatrix<float, 4, 3>(
5836 make_float3(m[0], m[4], m[8]),
5837 make_float3(m[1], m[5], m[9]),
5838 make_float3(m[2], m[6], m[10]),
5839 make_float3(m[3], m[7], m[11]));
5840}
5841#endif
5842
5843// OptiX multi-level traversal wrappers
5844// These wrappers convert OptiX's float[12] matrix pointer returns to Slang's Matrix type
5845__device__ __forceinline__ Matrix<float, 3, 4> _slang_optixGetInstanceTransformFromHandle(
5846 ulonglong handle)
5847{
5848 const float4* m = optixGetInstanceTransformFromHandle(handle);
5849 // OptiX stores matrix as 3 rows of float4 in the array
5850 return makeMatrix<float, 3, 4>(m[0], m[1], m[2]);
5851}
5852
5853__device__ __forceinline__ Matrix<float, 3, 4> _slang_optixGetInstanceInverseTransformFromHandle(
5854 ulonglong handle)
5855{
5856 const float4* m = optixGetInstanceInverseTransformFromHandle(handle);
5857 // OptiX stores matrix as 3 rows of float4 in the array
5858 return makeMatrix<float, 3, 4>(m[0], m[1], m[2]);
5859}
5860
5861// OptiX transformation matrix wrappers
5862// These wrappers convert OptiX's float[12] matrix format to Slang's Matrix type
5863__device__ __forceinline__ Matrix<float, 3, 4> slangOptixGetObjectToWorldTransformMatrix()
5864{
5865 float m[12];
5866 optixGetObjectToWorldTransformMatrix(m);
5867 // OptiX stores matrix as 3 rows of float4 in the array
5868 return makeMatrix<float, 3, 4>(
5869 make_float4(m[0], m[1], m[2], m[3]),
5870 make_float4(m[4], m[5], m[6], m[7]),
5871 make_float4(m[8], m[9], m[10], m[11]));
5872}
5873
5874__device__ __forceinline__ Matrix<float, 3, 4> slangOptixGetWorldToObjectTransformMatrix()
5875{
5876 float m[12];
5877 optixGetWorldToObjectTransformMatrix(m);
5878 // OptiX stores matrix as 3 rows of float4 in the array
5879 return makeMatrix<float, 3, 4>(
5880 make_float4(m[0], m[1], m[2], m[3]),
5881 make_float4(m[4], m[5], m[6], m[7]),
5882 make_float4(m[8], m[9], m[10], m[11]));
5883}
5884
5885__device__ __forceinline__ Matrix<float, 4, 3> slangOptixGetObjectToWorldTransformMatrix4x3()
5886{
5887 float m[12];
5888 optixGetObjectToWorldTransformMatrix(m);
5889 // OptiX stores matrix as 3 rows of float4, we need to transpose to 4 rows of float3
5890 return makeMatrix<float, 4, 3>(
5891 make_float3(m[0], m[4], m[8]),
5892 make_float3(m[1], m[5], m[9]),
5893 make_float3(m[2], m[6], m[10]),
5894 make_float3(m[3], m[7], m[11]));
5895}
5896
5897__device__ __forceinline__ Matrix<float, 4, 3> slangOptixGetWorldToObjectTransformMatrix4x3()
5898{
5899 float m[12];
5900 optixGetWorldToObjectTransformMatrix(m);
5901 // OptiX stores matrix as 3 rows of float4, we need to transpose to 4 rows of float3
5902 return makeMatrix<float, 4, 3>(
5903 make_float3(m[0], m[4], m[8]),
5904 make_float3(m[1], m[5], m[9]),
5905 make_float3(m[2], m[6], m[10]),
5906 make_float3(m[3], m[7], m[11]));
5907}
5908
5909#else
5910// Define OptixTraversableHandle even if OptiX is not enabled.
5911// This allows RaytracingAccelerationStructure to be properly reflected in non-OptiX code.
5912typedef unsigned long long OptixTraversableHandle;
5913#endif
5914static const int kSlangTorchTensorMaxDim = 5;
5915
5916// TensorView
5917// NOTE: If you change this struct's layout, also update the hard-coded size/alignment
5918// in _createTypeLayout() in slang-type-layout.cpp.
5920{
5921 uint8_t* data;
5925
5926 template<typename T>
5927 __device__ T* data_ptr()
5928 {
5929 return reinterpret_cast<T*>(data);
5930 }
5931
5932 template<typename T>
5933 __device__ T* data_ptr_at(uint32_t index)
5934 {
5935 uint64_t offset = strides[0] * index;
5936 return reinterpret_cast<T*>(data + offset);
5937 }
5938
5939 template<typename T>
5940 __device__ T* data_ptr_at(uint2 index)
5941 {
5942 uint64_t offset = strides[0] * index.x + strides[1] * index.y;
5943 return reinterpret_cast<T*>(data + offset);
5944 }
5945
5946 template<typename T>
5947 __device__ T* data_ptr_at(uint3 index)
5948 {
5949 uint64_t offset = strides[0] * index.x + strides[1] * index.y + strides[2] * index.z;
5950 return reinterpret_cast<T*>(data + offset);
5951 }
5952
5953 template<typename T>
5954 __device__ T* data_ptr_at(uint4 index)
5955 {
5956 uint64_t offset = strides[0] * index.x + strides[1] * index.y + strides[2] * index.z +
5957 strides[3] * index.w;
5958 return reinterpret_cast<T*>(data + offset);
5959 }
5960
5961 template<typename T, unsigned int N>
5962 __device__ T* data_ptr_at(uint index[N])
5963 {
5964 uint64_t offset = 0;
5965 for (unsigned int i = 0; i < N; ++i)
5966 {
5967 offset += strides[i] * index[i];
5968 }
5969 return reinterpret_cast<T*>(data + offset);
5970 }
5971
5972 template<typename T>
5973 __device__ T& load(uint32_t x)
5974 {
5975 return *reinterpret_cast<T*>(data + strides[0] * x);
5976 }
5977 template<typename T>
5978 __device__ T& load(uint32_t x, uint32_t y)
5979 {
5980 return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y);
5981 }
5982 template<typename T>
5983 __device__ T& load(uint2 index)
5984 {
5985 return *reinterpret_cast<T*>(data + strides[0] * index.x + strides[1] * index.y);
5986 }
5987 template<typename T>
5988 __device__ T& load(uint32_t x, uint32_t y, uint32_t z)
5989 {
5990 return *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y + strides[2] * z);
5991 }
5992 template<typename T>
5993 __device__ T& load(uint3 index)
5994 {
5995 return *reinterpret_cast<T*>(
5996 data + strides[0] * index.x + strides[1] * index.y + strides[2] * index.z);
5997 }
5998 template<typename T>
5999 __device__ T& load(uint32_t x, uint32_t y, uint32_t z, uint32_t w)
6000 {
6001 return *reinterpret_cast<T*>(
6002 data + strides[0] * x + strides[1] * y + strides[2] * z + strides[3] * w);
6003 }
6004 template<typename T>
6005 __device__ T& load(uint4 index)
6006 {
6007 return *reinterpret_cast<T*>(
6008 data + strides[0] * index.x + strides[1] * index.y + strides[2] * index.z +
6009 strides[3] * index.w);
6010 }
6011 template<typename T>
6012 __device__ T& load(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4)
6013 {
6014 return *reinterpret_cast<T*>(
6015 data + strides[0] * i0 + strides[1] * i1 + strides[2] * i2 + strides[3] * i3 +
6016 strides[4] * i4);
6017 }
6018
6019 // Generic version of load
6020 template<typename T, unsigned int N>
6021 __device__ T& load(uint index[N])
6022 {
6023 uint64_t offset = 0;
6024 for (unsigned int i = 0; i < N; ++i)
6025 {
6026 offset += strides[i] * index[i];
6027 }
6028 return *reinterpret_cast<T*>(data + offset);
6029 }
6030
6031 template<typename T>
6032 __device__ void store(uint32_t x, T val)
6033 {
6034 *reinterpret_cast<T*>(data + strides[0] * x) = val;
6035 }
6036 template<typename T>
6037 __device__ void store(uint32_t x, uint32_t y, T val)
6038 {
6039 *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y) = val;
6040 }
6041 template<typename T>
6042 __device__ void store(uint2 index, T val)
6043 {
6044 *reinterpret_cast<T*>(data + strides[0] * index.x + strides[1] * index.y) = val;
6045 }
6046 template<typename T>
6047 __device__ void store(uint32_t x, uint32_t y, uint32_t z, T val)
6048 {
6049 *reinterpret_cast<T*>(data + strides[0] * x + strides[1] * y + strides[2] * z) = val;
6050 }
6051 template<typename T>
6052 __device__ void store(uint3 index, T val)
6053 {
6054 *reinterpret_cast<T*>(
6055 data + strides[0] * index.x + strides[1] * index.y + strides[2] * index.z) = val;
6056 }
6057 template<typename T>
6058 __device__ void store(uint32_t x, uint32_t y, uint32_t z, uint32_t w, T val)
6059 {
6060 *reinterpret_cast<T*>(
6061 data + strides[0] * x + strides[1] * y + strides[2] * z + strides[3] * w) = val;
6062 }
6063 template<typename T>
6064 __device__ void store(uint4 index, T val)
6065 {
6066 *reinterpret_cast<T*>(
6067 data + strides[0] * index.x + strides[1] * index.y + strides[2] * index.z +
6068 strides[3] * index.w) = val;
6069 }
6070 template<typename T>
6071 __device__ void store(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4, T val)
6072 {
6073 *reinterpret_cast<T*>(
6074 data + strides[0] * i0 + strides[1] * i1 + strides[2] * i2 + strides[3] * i3 +
6075 strides[4] * i4) = val;
6076 }
6077
6078 // Generic version
6079 template<typename T, unsigned int N>
6080 __device__ void store(uint index[N], T val)
6081 {
6082 uint64_t offset = 0;
6083 for (unsigned int i = 0; i < N; ++i)
6084 {
6085 offset += strides[i] * index[i];
6086 }
6087 *reinterpret_cast<T*>(data + offset) = val;
6088 }
6089};
6090
6091// Implementations for texture fetch/load functions using tex PTX intrinsics
6092// These are used for read-only texture access with integer coordinates.
6093
6094// 1D is not supported via PTX. Keeping the implementation below in case it ever gets supported.
6095template<typename T>
6097{
6098 // TODO: static_assert(false) can fail on some compilers, even if template is not instantiated.
6099 // We should check for this in hlsl.meta.slang instead.
6100 // static_assert(false, "CUDA does not support fetching from 1D textures");
6101}
6102
6103#if 0
6104#define SLANG_TEX1DFETCH_INT_IMPL(T, dtype, c) \
6105 template<> \
6106 SLANG_FORCE_INLINE SLANG_CUDA_CALL T tex1Dfetch_int(CUtexObject texObj, int x, int mip) \
6107 { \
6108 T result; \
6109 [[maybe_unused]] T stub; \
6110 asm("tex.level.1d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5}], %6;" \
6111 : c(result), c(stub), c(stub), c(stub) \
6112 : "l"(texObj), "r"(x), "r"(mip)); \
6113 return result; \
6114 } \
6115 template<> \
6116 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##2 tex1Dfetch_int(CUtexObject texObj, int x, int mip) \
6117 { \
6118 T result_x, result_y; \
6119 [[maybe_unused]] T stub; \
6120 asm("tex.level.1d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5}], %6;" \
6121 : c(result_x), c(result_y), c(stub), c(stub) \
6122 : "l"(texObj), "r"(x), "r"(mip)); \
6123 return make_##T##2(result_x, result_y); \
6124 } \
6125 template<> \
6126 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##4 tex1Dfetch_int(CUtexObject texObj, int x, int mip) \
6127 { \
6128 T result_x, result_y, result_z, result_w; \
6129 asm("tex.level.1d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5}], %6;" \
6130 : c(result_x), c(result_y), c(result_z), c(result_w) \
6131 : "l"(texObj), "r"(x), "r"(mip)); \
6132 return make_##T##4(result_x, result_y, result_z, result_w); \
6133 }
6134
6135SLANG_TEX1DFETCH_INT_IMPL(float, "f32", "=f")
6136SLANG_TEX1DFETCH_INT_IMPL(uint, "u32", "=r")
6137SLANG_TEX1DFETCH_INT_IMPL(int, "s32", "=r")
6138#endif
6139
6140template<typename T>
6142
6143#define SLANG_TEX2DFETCH_INT_IMPL(T, dtype, c) \
6144 template<> \
6145 SLANG_FORCE_INLINE SLANG_CUDA_CALL T tex2Dfetch_int(CUtexObject texObj, int x, int y, int mip) \
6146 { \
6147 T result; \
6148 [[maybe_unused]] T stub; \
6149 asm("tex.level.2d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6}], %7;" \
6150 : c(result), c(stub), c(stub), c(stub) \
6151 : "l"(texObj), "r"(x), "r"(y), "r"(mip)); \
6152 return result; \
6153 } \
6154 template<> \
6155 SLANG_FORCE_INLINE SLANG_CUDA_CALL \
6156 T##2 tex2Dfetch_int(CUtexObject texObj, int x, int y, int mip) \
6157 { \
6158 T result_x, result_y; \
6159 [[maybe_unused]] T stub; \
6160 asm("tex.level.2d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6}], %7;" \
6161 : c(result_x), c(result_y), c(stub), c(stub) \
6162 : "l"(texObj), "r"(x), "r"(y), "r"(mip)); \
6163 return make_##T##2(result_x, result_y); \
6164 } \
6165 template<> \
6166 SLANG_FORCE_INLINE SLANG_CUDA_CALL \
6167 T##4 tex2Dfetch_int(CUtexObject texObj, int x, int y, int mip) \
6168 { \
6169 T result_x, result_y, result_z, result_w; \
6170 asm("tex.level.2d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6}], %7;" \
6171 : c(result_x), c(result_y), c(result_z), c(result_w) \
6172 : "l"(texObj), "r"(x), "r"(y), "r"(mip)); \
6173 return make_##T##4(result_x, result_y, result_z, result_w); \
6174 }
6175
6176SLANG_TEX2DFETCH_INT_IMPL(float, "f32", "=f")
6177SLANG_TEX2DFETCH_INT_IMPL(uint, "u32", "=r")
6178SLANG_TEX2DFETCH_INT_IMPL(int, "s32", "=r")
6179
6180
6181template<typename T>
6183tex3Dfetch_int(CUtexObject texObj, int x, int y, int z, int mip);
6184
6185#define SLANG_TEX3DFETCH_INT_IMPL(T, dtype, c) \
6186 template<> \
6187 SLANG_FORCE_INLINE SLANG_CUDA_CALL T \
6188 tex3Dfetch_int(CUtexObject texObj, int x, int y, int z, int mip) \
6189 { \
6190 T result; \
6191 [[maybe_unused]] T stub; \
6192 asm("tex.level.3d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6, %7, %8}], %9;" \
6193 : c(result), c(stub), c(stub), c(stub) \
6194 : "l"(texObj), "r"(x), "r"(y), "r"(z), "r"(z) /* ignored */, "r"(mip)); \
6195 return result; \
6196 } \
6197 template<> \
6198 SLANG_FORCE_INLINE SLANG_CUDA_CALL \
6199 T##2 tex3Dfetch_int(CUtexObject texObj, int x, int y, int z, int mip) \
6200 { \
6201 T result_x, result_y; \
6202 [[maybe_unused]] T stub; \
6203 asm("tex.level.3d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6, %7, %8}], %9;" \
6204 : c(result_x), c(result_y), c(stub), c(stub) \
6205 : "l"(texObj), "r"(x), "r"(y), "r"(z), "r"(z) /* ignored */, "r"(mip)); \
6206 return make_##T##2(result_x, result_y); \
6207 } \
6208 template<> \
6209 SLANG_FORCE_INLINE SLANG_CUDA_CALL \
6210 T##4 tex3Dfetch_int(CUtexObject texObj, int x, int y, int z, int mip) \
6211 { \
6212 T result_x, result_y, result_z, result_w; \
6213 asm("tex.level.3d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6, %7, %8}], %9;" \
6214 : c(result_x), c(result_y), c(result_z), c(result_w) \
6215 : "l"(texObj), "r"(x), "r"(y), "r"(z), "r"(z) /* ignored */, "r"(mip)); \
6216 return make_##T##4(result_x, result_y, result_z, result_w); \
6217 }
6218
6219SLANG_TEX3DFETCH_INT_IMPL(float, "f32", "=f")
6220SLANG_TEX3DFETCH_INT_IMPL(uint, "u32", "=r")
6221SLANG_TEX3DFETCH_INT_IMPL(int, "s32", "=r")
6222
6223template<typename T>
6225tex1DArrayfetch_int(CUtexObject texObj, int x, int layer, int mip);
6226
6227#define SLANG_TEX1DARRAYFETCH_INT_IMPL(T, dtype, c) \
6228 template<> \
6229 SLANG_FORCE_INLINE SLANG_CUDA_CALL T \
6230 tex1DArrayfetch_int(CUtexObject texObj, int x, int layer, int mip) \
6231 { \
6232 T result; \
6233 [[maybe_unused]] T stub; \
6234 asm("tex.level.a1d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6}], %7;" \
6235 : c(result), c(stub), c(stub), c(stub) \
6236 : "l"(texObj), "r"(layer), "r"(x), "r"(mip)); \
6237 return result; \
6238 } \
6239 template<> \
6240 SLANG_FORCE_INLINE SLANG_CUDA_CALL \
6241 T##2 tex1DArrayfetch_int(CUtexObject texObj, int x, int layer, int mip) \
6242 { \
6243 T result_x, result_y; \
6244 [[maybe_unused]] T stub; \
6245 asm("tex.level.a1d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6}], %7;" \
6246 : c(result_x), c(result_y), c(stub), c(stub) \
6247 : "l"(texObj), "r"(layer), "r"(x), "r"(mip)); \
6248 return make_##T##2(result_x, result_y); \
6249 } \
6250 template<> \
6251 SLANG_FORCE_INLINE SLANG_CUDA_CALL \
6252 T##4 tex1DArrayfetch_int(CUtexObject texObj, int x, int layer, int mip) \
6253 { \
6254 T result_x, result_y, result_z, result_w; \
6255 asm("tex.level.a1d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6}], %7;" \
6256 : c(result_x), c(result_y), c(result_z), c(result_w) \
6257 : "l"(texObj), "r"(layer), "r"(x), "r"(mip)); \
6258 return make_##T##4(result_x, result_y, result_z, result_w); \
6259 }
6260
6261SLANG_TEX1DARRAYFETCH_INT_IMPL(float, "f32", "=f")
6263SLANG_TEX1DARRAYFETCH_INT_IMPL(int, "s32", "=r")
6264
6265template<typename T>
6267tex2DArrayfetch_int(CUtexObject texObj, int x, int y, int layer, int mip);
6268
6269#define SLANG_TEX2DARRAYFETCH_INT_IMPL(T, dtype, c) \
6270 template<> \
6271 SLANG_FORCE_INLINE SLANG_CUDA_CALL T \
6272 tex2DArrayfetch_int(CUtexObject texObj, int x, int y, int layer, int mip) \
6273 { \
6274 T result; \
6275 [[maybe_unused]] T stub; \
6276 asm("tex.level.a2d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6, %7, %8}], %9;" \
6277 : c(result), c(stub), c(stub), c(stub) \
6278 : "l"(texObj), "r"(layer), "r"(x), "r"(y), "r"(layer) /* ignored */, "r"(mip)); \
6279 return result; \
6280 } \
6281 template<> \
6282 SLANG_FORCE_INLINE SLANG_CUDA_CALL \
6283 T##2 tex2DArrayfetch_int(CUtexObject texObj, int x, int y, int layer, int mip) \
6284 { \
6285 T result_x, result_y; \
6286 [[maybe_unused]] T stub; \
6287 asm("tex.level.a2d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6, %7, %8}], %9;" \
6288 : c(result_x), c(result_y), c(stub), c(stub) \
6289 : "l"(texObj), "r"(layer), "r"(x), "r"(y), "r"(layer) /* ignored */, "r"(mip)); \
6290 return make_##T##2(result_x, result_y); \
6291 } \
6292 template<> \
6293 SLANG_FORCE_INLINE SLANG_CUDA_CALL \
6294 T##4 tex2DArrayfetch_int(CUtexObject texObj, int x, int y, int layer, int mip) \
6295 { \
6296 T result_x, result_y, result_z, result_w; \
6297 asm("tex.level.a2d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6, %7, %8}], %9;" \
6298 : c(result_x), c(result_y), c(result_z), c(result_w) \
6299 : "l"(texObj), "r"(layer), "r"(x), "r"(y), "r"(layer) /* ignored */, "r"(mip)); \
6300 return make_##T##4(result_x, result_y, result_z, result_w); \
6301 }
6302
6303SLANG_TEX2DARRAYFETCH_INT_IMPL(float, "f32", "=f")
6305SLANG_TEX2DARRAYFETCH_INT_IMPL(int, "s32", "=r")
6306
6307// Wave rotate helper functions - templated approach
6308#define SLANG_WARP_FULL_MASK 0xFFFFFFFF
6309
6310// Macro-based wave rotate implementation following codebase patterns
6311#define SLANG_WAVE_ROTATE_IMPL(T) \
6312 __device__ __forceinline__ T##2 _slang_waveRotate(T##2 value, unsigned int delta) \
6313 { \
6314 return make_##T##2( \
6315 (T)__shfl_sync( \
6316 SLANG_WARP_FULL_MASK, \
6317 value.x, \
6318 (_getLaneId() + delta) % SLANG_CUDA_WARP_SIZE), \
6319 (T)__shfl_sync( \
6320 SLANG_WARP_FULL_MASK, \
6321 value.y, \
6322 (_getLaneId() + delta) % SLANG_CUDA_WARP_SIZE)); \
6323 } \
6324 __device__ __forceinline__ T##3 _slang_waveRotate(T##3 value, unsigned int delta) \
6325 { \
6326 return make_##T##3( \
6327 (T)__shfl_sync( \
6328 SLANG_WARP_FULL_MASK, \
6329 value.x, \
6330 (_getLaneId() + delta) % SLANG_CUDA_WARP_SIZE), \
6331 (T)__shfl_sync( \
6332 SLANG_WARP_FULL_MASK, \
6333 value.y, \
6334 (_getLaneId() + delta) % SLANG_CUDA_WARP_SIZE), \
6335 (T)__shfl_sync( \
6336 SLANG_WARP_FULL_MASK, \
6337 value.z, \
6338 (_getLaneId() + delta) % SLANG_CUDA_WARP_SIZE)); \
6339 } \
6340 __device__ __forceinline__ T##4 _slang_waveRotate(T##4 value, unsigned int delta) \
6341 { \
6342 return make_##T##4( \
6343 (T)__shfl_sync( \
6344 SLANG_WARP_FULL_MASK, \
6345 value.x, \
6346 (_getLaneId() + delta) % SLANG_CUDA_WARP_SIZE), \
6347 (T)__shfl_sync( \
6348 SLANG_WARP_FULL_MASK, \
6349 value.y, \
6350 (_getLaneId() + delta) % SLANG_CUDA_WARP_SIZE), \
6351 (T)__shfl_sync( \
6352 SLANG_WARP_FULL_MASK, \
6353 value.z, \
6354 (_getLaneId() + delta) % SLANG_CUDA_WARP_SIZE), \
6355 (T)__shfl_sync( \
6356 SLANG_WARP_FULL_MASK, \
6357 value.w, \
6358 (_getLaneId() + delta) % SLANG_CUDA_WARP_SIZE)); \
6359 }
6360
6361// Generate wave rotate functions for all standard vector types
6371
6372#ifdef SLANG_CUDA_ENABLE_HALF
6374#endif
6375
6376// Special handling for boolean vectors (requires int conversion)
6377__device__ __forceinline__ bool2 _slang_waveRotate(bool2 value, unsigned int delta)
6378{
6379 int2 intValue = make_int2((int)value.x, (int)value.y);
6380 int2 result = _slang_waveRotate(intValue, delta);
6381 return make_bool2((bool)result.x, (bool)result.y);
6382}
6383
6384__device__ __forceinline__ bool3 _slang_waveRotate(bool3 value, unsigned int delta)
6385{
6386 int3 intValue = make_int3((int)value.x, (int)value.y, (int)value.z);
6387 int3 result = _slang_waveRotate(intValue, delta);
6388 return make_bool3((bool)result.x, (bool)result.y, (bool)result.z);
6389}
6390
6391__device__ __forceinline__ bool4 _slang_waveRotate(bool4 value, unsigned int delta)
6392{
6393 int4 intValue = make_int4((int)value.x, (int)value.y, (int)value.z, (int)value.w);
6394 int4 result = _slang_waveRotate(intValue, delta);
6395 return make_bool4((bool)result.x, (bool)result.y, (bool)result.z, (bool)result.w);
6396}
6397
6398#undef SLANG_WAVE_ROTATE_IMPL
6399
6400// Quad control operations for CUDA
6401__device__ __forceinline__ bool _slang_quadAny(bool expr)
6402{
6403 // Get values from all 4 lanes in the quad
6404 bool v0 = __shfl_sync(0xFFFFFFFF, expr, (_getLaneId() & 0xFFFFFFFC) | 0);
6405 bool v1 = __shfl_sync(0xFFFFFFFF, expr, (_getLaneId() & 0xFFFFFFFC) | 1);
6406 bool v2 = __shfl_sync(0xFFFFFFFF, expr, (_getLaneId() & 0xFFFFFFFC) | 2);
6407 bool v3 = __shfl_sync(0xFFFFFFFF, expr, (_getLaneId() & 0xFFFFFFFC) | 3);
6408 return v0 || v1 || v2 || v3;
6409}
6410
6411__device__ __forceinline__ bool _slang_quadAll(bool expr)
6412{
6413 // Get values from all 4 lanes in the quad
6414 bool v0 = __shfl_sync(0xFFFFFFFF, expr, (_getLaneId() & 0xFFFFFFFC) | 0);
6415 bool v1 = __shfl_sync(0xFFFFFFFF, expr, (_getLaneId() & 0xFFFFFFFC) | 1);
6416 bool v2 = __shfl_sync(0xFFFFFFFF, expr, (_getLaneId() & 0xFFFFFFFC) | 2);
6417 bool v3 = __shfl_sync(0xFFFFFFFF, expr, (_getLaneId() & 0xFFFFFFFC) | 3);
6418 return v0 && v1 && v2 && v3;
6419}
6420
6421// Clustered wave rotate operations for CUDA
6422// Clustered rotate rotates values within clusters of specified size
6423#define SLANG_WAVE_CLUSTERED_ROTATE_IMPL(T) \
6424 __device__ __forceinline__ T \
6425 _slang_waveClusteredRotate(T value, unsigned int delta, unsigned int clusterSize) \
6426 { \
6427 unsigned int laneId = _getLaneId(); \
6428 unsigned int clusterStart = (laneId / clusterSize) * clusterSize; \
6429 unsigned int targetLane = clusterStart + ((laneId - clusterStart + delta) % clusterSize); \
6430 return __shfl_sync(SLANG_WARP_FULL_MASK, value, targetLane); \
6431 } \
6432 __device__ __forceinline__ \
6433 T##2 _slang_waveClusteredRotate(T##2 value, unsigned int delta, unsigned int clusterSize) \
6434 { \
6435 unsigned int laneId = _getLaneId(); \
6436 unsigned int clusterStart = (laneId / clusterSize) * clusterSize; \
6437 unsigned int targetLane = clusterStart + ((laneId - clusterStart + delta) % clusterSize); \
6438 return make_##T##2( \
6439 (T)__shfl_sync(SLANG_WARP_FULL_MASK, value.x, targetLane), \
6440 (T)__shfl_sync(SLANG_WARP_FULL_MASK, value.y, targetLane)); \
6441 } \
6442 __device__ __forceinline__ \
6443 T##3 _slang_waveClusteredRotate(T##3 value, unsigned int delta, unsigned int clusterSize) \
6444 { \
6445 unsigned int laneId = _getLaneId(); \
6446 unsigned int clusterStart = (laneId / clusterSize) * clusterSize; \
6447 unsigned int targetLane = clusterStart + ((laneId - clusterStart + delta) % clusterSize); \
6448 return make_##T##3( \
6449 (T)__shfl_sync(SLANG_WARP_FULL_MASK, value.x, targetLane), \
6450 (T)__shfl_sync(SLANG_WARP_FULL_MASK, value.y, targetLane), \
6451 (T)__shfl_sync(SLANG_WARP_FULL_MASK, value.z, targetLane)); \
6452 } \
6453 __device__ __forceinline__ \
6454 T##4 _slang_waveClusteredRotate(T##4 value, unsigned int delta, unsigned int clusterSize) \
6455 { \
6456 unsigned int laneId = _getLaneId(); \
6457 unsigned int clusterStart = (laneId / clusterSize) * clusterSize; \
6458 unsigned int targetLane = clusterStart + ((laneId - clusterStart + delta) % clusterSize); \
6459 return make_##T##4( \
6460 (T)__shfl_sync(SLANG_WARP_FULL_MASK, value.x, targetLane), \
6461 (T)__shfl_sync(SLANG_WARP_FULL_MASK, value.y, targetLane), \
6462 (T)__shfl_sync(SLANG_WARP_FULL_MASK, value.z, targetLane), \
6463 (T)__shfl_sync(SLANG_WARP_FULL_MASK, value.w, targetLane)); \
6464 }
6465
6466// Generate clustered wave rotate functions for all standard types
6476
6477#ifdef SLANG_CUDA_ENABLE_HALF
6479#endif
6480
6481// Special handling for boolean clustered rotate
6482__device__ __forceinline__ bool _slang_waveClusteredRotate(
6483 bool value,
6484 unsigned int delta,
6485 unsigned int clusterSize)
6486{
6487 int intValue = (int)value;
6488 int result = _slang_waveClusteredRotate(intValue, delta, clusterSize);
6489 return (bool)result;
6490}
6491
6492__device__ __forceinline__ bool2
6493_slang_waveClusteredRotate(bool2 value, unsigned int delta, unsigned int clusterSize)
6494{
6495 int2 intValue = make_int2((int)value.x, (int)value.y);
6496 int2 result = _slang_waveClusteredRotate(intValue, delta, clusterSize);
6497 return make_bool2((bool)result.x, (bool)result.y);
6498}
6499
6500__device__ __forceinline__ bool3
6501_slang_waveClusteredRotate(bool3 value, unsigned int delta, unsigned int clusterSize)
6502{
6503 int3 intValue = make_int3((int)value.x, (int)value.y, (int)value.z);
6504 int3 result = _slang_waveClusteredRotate(intValue, delta, clusterSize);
6505 return make_bool3((bool)result.x, (bool)result.y, (bool)result.z);
6506}
6507
6508__device__ __forceinline__ bool4
6509_slang_waveClusteredRotate(bool4 value, unsigned int delta, unsigned int clusterSize)
6510{
6511 int4 intValue = make_int4((int)value.x, (int)value.y, (int)value.z, (int)value.w);
6512 int4 result = _slang_waveClusteredRotate(intValue, delta, clusterSize);
6513 return make_bool4((bool)result.x, (bool)result.y, (bool)result.z, (bool)result.w);
6514}
6515
6516#undef SLANG_WAVE_CLUSTERED_ROTATE_IMPL
6517
6518// ---------------------- OptiX Cooperative Vector Wrappers --------------------------------------
6519#ifdef SLANG_CUDA_ENABLE_OPTIX
6520
6521#if (OPTIX_VERSION >= 90000)
6522
6523// Template trait to extract vector size from OptixCoopVec<T, N>
6524// Conditional compilation for NVRTC compatibility
6525template<typename T>
6526struct OptixCoopVecTraits;
6527
6528// Template specialization for OptiX's OptixCoopVec - only enabled when cooperative vectors are
6529// available NVRTC explicitly disables cooperative vectors by setting
6530// OPTIX_INCLUDE_COOPERATIVE_VECTOR to 0
6531#if defined(OPTIX_VERSION) && OPTIX_VERSION > 90000
6532template<typename T, unsigned int N>
6533struct OptixCoopVecTraits<OptixCoopVec<T, N>>
6534{
6535 static constexpr unsigned int size = N;
6536};
6537#endif
6538
6539template<
6540 typename VecTOut,
6541 typename VecTIn,
6542 OptixCoopVecElemType inputInterpretation,
6543 OptixCoopVecElemType matrixInterpretation,
6544 OptixCoopVecMatrixLayout matrixLayout>
6545__forceinline__ __device__ VecTOut slangOptixCoopVecMatMul(
6546 const VecTIn& inputVector,
6547 CUdeviceptr matrix,
6548 unsigned matrixOffset,
6549 bool transpose,
6550 unsigned matrixStride)
6551{
6552 constexpr unsigned N = OptixCoopVecTraits<VecTOut>::size; // Output vector size
6553 constexpr unsigned K = OptixCoopVecTraits<VecTIn>::size; // Input vector size
6554
6555 return optixCoopVecMatMul<
6556 VecTOut,
6557 VecTIn,
6558 inputInterpretation,
6559 matrixLayout,
6560 false,
6561 N,
6562 K,
6563 matrixInterpretation>(inputVector, matrix, matrixOffset, matrixStride);
6564}
6565
6566// OptiX cooperative vector matrix multiplication wrapper (WITH bias - 6 runtime params)
6567template<
6568 typename VecTOut,
6569 typename VecTIn,
6570 OptixCoopVecElemType inputInterpretation,
6571 OptixCoopVecElemType matrixInterpretation,
6572 OptixCoopVecMatrixLayout matrixLayout,
6573 OptixCoopVecElemType biasInterpretation>
6574__forceinline__ __device__ VecTOut slangOptixCoopVecMatMul(
6575 const VecTIn& inputVector,
6576 CUdeviceptr matrix,
6577 unsigned matrixOffset,
6578 CUdeviceptr bias,
6579 unsigned biasOffset,
6580 unsigned matrixStride)
6581{
6582 constexpr unsigned N = OptixCoopVecTraits<VecTOut>::size; // Output vector size
6583 constexpr unsigned K = OptixCoopVecTraits<VecTIn>::size; // Input vector size
6584
6585 // Call OptiX SDK with bias (6 runtime parameters)
6586 return optixCoopVecMatMul<
6587 VecTOut,
6588 VecTIn,
6589 inputInterpretation,
6590 matrixLayout,
6591 false,
6592 N,
6593 K,
6594 matrixInterpretation,
6595 biasInterpretation>(inputVector, matrix, matrixOffset, bias, biasOffset, matrixStride);
6596}
6597
6598// OptiX cooperative vector matrix multiplication wrapper (WITHOUT bias, 4 runtime params -
6599// StructuredBuffer variant)
6600template<
6601 typename VecTOut,
6602 typename VecTIn,
6603 OptixCoopVecElemType inputInterpretation,
6604 OptixCoopVecElemType matrixInterpretation,
6605 OptixCoopVecMatrixLayout matrixLayout>
6606__forceinline__ __device__ VecTOut slangOptixCoopVecMatMul(
6607 const VecTIn& inputVector,
6608 CUdeviceptr matrix,
6609 unsigned matrixOffset,
6610 unsigned matrixStride)
6611{
6612 constexpr unsigned N = OptixCoopVecTraits<VecTOut>::size; // Output vector size
6613 constexpr unsigned K = OptixCoopVecTraits<VecTIn>::size; // Input vector size
6614
6615 // Call OptiX SDK without bias and without transpose (4 runtime parameters)
6616 return optixCoopVecMatMul<
6617 VecTOut,
6618 VecTIn,
6619 inputInterpretation,
6620 matrixLayout,
6621 false,
6622 N,
6623 K,
6624 matrixInterpretation>(inputVector, matrix, matrixOffset, matrixStride);
6625}
6626
6627#endif // (OPTIX_VERSION >= 90000)
6628
6629#endif // SLANG_CUDA_ENABLE_OPTIX
6630
6631
6632// This implementation can only be enabled on CUDA Toolkit 12.5+
6633#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) || \
6634 (CUDA_VERSION >= 12050)
6635// Cooperative matrix support using mma.sync.aligned.m16n8k16 PTX instructions.
6636// Only the m16n16k16 shape is supported (implemented as 2x m16n8k16 internally).
6637// Load/Store use warp shuffle redistribution instead of wmma.load/wmma.store,
6638// matching SPIRV's cooperative_matrix interface where layout is specified at
6639// load/store time rather than at fragment declaration.
6640namespace Slang_CUDA_WMMA
6641{
6642
6643template<typename A, typename B>
6644struct IsSameType
6645{
6646 static constexpr bool value = false;
6647};
6648template<typename A>
6649struct IsSameType<A, A>
6650{
6651 static constexpr bool value = true;
6652};
6653
6654// Enums for template specialization
6655enum MatrixUse : int
6656{
6657 MatrixA = 0,
6658 MatrixB = 1,
6659 MatrixC = 2,
6660 MatrixD = 3,
6661};
6662
6663enum Layout : int
6664{
6665 RowMajor = 0,
6666 ColMajor = 1
6667};
6668
6669// ====================================================================================
6670// Register Counts for different matrices
6671// ====================================================================================
6672template<typename ElemT, int M, int N, int K, MatrixUse use>
6673struct RegisterCount;
6674
6675#if SLANG_CUDA_ENABLE_HALF
6676// Half (f16) - m16n16k16: 4 regs for A/B (mma.sync.m16n8k16 register layout), 4 regs for C/D
6677template<>
6678struct RegisterCount<half, 16, 16, 16, MatrixUse::MatrixA>
6679{
6680 static constexpr int value = 4;
6681};
6682template<>
6683struct RegisterCount<half, 16, 16, 16, MatrixUse::MatrixB>
6684{
6685 static constexpr int value = 4;
6686};
6687template<>
6688struct RegisterCount<half, 16, 16, 16, MatrixUse::MatrixC>
6689{
6690 static constexpr int value = 4;
6691};
6692template<>
6693struct RegisterCount<half, 16, 16, 16, MatrixUse::MatrixD>
6694{
6695 static constexpr int value = 4;
6696};
6697#endif // #if SLANG_CUDA_ENABLE_HALF
6698
6699#if SLANG_CUDA_ENABLE_BF16
6700// bfloat16 (bf16) - m16n16k16: 4 regs for A/B (mma.sync.m16n8k16 register layout).
6701// Note: bf16 MMA only supports float (f32) accumulators on PTX, so MatrixC/D
6702// register counts for __nv_bfloat16 are intentionally not defined here.
6703template<>
6704struct RegisterCount<__nv_bfloat16, 16, 16, 16, MatrixUse::MatrixA>
6705{
6706 static constexpr int value = 4;
6707};
6708template<>
6709struct RegisterCount<__nv_bfloat16, 16, 16, 16, MatrixUse::MatrixB>
6710{
6711 static constexpr int value = 4;
6712};
6713#endif // #if SLANG_CUDA_ENABLE_BF16
6714
6715// Float (f32) - 8 regs for C/D only
6716template<int M, int N, int K>
6717struct RegisterCount<float, M, N, K, MatrixUse::MatrixC>
6718{
6719 static constexpr int value = 8;
6720};
6721template<int M, int N, int K>
6722struct RegisterCount<float, M, N, K, MatrixUse::MatrixD>
6723{
6724 static constexpr int value = 8;
6725};
6726
6727// Int32 (s32) - 8 regs for C/D (accumulator for int8 operations)
6728template<int M, int N, int K>
6729struct RegisterCount<int32_t, M, N, K, MatrixUse::MatrixC>
6730{
6731 static constexpr int value = 8;
6732};
6733template<int M, int N, int K>
6734struct RegisterCount<int32_t, M, N, K, MatrixUse::MatrixD>
6735{
6736 static constexpr int value = 8;
6737};
6738
6739// Uint8 (u8) - varies by shape
6740template<>
6741struct RegisterCount<unsigned char, 16, 16, 16, MatrixUse::MatrixA>
6742{
6743 static constexpr int value = 2;
6744};
6745template<>
6746struct RegisterCount<unsigned char, 16, 16, 16, MatrixUse::MatrixB>
6747{
6748 static constexpr int value = 2;
6749};
6750
6751// Int8 (s8) - same as u8
6752template<>
6753struct RegisterCount<char, 16, 16, 16, MatrixUse::MatrixA>
6754{
6755 static constexpr int value = 2;
6756};
6757template<>
6758struct RegisterCount<char, 16, 16, 16, MatrixUse::MatrixB>
6759{
6760 static constexpr int value = 2;
6761};
6762
6763#if SLANG_CUDA_ENABLE_FP8
6764// fp8 - same as u8
6765template<>
6766struct RegisterCount<__nv_fp8_e4m3, 16, 16, 16, MatrixUse::MatrixA>
6767{
6768 static constexpr int value = 2;
6769};
6770template<>
6771struct RegisterCount<__nv_fp8_e4m3, 16, 16, 16, MatrixUse::MatrixB>
6772{
6773 static constexpr int value = 2;
6774};
6775template<>
6776struct RegisterCount<__nv_fp8_e5m2, 16, 16, 16, MatrixUse::MatrixA>
6777{
6778 static constexpr int value = 2;
6779};
6780template<>
6781struct RegisterCount<__nv_fp8_e5m2, 16, 16, 16, MatrixUse::MatrixB>
6782{
6783 static constexpr int value = 2;
6784};
6785#endif
6786
6787
6788// ====================================================================================
6789// MMA m16n8k16 Load/Store
6790// Uses 128-bit vectorized loads with warp shuffle redistribution for Matrix A,
6791// 64-bit paired loads for f32 C/D, and 32-bit coalesced loads for B.
6792// Falls back to element-wise access for column-major layouts where data is non-contiguous.
6793// ====================================================================================
6794
6795
6796// ====================================================================================
6797// MMA m16n8k16 Matrix A Load (f16, 16x16)
6798//
6799// Uses 128-bit vectorized load + warp shuffle redistribution.
6800// Each thread loads one 128-bit row-half, then 4 rounds of shuffles redistribute
6801// the data so each thread ends up with the correct MMA fragment registers.
6802//
6803// Target MMA fragment layout:
6804//
6805// |<----------- columns 0-7 ----------->|<---------- columns 8-15 ----------->|
6806// R\C | c0,1 | c2,3 | c4,5 | c6,7 || c8,9 | c10,11 | c12,13 | c14,15 |
6807// ---------+----------+----------+----------+----------++----------+----------+----------+----------+
6808// row 0 | T0:a0a1 | T1:a0a1 | T2:a0a1 | T3:a0a1 || T0:a4a5 | T1:a4a5 | T2:a4a5 | T3:a4a5
6809// | row 1 | T4:a0a1 | T5:a0a1 | T6:a0a1 | T7:a0a1 || T4:a4a5 | T5:a4a5 | T6:a4a5 |
6810// T7:a4a5 |
6811// .. | .. | .. | .. | .. || .. | .. | .. | .. |
6812// row 7 | T28:a0a1 | T29:a0a1 | T30:a0a1 | T31:a0a1 || T28:a4a5 | T29:a4a5 | T30:a4a5 | T31:a4a5
6813// |
6814// ---------+----------+----------+----------+----------++----------+----------+----------+----------+
6815// row 8 | T0:a2a3 | T1:a2a3 | T2:a2a3 | T3:a2a3 || T0:a6a7 | T1:a6a7 | T2:a6a7 | T3:a6a7
6816// | row 9 | T4:a2a3 | T5:a2a3 | T6:a2a3 | T7:a2a3 || T4:a6a7 | T5:a6a7 | T6:a6a7 |
6817// T7:a6a7 |
6818// .. | .. | .. | .. | .. || .. | .. | .. | .. |
6819// row 15 | T28:a2a3 | T29:a2a3 | T30:a2a3 | T31:a2a3 || T28:a6a7 | T29:a6a7 | T30:a6a7 | T31:a6a7
6820// |
6821// ---------+----------+----------+----------+----------++----------+----------+----------+----------+
6822//
6823// Initial state after 128-bit load (each thread holds one row-half):
6824//
6825// |<----------- columns 0-7 ----------->|<---------- columns 8-15 ----------->|
6826// R\C | c0,1 | c2,3 | c4,5 | c6,7 || c8,9 | c10,11 | c12,13 | c14,15 |
6827// ---------+----------+----------+----------+----------++----------+----------+----------+----------+
6828// row 0 | T0:a0a1 | T0:a2a3 | T0:a4a5 | T0:a6a7 || T1:a0a1 | T1:a2a3 | T1:a4a5 | T1:a6a7
6829// | row 1 | T2:a0a1 | T2:a2a3 | T2:a4a5 | T2:a6a7 || T3:a0a1 | T3:a2a3 | T3:a4a5 |
6830// T3:a6a7 |
6831// .. | .. | .. | .. | .. || .. | .. | .. | .. |
6832// row 7 | T14:a0a1 | T14:a2a3 | T14:a4a5 | T14:a6a7 || T15:a0a1 | T15:a2a3 | T15:a4a5 | T15:a6a7
6833// |
6834// ---------+----------+----------+----------+----------++----------+----------+----------+----------+
6835// row 8 | T16:a0a1 | T16:a2a3 | T16:a4a5 | T16:a6a7 || T17:a0a1 | T17:a2a3 | T17:a4a5 | T17:a6a7
6836// | row 9 | T18:a0a1 | T18:a2a3 | T18:a4a5 | T18:a6a7 || T19:a0a1 | T19:a2a3 | T19:a4a5 |
6837// T19:a6a7 |
6838// .. | .. | .. | .. | .. || .. | .. | .. | .. |
6839// row 15 | T30:a0a1 | T30:a2a3 | T30:a4a5 | T30:a6a7 || T31:a0a1 | T31:a2a3 | T31:a4a5 | T31:a6a7
6840// |
6841// ---------+----------+----------+----------+----------++----------+----------+----------+----------+
6842//
6843// After Round k=0 (shuffle loaded[0] -- captures col pairs 0,1 and 8,9):
6844//
6845// |<----------- columns 0-7 ----------->|<---------- columns 8-15 ----------->|
6846// R\C | c0,1 | c2,3 | c4,5 | c6,7 || c8,9 | c10,11 | c12,13 | c14,15 |
6847// ---------+----------+----------+----------+----------++----------+----------+----------+----------+
6848// row 0 | T0:a0a1 | | | || T0:a4a5 | | | | row 1
6849// | T4:a0a1 | | | || T4:a4a5 | | | |
6850// .. | .. | | | || .. | | | |
6851// row 7 | T28:a0a1 | | | || T28:a4a5 | | | |
6852// ---------+----------+----------+----------+----------++----------+----------+----------+----------+
6853// row 8 | T0:a2a3 | | | || T0:a6a7 | | | | row 9
6854// | T4:a2a3 | | | || T4:a6a7 | | | |
6855// .. | .. | | | || .. | | | |
6856// row 15 | T28:a2a3 | | | || T28:a6a7 | | | |
6857// ---------+----------+----------+----------+----------++----------+----------+----------+----------+
6858//
6859// After Round k=1 (shuffle loaded[1] -- adds col pairs 2,3 and 10,11):
6860//
6861// |<----------- columns 0-7 ----------->|<---------- columns 8-15 ----------->|
6862// R\C | c0,1 | c2,3 | c4,5 | c6,7 || c8,9 | c10,11 | c12,13 | c14,15 |
6863// ---------+----------+----------+----------+----------++----------+----------+----------+----------+
6864// row 0 | T0:a0a1 | T1:a0a1 | | || T0:a4a5 | T1:a4a5 | | | row 1
6865// | T4:a0a1 | T5:a0a1 | | || T4:a4a5 | T5:a4a5 | | |
6866// .. | .. | .. | | || .. | .. | | |
6867// row 7 | T28:a0a1 | T29:a0a1 | | || T28:a4a5 | T29:a4a5 | | |
6868// ---------+----------+----------+----------+----------++----------+----------+----------+----------+
6869// row 8 | T0:a2a3 | T1:a2a3 | | || T0:a6a7 | T1:a6a7 | | | row 9
6870// | T4:a2a3 | T5:a2a3 | | || T4:a6a7 | T5:a6a7 | | |
6871// .. | .. | .. | | || .. | .. | | |
6872// row 15 | T28:a2a3 | T29:a2a3 | | || T28:a6a7 | T29:a6a7 | | |
6873// ---------+----------+----------+----------+----------++----------+----------+----------+----------+
6874//
6875// After Round k=2 (shuffle loaded[2] -- adds col pairs 4,5 and 12,13):
6876//
6877// |<----------- columns 0-7 ----------->|<---------- columns 8-15 ----------->|
6878// R\C | c0,1 | c2,3 | c4,5 | c6,7 || c8,9 | c10,11 | c12,13 | c14,15 |
6879// ---------+----------+----------+----------+----------++----------+----------+----------+----------+
6880// row 0 | T0:a0a1 | T1:a0a1 | T2:a0a1 | || T0:a4a5 | T1:a4a5 | T2:a4a5 | | row 1
6881// | T4:a0a1 | T5:a0a1 | T6:a0a1 | || T4:a4a5 | T5:a4a5 | T6:a4a5 | |
6882// .. | .. | .. | .. | || .. | .. | .. | |
6883// row 7 | T28:a0a1 | T29:a0a1 | T30:a0a1 | || T28:a4a5 | T29:a4a5 | T30:a4a5 | |
6884// ---------+----------+----------+----------+----------++----------+----------+----------+----------+
6885// row 8 | T0:a2a3 | T1:a2a3 | T2:a2a3 | || T0:a6a7 | T1:a6a7 | T2:a6a7 | | row 9
6886// | T4:a2a3 | T5:a2a3 | T6:a2a3 | || T4:a6a7 | T5:a6a7 | T6:a6a7 | |
6887// .. | .. | .. | .. | || .. | .. | .. | |
6888// row 15 | T28:a2a3 | T29:a2a3 | T30:a2a3 | || T28:a6a7 | T29:a6a7 | T30:a6a7 | |
6889// ---------+----------+----------+----------+----------++----------+----------+----------+----------+
6890//
6891// After Round k=3 (shuffle loaded[3] -- adds col pairs 6,7 and 14,15 -- COMPLETE):
6892// (matches the target layout above)
6893//
6894// ====================================================================================
6895// Shuffle-only: redistributes pre-loaded uint32 data into MMA fragment registers.
6896// `loaded` must contain 4 uint32 values in the same format as a 128-bit memory load.
6897
6898// ====================================================================================
6899// MmaLoad: unified template for loading cooperative matrix tiles.
6900// Partial specializations dispatch to sub-tile loaders by matrix role and dimensions.
6901// ====================================================================================
6902
6903template<typename ElemT, Layout layout, int Row, int Col, MatrixUse use>
6904struct MMALoadHelper;
6905
6906template<typename ElemT, Layout layout>
6907struct MMALoadHelper<ElemT, layout, 16, 16, MatrixUse::MatrixA>
6908{
6909 static __device__ inline void exec(
6910 uint32_t* regs,
6911 const ElemT* buffer,
6912 int stride,
6913 unsigned laneid,
6914 unsigned gid,
6915 unsigned tid)
6916 {
6917 if constexpr (sizeof(ElemT) == 1)
6918 {
6919 // 8-bit MatrixA fragment layout (m16n8k16 with .s8/.u8):
6920 // regs[0] = (row=gid, cols=4*tid..4*tid+3) -- 4 contiguous bytes
6921 // regs[1] = (row=gid+8, cols=4*tid..4*tid+3) -- 4 contiguous bytes
6922 const uint8_t* ubuf = reinterpret_cast<const uint8_t*>(buffer);
6923 if constexpr (layout == Layout::RowMajor)
6924 {
6925 // Row-major: a thread's 4 column-adjacent elements are 4 contiguous
6926 // bytes in memory, so each register half is a single 32-bit load.
6927 regs[0] = *reinterpret_cast<const uint32_t*>(&ubuf[gid * stride + 4 * tid]);
6928 regs[1] = *reinterpret_cast<const uint32_t*>(&ubuf[(gid + 8) * stride + 4 * tid]);
6929 }
6930 else
6931 {
6932 // Col-major: each of the 4 column-adjacent elements lives in a
6933 // different column of the buffer, so we gather byte-by-byte.
6934 uint32_t r0 = 0;
6935 uint32_t r1 = 0;
6936#pragma unroll
6937 for (int e = 0; e < 4; e++)
6938 {
6939 unsigned col = 4 * tid + e;
6940 uint32_t b0 = (uint32_t)ubuf[col * stride + gid];
6941 uint32_t b1 = (uint32_t)ubuf[col * stride + (gid + 8)];
6942 r0 |= b0 << (e * 8);
6943 r1 |= b1 << (e * 8);
6944 }
6945 regs[0] = r0;
6946 regs[1] = r1;
6947 }
6948 return;
6949 }
6950 unsigned row = laneid >> 1;
6951 unsigned side = laneid & 1;
6952 uint4 loaded_v = *reinterpret_cast<const uint4*>(&buffer[row * stride + side * 8]);
6953 uint32_t* loaded = reinterpret_cast<uint32_t*>(&loaded_v);
6954
6955 const uint32_t mask = 0xFFFFFFFF;
6956 if constexpr (layout == Layout::RowMajor)
6957 {
6958 uint32_t tmp;
6959#pragma unroll
6960 for (int k = 0; k < 4; k++)
6961 {
6962 tmp = __shfl_sync(mask, loaded[k], gid * 2);
6963 if (tid == k)
6964 regs[0] = tmp;
6965 tmp = __shfl_sync(mask, loaded[k], (gid + 8) * 2);
6966 if (tid == k)
6967 regs[1] = tmp;
6968 tmp = __shfl_sync(mask, loaded[k], gid * 2 + 1);
6969 if (tid == k)
6970 regs[2] = tmp;
6971 tmp = __shfl_sync(mask, loaded[k], (gid + 8) * 2 + 1);
6972 if (tid == k)
6973 regs[3] = tmp;
6974 }
6975 }
6976 else
6977 {
6978 unsigned k = gid >> 1;
6979 unsigned half_sel = gid & 1;
6980
6981 uint32_t s[4][8];
6982#pragma unroll
6983 for (int ki = 0; ki < 4; ki++)
6984 {
6985 s[ki][0] = __shfl_sync(mask, loaded[ki], tid * 4);
6986 s[ki][1] = __shfl_sync(mask, loaded[ki], tid * 4 + 1);
6987 s[ki][2] = __shfl_sync(mask, loaded[ki], tid * 4 + 2);
6988 s[ki][3] = __shfl_sync(mask, loaded[ki], tid * 4 + 3);
6989 s[ki][4] = __shfl_sync(mask, loaded[ki], tid * 4 + 16);
6990 s[ki][5] = __shfl_sync(mask, loaded[ki], tid * 4 + 17);
6991 s[ki][6] = __shfl_sync(mask, loaded[ki], tid * 4 + 18);
6992 s[ki][7] = __shfl_sync(mask, loaded[ki], tid * 4 + 19);
6993 }
6994
6995 unsigned shift = half_sel * 16;
6996 uint16_t h0 = (uint16_t)(s[k][0] >> shift);
6997 uint16_t h1 = (uint16_t)(s[k][2] >> shift);
6998 regs[0] = (uint32_t)h0 | ((uint32_t)h1 << 16);
6999
7000 h0 = (uint16_t)(s[k][1] >> shift);
7001 h1 = (uint16_t)(s[k][3] >> shift);
7002 regs[1] = (uint32_t)h0 | ((uint32_t)h1 << 16);
7003
7004 h0 = (uint16_t)(s[k][4] >> shift);
7005 h1 = (uint16_t)(s[k][6] >> shift);
7006 regs[2] = (uint32_t)h0 | ((uint32_t)h1 << 16);
7007
7008 h0 = (uint16_t)(s[k][5] >> shift);
7009 h1 = (uint16_t)(s[k][7] >> shift);
7010 regs[3] = (uint32_t)h0 | ((uint32_t)h1 << 16);
7011 }
7012 }
7013};
7014
7015template<typename ElemT, Layout layout>
7016struct MMALoadHelper<ElemT, layout, 16, 8, MatrixUse::MatrixB>
7017{
7018 static __device__ inline void exec(
7019 uint32_t* regs,
7020 const ElemT* buffer,
7021 int stride,
7022 unsigned laneid,
7023 unsigned gid,
7024 unsigned tid)
7025 {
7026 if constexpr (sizeof(ElemT) == 1)
7027 {
7028 // 8-bit MatrixB fragment layout (m16n8k16 with .s8/.u8):
7029 // regs[0] = (rows=4*tid..4*tid+3, col=gid) -- 4 elements packed
7030 // gid here is in 0..7 (the column index inside the 16x8 tile).
7031 const uint8_t* ubuf = reinterpret_cast<const uint8_t*>(buffer);
7032 if constexpr (layout == Layout::ColMajor)
7033 {
7034 // Col-major: 4 row-adjacent elements in column `gid` are 4
7035 // contiguous bytes at &buffer[gid*stride + 4*tid].
7036 regs[0] = *reinterpret_cast<const uint32_t*>(&ubuf[gid * stride + 4 * tid]);
7037 }
7038 else
7039 {
7040 // Row-major: 4 row-adjacent elements all live in column `gid` of
7041 // 4 different rows; gather byte-by-byte.
7042 uint32_t r0 = 0;
7043#pragma unroll
7044 for (int e = 0; e < 4; e++)
7045 {
7046 unsigned row = 4 * tid + e;
7047 r0 |= ((uint32_t)ubuf[row * stride + gid]) << (e * 8);
7048 }
7049 regs[0] = r0;
7050 }
7051 return;
7052 }
7053 uint4 loaded_v;
7054 if constexpr (layout == Layout::ColMajor)
7055 {
7056 unsigned col = laneid >> 1;
7057 unsigned side = laneid & 1;
7058 loaded_v = *reinterpret_cast<const uint4*>(&buffer[col * stride + side * 8]);
7059 }
7060 else
7061 {
7062 unsigned row = laneid & 15;
7063 loaded_v = *reinterpret_cast<const uint4*>(&buffer[row * stride]);
7064 }
7065 uint32_t* loaded = reinterpret_cast<uint32_t*>(&loaded_v);
7066
7067 const uint32_t mask = 0xFFFFFFFF;
7068 if constexpr (layout == Layout::ColMajor)
7069 {
7070 uint32_t tmp;
7071#pragma unroll
7072 for (int k = 0; k < 4; k++)
7073 {
7074 tmp = __shfl_sync(mask, loaded[k], gid * 2);
7075 if (tid == k)
7076 regs[0] = tmp;
7077 tmp = __shfl_sync(mask, loaded[k], gid * 2 + 1);
7078 if (tid == k)
7079 regs[1] = tmp;
7080 }
7081 }
7082 else
7083 {
7084 uint32_t s[4][4];
7085#pragma unroll
7086 for (int ki = 0; ki < 4; ki++)
7087 {
7088 s[ki][0] = __shfl_sync(mask, loaded[ki], tid * 2);
7089 s[ki][1] = __shfl_sync(mask, loaded[ki], tid * 2 + 1);
7090 s[ki][2] = __shfl_sync(mask, loaded[ki], tid * 2 + 8);
7091 s[ki][3] = __shfl_sync(mask, loaded[ki], tid * 2 + 9);
7092 }
7093
7094 unsigned k = gid >> 1;
7095 unsigned shift = (gid & 1) * 16;
7096
7097 uint16_t h0 = (uint16_t)(s[k][0] >> shift);
7098 uint16_t h1 = (uint16_t)(s[k][1] >> shift);
7099 regs[0] = (uint32_t)h0 | ((uint32_t)h1 << 16);
7100
7101 h0 = (uint16_t)(s[k][2] >> shift);
7102 h1 = (uint16_t)(s[k][3] >> shift);
7103 regs[1] = (uint32_t)h0 | ((uint32_t)h1 << 16);
7104 }
7105 }
7106};
7107
7108template<typename ElemT, Layout layout>
7109struct MMALoadHelper<ElemT, layout, 16, 16, MatrixUse::MatrixB>
7110{
7111 static __device__ inline void exec(
7112 uint32_t* regs,
7113 const ElemT* buffer,
7114 int stride,
7115 unsigned laneid,
7116 unsigned gid,
7117 unsigned tid)
7118 {
7119 // 8-bit B uses 1 reg per m16n8k16 sub-tile; 16-bit B uses 2 regs.
7120 constexpr int regsPerSubTile = (sizeof(ElemT) == 1) ? 1 : 2;
7121 MMALoadHelper<ElemT, layout, 16, 8, MatrixUse::MatrixB>::exec(
7122 regs,
7123 buffer,
7124 stride,
7125 laneid,
7126 gid,
7127 tid);
7128 if constexpr (layout == Layout::RowMajor)
7129 MMALoadHelper<ElemT, layout, 16, 8, MatrixUse::MatrixB>::exec(
7130 regs + regsPerSubTile,
7131 buffer + 8,
7132 stride,
7133 laneid,
7134 gid,
7135 tid);
7136 else
7137 MMALoadHelper<ElemT, layout, 16, 8, MatrixUse::MatrixB>::exec(
7138 regs + regsPerSubTile,
7139 buffer + 8 * stride,
7140 stride,
7141 laneid,
7142 gid,
7143 tid);
7144 }
7145};
7146
7147template<typename ElemT, Layout layout>
7148struct MMALoadHelper<ElemT, layout, 16, 8, MatrixUse::MatrixC>
7149{
7150 static __device__ inline void exec(
7151 uint32_t* regs,
7152 const ElemT* buffer,
7153 int stride,
7154 unsigned laneid,
7155 unsigned gid,
7156 unsigned tid)
7157 {
7158 if constexpr (sizeof(ElemT) == 4)
7159 {
7160 const float* fbuf = reinterpret_cast<const float*>(buffer);
7161 uint4 loaded_v;
7162 if constexpr (layout == Layout::RowMajor)
7163 {
7164 unsigned row = laneid >> 1;
7165 unsigned side = laneid & 1;
7166 loaded_v = *reinterpret_cast<const uint4*>(&fbuf[row * stride + side * 4]);
7167 }
7168 else
7169 {
7170 unsigned col = laneid >> 2;
7171 unsigned chunk = laneid & 3;
7172 loaded_v = *reinterpret_cast<const uint4*>(&fbuf[col * stride + chunk * 4]);
7173 }
7174 uint32_t* loaded = reinterpret_cast<uint32_t*>(&loaded_v);
7175
7176 const uint32_t mask = 0xFFFFFFFF;
7177 if constexpr (layout == Layout::RowMajor)
7178 {
7179 uint32_t tmp;
7180 unsigned kb = (tid & 1) * 2;
7181 unsigned sb = (tid >> 1) * 2;
7182#pragma unroll
7183 for (int k = 0; k < 4; k++)
7184 {
7185#pragma unroll
7186 for (int j = 0; j < 4; j++)
7187 {
7188 unsigned srcLane = (j < 2) ? ((j == 0) ? gid * 2 : (gid + 8) * 2)
7189 : ((j == 2) ? gid * 2 + 1 : (gid + 8) * 2 + 1);
7190 tmp = __shfl_sync(mask, loaded[k], srcLane);
7191 if (k == kb && j == sb)
7192 regs[0] = tmp;
7193 if (k == kb + 1 && j == sb)
7194 regs[1] = tmp;
7195 if (k == kb && j == sb + 1)
7196 regs[2] = tmp;
7197 if (k == kb + 1 && j == sb + 1)
7198 regs[3] = tmp;
7199 }
7200 }
7201 }
7202 else
7203 {
7204 uint32_t tmp;
7205 unsigned k = gid & 3;
7206#pragma unroll
7207 for (int ki = 0; ki < 4; ki++)
7208 {
7209 tmp = __shfl_sync(mask, loaded[ki], tid * 8 + gid / 4);
7210 if (ki == k)
7211 regs[0] = tmp;
7212 tmp = __shfl_sync(mask, loaded[ki], tid * 8 + 4 + gid / 4);
7213 if (ki == k)
7214 regs[1] = tmp;
7215 tmp = __shfl_sync(mask, loaded[ki], tid * 8 + 2 + gid / 4);
7216 if (ki == k)
7217 regs[2] = tmp;
7218 tmp = __shfl_sync(mask, loaded[ki], tid * 8 + 6 + gid / 4);
7219 if (ki == k)
7220 regs[3] = tmp;
7221 }
7222 }
7223 }
7224 else
7225 {
7226 uint4 loaded_v;
7227 if constexpr (layout == Layout::RowMajor)
7228 {
7229 unsigned row = laneid & 15;
7230 loaded_v = *reinterpret_cast<const uint4*>(&buffer[row * stride]);
7231 }
7232 else
7233 {
7234 unsigned col = (laneid & 15) >> 1;
7235 unsigned side = laneid & 1;
7236 loaded_v = *reinterpret_cast<const uint4*>(&buffer[col * stride + side * 8]);
7237 }
7238 uint32_t* loaded = reinterpret_cast<uint32_t*>(&loaded_v);
7239
7240 const uint32_t mask = 0xFFFFFFFF;
7241 if constexpr (layout == Layout::RowMajor)
7242 {
7243 uint32_t tmp;
7244#pragma unroll
7245 for (int k = 0; k < 4; k++)
7246 {
7247 tmp = __shfl_sync(mask, loaded[k], gid);
7248 if (tid == k)
7249 regs[0] = tmp;
7250 tmp = __shfl_sync(mask, loaded[k], gid + 8);
7251 if (tid == k)
7252 regs[1] = tmp;
7253 }
7254 }
7255 else
7256 {
7257 uint32_t s[4][4];
7258#pragma unroll
7259 for (int ki = 0; ki < 4; ki++)
7260 {
7261 s[ki][0] = __shfl_sync(mask, loaded[ki], tid * 4);
7262 s[ki][1] = __shfl_sync(mask, loaded[ki], tid * 4 + 1);
7263 s[ki][2] = __shfl_sync(mask, loaded[ki], tid * 4 + 2);
7264 s[ki][3] = __shfl_sync(mask, loaded[ki], tid * 4 + 3);
7265 }
7266
7267 unsigned k = gid >> 1;
7268 unsigned shift = (gid & 1) * 16;
7269
7270 uint16_t h0 = (uint16_t)(s[k][0] >> shift);
7271 uint16_t h1 = (uint16_t)(s[k][2] >> shift);
7272 regs[0] = (uint32_t)h0 | ((uint32_t)h1 << 16);
7273
7274 h0 = (uint16_t)(s[k][1] >> shift);
7275 h1 = (uint16_t)(s[k][3] >> shift);
7276 regs[1] = (uint32_t)h0 | ((uint32_t)h1 << 16);
7277 }
7278 }
7279 }
7280};
7281
7282template<typename ElemT, Layout layout>
7283struct MMALoadHelper<ElemT, layout, 16, 16, MatrixUse::MatrixC>
7284{
7285 static __device__ inline void exec(
7286 uint32_t* regs,
7287 const ElemT* buffer,
7288 int stride,
7289 unsigned laneid,
7290 unsigned gid,
7291 unsigned tid)
7292 {
7293 constexpr int regsPerSubTile = (sizeof(ElemT) == 4) ? 4 : 2;
7294 MMALoadHelper<ElemT, layout, 16, 8, MatrixUse::MatrixC>::exec(
7295 regs,
7296 buffer,
7297 stride,
7298 laneid,
7299 gid,
7300 tid);
7301 if constexpr (layout == Layout::RowMajor)
7302 MMALoadHelper<ElemT, layout, 16, 8, MatrixUse::MatrixC>::exec(
7303 regs + regsPerSubTile,
7304 buffer + 8,
7305 stride,
7306 laneid,
7307 gid,
7308 tid);
7309 else
7310 MMALoadHelper<ElemT, layout, 16, 8, MatrixUse::MatrixC>::exec(
7311 regs + regsPerSubTile,
7312 buffer + 8 * stride,
7313 stride,
7314 laneid,
7315 gid,
7316 tid);
7317 }
7318};
7319
7320template<typename ElemT, Layout layout, int Row, int Col>
7321struct MMALoadHelper<ElemT, layout, Row, Col, MatrixUse::MatrixD>
7322 : MMALoadHelper<ElemT, layout, Row, Col, MatrixUse::MatrixC>
7323{
7324};
7325
7326template<typename ElemT, Layout layout, int Row, int Col, MatrixUse use>
7327__device__ inline void mmaLoad(uint32_t* regs, const void* ptr, int stride)
7328{
7329 const ElemT* buffer = static_cast<const ElemT*>(ptr);
7330 unsigned laneid;
7331 asm("mov.u32 %0, %%laneid;" : "=r"(laneid));
7332 unsigned gid = laneid >> 2;
7333 unsigned tid = laneid & 3;
7334 MMALoadHelper<ElemT, layout, Row, Col, use>::exec(regs, buffer, stride, laneid, gid, tid);
7335}
7336
7337// ====================================================================================
7338// MMAStoreHelper: unified template for storing cooperative matrix tiles.
7339// ====================================================================================
7340
7341template<typename ElemT, Layout layout, int Row, int Col>
7342struct MMAStoreHelper;
7343
7344template<typename ElemT, Layout layout>
7345struct MMAStoreHelper<ElemT, layout, 16, 8>
7346{
7347 static __device__ inline void exec(
7348 ElemT* buffer,
7349 const uint32_t* regs,
7350 int stride,
7351 unsigned laneid)
7352 {
7353 if constexpr (sizeof(ElemT) == 4)
7354 {
7355 float* fbuf = reinterpret_cast<float*>(buffer);
7356 if constexpr (layout == Layout::RowMajor)
7357 {
7358 unsigned write_row = laneid >> 1;
7359 unsigned write_side = laneid & 1;
7360 unsigned source_gid = (write_row < 8) ? write_row : (write_row - 8);
7361 unsigned src0 = source_gid * 4 + write_side * 2;
7362 unsigned src1 = src0 + 1;
7363
7364 const uint32_t mask = 0xFFFFFFFF;
7365 uint32_t r0_s0 = __shfl_sync(mask, regs[0], src0);
7366 uint32_t r1_s0 = __shfl_sync(mask, regs[1], src0);
7367 uint32_t r0_s1 = __shfl_sync(mask, regs[0], src1);
7368 uint32_t r1_s1 = __shfl_sync(mask, regs[1], src1);
7369 uint32_t r2_s0 = __shfl_sync(mask, regs[2], src0);
7370 uint32_t r3_s0 = __shfl_sync(mask, regs[3], src0);
7371 uint32_t r2_s1 = __shfl_sync(mask, regs[2], src1);
7372 uint32_t r3_s1 = __shfl_sync(mask, regs[3], src1);
7373
7374 uint4 out_v;
7375 uint32_t* out = reinterpret_cast<uint32_t*>(&out_v);
7376 if (write_row < 8)
7377 {
7378 out[0] = r0_s0;
7379 out[1] = r1_s0;
7380 out[2] = r0_s1;
7381 out[3] = r1_s1;
7382 }
7383 else
7384 {
7385 out[0] = r2_s0;
7386 out[1] = r3_s0;
7387 out[2] = r2_s1;
7388 out[3] = r3_s1;
7389 }
7390
7391 *reinterpret_cast<uint4*>(&fbuf[write_row * stride + write_side * 4]) = out_v;
7392 }
7393 else
7394 {
7395 unsigned write_col = laneid >> 2;
7396 unsigned write_chunk = laneid & 3;
7397 unsigned source_tid_val = write_col >> 1;
7398 unsigned gid_base = (write_chunk & 1) * 4;
7399 unsigned reg_idx = (write_chunk < 2) ? (write_col & 1) : (2 + (write_col & 1));
7400
7401 const uint32_t mask = 0xFFFFFFFF;
7402 uint32_t s[4][4];
7403#pragma unroll
7404 for (int r = 0; r < 4; r++)
7405 {
7406 s[r][0] = __shfl_sync(mask, regs[r], (gid_base + 0) * 4 + source_tid_val);
7407 s[r][1] = __shfl_sync(mask, regs[r], (gid_base + 1) * 4 + source_tid_val);
7408 s[r][2] = __shfl_sync(mask, regs[r], (gid_base + 2) * 4 + source_tid_val);
7409 s[r][3] = __shfl_sync(mask, regs[r], (gid_base + 3) * 4 + source_tid_val);
7410 }
7411
7412 uint4 out_v;
7413 uint32_t* out = reinterpret_cast<uint32_t*>(&out_v);
7414 out[0] = s[reg_idx][0];
7415 out[1] = s[reg_idx][1];
7416 out[2] = s[reg_idx][2];
7417 out[3] = s[reg_idx][3];
7418
7419 *reinterpret_cast<uint4*>(&fbuf[write_col * stride + write_chunk * 4]) = out_v;
7420 }
7421 }
7422 else
7423 {
7424 if constexpr (layout == Layout::RowMajor)
7425 {
7426 unsigned write_row = laneid & 15;
7427
7428 const uint32_t mask = 0xFFFFFFFF;
7429 uint32_t s[4][2];
7430#pragma unroll
7431 for (int k = 0; k < 4; k++)
7432 {
7433 s[k][0] = __shfl_sync(mask, regs[0], write_row * 4 + k);
7434 s[k][1] = __shfl_sync(mask, regs[1], write_row * 4 + k);
7435 }
7436
7437 uint4 out_v;
7438 uint32_t* out = reinterpret_cast<uint32_t*>(&out_v);
7439 if (write_row < 8)
7440 {
7441 out[0] = s[0][0];
7442 out[1] = s[1][0];
7443 out[2] = s[2][0];
7444 out[3] = s[3][0];
7445 }
7446 else
7447 {
7448 out[0] = s[0][1];
7449 out[1] = s[1][1];
7450 out[2] = s[2][1];
7451 out[3] = s[3][1];
7452 }
7453
7454 *reinterpret_cast<uint4*>(&buffer[write_row * stride]) = out_v;
7455 }
7456 else
7457 {
7458 unsigned write_col = (laneid & 15) >> 1;
7459 unsigned write_side = laneid & 1;
7460 unsigned source_tid_val = write_col >> 1;
7461 unsigned col_shift = (write_col & 1) * 16;
7462
7463 const uint32_t mask = 0xFFFFFFFF;
7464 uint32_t from_r0[8], from_r1[8];
7465#pragma unroll
7466 for (int k = 0; k < 4; k++)
7467 {
7468 unsigned src_even = (2 * k) * 4 + source_tid_val;
7469 unsigned src_odd = (2 * k + 1) * 4 + source_tid_val;
7470 from_r0[2 * k] = __shfl_sync(mask, regs[0], src_even);
7471 from_r0[2 * k + 1] = __shfl_sync(mask, regs[0], src_odd);
7472 from_r1[2 * k] = __shfl_sync(mask, regs[1], src_even);
7473 from_r1[2 * k + 1] = __shfl_sync(mask, regs[1], src_odd);
7474 }
7475
7476 uint4 out_v;
7477 uint32_t* out = reinterpret_cast<uint32_t*>(&out_v);
7478#pragma unroll
7479 for (int k = 0; k < 4; k++)
7480 {
7481 uint32_t val_even = (write_side == 0) ? from_r0[2 * k] : from_r1[2 * k];
7482 uint32_t val_odd = (write_side == 0) ? from_r0[2 * k + 1] : from_r1[2 * k + 1];
7483 uint16_t h0 = (uint16_t)(val_even >> col_shift);
7484 uint16_t h1 = (uint16_t)(val_odd >> col_shift);
7485 out[k] = (uint32_t)h0 | ((uint32_t)h1 << 16);
7486 }
7487
7488 *reinterpret_cast<uint4*>(&buffer[write_col * stride + write_side * 8]) = out_v;
7489 }
7490 }
7491 }
7492};
7493
7494template<typename ElemT, Layout layout>
7495struct MMAStoreHelper<ElemT, layout, 16, 16>
7496{
7497 static __device__ inline void exec(
7498 ElemT* buffer,
7499 const uint32_t* regs,
7500 int stride,
7501 unsigned laneid)
7502 {
7503 // 8-bit cooperative-matrix fragments (s8 / u8 / e4m3 / e5m2) only
7504 // exist as MatrixA / MatrixB on CUDA, and `Store` on MatA / MatB is
7505 // broken in this prelude across element sizes -- the inner
7506 // `MMAStoreHelper<...,16,8>` only implements the f32/int32 and f16/bf16
7507 // *accumulator* register layouts. For half/bf16 that produces
7508 // silently-wrong data; for 8-bit it would also read past the
7509 // fragment's `regs[]` array via `regs + regsPerSubTile` (an int8 16x16
7510 // MatA fragment carries 2 regs total, so `regs + 2` is undefined
7511 // behaviour). Fail loudly here until a MatrixUse-aware Store path
7512 // lands; in the meantime use `.equals()` against a known-expected
7513 // fragment to verify content (see int8-arith.slang for the pattern).
7514 static_assert(
7515 sizeof(ElemT) != 1,
7516 "CUDA `Store` on a 16x16 cooperative-matrix fragment with an 8-bit "
7517 "element type (s8 / u8 / e4m3 / e5m2) is not implemented: it would "
7518 "read past the fragment's register array. Use `.equals()` against "
7519 "a known-expected fragment to verify content (see int8-arith.slang) "
7520 "until MatrixA / MatrixB Store is rewritten.");
7521 constexpr int regsPerSubTile = (sizeof(ElemT) == 4) ? 4 : 2;
7522 MMAStoreHelper<ElemT, layout, 16, 8>::exec(buffer, regs, stride, laneid);
7523 if constexpr (layout == Layout::RowMajor)
7524 MMAStoreHelper<ElemT, layout, 16, 8>::exec(
7525 buffer + 8,
7526 regs + regsPerSubTile,
7527 stride,
7528 laneid);
7529 else
7530 MMAStoreHelper<ElemT, layout, 16, 8>::exec(
7531 buffer + 8 * stride,
7532 regs + regsPerSubTile,
7533 stride,
7534 laneid);
7535 }
7536};
7537
7538template<typename ElemT, Layout layout, int Row, int Col>
7539__device__ inline void mmaStore(void* ptr, const uint32_t* regs, int stride)
7540{
7541 ElemT* buffer = static_cast<ElemT*>(ptr);
7542 unsigned laneid;
7543 asm("mov.u32 %0, %%laneid;" : "=r"(laneid));
7544 MMAStoreHelper<ElemT, layout, Row, Col>::exec(buffer, regs, stride, laneid);
7545}
7546
7547// ====================================================================================
7548// Packed-pair traits for 16-bit float types (half / bfloat16).
7549//
7550// Both half and bfloat16 fit two elements per 32-bit register. Arithmetic operators
7551// for the corresponding pair types (__half2 / __nv_bfloat162) are already available:
7552// - __half2 ops are defined at global scope earlier in this prelude (when
7553// SLANG_CUDA_ENABLE_HALF is defined; cuda_fp16.h is included with
7554// __CUDA_NO_HALF2_OPERATORS__ so the prelude can supply HLSL-friendly versions).
7555// - __nv_bfloat162 ops come straight from <cuda_bf16.h>.
7556// So this trait only needs to expose the pair type and the per-type broadcast
7557// intrinsic — arithmetic in the WmmaFragment operator overloads can use native
7558// `+`, `-`, `*`, `/`, unary `-` directly.
7559// ====================================================================================
7560
7561template<typename T>
7562struct PackedFp16Traits;
7563
7564#if SLANG_CUDA_ENABLE_HALF
7565template<>
7566struct PackedFp16Traits<half>
7567{
7568 using PairType = __half2;
7569 static __device__ PairType broadcast(half v) { return __half2half2(v); }
7570};
7571#endif
7572
7573#if SLANG_CUDA_ENABLE_BF16
7574template<>
7575struct PackedFp16Traits<__nv_bfloat16>
7576{
7577 using PairType = __nv_bfloat162;
7578 static __device__ PairType broadcast(__nv_bfloat16 v) { return __bfloat162bfloat162(v); }
7579};
7580#endif
7581
7582template<typename T>
7583struct IsPackedFp16
7584{
7585 static constexpr bool value = false;
7586};
7587#if SLANG_CUDA_ENABLE_HALF
7588template<>
7589struct IsPackedFp16<half>
7590{
7591 static constexpr bool value = true;
7592};
7593#endif
7594#if SLANG_CUDA_ENABLE_BF16
7595template<>
7596struct IsPackedFp16<__nv_bfloat16>
7597{
7598 static constexpr bool value = true;
7599};
7600#endif
7601
7602template<typename T>
7603inline unsigned __device__ Pack32Helper(T value);
7604
7605#if SLANG_CUDA_ENABLE_HALF
7606template<>
7607inline unsigned __device__ Pack32Helper<half>(half value)
7608{
7609 return __half_as_ushort(value) | (__half_as_ushort(value) << 16);
7610};
7611#endif
7612
7613#if SLANG_CUDA_ENABLE_BF16
7614template<>
7615inline unsigned __device__ Pack32Helper<__nv_bfloat16>(__nv_bfloat16 value)
7616{
7617 unsigned short bits = __bfloat16_as_ushort(value);
7618 return (unsigned)bits | ((unsigned)bits << 16);
7619};
7620#endif
7621
7622template<>
7623inline unsigned __device__ Pack32Helper<float>(float value)
7624{
7625 return __float_as_uint(value);
7626};
7627
7628template<>
7629inline unsigned __device__ Pack32Helper<int>(int value)
7630{
7631 return (unsigned)value;
7632};
7633template<>
7634inline unsigned __device__ Pack32Helper<char>(char value)
7635{
7636 // Cast through unsigned char first to avoid sign-extension when `value` is
7637 // negative; otherwise the OR-chain below would leak the sign-extension bits
7638 // and produce 0xFFFFFFFF for any negative `value`.
7639 unsigned bits = (unsigned)(unsigned char)value;
7640 return (bits << 24) | (bits << 16) | (bits << 8) | bits;
7641};
7642template<>
7643inline unsigned __device__ Pack32Helper<unsigned char>(unsigned char value)
7644{
7645 unsigned bits = (unsigned)value;
7646 return (bits << 24) | (bits << 16) | (bits << 8) | bits;
7647};
7648
7649#if SLANG_CUDA_ENABLE_FP8
7650template<>
7651inline unsigned __device__ Pack32Helper<__nv_fp8_e4m3>(__nv_fp8_e4m3 value)
7652{
7653 // fp8 types are 1-byte structs; extract the storage byte and replicate.
7654 unsigned bits = (unsigned)*reinterpret_cast<const uint8_t*>(&value);
7655 return (bits << 24) | (bits << 16) | (bits << 8) | bits;
7656};
7657template<>
7658inline unsigned __device__ Pack32Helper<__nv_fp8_e5m2>(__nv_fp8_e5m2 value)
7659{
7660 unsigned bits = (unsigned)*reinterpret_cast<const uint8_t*>(&value);
7661 return (bits << 24) | (bits << 16) | (bits << 8) | bits;
7662};
7663#endif
7664
7665
7666// ====================================================================================
7667// WmmaFragment struct
7668// ====================================================================================
7669
7670// The dimensions of the fragment are specified by M, N, K which are totally determined during
7671// compile time, so slang already did the pre-filter on the shape & type combination.
7672template<typename T, int M, int N, int K, MatrixUse R>
7673struct WmmaFragment
7674{
7675 __device__ WmmaFragment() {}
7676 __device__ WmmaFragment(T scalarValue) { fill(scalarValue); }
7677
7678 typedef WmmaFragment<T, M, N, K, R> This;
7679 template<Layout layout>
7680 void __device__ Store(RWStructuredBuffer<T> buffer, uint element, uint stride)
7681 {
7682 Store<layout>(buffer.data, element, stride);
7683 }
7684
7685 template<Layout layout>
7686 static This __device__ Load(StructuredBuffer<T> buffer, uint element, uint stride)
7687 {
7688 return Load<layout>(buffer.data, element, stride);
7689 }
7690
7691 // There is no fill intrinsic in PTX wmma, so it's just 'move' value
7692 // to the fragment registers.
7693 void __device__ fill(T value)
7694 {
7695 unsigned packed = Pack32Helper(value);
7696 constexpr int nregs = RegisterCount<T, M, N, K, R>::value;
7697#pragma unroll
7698 for (int i = 0; i < nregs; i++)
7699 {
7700 regs[i] = packed;
7701 }
7702 }
7703
7704 // Zero-clear all registers using integer zero (enables CSE to single register).
7705 void __device__ clear()
7706 {
7707#pragma unroll
7708 for (int i = 0; i < RegsCount; i++)
7709 regs[i] = 0U;
7710 }
7711
7712 __device__ This operator*(T b)
7713 {
7714 This result;
7715 if constexpr (IsPackedFp16<T>::value)
7716 {
7717 using PairT = typename PackedFp16Traits<T>::PairType;
7718 PairT bv = PackedFp16Traits<T>::broadcast(b);
7719#pragma unroll
7720 for (int i = 0; i < RegsCount; i++)
7721 {
7722 PairT r = *reinterpret_cast<const PairT*>(&regs[i]) * bv;
7723 memcpy(&result.regs[i], &r, 4);
7724 }
7725 }
7726 else
7727 {
7728 for (int i = 0; i < GetLength(); i++)
7729 result.set(i, get(i) * b);
7730 }
7731 return result;
7732 }
7733
7734 __device__ This operator*(const This& b)
7735 {
7736 This result;
7737 if constexpr (IsPackedFp16<T>::value)
7738 {
7739 using PairT = typename PackedFp16Traits<T>::PairType;
7740#pragma unroll
7741 for (int i = 0; i < RegsCount; i++)
7742 {
7743 PairT r = *reinterpret_cast<const PairT*>(&regs[i]) *
7744 *reinterpret_cast<const PairT*>(&b.regs[i]);
7745 memcpy(&result.regs[i], &r, 4);
7746 }
7747 }
7748 else
7749 {
7750 for (int i = 0; i < GetLength(); i++)
7751 result.set(i, get(i) * b.get(i));
7752 }
7753 return result;
7754 }
7755
7756 __device__ This operator/(const This& other)
7757 {
7758 This result;
7759 if constexpr (IsPackedFp16<T>::value)
7760 {
7761 using PairT = typename PackedFp16Traits<T>::PairType;
7762#pragma unroll
7763 for (int i = 0; i < RegsCount; i++)
7764 {
7765 PairT r = *reinterpret_cast<const PairT*>(&regs[i]) /
7766 *reinterpret_cast<const PairT*>(&other.regs[i]);
7767 memcpy(&result.regs[i], &r, 4);
7768 }
7769 }
7770 else
7771 {
7772 for (int i = 0; i < GetLength(); i++)
7773 result.set(i, get(i) / other.get(i));
7774 }
7775 return result;
7776 }
7777
7778 __device__ This operator-(const This& other)
7779 {
7780 This result;
7781 if constexpr (IsPackedFp16<T>::value)
7782 {
7783 using PairT = typename PackedFp16Traits<T>::PairType;
7784#pragma unroll
7785 for (int i = 0; i < RegsCount; i++)
7786 {
7787 PairT r = *reinterpret_cast<const PairT*>(&regs[i]) -
7788 *reinterpret_cast<const PairT*>(&other.regs[i]);
7789 memcpy(&result.regs[i], &r, 4);
7790 }
7791 }
7792 else
7793 {
7794 for (int i = 0; i < GetLength(); i++)
7795 result.set(i, get(i) - other.get(i));
7796 }
7797 return result;
7798 }
7799
7800 __device__ This operator-()
7801 {
7802 This result;
7803 if constexpr (IsPackedFp16<T>::value)
7804 {
7805 using PairT = typename PackedFp16Traits<T>::PairType;
7806#pragma unroll
7807 for (int i = 0; i < RegsCount; i++)
7808 {
7809 PairT r = -*reinterpret_cast<const PairT*>(&regs[i]);
7810 memcpy(&result.regs[i], &r, 4);
7811 }
7812 }
7813 else
7814 {
7815 for (int i = 0; i < GetLength(); i++)
7816 result.set(i, -get(i));
7817 }
7818 return result;
7819 }
7820
7821 __device__ This operator+(const This& other)
7822 {
7823 This result;
7824 if constexpr (IsPackedFp16<T>::value)
7825 {
7826 using PairT = typename PackedFp16Traits<T>::PairType;
7827#pragma unroll
7828 for (int i = 0; i < RegsCount; i++)
7829 {
7830 PairT r = *reinterpret_cast<const PairT*>(&regs[i]) +
7831 *reinterpret_cast<const PairT*>(&other.regs[i]);
7832 memcpy(&result.regs[i], &r, 4);
7833 }
7834 }
7835 else
7836 {
7837 for (int i = 0; i < GetLength(); i++)
7838 result.set(i, get(i) + other.get(i));
7839 }
7840 return result;
7841 }
7842
7843 __device__ This operator%(const This& other)
7844 {
7845 This result;
7846 if constexpr (IsPackedFp16<T>::value)
7847 {
7848 // 16-bit float types (half / bfloat16) have no native `%`. Compute the
7849 // modulo through a float intermediate, which is precision-preserving for
7850 // both half and bfloat16.
7851 for (int i = 0; i < GetLength(); i++)
7852 {
7853 float a = static_cast<float>(get(i));
7854 float b = static_cast<float>(other.get(i));
7855 result.set(i, T(fmodf(a, b)));
7856 }
7857 }
7858 else
7859 {
7860 for (int i = 0; i < GetLength(); i++)
7861 result.set(i, get(i) % other.get(i));
7862 }
7863 return result;
7864 }
7865
7866 // Element-wise equality: true iff every element matches `other`'s element at
7867 // the same index. Used by Slang's `equals` on cooperative matrices.
7868 __device__ bool operator==(const This& other) const
7869 {
7870 for (int i = 0; i < GetLength(); i++)
7871 {
7872 if (get(i) != other.get(i))
7873 return false;
7874 }
7875 return true;
7876 }
7877
7878 // Lexicographic ordering: scan elements in index order; first non-equal pair
7879 // decides. Used by Slang's `lessThan` / `lessThanOrEquals` on cooperative
7880 // matrices.
7881 __device__ bool operator<(const This& other) const
7882 {
7883 for (int i = 0; i < GetLength(); i++)
7884 {
7885 if (get(i) < other.get(i))
7886 return true;
7887 if (get(i) > other.get(i))
7888 return false;
7889 }
7890 return false;
7891 }
7892
7893 __device__ bool operator<=(const This& other) const
7894 {
7895 for (int i = 0; i < GetLength(); i++)
7896 {
7897 if (get(i) < other.get(i))
7898 return true;
7899 if (get(i) > other.get(i))
7900 return false;
7901 }
7902 return true;
7903 }
7904
7905 template<typename U, MatrixUse R2>
7906 __device__ void copyFrom(const WmmaFragment<U, M, N, K, R2>& other)
7907 {
7908 constexpr int OtherRegsCount = WmmaFragment<U, M, N, K, R2>::RegsCount;
7909 if constexpr (IsSameType<T, U>::value && RegsCount == OtherRegsCount)
7910 {
7911 if constexpr (RegsCount == 2)
7912 *reinterpret_cast<uint2*>(regs) = *reinterpret_cast<const uint2*>(other.regs);
7913 else if constexpr (RegsCount == 4)
7914 *reinterpret_cast<uint4*>(regs) = *reinterpret_cast<const uint4*>(other.regs);
7915 else if constexpr (RegsCount == 8)
7916 {
7917 *reinterpret_cast<uint4*>(regs) = *reinterpret_cast<const uint4*>(other.regs);
7918 *reinterpret_cast<uint4*>(regs + 4) =
7919 *reinterpret_cast<const uint4*>(other.regs + 4);
7920 }
7921 else
7922 {
7923#pragma unroll
7924 for (int i = 0; i < RegsCount; i++)
7925 regs[i] = other.regs[i];
7926 }
7927 }
7928 else
7929 {
7930 for (int i = 0; i < GetLength(); i++)
7931 set(i, static_cast<T>(other.get(i)));
7932 }
7933 }
7934
7935 // Get element by index (handles bit-level access for packed types)
7936 // For example: u8/s8 matrices have 4 elements per register (32-bit)
7937 // - index 0: bits [0:7] of regs[0]
7938 // - index 1: bits [8:15] of regs[0]
7939 // - index 2: bits [16:23] of regs[0]
7940 // - index 3: bits [24:31] of regs[0]
7941 __device__ T get(int index) const
7942 {
7943 if constexpr (sizeof(T) == 4)
7944 {
7945 // T is 32-bit (float or int32): 1 element per register
7946 T v;
7947 memcpy(&v, &regs[index], 4);
7948 return v;
7949 }
7950 else if constexpr (sizeof(T) == 2)
7951 {
7952 // T is 16-bit (half): 2 elements per register
7953 // Elements per register: [0:15] and [16:31]
7954 int regIndex = index / 2;
7955 int elementOffset = index % 2;
7956 int bitOffset = elementOffset * 16;
7957 uint32_t extracted = (regs[regIndex] >> bitOffset) & 0xFFFF;
7958 uint16_t value16 = static_cast<uint16_t>(extracted);
7959 T v;
7960 memcpy(&v, &value16, 2);
7961 return v;
7962 }
7963 else if constexpr (sizeof(T) == 1)
7964 {
7965 // T is 8-bit (int8_t, uint8_t): 4 elements per register
7966 // Elements per register: [0:7], [8:15], [16:23], [24:31]
7967 int regIndex = index / 4;
7968 int elementOffset = index % 4;
7969 int bitOffset = elementOffset * 8;
7970 uint32_t extracted = (regs[regIndex] >> bitOffset) & 0xFF;
7971 uint8_t value8 = static_cast<uint8_t>(extracted);
7972 return *reinterpret_cast<const T*>(&value8);
7973 }
7974 }
7975
7976 // Set element by index (handles bit-level access for packed types)
7977 __device__ void set(int index, T value)
7978 {
7979 if constexpr (sizeof(T) == 4)
7980 {
7981 // T is 32-bit (float or int32): 1 element per register
7982 memcpy(&regs[index], &value, 4);
7983 }
7984 else if constexpr (sizeof(T) == 2)
7985 {
7986 // T is 16-bit (half): 2 elements per register
7987 int regIndex = index / 2;
7988 int elementOffset = index % 2;
7989 int bitOffset = elementOffset * 16;
7990 uint32_t mask = 0xFFFF;
7991 uint16_t value16;
7992 memcpy(&value16, &value, 2);
7993
7994 // Clear the bits at the target position
7995 regs[regIndex] &= ~(mask << bitOffset);
7996
7997 // Set the new value
7998 regs[regIndex] |= (static_cast<uint32_t>(value16) << bitOffset);
7999 }
8000 else if constexpr (sizeof(T) == 1)
8001 {
8002 // T is 8-bit (int8_t, uint8_t): 4 elements per register
8003 int regIndex = index / 4;
8004 int elementOffset = index % 4;
8005 int bitOffset = elementOffset * 8;
8006 uint32_t mask = 0xFF;
8007 uint8_t value8 = *reinterpret_cast<const uint8_t*>(&value);
8008
8009 // Clear the bits at the target position
8010 regs[regIndex] &= ~(mask << bitOffset);
8011
8012 // Set the new value
8013 regs[regIndex] |= (static_cast<uint32_t>(value8) << bitOffset);
8014 }
8015 }
8016
8017 __device__ void FragmentWrite(int regIndex, unsigned value) { regs[regIndex] = value; }
8018 __device__ unsigned FragmentRead(int regIndex) const { return regs[regIndex]; }
8019
8020 // Uses movmatrix.sync.aligned.m8n8.trans.b16 to transpose each 8x8 sub-block
8021 // independently. Does NOT swap off-diagonal blocks — this reinterprets
8022 // row-major as column-major (and vice versa) without a full 16x16 transpose.
8023 //
8024 // Before: reg0=A00, reg1=A10, reg2=A01, reg3=A11
8025 // After: reg0=A00^T, reg1=A10^T, reg2=A01^T, reg3=A11^T
8026 //
8027 // For a full 16x16 transpose, combine with a reg1<->reg2 swap afterwards.
8028 __device__ void ChangeMajor()
8029 {
8030 if constexpr (RegsCount == 4 && (R == MatrixUse::MatrixA || R == MatrixUse::MatrixB))
8031 {
8032 uint32_t t0, t1, t2, t3;
8033 asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" : "=r"(t0) : "r"(regs[0]));
8034 asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" : "=r"(t1) : "r"(regs[1]));
8035 asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" : "=r"(t2) : "r"(regs[2]));
8036 asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" : "=r"(t3) : "r"(regs[3]));
8037 regs[0] = t0;
8038 regs[1] = t1;
8039 regs[2] = t2;
8040 regs[3] = t3;
8041 }
8042 }
8043
8044 template<Layout layout>
8045 void __device__ Store(T* buffer, uint element, uint stride)
8046 {
8047 (void)RegisterCount<T, M, N, K, R>::value;
8048 mmaStore<T, layout, M, N>(buffer + element, regs, stride);
8049 }
8050
8051 template<Layout layout, typename U>
8052 void __device__ Store(U* buffer, uint stride)
8053 {
8054 (void)RegisterCount<T, M, N, K, R>::value;
8055 mmaStore<T, layout, M, N>(buffer, regs, stride * sizeof(U) / sizeof(T));
8056 }
8057
8058 template<Layout layout>
8059 static This __device__ Load(T* buffer, uint element, uint stride)
8060 {
8061 WmmaFragment<T, M, N, K, R> fragment;
8062 (void)RegisterCount<T, M, N, K, R>::value;
8063 mmaLoad<T, layout, M, N, R>(fragment.regs, buffer + element, stride);
8064 return fragment;
8065 }
8066
8067 template<Layout layout, typename U>
8068 static This __device__ Load(U* buffer, uint stride)
8069 {
8070 WmmaFragment<T, M, N, K, R> fragment;
8071 (void)RegisterCount<T, M, N, K, R>::value;
8072 mmaLoad<T, layout, M, N, R>(fragment.regs, buffer, stride * sizeof(U) / sizeof(T));
8073 return fragment;
8074 }
8075
8076 static constexpr __device__ uint32_t GetLength() { return This::elements_per_thread; }
8077 static constexpr __device__ int GetPackedFragmentCount() { return RegsCount; }
8078
8079 using ElementType = T;
8080 static constexpr int m_M = M;
8081 static constexpr int m_N = N;
8082 static constexpr int m_K = K;
8083
8084 // Register Count requirement
8085 static constexpr int RegsCount = RegisterCount<T, M, N, K, R>::value;
8086 unsigned regs[RegsCount] = {};
8087
8088 static constexpr uint32_t elements_per_thread = RegsCount * (4 / sizeof(T));
8089};
8090
8091// ====================================================================================
8092// 16-bit Float MMA Helpers - For half x half / bfloat16 x bfloat16 inputs
8093// Specialized on CType and DType (accumulator/output types).
8094//
8095// Uses mma.sync.aligned.m16n8k16 instructions (2x per m16n16k16 tile).
8096// Only the m16n16k16 shape is supported.
8097//
8098// Register layout for m16n16k16 = 2x m16n8k16:
8099// A: 4 regs (shared between both calls)
8100// B: 4 regs (b[0:1] → lo N-half, b[2:3] → hi N-half)
8101// C/D half: 4 regs (2 per sub-tile)
8102// C/D float: 8 regs (4 per sub-tile)
8103//
8104// The mma<> function template is parameterized on the input element type so half
8105// and bfloat16 can share the same dispatch shape while emitting different PTX.
8106// ====================================================================================
8107
8108template<typename InputT, typename AccumT, int M, int N, int K>
8109__device__ inline void mma(uint32_t* d, const uint32_t* a, const uint32_t* b, const uint32_t* c);
8110
8111#if SLANG_CUDA_ENABLE_HALF
8112
8113template<>
8114__device__ inline void mma<half, float, 16, 8, 16>(
8115 uint32_t* d,
8116 const uint32_t* a,
8117 const uint32_t* b,
8118 const uint32_t* c)
8119{
8120 asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
8121 "{%0, %1, %2, %3}, "
8122 "{%4, %5, %6, %7}, "
8123 "{%8, %9}, "
8124 "{%10, %11, %12, %13};\n"
8125 : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3])
8126 : "r"(a[0]),
8127 "r"(a[1]),
8128 "r"(a[2]),
8129 "r"(a[3]),
8130 "r"(b[0]),
8131 "r"(b[1]),
8132 "r"(c[0]),
8133 "r"(c[1]),
8134 "r"(c[2]),
8135 "r"(c[3]));
8136}
8137
8138template<>
8139__device__ inline void mma<half, half, 16, 8, 16>(
8140 uint32_t* d,
8141 const uint32_t* a,
8142 const uint32_t* b,
8143 const uint32_t* c)
8144{
8145 asm volatile(
8146 "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
8147 "{%0, %1}, "
8148 "{%2, %3, %4, %5}, "
8149 "{%6, %7}, "
8150 "{%8, %9};\n"
8151 : "=r"(d[0]), "=r"(d[1])
8152 : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "r"(c[0]), "r"(c[1]));
8153}
8154
8155#endif // #if SLANG_CUDA_ENABLE_HALF
8156
8157#if SLANG_CUDA_ENABLE_BF16
8158
8159// bf16 MMA only supports float (f32) accumulators on PTX.
8160template<>
8161__device__ inline void mma<__nv_bfloat16, float, 16, 8, 16>(
8162 uint32_t* d,
8163 const uint32_t* a,
8164 const uint32_t* b,
8165 const uint32_t* c)
8166{
8167 asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
8168 "{%0, %1, %2, %3}, "
8169 "{%4, %5, %6, %7}, "
8170 "{%8, %9}, "
8171 "{%10, %11, %12, %13};\n"
8172 : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3])
8173 : "r"(a[0]),
8174 "r"(a[1]),
8175 "r"(a[2]),
8176 "r"(a[3]),
8177 "r"(b[0]),
8178 "r"(b[1]),
8179 "r"(c[0]),
8180 "r"(c[1]),
8181 "r"(c[2]),
8182 "r"(c[3]));
8183}
8184
8185#endif // #if SLANG_CUDA_ENABLE_BF16
8186
8187// ====================================================================================
8188// Fp16MMAHelper m16n16k16 specializations (via 2x mma.sync.m16n8k16)
8189//
8190// Override the generic WMMA-based Fp16MMAHelper for the m16n16k16 shape.
8191// Each specialization fixes CType, DType AND M=16,N=16,K=16, making it
8192// strictly more specialized than the corresponding generic — no ambiguity.
8193// ====================================================================================
8194
8195template<typename CType, typename DType, int M, int N, int K>
8196struct Fp16MMAHelper;
8197
8198#if SLANG_CUDA_ENABLE_HALF
8199
8200template<>
8201struct Fp16MMAHelper<half, half, 16, 16, 16>
8202{
8203 __device__ static void eval(
8204 WmmaFragment<half, 16, 16, 16, MatrixC>& d,
8205 const WmmaFragment<half, 16, 16, 16, MatrixUse::MatrixA>& a,
8206 const WmmaFragment<half, 16, 16, 16, MatrixUse::MatrixB>& b,
8207 const WmmaFragment<half, 16, 16, 16, MatrixC>& c)
8208 {
8209 mma<half, half, 16, 8, 16>(d.regs, a.regs, b.regs, c.regs);
8210 mma<half, half, 16, 8, 16>(d.regs + 2, a.regs, b.regs + 2, c.regs + 2);
8211 }
8212};
8213
8214template<>
8215struct Fp16MMAHelper<float, float, 16, 16, 16>
8216{
8217 __device__ static void eval(
8218 WmmaFragment<float, 16, 16, 16, MatrixC>& d,
8219 const WmmaFragment<half, 16, 16, 16, MatrixUse::MatrixA>& a,
8220 const WmmaFragment<half, 16, 16, 16, MatrixUse::MatrixB>& b,
8221 const WmmaFragment<float, 16, 16, 16, MatrixC>& c)
8222 {
8223 mma<half, float, 16, 8, 16>(d.regs, a.regs, b.regs, c.regs);
8224 mma<half, float, 16, 8, 16>(d.regs + 4, a.regs, b.regs + 2, c.regs + 4);
8225 }
8226};
8227
8228template<>
8229struct Fp16MMAHelper<half, float, 16, 16, 16>
8230{
8231 __device__ static void eval(
8232 WmmaFragment<float, 16, 16, 16, MatrixC>& d,
8233 const WmmaFragment<half, 16, 16, 16, MatrixUse::MatrixA>& a,
8234 const WmmaFragment<half, 16, 16, 16, MatrixUse::MatrixB>& b,
8235 const WmmaFragment<half, 16, 16, 16, MatrixC>& c)
8236 {
8237 uint32_t fc[8];
8238#pragma unroll
8239 for (int i = 0; i < 4; i++)
8240 {
8241 half lo = __ushort_as_half((unsigned short)(c.regs[i] & 0xFFFF));
8242 half hi = __ushort_as_half((unsigned short)(c.regs[i] >> 16));
8243 fc[2 * i] = __float_as_uint(__half2float(lo));
8244 fc[2 * i + 1] = __float_as_uint(__half2float(hi));
8245 }
8246 mma<half, float, 16, 8, 16>(d.regs, a.regs, b.regs, fc);
8247 mma<half, float, 16, 8, 16>(d.regs + 4, a.regs, b.regs + 2, fc + 4);
8248 }
8249};
8250
8251template<>
8252struct Fp16MMAHelper<float, half, 16, 16, 16>
8253{
8254 __device__ static void eval(
8255 WmmaFragment<half, 16, 16, 16, MatrixC>& d,
8256 const WmmaFragment<half, 16, 16, 16, MatrixUse::MatrixA>& a,
8257 const WmmaFragment<half, 16, 16, 16, MatrixUse::MatrixB>& b,
8258 const WmmaFragment<float, 16, 16, 16, MatrixC>& c)
8259 {
8260 uint32_t fd[8];
8261 mma<half, float, 16, 8, 16>(fd, a.regs, b.regs, c.regs);
8262 mma<half, float, 16, 8, 16>(fd + 4, a.regs, b.regs + 2, c.regs + 4);
8263#pragma unroll
8264 for (int i = 0; i < 4; i++)
8265 {
8266 half lo = __float2half(__uint_as_float(fd[2 * i]));
8267 half hi = __float2half(__uint_as_float(fd[2 * i + 1]));
8268 d.regs[i] = (uint32_t)__half_as_ushort(lo) | ((uint32_t)__half_as_ushort(hi) << 16);
8269 }
8270 }
8271};
8272
8273#endif // #if SLANG_CUDA_ENABLE_HALF
8274
8275// ====================================================================================
8276// Bf16MMAHelper m16n16k16 specializations (via 2x mma.sync.m16n8k16)
8277//
8278// bfloat16 MMA only supports float (f32) accumulators on PTX, so only the
8279// (CType=float, DType=float) combination is provided here.
8280// ====================================================================================
8281
8282template<typename CType, typename DType, int M, int N, int K>
8283struct Bf16MMAHelper;
8284
8285#if SLANG_CUDA_ENABLE_BF16
8286
8287template<>
8288struct Bf16MMAHelper<float, float, 16, 16, 16>
8289{
8290 __device__ static void eval(
8291 WmmaFragment<float, 16, 16, 16, MatrixC>& d,
8292 const WmmaFragment<__nv_bfloat16, 16, 16, 16, MatrixUse::MatrixA>& a,
8293 const WmmaFragment<__nv_bfloat16, 16, 16, 16, MatrixUse::MatrixB>& b,
8294 const WmmaFragment<float, 16, 16, 16, MatrixC>& c)
8295 {
8296 mma<__nv_bfloat16, float, 16, 8, 16>(d.regs, a.regs, b.regs, c.regs);
8297 mma<__nv_bfloat16, float, 16, 8, 16>(d.regs + 4, a.regs, b.regs + 2, c.regs + 4);
8298 }
8299};
8300
8301#endif // #if SLANG_CUDA_ENABLE_BF16
8302
8303// ====================================================================================
8304// 8-bit Integer MMA intrinsics (m16n8k16 with .s8/.u8 inputs and .s32 accumulator).
8305//
8306// PTX form: mma.sync.aligned.m16n8k16.row.col{.satfinite}.s32.{s8|u8}.{s8|u8}.s32
8307// - Matrix A: 16x16, 2 b32 regs per thread (each register packs 4 .s8/.u8 elements)
8308// - Matrix B: 16x8, 1 b32 reg per thread (packs 4 .s8/.u8 elements)
8309// - Matrix C: 16x8, 4 .s32 regs per thread
8310//
8311// PTX integer mma only supports an .s32 accumulator (unlike the fp8 mma at the
8312// same shape, which adds .f16/.f32 accumulator forms). The `.satfinite` modifier
8313// clamps overflow to [INT32_MIN, INT32_MAX]; without it, the accumulator wraps.
8314// ====================================================================================
8315
8316// Non-saturating: signed 8-bit inputs.
8317template<>
8318__device__ inline void mma<char, int32_t, 16, 8, 16>(
8319 uint32_t* d,
8320 const uint32_t* a,
8321 const uint32_t* b,
8322 const uint32_t* c)
8323{
8324 asm volatile("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 "
8325 "{%0, %1, %2, %3}, "
8326 "{%4, %5}, "
8327 "{%6}, "
8328 "{%7, %8, %9, %10};\n"
8329 : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3])
8330 : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
8331}
8332
8333// Non-saturating: unsigned 8-bit inputs.
8334template<>
8335__device__ inline void mma<unsigned char, int32_t, 16, 8, 16>(
8336 uint32_t* d,
8337 const uint32_t* a,
8338 const uint32_t* b,
8339 const uint32_t* c)
8340{
8341 asm volatile("mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 "
8342 "{%0, %1, %2, %3}, "
8343 "{%4, %5}, "
8344 "{%6}, "
8345 "{%7, %8, %9, %10};\n"
8346 : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3])
8347 : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
8348}
8349
8350// Saturating variants are exposed via a parallel `mma_sat` template so the
8351// existing `mma<>` specializations (which all happen to be non-saturating) stay
8352// untouched. `Int8MMAHelper` selects between the two with `if constexpr`.
8353template<typename InputT, typename AccumT, int M, int N, int K>
8354__device__ inline void mma_sat(
8355 uint32_t* d,
8356 const uint32_t* a,
8357 const uint32_t* b,
8358 const uint32_t* c);
8359
8360template<>
8361__device__ inline void mma_sat<char, int32_t, 16, 8, 16>(
8362 uint32_t* d,
8363 const uint32_t* a,
8364 const uint32_t* b,
8365 const uint32_t* c)
8366{
8367 asm volatile("mma.sync.aligned.m16n8k16.row.col.satfinite.s32.s8.s8.s32 "
8368 "{%0, %1, %2, %3}, "
8369 "{%4, %5}, "
8370 "{%6}, "
8371 "{%7, %8, %9, %10};\n"
8372 : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3])
8373 : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
8374}
8375
8376template<>
8377__device__ inline void mma_sat<unsigned char, int32_t, 16, 8, 16>(
8378 uint32_t* d,
8379 const uint32_t* a,
8380 const uint32_t* b,
8381 const uint32_t* c)
8382{
8383 asm volatile("mma.sync.aligned.m16n8k16.row.col.satfinite.s32.u8.u8.s32 "
8384 "{%0, %1, %2, %3}, "
8385 "{%4, %5}, "
8386 "{%6}, "
8387 "{%7, %8, %9, %10};\n"
8388 : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3])
8389 : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
8390}
8391
8392// ====================================================================================
8393// Int8MMAHelper m16n16k16 (via 2x mma.sync.m16n8k16)
8394//
8395// Register layout for m16n16k16 = 2x m16n8k16 with 8-bit inputs / 32-bit accum:
8396// A: 2 regs per thread (shared between both calls — A is full-width for both N halves)
8397// B: 2 regs per thread (b[0] -> lo N-half cols 0..7, b[1] -> hi N-half cols 8..15)
8398// C/D: 8 regs per thread (4 per sub-tile)
8399// ====================================================================================
8400
8401template<typename AInputT, typename CType, typename DType, int M, int N, int K, bool Saturating>
8402struct Int8MMAHelper;
8403
8404template<typename AInputT, bool Saturating>
8405struct Int8MMAHelper<AInputT, int32_t, int32_t, 16, 16, 16, Saturating>
8406{
8407 __device__ static void eval(
8408 WmmaFragment<int32_t, 16, 16, 16, MatrixC>& d,
8409 const WmmaFragment<AInputT, 16, 16, 16, MatrixUse::MatrixA>& a,
8410 const WmmaFragment<AInputT, 16, 16, 16, MatrixUse::MatrixB>& b,
8411 const WmmaFragment<int32_t, 16, 16, 16, MatrixC>& c)
8412 {
8413 if constexpr (Saturating)
8414 {
8415 mma_sat<AInputT, int32_t, 16, 8, 16>(d.regs, a.regs, b.regs, c.regs);
8416 mma_sat<AInputT, int32_t, 16, 8, 16>(d.regs + 4, a.regs, b.regs + 1, c.regs + 4);
8417 }
8418 else
8419 {
8420 mma<AInputT, int32_t, 16, 8, 16>(d.regs, a.regs, b.regs, c.regs);
8421 mma<AInputT, int32_t, 16, 8, 16>(d.regs + 4, a.regs, b.regs + 1, c.regs + 4);
8422 }
8423 }
8424};
8425
8426// ====================================================================================
8427// 8-bit Float MMA intrinsics (m16n8k16 with .e4m3 / .e5m2 inputs).
8428//
8429// PTX form: mma.sync.aligned.m16n8k16.row.col.{f16|f32}.{e4m3|e5m2}.{e4m3|e5m2}.{f16|f32}
8430// - Matrix A: 16x16, 2 b32 regs per thread (4 e4m3/e5m2 elements per reg)
8431// - Matrix B: 16x8, 1 b32 reg per thread (4 e4m3/e5m2 elements)
8432// - Matrix C: 16x8, 4 .f32 regs per thread (with f32 accumulator) OR
8433// 2 .f16x2 regs per thread (with f16 accumulator)
8434// Same A/B register layout as the integer m16n8k16 path; only the accumulator
8435// differs. Requires SM 8.9+ (Ada Lovelace) at runtime.
8436// ====================================================================================
8437
8438#if SLANG_CUDA_ENABLE_FP8
8439
8440template<>
8441__device__ inline void mma<__nv_fp8_e4m3, float, 16, 8, 16>(
8442 uint32_t* d,
8443 const uint32_t* a,
8444 const uint32_t* b,
8445 const uint32_t* c)
8446{
8447 asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
8448 "{%0, %1, %2, %3}, "
8449 "{%4, %5}, "
8450 "{%6}, "
8451 "{%7, %8, %9, %10};\n"
8452 : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3])
8453 : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
8454}
8455
8456template<>
8457__device__ inline void mma<__nv_fp8_e5m2, float, 16, 8, 16>(
8458 uint32_t* d,
8459 const uint32_t* a,
8460 const uint32_t* b,
8461 const uint32_t* c)
8462{
8463 asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.e5m2.e5m2.f32 "
8464 "{%0, %1, %2, %3}, "
8465 "{%4, %5}, "
8466 "{%6}, "
8467 "{%7, %8, %9, %10};\n"
8468 : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3])
8469 : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3]));
8470}
8471
8472template<>
8473__device__ inline void mma<__nv_fp8_e4m3, half, 16, 8, 16>(
8474 uint32_t* d,
8475 const uint32_t* a,
8476 const uint32_t* b,
8477 const uint32_t* c)
8478{
8479 asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.e4m3.e4m3.f16 "
8480 "{%0, %1}, "
8481 "{%2, %3}, "
8482 "{%4}, "
8483 "{%5, %6};\n"
8484 : "=r"(d[0]), "=r"(d[1])
8485 : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]));
8486}
8487
8488template<>
8489__device__ inline void mma<__nv_fp8_e5m2, half, 16, 8, 16>(
8490 uint32_t* d,
8491 const uint32_t* a,
8492 const uint32_t* b,
8493 const uint32_t* c)
8494{
8495 asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.e5m2.e5m2.f16 "
8496 "{%0, %1}, "
8497 "{%2, %3}, "
8498 "{%4}, "
8499 "{%5, %6};\n"
8500 : "=r"(d[0]), "=r"(d[1])
8501 : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]));
8502}
8503
8504#endif // #if SLANG_CUDA_ENABLE_FP8
8505
8506// ====================================================================================
8507// Fp8MMAHelper m16n16k16 (via 2x mma.sync.m16n8k16)
8508//
8509// Register layout for m16n16k16 = 2x m16n8k16 with 8-bit FP inputs:
8510// A: 2 regs per thread (shared between both calls)
8511// B: 2 regs per thread (b[0] -> lo N-half cols 0..7, b[1] -> hi N-half cols 8..15)
8512// C/D float (f32 acc): 8 regs per thread (4 per sub-tile)
8513// C/D half (f16 acc): 4 regs per thread (2 per sub-tile)
8514// AInputT must be either __nv_fp8_e4m3 or __nv_fp8_e5m2. CType==DType, both
8515// either float (f32 mma form) or half (f16 mma form).
8516// ====================================================================================
8517
8518template<typename AInputT, typename CType, typename DType, int M, int N, int K>
8519struct Fp8MMAHelper;
8520
8521#if SLANG_CUDA_ENABLE_FP8
8522
8523template<typename AInputT>
8524struct Fp8MMAHelper<AInputT, float, float, 16, 16, 16>
8525{
8526 __device__ static void eval(
8527 WmmaFragment<float, 16, 16, 16, MatrixC>& d,
8528 const WmmaFragment<AInputT, 16, 16, 16, MatrixUse::MatrixA>& a,
8529 const WmmaFragment<AInputT, 16, 16, 16, MatrixUse::MatrixB>& b,
8530 const WmmaFragment<float, 16, 16, 16, MatrixC>& c)
8531 {
8532 mma<AInputT, float, 16, 8, 16>(d.regs, a.regs, b.regs, c.regs);
8533 mma<AInputT, float, 16, 8, 16>(d.regs + 4, a.regs, b.regs + 1, c.regs + 4);
8534 }
8535};
8536
8537template<typename AInputT>
8538struct Fp8MMAHelper<AInputT, half, half, 16, 16, 16>
8539{
8540 __device__ static void eval(
8541 WmmaFragment<half, 16, 16, 16, MatrixC>& d,
8542 const WmmaFragment<AInputT, 16, 16, 16, MatrixUse::MatrixA>& a,
8543 const WmmaFragment<AInputT, 16, 16, 16, MatrixUse::MatrixB>& b,
8544 const WmmaFragment<half, 16, 16, 16, MatrixC>& c)
8545 {
8546 mma<AInputT, half, 16, 8, 16>(d.regs, a.regs, b.regs, c.regs);
8547 mma<AInputT, half, 16, 8, 16>(d.regs + 2, a.regs, b.regs + 1, c.regs + 2);
8548 }
8549};
8550
8551#endif // #if SLANG_CUDA_ENABLE_FP8
8552
8553// ====================================================================================
8554// MMA Helper - Primary Template (dispatcher)
8555//
8556// Selects between Fp16MMAHelper (half x half), Bf16MMAHelper (bfloat16 x bfloat16),
8557// Int8MMAHelper (s8/u8 x s32 accumulator), and Fp8MMAHelper (e4m3/e5m2 x f16/f32
8558// accumulator) based on the input element type. AType is required to equal BType
8559// — the supported PTX shapes always have matching A/B element types.
8560// ====================================================================================
8561
8562template<
8563 typename AType,
8564 typename BType,
8565 typename CType,
8566 typename DType,
8567 int M,
8568 int N,
8569 int K,
8570 bool saturatingAccumulation>
8571WmmaFragment<DType, M, N, K, MatrixC> __device__ coopMatMulAdd(
8572 WmmaFragment<AType, M, N, K, MatrixUse::MatrixA> matA,
8573 WmmaFragment<BType, M, N, K, MatrixUse::MatrixB> matB,
8574 WmmaFragment<CType, M, N, K, MatrixUse::MatrixC> matC)
8575{
8576 WmmaFragment<DType, M, N, K, MatrixC> matD;
8577 if constexpr (IsSameType<AType, char>::value || IsSameType<AType, unsigned char>::value)
8578 {
8579 Int8MMAHelper<AType, CType, DType, M, N, K, saturatingAccumulation>::eval(
8580 matD,
8581 matA,
8582 matB,
8583 matC);
8584 }
8585#if SLANG_CUDA_ENABLE_FP8
8586 else if constexpr (
8587 IsSameType<AType, __nv_fp8_e4m3>::value || IsSameType<AType, __nv_fp8_e5m2>::value)
8588 {
8589 Fp8MMAHelper<AType, CType, DType, M, N, K>::eval(matD, matA, matB, matC);
8590 }
8591#endif
8592#if SLANG_CUDA_ENABLE_BF16
8593 else if constexpr (IsSameType<AType, __nv_bfloat16>::value)
8594 {
8595 Bf16MMAHelper<CType, DType, M, N, K>::eval(matD, matA, matB, matC);
8596 }
8597#endif
8598 else
8599 {
8600 Fp16MMAHelper<CType, DType, M, N, K>::eval(matD, matA, matB, matC);
8601 }
8602 return matD;
8603}
8604
8605} // namespace Slang_CUDA_WMMA
8606#endif // #if (((__CUDACC_VER_MAJOR__ >=12)&&(__CUDACC_VER_MINOR__>=5)) || (CUDA_VERSION >= 12050))
8607
8608#endif
static MaterialRegister< MetalMaterial > R
Definition MetalMaterial.cpp:10
bool operator==(const json_pointer< RefStringTypeLhs > &lhs, const json_pointer< RefStringTypeRhs > &rhs) noexcept
Definition json.hpp:14737
bool operator<(const json_pointer< RefStringTypeLhs > &lhs, const json_pointer< RefStringTypeRhs > &rhs) noexcept
Definition json.hpp:14787
@ value
the parser finished reading a JSON value
auto get(const nlohmann::detail::iteration_proxy_value< IteratorType > &i) -> decltype(i.key())
Definition json.hpp:5342
Direction
Definition Direction.h:18
constexpr FloatOptional operator+(FloatOptional lhs, FloatOptional rhs)
Definition FloatOptional.h:56
uint8_t Type
Definition slang-gfx.h:1359
#define SLANG_FORCE_INLINE
Definition slang-cpp-prelude.h:286
SLANG_FORCE_INLINE int16_t F16_asint(half h)
Definition slang-cpp-scalar-intrinsics.h:732
SLANG_FORCE_INLINE half F16_exp(half f)
Definition slang-cpp-scalar-intrinsics.h:834
SLANG_FORCE_INLINE half F16_acosh(half f)
Definition slang-cpp-scalar-intrinsics.h:804
SLANG_FORCE_INLINE half F16_fmod(half a, half b)
Definition slang-cpp-scalar-intrinsics.h:895
SLANG_FORCE_INLINE half F16_tan(half f)
Definition slang-cpp-scalar-intrinsics.h:764
SLANG_FORCE_INLINE half F16_tanh(half f)
Definition slang-cpp-scalar-intrinsics.h:794
SLANG_FORCE_INLINE half F16_atan2(half a, half b)
Definition slang-cpp-scalar-intrinsics.h:905
SLANG_FORCE_INLINE bool F16_isnan(half f)
Definition slang-cpp-scalar-intrinsics.h:854
SLANG_FORCE_INLINE half F16_frac(half h)
Definition slang-cpp-scalar-intrinsics.h:951
SLANG_FORCE_INLINE bool F16_isinf(half f)
Definition slang-cpp-scalar-intrinsics.h:866
SLANG_FORCE_INLINE half F16_frexp(half x, int *e)
Definition slang-cpp-scalar-intrinsics.h:910
SLANG_FORCE_INLINE bool F16_isfinite(half f)
Definition slang-cpp-scalar-intrinsics.h:860
SLANG_FORCE_INLINE half F16_pow(half a, half b)
Definition slang-cpp-scalar-intrinsics.h:890
SLANG_FORCE_INLINE half F16_cos(half f)
Definition slang-cpp-scalar-intrinsics.h:759
SLANG_FORCE_INLINE half F16_sinh(half f)
Definition slang-cpp-scalar-intrinsics.h:784
SLANG_FORCE_INLINE half F16_asin(half f)
Definition slang-cpp-scalar-intrinsics.h:769
SLANG_FORCE_INLINE half F16_modf(half x, half *ip)
Definition slang-cpp-scalar-intrinsics.h:915
SLANG_FORCE_INLINE half F16_sin(half f)
Definition slang-cpp-scalar-intrinsics.h:754
SLANG_FORCE_INLINE half F16_exp2(half f)
Definition slang-cpp-scalar-intrinsics.h:829
SLANG_FORCE_INLINE half F16_trunc(half f)
Definition slang-cpp-scalar-intrinsics.h:844
SLANG_FORCE_INLINE half F16_log2(half f)
Definition slang-cpp-scalar-intrinsics.h:814
SLANG_FORCE_INLINE int F16_sign(half f)
Definition slang-cpp-scalar-intrinsics.h:943
SLANG_FORCE_INLINE half F16_round(half f)
Definition slang-cpp-scalar-intrinsics.h:749
SLANG_FORCE_INLINE half F16_max(half a, half b)
Definition slang-cpp-scalar-intrinsics.h:881
SLANG_FORCE_INLINE half F16_log(half f)
Definition slang-cpp-scalar-intrinsics.h:819
SLANG_FORCE_INLINE half F16_acos(half f)
Definition slang-cpp-scalar-intrinsics.h:774
SLANG_FORCE_INLINE half F16_atanh(half f)
Definition slang-cpp-scalar-intrinsics.h:809
SLANG_FORCE_INLINE half F16_sqrt(half f)
Definition slang-cpp-scalar-intrinsics.h:849
SLANG_FORCE_INLINE half F16_cosh(half f)
Definition slang-cpp-scalar-intrinsics.h:789
SLANG_FORCE_INLINE half F16_asinh(half f)
Definition slang-cpp-scalar-intrinsics.h:799
SLANG_FORCE_INLINE half F16_abs(half f)
Definition slang-cpp-scalar-intrinsics.h:839
SLANG_FORCE_INLINE uint16_t F16_asuint(half h)
Definition slang-cpp-scalar-intrinsics.h:725
SLANG_FORCE_INLINE half F16_rsqrt(half f)
Definition slang-cpp-scalar-intrinsics.h:938
SLANG_FORCE_INLINE half F16_fma(half a, half b, half c)
Definition slang-cpp-scalar-intrinsics.h:923
SLANG_FORCE_INLINE half F16_min(half a, half b)
Definition slang-cpp-scalar-intrinsics.h:872
SLANG_FORCE_INLINE half F16_atan(half f)
Definition slang-cpp-scalar-intrinsics.h:779
SLANG_FORCE_INLINE half F16_floor(half f)
Definition slang-cpp-scalar-intrinsics.h:744
SLANG_FORCE_INLINE half F16_log10(half f)
Definition slang-cpp-scalar-intrinsics.h:824
SLANG_FORCE_INLINE half F16_remainder(half a, half b)
Definition slang-cpp-scalar-intrinsics.h:900
SLANG_FORCE_INLINE half F16_ceil(half f)
Definition slang-cpp-scalar-intrinsics.h:739
uint32_t uint
Definition slang-cpp-types-core.h:272
SLANG_FORCE_INLINE T _slang_vector_get_element(Vector< T, N > x, int index)
Definition slang-cpp-types-core.h:241
SLANG_FORCE_INLINE const T * _slang_vector_get_element_ptr(const Vector< T, N > *x, int index)
Definition slang-cpp-types-core.h:247
SLANG_FORCE_INLINE SLANG_CUDA_CALL uintptr_t UPTR_max(uintptr_t a, uintptr_t b)
Definition slang-cuda-prelude.h:2601
SLANG_FORCE_INLINE SLANG_CUDA_CALL Vector< T, n > _slang_vector_reshape(const Vector< OtherT, m > other)
Definition slang-cuda-prelude.h:914
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_fma(double a, double b, double c)
Definition slang-cuda-prelude.h:2341
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_atan2(float a, float b)
Definition slang-cuda-prelude.h:2140
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_remainder(float a, float b)
Definition slang-cuda-prelude.h:2136
#define SLANG_WAVE_ROTATE_IMPL(T)
Definition slang-cuda-prelude.h:6311
SLANG_FORCE_INLINE SLANG_CUDA_CALL intptr_t IPTR_max(intptr_t a, intptr_t b)
Definition slang-cuda-prelude.h:2584
__device__ T _waveOpCopy(T *dst, const T *src)
Definition slang-cuda-prelude.h:3959
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_abs(double f)
Definition slang-cuda-prelude.h:2250
__forceinline__ __device__ WarpMask _getMultiPrefixMask(int mask)
Definition slang-cuda-prelude.h:3132
__device__ T _wavePrefixInvertableScalar(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3861
#define SLANG_BOUND_CHECK_FIXED_ARRAY(index, count)
Definition slang-cuda-prelude.h:117
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_frexp(float x, int *e)
Definition slang-cuda-prelude.h:2145
__inline__ __device__ uint4 _waveMatchScalar(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4288
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_fmod(float a, float b)
Definition slang-cuda-prelude.h:2132
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_tanh(double f)
Definition slang-cuda-prelude.h:2226
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_log10(float f)
Definition slang-cuda-prelude.h:2238
#define SLANG_VECTOR_GET_ELEMENT(T)
Definition slang-cuda-prelude.h:415
SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf1Dwrite_convert(T v, cudaSurfaceObject_t surfObj, int x, cudaSurfaceBoundaryMode boundaryMode)
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U64_firstbitlow(uint64_t v)
Definition slang-cuda-prelude.h:2515
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool F64_isfinite(double f)
Definition slang-cuda-prelude.h:2279
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_acos(float f)
Definition slang-cuda-prelude.h:2029
__inline__ __device__ T _wavePrefixOr(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4120
__inline__ __device__ uint getAt(dim3 a, int b)
Definition slang-cuda-prelude.h:4309
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t I16_countbits(int16_t v)
Definition slang-cuda-prelude.h:2371
__device__ __forceinline__ longlong atomicCAS(longlong *address, longlong compare, longlong val)
Definition slang-cuda-prelude.h:2715
__inline__ __device__ T _wavePrefixMax(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4189
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U32_firstbitlow(uint32_t v)
Definition slang-cuda-prelude.h:2417
__inline__ __device__ T _waveAndMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3744
__forceinline__ __device__ WarpMask _getActiveMask()
Definition slang-cuda-prelude.h:3126
__inline__ __device__ T _waveMin(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3665
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_round(float f)
Definition slang-cuda-prelude.h:2005
#define SLANG_CUDA_WARP_MASK
Definition slang-cuda-prelude.h:63
__inline__ __device__ T _wavePrefixMaxMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4205
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_max(float a, float b)
Definition slang-cuda-prelude.h:2124
#define SLANG_PRELUDE_ASSERT(x)
Definition slang-cuda-prelude.h:57
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool3 make_bool3(bool x, bool y, bool z)
Definition slang-cuda-prelude.h:740
SLANG_FORCE_INLINE SLANG_CUDA_CALL int64_t U64_abs(uint64_t f)
Definition slang-cuda-prelude.h:2496
__inline__ __device__ T _wavePrefixExclusiveMaxMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4278
__inline__ __device__ T _wavePrefixOrMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4163
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool2 make_bool2(bool x, bool y)
Definition slang-cuda-prelude.h:736
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U32_max(uint32_t a, uint32_t b)
Definition slang-cuda-prelude.h:2389
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t I64_firstbitlow(int64_t v)
Definition slang-cuda-prelude.h:2555
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t I32_asuint(int32_t x)
Definition slang-cuda-prelude.h:2463
SLANG_FORCE_INLINE SLANG_CUDA_CALL T tex1DArrayfetch_int(CUtexObject texObj, int x, int layer, int mip)
__inline__ __device__ T _waveOrMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3736
SLANG_FORCE_INLINE SLANG_CUDA_CALL double U32_asdouble(uint32_t low, uint32_t hi)
Definition slang-cuda-prelude.h:2405
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_remainder(double a, double b)
Definition slang-cuda-prelude.h:2305
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_sqrt(float f)
Definition slang-cuda-prelude.h:2089
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool4 make_bool4(bool x, bool y, bool z, bool w)
Definition slang-cuda-prelude.h:744
unsigned long long CUsurfObject
Definition slang-cuda-prelude.h:185
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_floor(float f)
Definition slang-cuda-prelude.h:2001
__device__ __forceinline__ bool _slang_quadAll(bool expr)
Definition slang-cuda-prelude.h:6411
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_tanh(float f)
Definition slang-cuda-prelude.h:2045
__device__ __forceinline__ longlong atomicAdd(longlong *address, longlong val)
Definition slang-cuda-prelude.h:2723
#define SLANG_CUDA_CALL
Definition slang-cuda-prelude.h:70
static const int kSlangTorchTensorMaxDim
Definition slang-cuda-prelude.h:5914
#define SLANG_WAVE_MAX_SPEC(T, EXCL_VAL)
Definition slang-cuda-prelude.h:3247
#define SLANG_CUDA_VECTOR_FLOAT_OPS(T)
Definition slang-cuda-prelude.h:665
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_atanh(float f)
Definition slang-cuda-prelude.h:2057
size_t NonUniformResourceIndex
Definition slang-cuda-prelude.h:197
__inline__ __device__ T _wavePrefixMin(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4183
SLANG_FORCE_INLINE SLANG_CUDA_CALL uintptr_t UPTR_abs(uintptr_t f)
Definition slang-cuda-prelude.h:2591
__inline__ __device__ T _waveSumMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3768
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_tan(double f)
Definition slang-cuda-prelude.h:2202
__inline__ __device__ T _waveProduct(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3653
int64_t longlong
Definition slang-cuda-prelude.h:301
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_cosh(double f)
Definition slang-cuda-prelude.h:2222
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint64_t U64_reversebits(uint64_t v)
Definition slang-cuda-prelude.h:2529
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_log(double f)
Definition slang-cuda-prelude.h:2234
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool F64_isnan(double f)
Definition slang-cuda-prelude.h:2275
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_min(float a, float b)
Definition slang-cuda-prelude.h:2120
#define GET_VECTOR_TYPE_IMPL_N(T)
Definition slang-cuda-prelude.h:882
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_exp2(float f)
Definition slang-cuda-prelude.h:2073
SLANG_FORCE_INLINE SLANG_CUDA_CALL T tex1Dfetch_int(CUtexObject texObj, int x, int mip)
Definition slang-cuda-prelude.h:6096
__device__ __forceinline__ longlong atomicExch(longlong *address, longlong val)
Definition slang-cuda-prelude.h:2710
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_min(double a, double b)
Definition slang-cuda-prelude.h:2289
__inline__ __device__ T _waveOr(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3635
SLANG_FORCE_INLINE SLANG_CUDA_CALL intptr_t IPTR_min(intptr_t a, intptr_t b)
Definition slang-cuda-prelude.h:2579
SLANG_FORCE_INLINE SLANG_CUDA_CALL void F64_sincos(double f, double *s, double *c)
Definition slang-cuda-prelude.h:2198
SLANG_FORCE_INLINE SLANG_CUDA_CALL int F32_sign(float f)
Definition slang-cuda-prelude.h:2097
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U16_countbits(uint16_t v)
Definition slang-cuda-prelude.h:2363
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_cos(double f)
Definition slang-cuda-prelude.h:2194
__inline__ __device__ uint4 _waveMatchMultiple(WarpMask mask, const T &inVal)
Definition slang-cuda-prelude.h:4295
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U64_countbits(uint64_t v)
Definition slang-cuda-prelude.h:2510
SLANG_FORCE_INLINE SLANG_CUDA_CALL intptr_t IPTR_abs(intptr_t f)
Definition slang-cuda-prelude.h:2574
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U32_firstbithigh(uint32_t v)
Definition slang-cuda-prelude.h:2424
__inline__ __device__ T _wavePrefixMinMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4195
__device__ __forceinline__ void __slang_atomic_reduce_dec(uint32_t *addr, int order)
Definition slang-cuda-prelude.h:2957
#define SLANG_MATRIX_INT_NEG_OP(T)
Definition slang-cuda-prelude.h:1267
__inline__ __device__ T _waveAnd(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3641
SLANG_FORCE_INLINE SLANG_CUDA_CALL float _slang_fmod(float x, float y)
Definition slang-cuda-prelude.h:335
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t F32_asuint(float f)
Definition slang-cuda-prelude.h:2155
#define SLANG_TEX2DFETCH_INT_IMPL(T, dtype, c)
Definition slang-cuda-prelude.h:6143
#define SLANG_WAVE_MIN_SPEC(T, EXCL_VAL)
Definition slang-cuda-prelude.h:3240
__inline__ __device__ TResult slang_bit_cast(TInput val)
Definition slang-cuda-prelude.h:4324
__device__ T _waveOpSetInitial(T *out, const T *val)
Definition slang-cuda-prelude.h:3978
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_sinh(float f)
Definition slang-cuda-prelude.h:2037
unsigned char uchar
Definition slang-cuda-prelude.h:307
__inline__ __device__ T _wavePrefixInclusiveMinMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4243
#define SLANG_FORCE_INLINE
Definition slang-cuda-prelude.h:68
__device__ __forceinline__ bool2 _slang_waveRotate(bool2 value, unsigned int delta)
Definition slang-cuda-prelude.h:6377
#define SLANG_INT_MATRIX_OPS(T)
Definition slang-cuda-prelude.h:1235
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_log2(double f)
Definition slang-cuda-prelude.h:2230
#define SLANG_CUDA_WARP_BITMASK
Definition slang-cuda-prelude.h:65
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_log(float f)
Definition slang-cuda-prelude.h:2065
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U32_min(uint32_t a, uint32_t b)
Definition slang-cuda-prelude.h:2385
SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf3Dwrite_convert(T v, cudaSurfaceObject_t surfObj, int x, int y, int z, cudaSurfaceBoundaryMode boundaryMode)
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_tan(float f)
Definition slang-cuda-prelude.h:2021
__inline__ __device__ T _waveReadFirstMultiple(WarpMask mask, T inVal)
Definition slang-cuda-prelude.h:3826
__device__ void _waveReduceMultiple(WarpMask mask, T *val)
Definition slang-cuda-prelude.h:3586
SLANG_FORCE_INLINE SLANG_CUDA_CALL float make_float(T val)
Definition slang-cuda-prelude.h:330
__inline__ __device__ T _wavePrefixInclusiveMin(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4231
struct __align__(1) bool1
Definition slang-cuda-prelude.h:207
SLANG_FORCE_INLINE SLANG_CUDA_CALL T tex2Dfetch_int(CUtexObject texObj, int x, int y, int mip)
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_asin(float f)
Definition slang-cuda-prelude.h:2025
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U32_countbits(uint32_t v)
Definition slang-cuda-prelude.h:2412
__inline__ __device__ uint3 operator*(uint3 a, dim3 b)
Definition slang-cuda-prelude.h:4314
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool F32_isfinite(float f)
Definition slang-cuda-prelude.h:2110
SLANG_FORCE_INLINE SLANG_CUDA_CALL int32_t I32_max(int32_t a, int32_t b)
Definition slang-cuda-prelude.h:2452
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t I32_firstbitlow(int32_t v)
Definition slang-cuda-prelude.h:2479
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_trunc(double f)
Definition slang-cuda-prelude.h:2254
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool F32_isinf(float f)
Definition slang-cuda-prelude.h:2114
SLANG_FORCE_INLINE SLANG_CUDA_CALL uintptr_t UPTR_min(uintptr_t a, uintptr_t b)
Definition slang-cuda-prelude.h:2596
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool1 make_bool1(bool x)
Definition slang-cuda-prelude.h:732
#define SLANG_SELECT_T(T)
Definition slang-cuda-prelude.h:1342
#define SLANG_TEX3DFETCH_INT_IMPL(T, dtype, c)
Definition slang-cuda-prelude.h:6185
__forceinline__ __device__ WarpMask _getLaneLtMask()
Definition slang-cuda-prelude.h:3118
SLANG_FORCE_INLINE SLANG_CUDA_CALL void F32_sincos(float f, float *s, float *c)
Definition slang-cuda-prelude.h:2017
SLANG_FORCE_INLINE SLANG_CUDA_CALL int F64_sign(double f)
Definition slang-cuda-prelude.h:2266
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_pow(float a, float b)
Definition slang-cuda-prelude.h:2128
__inline__ __device__ int _waveCalcPow2Offset(WarpMask mask)
Definition slang-cuda-prelude.h:3151
__device__ __forceinline__ void __slang_atomic_reduce_min(int32_t *addr, int32_t val, int order)
Definition slang-cuda-prelude.h:2828
SLANG_FORCE_INLINE SLANG_CUDA_CALL int32_t I32_min(int32_t a, int32_t b)
Definition slang-cuda-prelude.h:2448
#define SLANG_FLOAT_MATRIX_OPS(T)
Definition slang-cuda-prelude.h:1248
SamplerStateUnused * SamplerState
Definition slang-cuda-prelude.h:192
__device__ __forceinline__ void __slang_atomic_reduce_or(int32_t *addr, int32_t val, int order)
Definition slang-cuda-prelude.h:2901
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_abs(float f)
Definition slang-cuda-prelude.h:2081
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_modf(double x, double *ip)
Definition slang-cuda-prelude.h:2319
int WarpMask
Definition slang-cuda-prelude.h:3094
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_log2(float f)
Definition slang-cuda-prelude.h:2061
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_max(double a, double b)
Definition slang-cuda-prelude.h:2293
#define SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(Fn, T, N)
Definition slang-cuda-prelude.h:838
unsigned long long ulonglong
Definition slang-cuda-prelude.h:303
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool F64_isinf(double f)
Definition slang-cuda-prelude.h:2283
__device__ T _waveOpDoInverse(T *inOut, const T *val)
Definition slang-cuda-prelude.h:3969
#define SLANG_CUDA_WARP_SIZE
Definition slang-cuda-prelude.h:60
SLANG_FORCE_INLINE SLANG_CUDA_CALL double I32_asdouble(int32_t low, int32_t hi)
Definition slang-cuda-prelude.h:2467
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_round(double f)
Definition slang-cuda-prelude.h:2186
unsigned int uint
Definition slang-cuda-prelude.h:309
__inline__ __device__ bool _waveIsFirstLane()
Definition slang-cuda-prelude.h:3172
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_exp(float f)
Definition slang-cuda-prelude.h:2077
SLANG_FORCE_INLINE SLANG_CUDA_CALL int32_t I32_abs(int32_t f)
Definition slang-cuda-prelude.h:2442
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U32_reversebits(uint32_t v)
Definition slang-cuda-prelude.h:2434
__device__ __forceinline__ void __slang_atomic_reduce_and(int32_t *addr, int32_t val, int order)
Definition slang-cuda-prelude.h:2880
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_cos(float f)
Definition slang-cuda-prelude.h:2013
__inline__ __device__ T _wavePrefixExclusiveMinMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4268
#define SLANG_CUDA_VECTOR_INT_OPS(T)
Definition slang-cuda-prelude.h:571
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_atan(double f)
Definition slang-cuda-prelude.h:2214
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool __ldg(const bool *ptr)
Definition slang-cuda-prelude.h:263
SLANG_FORCE_INLINE SLANG_CUDA_CALL T _slang_select(bool condition, T v0, T v1)
Definition slang-cuda-prelude.h:1358
SLANG_FORCE_INLINE SLANG_CUDA_CALL int64_t U64_min(uint64_t a, uint64_t b)
Definition slang-cuda-prelude.h:2501
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_atan(float f)
Definition slang-cuda-prelude.h:2033
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t I64_firstbithigh(int64_t v)
Definition slang-cuda-prelude.h:2560
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t I32_firstbithigh(int32_t v)
Definition slang-cuda-prelude.h:2484
__device__ T _wavePrefixMultiple(WarpMask mask, T *val)
Definition slang-cuda-prelude.h:4045
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_rsqrt(double f)
Definition slang-cuda-prelude.h:2262
SLANG_FORCE_INLINE SLANG_CUDA_CALL float U32_asfloat(uint32_t x)
Definition slang-cuda-prelude.h:2394
__inline__ __device__ T _wavePrefixSum(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4108
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_pow(double a, double b)
Definition slang-cuda-prelude.h:2297
SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf1DLayeredwrite_convert(T v, cudaSurfaceObject_t surfObj, int x, int layer, cudaSurfaceBoundaryMode boundaryMode)
Definition slang-cuda-prelude.h:1634
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t I8_countbits(int8_t v)
Definition slang-cuda-prelude.h:2356
#define SLANG_SURF3DWRITE_CONVERT_IMPL(T, c)
Definition slang-cuda-prelude.h:1739
__inline__ __device__ T _wavePrefixProductMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4133
unsigned long long OptixTraversableHandle
Definition slang-cuda-prelude.h:5912
__inline__ __device__ T _wavePrefixExclusiveMax(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4262
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_ceil(float f)
Definition slang-cuda-prelude.h:1997
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t I64_countbits(int64_t v)
Definition slang-cuda-prelude.h:2550
#define SLANG_BOUND_CHECK_BYTE_ADDRESS(index, elemSize, sizeInBytes)
Definition slang-cuda-prelude.h:111
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_sinh(double f)
Definition slang-cuda-prelude.h:2218
__inline__ __device__ T _waveSum(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3659
SLANG_FORCE_INLINE SLANG_CUDA_CALL int64_t I64_min(int64_t a, int64_t b)
Definition slang-cuda-prelude.h:2541
__forceinline__ __device__ uint32_t _getLaneId()
Definition slang-cuda-prelude.h:3076
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_trunc(float f)
Definition slang-cuda-prelude.h:2085
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_asinh(float f)
Definition slang-cuda-prelude.h:2049
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_sin(double f)
Definition slang-cuda-prelude.h:2190
__device__ T _wavePrefixInvertableMultiple(WarpMask mask, T *val)
Definition slang-cuda-prelude.h:3987
SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf2DLayeredwrite_convert(T v, cudaSurfaceObject_t surfObj, int x, int y, int layer, cudaSurfaceBoundaryMode boundaryMode)
Definition slang-cuda-prelude.h:1715
unsigned short ushort
Definition slang-cuda-prelude.h:308
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_sqrt(double f)
Definition slang-cuda-prelude.h:2258
#define SLANG_SURF2DWRITE_CONVERT_IMPL(T, c)
Definition slang-cuda-prelude.h:1656
SLANG_FORCE_INLINE SLANG_CUDA_CALL void F64_asint(double d, int32_t *low, int32_t *hi)
Definition slang-cuda-prelude.h:2332
__inline__ __device__ T _waveMax(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3671
SLANG_FORCE_INLINE SLANG_CUDA_CALL int64_t U64_max(uint64_t a, uint64_t b)
Definition slang-cuda-prelude.h:2505
__inline__ __device__ bool _waveAllEqual(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3793
__inline__ __device__ T _waveShuffleMultiple(WarpMask mask, T inVal, int lane)
Definition slang-cuda-prelude.h:3842
SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix< T, ROWS, COLS > makeMatrix(T scalar)
Definition slang-cuda-prelude.h:944
__device__ __forceinline__ void __slang_atomic_reduce_max(int32_t *addr, int32_t val, int order)
Definition slang-cuda-prelude.h:2854
SLANG_FORCE_INLINE SLANG_CUDA_CALL int32_t F32_asint(float f)
Definition slang-cuda-prelude.h:2161
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_frac(double f)
Definition slang-cuda-prelude.h:2270
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_acosh(float f)
Definition slang-cuda-prelude.h:2053
__inline__ __device__ T _waveProductMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3760
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_ceil(double f)
Definition slang-cuda-prelude.h:2178
SLANG_FORCE_INLINE SLANG_CUDA_CALL int64_t I64_max(int64_t a, int64_t b)
Definition slang-cuda-prelude.h:2545
#define SLANG_INFINITY
Definition slang-cuda-prelude.h:53
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_asin(double f)
Definition slang-cuda-prelude.h:2206
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_exp2(double f)
Definition slang-cuda-prelude.h:2242
__device__ __forceinline__ bool _slang_quadAny(bool expr)
Definition slang-cuda-prelude.h:6401
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U64_firstbithigh(uint64_t v)
Definition slang-cuda-prelude.h:2522
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_atan2(double a, double b)
Definition slang-cuda-prelude.h:2309
SLANG_FORCE_INLINE SLANG_CUDA_CALL void F64_asuint(double d, uint32_t *low, uint32_t *hi)
Definition slang-cuda-prelude.h:2324
SLANG_FORCE_INLINE SLANG_CUDA_CALL T tex3Dfetch_int(CUtexObject texObj, int x, int y, int z, int mip)
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_rsqrt(float f)
Definition slang-cuda-prelude.h:2093
#define SLANG_SURF1DWRITE_CONVERT_IMPL(T, c)
Definition slang-cuda-prelude.h:1582
SLANG_FORCE_INLINE SLANG_CUDA_CALL float I32_asfloat(int32_t x)
Definition slang-cuda-prelude.h:2457
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_exp(double f)
Definition slang-cuda-prelude.h:2246
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U32_asint(int32_t x)
Definition slang-cuda-prelude.h:2400
__inline__ __device__ T _waveMaxMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3784
__inline__ __device__ T _waveXor(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3647
__device__ T _waveReduceScalar(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3554
__inline__ __device__ T _wavePrefixInclusiveMax(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4237
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_floor(double f)
Definition slang-cuda-prelude.h:2182
#define SLANG_CUDA_FLOAT_VECTOR_MOD(T)
Definition slang-cuda-prelude.h:682
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_fmod(double a, double b)
Definition slang-cuda-prelude.h:2301
SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf2Dwrite_convert(T v, cudaSurfaceObject_t surfObj, int x, int y, cudaSurfaceBoundaryMode boundaryMode)
unsigned long long CUtexObject
Definition slang-cuda-prelude.h:184
__inline__ __device__ T _wavePrefixXor(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4114
__inline__ __device__ T _wavePrefixAnd(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4126
#define SLANG_FLOAT_MATRIX_MOD(T)
Definition slang-cuda-prelude.h:1287
#define SLANG_TEX1DARRAYFETCH_INT_IMPL(T, dtype, c)
Definition slang-cuda-prelude.h:6227
__device__ T _wavePrefixScalar(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3911
__inline__ __device__ T _wavePrefixXorMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4153
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_fma(float a, float b, float c)
Definition slang-cuda-prelude.h:2169
__inline__ __device__ T _waveReadFirst(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3819
__inline__ __device__ T _wavePrefixSumMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4143
#define SLANG_WAVE_CLUSTERED_ROTATE_IMPL(T)
Definition slang-cuda-prelude.h:6423
__device__ __forceinline__ void __slang_atomic_reduce_xor(int32_t *addr, int32_t val, int order)
Definition slang-cuda-prelude.h:2922
SLANG_FORCE_INLINE SLANG_CUDA_CALL int64_t I64_abs(int64_t f)
Definition slang-cuda-prelude.h:2536
__inline__ __device__ bool _waveAllEqualMultiple(WarpMask mask, T inVal)
Definition slang-cuda-prelude.h:3801
#define SLANG_MAKE_VECTOR_FROM_SCALAR(T)
Definition slang-cuda-prelude.h:780
__device__ __forceinline__ void __slang_atomic_reduce_inc(uint32_t *addr, int order)
Definition slang-cuda-prelude.h:2947
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U8_countbits(uint8_t v)
Definition slang-cuda-prelude.h:2348
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_log10(float f)
Definition slang-cuda-prelude.h:2069
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_frexp(double x, int *e)
Definition slang-cuda-prelude.h:2314
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U32_abs(uint32_t f)
Definition slang-cuda-prelude.h:2379
#define SLANG_TEX2DARRAYFETCH_INT_IMPL(T, dtype, c)
Definition slang-cuda-prelude.h:6269
#define SLANG_BOUND_CHECK(index, count)
Definition slang-cuda-prelude.h:106
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_frac(float f)
Definition slang-cuda-prelude.h:2101
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_cosh(float f)
Definition slang-cuda-prelude.h:2041
__inline__ __device__ bool _waveIsSingleLane(WarpMask mask)
Definition slang-cuda-prelude.h:3139
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_sin(float f)
Definition slang-cuda-prelude.h:2009
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_modf(float x, float *ip)
Definition slang-cuda-prelude.h:2150
__inline__ __device__ T _wavePrefixInclusiveMaxMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4249
#define SLANG_VECTOR_GET_ELEMENT_PTR(T)
Definition slang-cuda-prelude.h:444
SLANG_FORCE_INLINE SLANG_CUDA_CALL int64_t I64_reversebits(int64_t v)
Definition slang-cuda-prelude.h:2567
SLANG_FORCE_INLINE SLANG_CUDA_CALL int32_t I32_reversebits(int32_t v)
Definition slang-cuda-prelude.h:2489
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_acos(double f)
Definition slang-cuda-prelude.h:2210
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool F32_isnan(float f)
Definition slang-cuda-prelude.h:2106
__inline__ __device__ T _wavePrefixExclusiveMin(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4256
__inline__ __device__ T _wavePrefixProduct(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4102
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t I32_countbits(int32_t v)
Definition slang-cuda-prelude.h:2474
__inline__ __device__ T _waveXorMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3752
__inline__ __device__ T _wavePrefixAndMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4173
__device__ __forceinline__ void __slang_atomic_reduce_add(int32_t *addr, int32_t val, int order)
Definition slang-cuda-prelude.h:2751
__inline__ __device__ T _waveMinMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3776
__device__ __forceinline__ bool _slang_waveClusteredRotate(bool value, unsigned int delta, unsigned int clusterSize)
Definition slang-cuda-prelude.h:6482
SLANG_FORCE_INLINE SLANG_CUDA_CALL T tex2DArrayfetch_int(CUtexObject texObj, int x, int y, int layer, int mip)
__INTPTR_TYPE__ intptr_t
Definition slang-llvm.h:146
__UINTPTR_TYPE__ uintptr_t
Definition slang-llvm.h:153
__PTRDIFF_TYPE__ ptrdiff_t
Definition slang-llvm.h:29
static const int kSlangTorchTensorMaxDim
Definition slang-torch-prelude.h:70
Definition slang-cpp-types-core.h:82
SLANG_CUDA_CALL T & operator[](size_t index)
Definition slang-cuda-prelude.h:173
SLANG_CUDA_CALL const T & operator[](size_t index) const
Definition slang-cuda-prelude.h:168
T * data
Definition slang-cpp-types-core.h:94
size_t count
Definition slang-cpp-types-core.h:95
Definition slang-cpp-types.h:123
SLANG_CUDA_CALL StructuredBuffer< T > asStructuredBuffer() const
Definition slang-cuda-prelude.h:2693
SLANG_CUDA_CALL uint4 Load4(size_t index) const
Definition slang-cuda-prelude.h:2678
SLANG_CUDA_CALL uint32_t Load(size_t index) const
Definition slang-cuda-prelude.h:2661
SLANG_CUDA_CALL T Load(size_t index) const
Definition slang-cuda-prelude.h:2685
SLANG_CUDA_CALL uint2 Load2(size_t index) const
Definition slang-cuda-prelude.h:2666
SLANG_CUDA_CALL uint3 Load3(size_t index) const
Definition slang-cuda-prelude.h:2672
const uint32_t * data
Definition slang-cpp-types.h:155
SLANG_CUDA_CALL void GetDimensions(uint32_t *outDim) const
Definition slang-cuda-prelude.h:2660
size_t sizeInBytes
Definition slang-cpp-types.h:156
T Type
Definition slang-cuda-prelude.h:3549
char Type
Definition slang-cuda-prelude.h:3440
char Type
Definition slang-cuda-prelude.h:3445
char Type
Definition slang-cuda-prelude.h:3450
char Type
Definition slang-cuda-prelude.h:3326
double Type
Definition slang-cuda-prelude.h:3418
double Type
Definition slang-cuda-prelude.h:3423
double Type
Definition slang-cuda-prelude.h:3428
double Type
Definition slang-cuda-prelude.h:3433
double Type
Definition slang-cuda-prelude.h:3311
float Type
Definition slang-cuda-prelude.h:3397
float Type
Definition slang-cuda-prelude.h:3402
float Type
Definition slang-cuda-prelude.h:3407
float Type
Definition slang-cuda-prelude.h:3412
float Type
Definition slang-cuda-prelude.h:3306
int Type
Definition slang-cuda-prelude.h:3355
int Type
Definition slang-cuda-prelude.h:3360
int Type
Definition slang-cuda-prelude.h:3365
int Type
Definition slang-cuda-prelude.h:3370
int64_t Type
Definition slang-cuda-prelude.h:3321
int Type
Definition slang-cuda-prelude.h:3296
int64_t Type
Definition slang-cuda-prelude.h:3500
int64_t Type
Definition slang-cuda-prelude.h:3505
int64_t Type
Definition slang-cuda-prelude.h:3510
short Type
Definition slang-cuda-prelude.h:3470
short Type
Definition slang-cuda-prelude.h:3475
short Type
Definition slang-cuda-prelude.h:3480
short Type
Definition slang-cuda-prelude.h:3336
uchar Type
Definition slang-cuda-prelude.h:3455
uchar Type
Definition slang-cuda-prelude.h:3460
uchar Type
Definition slang-cuda-prelude.h:3465
uchar Type
Definition slang-cuda-prelude.h:3331
uint Type
Definition slang-cuda-prelude.h:3376
uint Type
Definition slang-cuda-prelude.h:3381
uint Type
Definition slang-cuda-prelude.h:3386
uint Type
Definition slang-cuda-prelude.h:3391
uint64_t Type
Definition slang-cuda-prelude.h:3316
uint Type
Definition slang-cuda-prelude.h:3301
uint64_t Type
Definition slang-cuda-prelude.h:3515
uint64_t Type
Definition slang-cuda-prelude.h:3520
uint64_t Type
Definition slang-cuda-prelude.h:3525
ushort Type
Definition slang-cuda-prelude.h:3485
ushort Type
Definition slang-cuda-prelude.h:3490
ushort Type
Definition slang-cuda-prelude.h:3495
ushort Type
Definition slang-cuda-prelude.h:3341
Definition slang-cuda-prelude.h:3290
Definition slang-cpp-types-core.h:63
SLANG_CUDA_CALL const T & operator[](size_t index) const
Definition slang-cuda-prelude.h:149
T m_data[SIZE]
Definition slang-cpp-types-core.h:75
SLANG_CUDA_CALL T & operator[](size_t index)
Definition slang-cuda-prelude.h:154
Definition slang-cuda-prelude.h:869
Definition slang-cpp-types-core.h:401
SLANG_FORCE_INLINE SLANG_CUDA_CALL const Vector< T, COLS > & operator[](size_t index) const
Definition slang-cuda-prelude.h:936
SLANG_FORCE_INLINE SLANG_CUDA_CALL Vector< T, COLS > & operator[](size_t index)
Definition slang-cuda-prelude.h:931
Vector< T, COLS > rows[ROWS]
Definition slang-cpp-types-core.h:402
Definition slang-cpp-types.h:163
SLANG_CUDA_CALL uint32_t Load(size_t index) const
Definition slang-cuda-prelude.h:2976
SLANG_CUDA_CALL RWStructuredBuffer< T > asStructuredBuffer() const
Definition slang-cuda-prelude.h:3052
SLANG_CUDA_CALL void Store2(size_t index, uint2 v) const
Definition slang-cuda-prelude.h:3013
SLANG_CUDA_CALL void Store(size_t index, uint32_t v) const
Definition slang-cuda-prelude.h:3008
SLANG_CUDA_CALL void GetDimensions(uint32_t *outDim) const
Definition slang-cuda-prelude.h:2974
SLANG_CUDA_CALL uint3 Load3(size_t index) const
Definition slang-cuda-prelude.h:2987
SLANG_CUDA_CALL uint4 Load4(size_t index) const
Definition slang-cuda-prelude.h:2993
SLANG_CUDA_CALL void Store3(size_t index, uint3 v) const
Definition slang-cuda-prelude.h:3020
SLANG_CUDA_CALL uint2 Load2(size_t index) const
Definition slang-cuda-prelude.h:2981
SLANG_CUDA_CALL T * _getPtrAt(size_t index)
Can be used in the core module to gain access
Definition slang-cuda-prelude.h:3046
SLANG_CUDA_CALL T Load(size_t index) const
Definition slang-cuda-prelude.h:3000
size_t sizeInBytes
Definition slang-cpp-types.h:233
SLANG_CUDA_CALL void Store4(size_t index, uint4 v) const
Definition slang-cuda-prelude.h:3028
SLANG_CUDA_CALL void Store(size_t index, T const &value) const
Definition slang-cuda-prelude.h:3038
uint32_t * data
Definition slang-cpp-types.h:232
Definition slang-cpp-types.h:38
SLANG_CUDA_CALL T & operator[](size_t index) const
Definition slang-cuda-prelude.h:2648
T * data
Definition slang-cpp-types.h:55
size_t count
Definition slang-cpp-types.h:56
Definition slang-cpp-types.h:61
size_t count
Definition slang-cpp-types.h:79
SLANG_CUDA_CALL T & Load(size_t index) const
Definition slang-cuda-prelude.h:2623
T * data
Definition slang-cpp-types.h:78
SLANG_CUDA_CALL void GetDimensions(uint32_t *outNumStructs, uint32_t *outStride) const
Definition slang-cuda-prelude.h:2632
SLANG_CUDA_CALL T & operator[](size_t index) const
Definition slang-cuda-prelude.h:2615
Definition slang-cuda-prelude.h:5920
__device__ T * data_ptr_at(uint2 index)
Definition slang-cuda-prelude.h:5940
__device__ void store(uint32_t x, T val)
Definition slang-cuda-prelude.h:6032
__device__ void store(uint32_t x, uint32_t y, uint32_t z, uint32_t w, T val)
Definition slang-cuda-prelude.h:6058
__device__ void store(uint2 index, T val)
Definition slang-cuda-prelude.h:6042
__device__ T & load(uint32_t x, uint32_t y, uint32_t z, uint32_t w)
Definition slang-cuda-prelude.h:5999
__device__ T & load(uint4 index)
Definition slang-cuda-prelude.h:6005
__device__ T * data_ptr_at(uint32_t index)
Definition slang-cuda-prelude.h:5933
__device__ void store(uint32_t x, uint32_t y, T val)
Definition slang-cuda-prelude.h:6037
uint32_t strides[kSlangTorchTensorMaxDim]
Definition slang-cuda-prelude.h:5922
__device__ T & load(uint32_t x, uint32_t y, uint32_t z)
Definition slang-cuda-prelude.h:5988
__device__ T & load(uint32_t x, uint32_t y)
Definition slang-cuda-prelude.h:5978
uint32_t sizes[kSlangTorchTensorMaxDim]
Definition slang-cuda-prelude.h:5923
__device__ T * data_ptr()
Definition slang-cuda-prelude.h:5927
__device__ T * data_ptr_at(uint4 index)
Definition slang-cuda-prelude.h:5954
__device__ void store(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4, T val)
Definition slang-cuda-prelude.h:6071
__device__ void store(uint4 index, T val)
Definition slang-cuda-prelude.h:6064
uint8_t * data
Definition slang-cuda-prelude.h:5921
__device__ void store(uint3 index, T val)
Definition slang-cuda-prelude.h:6052
__device__ T & load(uint index[N])
Definition slang-cuda-prelude.h:6021
__device__ T & load(uint32_t x)
Definition slang-cuda-prelude.h:5973
__device__ T * data_ptr_at(uint index[N])
Definition slang-cuda-prelude.h:5962
__device__ T & load(uint2 index)
Definition slang-cuda-prelude.h:5983
__device__ T * data_ptr_at(uint3 index)
Definition slang-cuda-prelude.h:5947
__device__ T & load(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4)
Definition slang-cuda-prelude.h:6012
__device__ void store(uint index[N], T val)
Definition slang-cuda-prelude.h:6080
__device__ void store(uint32_t x, uint32_t y, uint32_t z, T val)
Definition slang-cuda-prelude.h:6047
uint32_t dimensionCount
Definition slang-cuda-prelude.h:5924
__device__ T & load(uint3 index)
Definition slang-cuda-prelude.h:5993
Definition slang-cpp-types-core.h:57
size_t typeSize
Definition slang-cpp-types-core.h:58
Definition slang-cpp-types-core.h:103
Definition slang-cuda-prelude.h:3208
__inline__ static __device__ T getInitial(T a)
Definition slang-cuda-prelude.h:3209
__inline__ static __device__ T doOp(T a, T b)
Definition slang-cuda-prelude.h:3210
__inline__ static __device__ T doInverse(T a, T b)
Definition slang-cuda-prelude.h:3211
Definition slang-cuda-prelude.h:3193
__inline__ static __device__ T getInitial(T a)
Definition slang-cuda-prelude.h:3194
__inline__ static __device__ T doOp(T a, T b)
Definition slang-cuda-prelude.h:3195
Definition slang-cuda-prelude.h:4224
__inline__ static __device__ T getInitial(T a)
Definition slang-cuda-prelude.h:4225
__inline__ static __device__ T doOp(T a, T b)
Definition slang-cuda-prelude.h:4226
Definition slang-cuda-prelude.h:4217
__inline__ static __device__ T doOp(T a, T b)
Definition slang-cuda-prelude.h:4219
__inline__ static __device__ T getInitial(T a)
Definition slang-cuda-prelude.h:4218
Definition slang-cuda-prelude.h:3227
__inline__ static __device__ T doOp(T a, T b)
Definition slang-cuda-prelude.h:3229
__inline__ static __device__ T getInitial(T a, bool exclusive=false)
Definition slang-cuda-prelude.h:3234
__inline__ static __device__ T doOp(T a, T b)
Definition slang-cuda-prelude.h:3236
__inline__ static __device__ T getInitial(T a, bool exclusive=false)
Definition slang-cuda-prelude.h:3216
__inline__ static __device__ T doInverse(T a, T b)
Definition slang-cuda-prelude.h:3222
__inline__ static __device__ T getInitial(T a)
Definition slang-cuda-prelude.h:3217
__inline__ static __device__ T doOp(T a, T b)
Definition slang-cuda-prelude.h:3218
Definition slang-cuda-prelude.h:3186
__inline__ static __device__ T doOp(T a, T b)
Definition slang-cuda-prelude.h:3188
__inline__ static __device__ T getInitial(T a)
Definition slang-cuda-prelude.h:3187
Definition slang-cuda-prelude.h:3200
__inline__ static __device__ T doOp(T a, T b)
Definition slang-cuda-prelude.h:3202
__inline__ static __device__ T doInverse(T a, T b)
Definition slang-cuda-prelude.h:3203
__inline__ static __device__ T getInitial(T a)
Definition slang-cuda-prelude.h:3201
Definition slang-cpp-scalar-intrinsics.h:671
Definition slang-cpp-scalar-intrinsics.h:24
uint32_t u
Definition slang-cpp-scalar-intrinsics.h:25
int32_t i
Definition slang-cpp-scalar-intrinsics.h:26
float f
Definition slang-cpp-scalar-intrinsics.h:27
Definition slang-cpp-scalar-intrinsics.h:31
int64_t i
Definition slang-cpp-scalar-intrinsics.h:33
double d
Definition slang-cpp-scalar-intrinsics.h:34
uint64_t u
Definition slang-cpp-scalar-intrinsics.h:32