1#ifndef SLANG_CUDA_PRELUDE_H
2#define SLANG_CUDA_PRELUDE_H
4#define SLANG_PRELUDE_EXPORT
7#define SLANG_CUDA_RTC 1
9#define SLANG_CUDA_RTC 0
27#ifdef SLANG_CUDA_ENABLE_HALF
30#define __CUDA_NO_HALF2_OPERATORS__
34#ifdef SLANG_CUDA_ENABLE_FP8
38#ifdef SLANG_CUDA_ENABLE_BF16
42#ifdef SLANG_CUDA_ENABLE_OPTIX
47#ifndef SLANG_OFFSET_OF
48#define SLANG_OFFSET_OF(type, member) (size_t)((char*)&(((type*)0)->member) - (char*)0)
53#define SLANG_INFINITY ((float)(1e+300 * 1e+300))
57#define SLANG_PRELUDE_ASSERT(x)
59#ifndef SLANG_CUDA_WARP_SIZE
60#define SLANG_CUDA_WARP_SIZE 32
63#define SLANG_CUDA_WARP_MASK \
64 (SLANG_CUDA_WARP_SIZE - 1)
65#define SLANG_CUDA_WARP_BITMASK (~int(0))
68#define SLANG_FORCE_INLINE inline
70#define SLANG_CUDA_CALL __device__
72#define SLANG_FORCE_INLINE inline
73#define SLANG_INLINE inline
82#define SLANG_BOUND_ASSERT(index, count) SLANG_PRELUDE_ASSERT(index < count);
83#define SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) \
84 SLANG_PRELUDE_ASSERT(index <= (sizeInBytes - elemSize) && (index & 3) == 0);
87#define SLANG_BOUND_ZERO_INDEX(index, count) index = (index < count) ? index : 0;
88#define SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes) \
89 index = (index <= (sizeInBytes - elemSize)) ? index : 0;
93#ifdef SLANG_ENABLE_BOUND_ZERO_INDEX
94#define SLANG_BOUND_FIX(index, count) SLANG_BOUND_ZERO_INDEX(index, count)
95#define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes) \
96 SLANG_BOUND_ZERO_INDEX_BYTE_ADDRESS(index, elemSize, sizeInBytes)
97#define SLANG_BOUND_FIX_FIXED_ARRAY(index, count) \
98 SLANG_BOUND_ZERO_INDEX(index, count) SLANG_BOUND_ZERO_INDEX(index, count)
100#define SLANG_BOUND_FIX(index, count)
101#define SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes)
102#define SLANG_BOUND_FIX_FIXED_ARRAY(index, count)
105#ifndef SLANG_BOUND_CHECK
106#define SLANG_BOUND_CHECK(index, count) \
107 SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX(index, count)
110#ifndef SLANG_BOUND_CHECK_BYTE_ADDRESS
111#define SLANG_BOUND_CHECK_BYTE_ADDRESS(index, elemSize, sizeInBytes) \
112 SLANG_BOUND_ASSERT_BYTE_ADDRESS(index, elemSize, sizeInBytes) \
113 SLANG_BOUND_FIX_BYTE_ADDRESS(index, elemSize, sizeInBytes)
116#ifndef SLANG_BOUND_CHECK_FIXED_ARRAY
117#define SLANG_BOUND_CHECK_FIXED_ARRAY(index, count) \
118 SLANG_BOUND_ASSERT(index, count) SLANG_BOUND_FIX_FIXED_ARRAY(index, count)
128#ifndef SLANG_CUDA_BOUNDARY_MODE
129#define SLANG_CUDA_BOUNDARY_MODE cudaBoundaryModeZero
138#define SLANG_PTX_BOUNDARY_MODE "zero"
146template<
typename T,
size_t SIZE>
191struct SamplerStateUnused;
200template<
typename T,
int ROWS,
int COLS>
265 return (
bool)(
__ldg((
const char*)ptr));
270 auto val =
__ldg((
const char2*)ptr);
271 return {val.x != 0, val.y != 0};
276 auto val =
__ldg((
const char4*)ptr);
277 return {val.x != 0, val.y != 0, val.z != 0, val.w != 0};
282typedef signed char int8_t;
283typedef short int16_t;
285typedef long long int64_t;
288typedef unsigned char uint8_t;
289typedef unsigned short uint16_t;
290typedef unsigned int uint32_t;
291typedef unsigned long long uint64_t;
311#if SLANG_CUDA_ENABLE_HALF
337 return ::fmodf(x, y);
344#if SLANG_CUDA_ENABLE_HALF
361#if SLANG_CUDA_ENABLE_BF16
370 __nv_bfloat16 x, y, z;
374 __nv_bfloat16 x, y, z, w;
378#if SLANG_CUDA_ENABLE_FP8
391 __nv_fp8_e4m3 x, y, z;
395 __nv_fp8_e4m3 x, y, z, w;
407 __nv_fp8_e5m2 x, y, z;
411 __nv_fp8_e5m2 x, y, z, w;
415#define SLANG_VECTOR_GET_ELEMENT(T) \
416 SLANG_FORCE_INLINE SLANG_CUDA_CALL T _slang_vector_get_element(T##1 x, int index) \
418 return ((T*)(&x))[index]; \
420 SLANG_FORCE_INLINE SLANG_CUDA_CALL T _slang_vector_get_element(T##2 x, int index) \
422 return ((T*)(&x))[index]; \
424 SLANG_FORCE_INLINE SLANG_CUDA_CALL T _slang_vector_get_element(T##3 x, int index) \
426 return ((T*)(&x))[index]; \
428 SLANG_FORCE_INLINE SLANG_CUDA_CALL T _slang_vector_get_element(T##4 x, int index) \
430 return ((T*)(&x))[index]; \
444#define SLANG_VECTOR_GET_ELEMENT_PTR(T) \
445 SLANG_FORCE_INLINE SLANG_CUDA_CALL T* _slang_vector_get_element_ptr(const T##1 * x, int index) \
447 return ((T*)(x)) + index; \
449 SLANG_FORCE_INLINE SLANG_CUDA_CALL T* _slang_vector_get_element_ptr(const T##2 * x, int index) \
451 return ((T*)(x)) + index; \
453 SLANG_FORCE_INLINE SLANG_CUDA_CALL T* _slang_vector_get_element_ptr(const T##3 * x, int index) \
455 return ((T*)(x)) + index; \
457 SLANG_FORCE_INLINE SLANG_CUDA_CALL T* _slang_vector_get_element_ptr(const T##4 * x, int index) \
459 return ((T*)(x)) + index; \
473#if SLANG_CUDA_ENABLE_HALF
478#if SLANG_CUDA_ENABLE_BF16
483_slang_vector_dot(__nv_bfloat162 v0, __nv_bfloat162 v1)
485 __nv_bfloat16 result = __nv_bfloat16(0.0f);
486 for (
int i = 0; i < 2; i++)
493_slang_vector_dot(__nv_bfloat163 v0, __nv_bfloat163 v1)
495 __nv_bfloat16 result = __nv_bfloat16(0.0f);
496 for (
int i = 0; i < 3; i++)
503_slang_vector_dot(__nv_bfloat164 v0, __nv_bfloat164 v1)
505 __nv_bfloat16 result = __nv_bfloat16(0.0f);
506 for (
int i = 0; i < 4; i++)
514#if SLANG_CUDA_ENABLE_FP8
521#define SLANG_CUDA_VECTOR_BINARY_OP(T, n, op) \
522 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##n operator op(T##n thisVal, T##n other) \
525 for (int i = 0; i < n; i++) \
526 *_slang_vector_get_element_ptr(&result, i) = \
527 _slang_vector_get_element(thisVal, i) op _slang_vector_get_element(other, i); \
530#define SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, op) \
531 SLANG_FORCE_INLINE SLANG_CUDA_CALL bool##n operator op(T##n thisVal, T##n other) \
534 for (int i = 0; i < n; i++) \
535 *_slang_vector_get_element_ptr(&result, i) = \
536 (_slang_vector_get_element(thisVal, i) op _slang_vector_get_element(other, i)); \
539#define SLANG_CUDA_VECTOR_UNARY_OP(T, n, op) \
540 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##n operator op(T##n thisVal) \
543 for (int i = 0; i < n; i++) \
544 *_slang_vector_get_element_ptr(&result, i) = op _slang_vector_get_element(thisVal, i); \
548#define SLANG_CUDA_VECTOR_INT_OP(T, n) \
549 SLANG_CUDA_VECTOR_BINARY_OP(T, n, +) \
550 SLANG_CUDA_VECTOR_BINARY_OP(T, n, -) \
551 SLANG_CUDA_VECTOR_BINARY_OP(T, n, *) \
552 SLANG_CUDA_VECTOR_BINARY_OP(T, n, /) \
553 SLANG_CUDA_VECTOR_BINARY_OP(T, n, %) \
554 SLANG_CUDA_VECTOR_BINARY_OP(T, n, ^) \
555 SLANG_CUDA_VECTOR_BINARY_OP(T, n, &) \
556 SLANG_CUDA_VECTOR_BINARY_OP(T, n, |) \
557 SLANG_CUDA_VECTOR_BINARY_OP(T, n, &&) \
558 SLANG_CUDA_VECTOR_BINARY_OP(T, n, ||) \
559 SLANG_CUDA_VECTOR_BINARY_OP(T, n, >>) \
560 SLANG_CUDA_VECTOR_BINARY_OP(T, n, <<) \
561 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, >) \
562 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, <) \
563 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, >=) \
564 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, <=) \
565 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, ==) \
566 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, !=) \
567 SLANG_CUDA_VECTOR_UNARY_OP(T, n, !) \
568 SLANG_CUDA_VECTOR_UNARY_OP(T, n, -) \
569 SLANG_CUDA_VECTOR_UNARY_OP(T, n, ~)
571#define SLANG_CUDA_VECTOR_INT_OPS(T) \
572 SLANG_CUDA_VECTOR_INT_OP(T, 2) \
573 SLANG_CUDA_VECTOR_INT_OP(T, 3) \
574 SLANG_CUDA_VECTOR_INT_OP(T, 4)
586#define SLANG_CUDA_VECTOR_FLOAT_OP(T, n) \
587 SLANG_CUDA_VECTOR_BINARY_OP(T, n, +) \
588 SLANG_CUDA_VECTOR_BINARY_OP(T, n, -) \
589 SLANG_CUDA_VECTOR_BINARY_OP(T, n, *) \
590 SLANG_CUDA_VECTOR_BINARY_OP(T, n, /) \
591 SLANG_CUDA_VECTOR_BINARY_OP(T, n, &&) \
592 SLANG_CUDA_VECTOR_BINARY_OP(T, n, ||) \
593 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, >) \
594 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, <) \
595 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, >=) \
596 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, <=) \
597 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, ==) \
598 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(T, n, !=) \
599 SLANG_CUDA_VECTOR_UNARY_OP(T, n, -)
602#define SLANG_CUDA_VECTOR_FLOAT_OP_HALF2 \
603 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator+(const __half2& lh, const __half2& rh) \
605 return __hadd2(lh, rh); \
607 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator-(const __half2& lh, const __half2& rh) \
609 return __hsub2(lh, rh); \
611 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator*(const __half2& lh, const __half2& rh) \
613 return __hmul2(lh, rh); \
615 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator/(const __half2& lh, const __half2& rh) \
617 return __h2div(lh, rh); \
619 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 operator-(const __half2& h) \
623 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator+=(__half2& lh, const __half2& rh) \
625 lh = __hadd2(lh, rh); \
628 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator-=(__half2& lh, const __half2& rh) \
630 lh = __hsub2(lh, rh); \
633 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator*=(__half2& lh, const __half2& rh) \
635 lh = __hmul2(lh, rh); \
638 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2& operator/=(__half2& lh, const __half2& rh) \
640 lh = __h2div(lh, rh); \
643 SLANG_CUDA_VECTOR_BINARY_OP(__half, 2, &&) \
644 SLANG_CUDA_VECTOR_BINARY_OP(__half, 2, ||) \
645 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(__half, 2, >) \
646 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(__half, 2, <) \
647 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(__half, 2, >=) \
648 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(__half, 2, <=) \
649 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(__half, 2, ==) \
650 SLANG_CUDA_VECTOR_BINARY_COMPARE_OP(__half, 2, !=)
653#define SLANG_CUDA_VECTOR_FLOAT_OPS_float \
654 SLANG_CUDA_VECTOR_FLOAT_OP(float, 2) \
655 SLANG_CUDA_VECTOR_FLOAT_OP(float, 3) \
656 SLANG_CUDA_VECTOR_FLOAT_OP(float, 4)
657#define SLANG_CUDA_VECTOR_FLOAT_OPS_double \
658 SLANG_CUDA_VECTOR_FLOAT_OP(double, 2) \
659 SLANG_CUDA_VECTOR_FLOAT_OP(double, 3) \
660 SLANG_CUDA_VECTOR_FLOAT_OP(double, 4)
661#define SLANG_CUDA_VECTOR_FLOAT_OPS___half \
662 SLANG_CUDA_VECTOR_FLOAT_OP_HALF2 \
663 SLANG_CUDA_VECTOR_FLOAT_OP(__half, 3) \
664 SLANG_CUDA_VECTOR_FLOAT_OP(__half, 4)
665#define SLANG_CUDA_VECTOR_FLOAT_OPS(T) SLANG_CUDA_VECTOR_FLOAT_OPS_##T
669#if SLANG_CUDA_ENABLE_HALF
672#define SLANG_CUDA_FLOAT_VECTOR_MOD_IMPL(T, n) \
673 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##n operator%(const T##n& left, const T##n& right) \
676 for (int i = 0; i < n; i++) \
677 *_slang_vector_get_element_ptr(&result, i) = _slang_fmod( \
678 _slang_vector_get_element(left, i), \
679 _slang_vector_get_element(right, i)); \
682#define SLANG_CUDA_FLOAT_VECTOR_MOD(T) \
683 SLANG_CUDA_FLOAT_VECTOR_MOD_IMPL(T, 2) \
684 SLANG_CUDA_FLOAT_VECTOR_MOD_IMPL(T, 3) \
685 SLANG_CUDA_FLOAT_VECTOR_MOD_IMPL(T, 4)
690#if SLANG_CUDA_RTC || SLANG_CUDA_ENABLE_HALF
691#define SLANG_MAKE_VECTOR(T) \
692 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##2 make_##T##2(T x, T y) \
694 return T##2 {x, y}; \
696 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##3 make_##T##3(T x, T y, T z) \
698 return T##3 {x, y, z}; \
700 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##4 make_##T##4(T x, T y, T z, T w) \
702 return T##4 {x, y, z, w}; \
707SLANG_MAKE_VECTOR(
int)
708SLANG_MAKE_VECTOR(
uint)
709SLANG_MAKE_VECTOR(
short)
711SLANG_MAKE_VECTOR(
char)
712SLANG_MAKE_VECTOR(
uchar)
713SLANG_MAKE_VECTOR(
float)
714SLANG_MAKE_VECTOR(
double)
719#if SLANG_CUDA_ENABLE_HALF
720SLANG_MAKE_VECTOR(__half)
723#if SLANG_CUDA_ENABLE_BF16
724SLANG_MAKE_VECTOR(__nv_bfloat16)
727#if SLANG_CUDA_ENABLE_FP8
728SLANG_MAKE_VECTOR(__nv_fp8_e4m3)
729SLANG_MAKE_VECTOR(__nv_fp8_e5m2)
742 return bool3{x, y, z};
746 return bool4{x, y, z, w};
754 return bool3{x, x, x};
758 return bool4{x, x, x, x};
762#define SLANG_MAKE_VECTOR_FROM_SCALAR(T) \
763 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##1 make_##T##1(T x) \
767 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##2 make_##T##2(T x) \
769 return make_##T##2(x, x); \
771 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##3 make_##T##3(T x) \
773 return make_##T##3(x, x, x); \
775 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##4 make_##T##4(T x) \
777 return make_##T##4(x, x, x, x); \
780#define SLANG_MAKE_VECTOR_FROM_SCALAR(T) \
781 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##2 make_##T##2(T x) \
783 return make_##T##2(x, x); \
785 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##3 make_##T##3(T x) \
787 return make_##T##3(x, x, x); \
789 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##4 make_##T##4(T x) \
791 return make_##T##4(x, x, x, x); \
804#if SLANG_CUDA_ENABLE_HALF
813#if SLANG_CUDA_ENABLE_BF16
818 return __nv_bfloat16{x};
823#if SLANG_CUDA_ENABLE_FP8
829 return __nv_fp8_e4m3{x};
833 return __nv_fp8_e5m2{x};
838#define SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(Fn, T, N) \
839 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##N Fn(T##N* address, T##N val) \
842 for (int i = 0; i < N; i++) \
843 *_slang_vector_get_element_ptr(&result, i) = \
844 Fn(_slang_vector_get_element_ptr(address, i), _slang_vector_get_element(val, i)); \
848#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 900
852#if defined(SLANG_CUDA_ENABLE_HALF) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)
867template<typename T,
int n>
872#define GET_VECTOR_TYPE_IMPL(T, n) \
874 struct GetVectorTypeImpl<T, n> \
877 static SLANG_FORCE_INLINE SLANG_CUDA_CALL T##n fromScalar(T v) \
879 return make_##T##n(v); \
882#define GET_VECTOR_TYPE_IMPL_N(T) \
883 GET_VECTOR_TYPE_IMPL(T, 1) \
884 GET_VECTOR_TYPE_IMPL(T, 2) \
885 GET_VECTOR_TYPE_IMPL(T, 3) \
886 GET_VECTOR_TYPE_IMPL(T, 4)
899#if SLANG_CUDA_ENABLE_HALF
902#if SLANG_CUDA_ENABLE_BF16
905#if SLANG_CUDA_ENABLE_FP8
910template<
typename T,
int n>
913template<
typename T,
int n,
typename OtherT,
int m>
917 for (
int i = 0; i < n; i++)
919 OtherT otherElement = T(0);
927template<
typename T,
int ROWS,
int COLS>
943template<
typename T,
int ROWS,
int COLS>
947 for (
int i = 0; i < ROWS; i++)
952template<
typename T,
int ROWS,
int COLS>
956 result.
rows[0] = row0;
960template<
typename T,
int ROWS,
int COLS>
966 result.
rows[0] = row0;
967 result.
rows[1] = row1;
971template<
typename T,
int ROWS,
int COLS>
978 result.
rows[0] = row0;
979 result.
rows[1] = row1;
980 result.
rows[2] = row2;
984template<
typename T,
int ROWS,
int COLS>
992 result.
rows[0] = row0;
993 result.
rows[1] = row1;
994 result.
rows[2] = row2;
995 result.
rows[3] = row3;
999template<
typename T,
int ROWS,
int COLS,
typename U,
int otherRow,
int otherCol>
1006 if (minRow > otherRow)
1008 if (minCol > otherCol)
1010 for (
int i = 0; i < minRow; i++)
1011 for (
int j = 0; j < minCol; j++)
1017template<
typename T,
int ROWS,
int COLS>
1028template<
typename T,
int ROWS,
int COLS>
1059template<
typename T,
int ROWS,
int COLS>
1096template<
typename T,
int ROWS,
int COLS>
1121template<
typename T,
int ROWS,
int COLS>
1170template<
typename T,
int ROWS,
int COLS>
1209#define SLANG_MATRIX_BINARY_OP(T, op) \
1210 template<int R, int C> \
1211 SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, R, C> operator op( \
1212 const Matrix<T, R, C>& thisVal, \
1213 const Matrix<T, R, C>& other) \
1215 Matrix<T, R, C> result; \
1216 for (int i = 0; i < R; i++) \
1217 for (int j = 0; j < C; j++) \
1218 *_slang_vector_get_element_ptr(result.rows + i, j) = \
1219 _slang_vector_get_element(thisVal.rows[i], j) \
1220 op _slang_vector_get_element(other.rows[i], j); \
1224#define SLANG_MATRIX_UNARY_OP(T, op) \
1225 template<int R, int C> \
1226 SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, R, C> operator op(const Matrix<T, R, C>& thisVal) \
1228 Matrix<T, R, C> result; \
1229 for (int i = 0; i < R; i++) \
1230 for (int j = 0; j < C; j++) \
1231 *_slang_vector_get_element_ptr(result.rows + i, j) = \
1232 op _slang_vector_get_element(thisVal.rows[i], j); \
1235#define SLANG_INT_MATRIX_OPS(T) \
1236 SLANG_MATRIX_BINARY_OP(T, +) \
1237 SLANG_MATRIX_BINARY_OP(T, -) \
1238 SLANG_MATRIX_BINARY_OP(T, *) \
1239 SLANG_MATRIX_BINARY_OP(T, /) \
1240 SLANG_MATRIX_BINARY_OP(T, &) \
1241 SLANG_MATRIX_BINARY_OP(T, |) \
1242 SLANG_MATRIX_BINARY_OP(T, &&) \
1243 SLANG_MATRIX_BINARY_OP(T, ||) \
1244 SLANG_MATRIX_BINARY_OP(T, ^) \
1245 SLANG_MATRIX_BINARY_OP(T, %) \
1246 SLANG_MATRIX_UNARY_OP(T, !) \
1247 SLANG_MATRIX_UNARY_OP(T, ~)
1248#define SLANG_FLOAT_MATRIX_OPS(T) \
1249 SLANG_MATRIX_BINARY_OP(T, +) \
1250 SLANG_MATRIX_BINARY_OP(T, -) \
1251 SLANG_MATRIX_BINARY_OP(T, *) \
1252 SLANG_MATRIX_BINARY_OP(T, /) \
1253 SLANG_MATRIX_UNARY_OP(T, -)
1264#if SLANG_CUDA_ENABLE_HALF
1267#define SLANG_MATRIX_INT_NEG_OP(T) \
1268 template<int R, int C> \
1269 SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, R, C> operator-(Matrix<T, R, C> thisVal) \
1271 Matrix<T, R, C> result; \
1272 for (int i = 0; i < R; i++) \
1273 for (int j = 0; j < C; j++) \
1274 *_slang_vector_get_element_ptr(result.rows + i, j) = \
1275 0 - _slang_vector_get_element(thisVal.rows[i], j); \
1287#define SLANG_FLOAT_MATRIX_MOD(T) \
1288 template<int R, int C> \
1289 SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix<T, R, C> operator%( \
1290 Matrix<T, R, C> left, \
1291 Matrix<T, R, C> right) \
1293 Matrix<T, R, C> result; \
1294 for (int i = 0; i < R; i++) \
1295 for (int j = 0; j < C; j++) \
1296 *_slang_vector_get_element_ptr(result.rows + i, j) = _slang_fmod( \
1297 _slang_vector_get_element(left.rows[i], j), \
1298 _slang_vector_get_element(right.rows[i], j)); \
1304#if SLANG_CUDA_ENABLE_HALF
1305template<
int R,
int C>
1311 for (
int i = 0; i <
R; i++)
1312 for (
int j = 0; j < C; j++)
1319#undef SLANG_FLOAT_MATRIX_MOD
1320#undef SLANG_MATRIX_BINARY_OP
1321#undef SLANG_MATRIX_UNARY_OP
1322#undef SLANG_INT_MATRIX_OPS
1323#undef SLANG_FLOAT_MATRIX_OPS
1324#undef SLANG_MATRIX_INT_NEG_OP
1325#undef SLANG_FLOAT_MATRIX_MOD
1327#define SLANG_SELECT_IMPL(T, N) \
1328 SLANG_FORCE_INLINE SLANG_CUDA_CALL Vector<T, N> _slang_select( \
1329 bool##N condition, \
1333 Vector<T, N> result; \
1334 for (int i = 0; i < N; i++) \
1336 *_slang_vector_get_element_ptr(&result, i) = _slang_vector_get_element(condition, i) \
1337 ? _slang_vector_get_element(v0, i) \
1338 : _slang_vector_get_element(v1, i); \
1342#define SLANG_SELECT_T(T) \
1343 SLANG_SELECT_IMPL(T, 2) \
1344 SLANG_SELECT_IMPL(T, 3) \
1345 SLANG_SELECT_IMPL(T, 4)
1360 return condition ? v0 : v1;
1367#if SLANG_CUDA_ENABLE_HALF
1374 return __halves2half2(__ushort_as_half(i.x), __ushort_as_half(i.y));
1378 return __half3{__ushort_as_half(i.x), __ushort_as_half(i.y), __ushort_as_half(i.z)};
1383 __ushort_as_half(i.x),
1384 __ushort_as_half(i.y),
1385 __ushort_as_half(i.z),
1386 __ushort_as_half(i.w)};
1393 return make_ushort2(__half_as_ushort(i.x), __half_as_ushort(i.y));
1397 return make_ushort3(__half_as_ushort(i.x), __half_as_ushort(i.y), __half_as_ushort(i.z));
1401 return make_ushort4(
1402 __half_as_ushort(i.x),
1403 __half_as_ushort(i.y),
1404 __half_as_ushort(i.z),
1405 __half_as_ushort(i.w));
1416struct __nv_isurf_trait<__half>
1421struct __nv_isurf_trait<__half2>
1426struct __nv_isurf_trait<__half4>
1431#define SLANG_DROP_PARENS(...) __VA_ARGS__
1433#define SLANG_SURFACE_READ(FUNC_NAME, TYPE_ARGS, ARGS) \
1435 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half FUNC_NAME<__half>( \
1436 cudaSurfaceObject_t surfObj, \
1437 SLANG_DROP_PARENS TYPE_ARGS, \
1438 cudaSurfaceBoundaryMode boundaryMode) \
1440 return __ushort_as_half(FUNC_NAME<ushort>(surfObj, SLANG_DROP_PARENS ARGS, boundaryMode)); \
1444 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half2 FUNC_NAME<__half2>( \
1445 cudaSurfaceObject_t surfObj, \
1446 SLANG_DROP_PARENS TYPE_ARGS, \
1447 cudaSurfaceBoundaryMode boundaryMode) \
1449 return __ushort_as_half( \
1450 FUNC_NAME<ushort2>(surfObj, SLANG_DROP_PARENS ARGS, boundaryMode)); \
1454 SLANG_FORCE_INLINE SLANG_CUDA_CALL __half4 FUNC_NAME<__half4>( \
1455 cudaSurfaceObject_t surfObj, \
1456 SLANG_DROP_PARENS TYPE_ARGS, \
1457 cudaSurfaceBoundaryMode boundaryMode) \
1459 return __ushort_as_half( \
1460 FUNC_NAME<ushort4>(surfObj, SLANG_DROP_PARENS ARGS, boundaryMode)); \
1463SLANG_SURFACE_READ(surf1Dread, (
int x), (x))
1464SLANG_SURFACE_READ(surf2Dread, (
int x,
int y), (x, y))
1465SLANG_SURFACE_READ(surf3Dread, (
int x,
int y,
int z), (x, y, z))
1466SLANG_SURFACE_READ(surf1DLayeredread, (
int x,
int layer), (x, layer))
1467SLANG_SURFACE_READ(surf2DLayeredread, (
int x,
int y,
int layer), (x, y, layer))
1468SLANG_SURFACE_READ(surfCubemapread, (
int x,
int y,
int face), (x, y, face))
1469SLANG_SURFACE_READ(surfCubemapLayeredread, (
int x,
int y,
int layerFace), (x, y, layerFace))
1471#define SLANG_SURFACE_WRITE(FUNC_NAME, TYPE_ARGS, ARGS) \
1473 SLANG_FORCE_INLINE SLANG_CUDA_CALL void FUNC_NAME<__half>( \
1475 cudaSurfaceObject_t surfObj, \
1476 SLANG_DROP_PARENS TYPE_ARGS, \
1477 cudaSurfaceBoundaryMode boundaryMode) \
1479 FUNC_NAME<ushort>(__half_as_ushort(data), surfObj, SLANG_DROP_PARENS ARGS, boundaryMode); \
1483 SLANG_FORCE_INLINE SLANG_CUDA_CALL void FUNC_NAME<__half2>( \
1485 cudaSurfaceObject_t surfObj, \
1486 SLANG_DROP_PARENS TYPE_ARGS, \
1487 cudaSurfaceBoundaryMode boundaryMode) \
1489 FUNC_NAME<ushort2>(__half_as_ushort(data), surfObj, SLANG_DROP_PARENS ARGS, boundaryMode); \
1493 SLANG_FORCE_INLINE SLANG_CUDA_CALL void FUNC_NAME<__half4>( \
1495 cudaSurfaceObject_t surfObj, \
1496 SLANG_DROP_PARENS TYPE_ARGS, \
1497 cudaSurfaceBoundaryMode boundaryMode) \
1499 FUNC_NAME<ushort4>(__half_as_ushort(data), surfObj, SLANG_DROP_PARENS ARGS, boundaryMode); \
1502SLANG_SURFACE_WRITE(surf1Dwrite, (
int x), (x))
1503SLANG_SURFACE_WRITE(surf2Dwrite, (
int x,
int y), (x, y))
1504SLANG_SURFACE_WRITE(surf3Dwrite, (
int x,
int y,
int z), (x, y, z))
1505SLANG_SURFACE_WRITE(surf1DLayeredwrite, (
int x,
int layer), (x, layer))
1506SLANG_SURFACE_WRITE(surf2DLayeredwrite, (
int x,
int y,
int layer), (x, y, layer))
1507SLANG_SURFACE_WRITE(surfCubemapwrite, (
int x,
int y,
int face), (x, y, face))
1508SLANG_SURFACE_WRITE(surfCubemapLayeredwrite, (
int x,
int y,
int layerFace), (x, y, layerFace))
1517#define SLANG_SURFACE_READ_HALF_CONVERT(FUNC_NAME, TYPE_ARGS, ARGS) \
1519 template<typename T> \
1520 SLANG_FORCE_INLINE SLANG_CUDA_CALL T FUNC_NAME##_convert( \
1521 cudaSurfaceObject_t surfObj, \
1522 SLANG_DROP_PARENS TYPE_ARGS, \
1523 cudaSurfaceBoundaryMode boundaryMode); \
1526 SLANG_FORCE_INLINE SLANG_CUDA_CALL float FUNC_NAME##_convert<float>( \
1527 cudaSurfaceObject_t surfObj, \
1528 SLANG_DROP_PARENS TYPE_ARGS, \
1529 cudaSurfaceBoundaryMode boundaryMode) \
1531 return __ushort_as_half( \
1532 FUNC_NAME<uint16_t>(surfObj, SLANG_DROP_PARENS ARGS, boundaryMode)); \
1536 SLANG_FORCE_INLINE SLANG_CUDA_CALL float2 FUNC_NAME##_convert<float2>( \
1537 cudaSurfaceObject_t surfObj, \
1538 SLANG_DROP_PARENS TYPE_ARGS, \
1539 cudaSurfaceBoundaryMode boundaryMode) \
1542 __ushort_as_half(FUNC_NAME<ushort2>(surfObj, SLANG_DROP_PARENS ARGS, boundaryMode)); \
1543 return float2{v.x, v.y}; \
1547 SLANG_FORCE_INLINE SLANG_CUDA_CALL float4 FUNC_NAME##_convert<float4>( \
1548 cudaSurfaceObject_t surfObj, \
1549 SLANG_DROP_PARENS TYPE_ARGS, \
1550 cudaSurfaceBoundaryMode boundaryMode) \
1553 __ushort_as_half(FUNC_NAME<ushort4>(surfObj, SLANG_DROP_PARENS ARGS, boundaryMode)); \
1554 return float4{v.x, v.y, v.z, v.w}; \
1557SLANG_SURFACE_READ_HALF_CONVERT(surf1Dread, (
int x), (x))
1558SLANG_SURFACE_READ_HALF_CONVERT(surf2Dread, (
int x,
int y), (x, y))
1559SLANG_SURFACE_READ_HALF_CONVERT(surf3Dread, (
int x,
int y,
int z), (x, y, z))
1578 cudaSurfaceObject_t surfObj,
1580 cudaSurfaceBoundaryMode boundaryMode);
1582#define SLANG_SURF1DWRITE_CONVERT_IMPL(T, c) \
1584 SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf1Dwrite_convert<T>( \
1586 cudaSurfaceObject_t surfObj, \
1588 cudaSurfaceBoundaryMode boundaryMode) \
1591 "sust.p.1d.b32." SLANG_PTX_BOUNDARY_MODE " [%0, {%1}], {%2};" ::"l"(surfObj), \
1596 SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf1Dwrite_convert<T##2>( \
1598 cudaSurfaceObject_t surfObj, \
1600 cudaSurfaceBoundaryMode boundaryMode) \
1602 const T vx = v.x, vy = v.y; \
1604 "sust.p.1d.v2.b32." SLANG_PTX_BOUNDARY_MODE " [%0, {%1}], {%2, %3};" ::"l"(surfObj), \
1610 SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf1Dwrite_convert<T##4>( \
1612 cudaSurfaceObject_t surfObj, \
1614 cudaSurfaceBoundaryMode boundaryMode) \
1616 const T vx = v.x, vy = v.y, vz = v.z, vw = v.w; \
1618 "sust.p.1d.v4.b32." SLANG_PTX_BOUNDARY_MODE \
1619 " [%0, {%1}], {%2, %3, %4, %5};" ::"l"(surfObj), \
1636 cudaSurfaceObject_t surfObj,
1639 cudaSurfaceBoundaryMode boundaryMode)
1651 cudaSurfaceObject_t surfObj,
1654 cudaSurfaceBoundaryMode boundaryMode);
1656#define SLANG_SURF2DWRITE_CONVERT_IMPL(T, c) \
1658 SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf2Dwrite_convert<T>( \
1660 cudaSurfaceObject_t surfObj, \
1663 cudaSurfaceBoundaryMode boundaryMode) \
1666 "sust.p.2d.b32." SLANG_PTX_BOUNDARY_MODE " [%0, {%1, %2}], {%3};" ::"l"(surfObj), \
1672 SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf2Dwrite_convert<T##2>( \
1674 cudaSurfaceObject_t surfObj, \
1677 cudaSurfaceBoundaryMode boundaryMode) \
1679 const T vx = v.x, vy = v.y; \
1681 "sust.p.2d.v2.b32." SLANG_PTX_BOUNDARY_MODE \
1682 " [%0, {%1, %2}], {%3, %4};" ::"l"(surfObj), \
1689 SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf2Dwrite_convert<T##4>( \
1691 cudaSurfaceObject_t surfObj, \
1694 cudaSurfaceBoundaryMode boundaryMode) \
1696 const T vx = v.x, vy = v.y, vz = v.z, vw = v.w; \
1698 "sust.p.2d.v4.b32." SLANG_PTX_BOUNDARY_MODE \
1699 " [%0, {%1, %2}], {%3, %4, %5, %6};" ::"l"(surfObj), \
1717 cudaSurfaceObject_t surfObj,
1721 cudaSurfaceBoundaryMode boundaryMode)
1733 cudaSurfaceObject_t surfObj,
1737 cudaSurfaceBoundaryMode boundaryMode);
1739#define SLANG_SURF3DWRITE_CONVERT_IMPL(T, c) \
1741 SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf3Dwrite_convert<T>( \
1743 cudaSurfaceObject_t surfObj, \
1747 cudaSurfaceBoundaryMode boundaryMode) \
1750 "sust.p.3d.b32." SLANG_PTX_BOUNDARY_MODE \
1751 " [%0, {%1, %2, %3, %4}], {%5};" ::"l"(surfObj), \
1759 SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf3Dwrite_convert<T##2>( \
1761 cudaSurfaceObject_t surfObj, \
1765 cudaSurfaceBoundaryMode boundaryMode) \
1767 const T vx = v.x, vy = v.y; \
1769 "sust.p.3d.v2.b32." SLANG_PTX_BOUNDARY_MODE \
1770 " [%0, {%1, %2, %3, %4}], {%5, %6};" ::"l"(surfObj), \
1779 SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf3Dwrite_convert<T##4>( \
1781 cudaSurfaceObject_t surfObj, \
1785 cudaSurfaceBoundaryMode boundaryMode) \
1787 const T vx = v.x, vy = v.y, vz = v.z, vw = v.w; \
1789 "sust.p.3d.v4.b32." SLANG_PTX_BOUNDARY_MODE \
1790 " [%0, {%1, %2, %3, %4}], {%5, %6, %7, %8};" ::"l"(surfObj), \
1806#if SLANG_CUDA_ENABLE_HALF
1841 return __float2half(::tanf(__half2float(f)));
1845 return __float2half(::asinf(__half2float(f)));
1849 return __float2half(::acosf(__half2float(f)));
1853 return __float2half(::atanf(__half2float(f)));
1857 return __float2half(::sinhf(__half2float(f)));
1861 return __float2half(::coshf(__half2float(f)));
1865 return __float2half(::tanhf(__half2float(f)));
1869 return __float2half(::asinhf(__half2float(f)));
1873 return __float2half(::acoshf(__half2float(f)));
1877 return __float2half(::atanhf(__half2float(f)));
1917 return (f == __half(0.0f)) ? 0 : ((f < __half(0.0f)) ? -1 : 1);
1931 return !__hisinf(f) && !__hisnan(f);
1941 return __hmin(a, b);
1945 return __hmax(a, b);
1949 return __float2half(::powf(__half2float(a), __half2float(b)));
1953 return __float2half(::fmodf(__half2float(a), __half2float(b)));
1957 return __float2half(::remainderf(__half2float(a), __half2float(b)));
1961 return __float2half(::atan2(__half2float(a), __half2float(b)));
1966 return __float2half(frexpf(__half2float(x), e));
1972 float res = ::modff(__half2float(x), &ipf);
1973 *ip = __float2half(ipf);
1974 return __float2half(res);
1979 return __half_as_ushort(h);
1983 return __half_as_short(h);
1989 return __hfma(a, b, c);
2099 return (f == 0.0f) ? 0 : ((f < 0.0f) ? -1 : 1);
2122 return ::fminf(a, b);
2126 return ::fmaxf(a, b);
2130 return ::powf(a, b);
2134 return ::fmodf(a, b);
2138 return ::remainderf(a, b);
2142 return float(::atan2(a, b));
2147 return frexpf(x, e);
2152 return ::modff(x, ip);
2171 return ::fmaf(a, b, c);
2268 return (f == 0.0) ? 0 : ((f < 0.0) ? -1 : 1);
2291 return ::fmin(a, b);
2295 return ::fmax(a, b);
2303 return ::fmod(a, b);
2307 return ::remainder(a, b);
2311 return ::atan2(a, b);
2316 return ::frexp(x, e);
2321 return ::modf(x, ip);
2328 *low = uint32_t(u.
u);
2329 *hi = uint32_t(u.
u >> 32);
2336 *low = int32_t(u.
u);
2337 *hi = int32_t(u.
u >> 32);
2343 return ::fma(a, b, c);
2351 return __popc(uint32_t(v));
2366 return __popc(uint32_t(v));
2387 return a < b ? a : b;
2391 return a > b ? a : b;
2408 u.
u = (uint64_t(hi) << 32) | low;
2421 return v == 0 ? ~0u : (__ffs(v) - 1);
2431 return 31 - __clz(v);
2444 return (f < 0) ? -f : f;
2450 return a < b ? a : b;
2454 return a > b ? a : b;
2470 u.
u = (uint64_t(hi) << 32) | uint32_t(low);
2503 return a < b ? a : b;
2507 return a > b ? a : b;
2519 return v == 0 ? ~uint32_t(0) : (__ffsll(v) - 1u);
2525 return ~uint32_t(0);
2526 return 63 - __clzll(v);
2538 return (f < 0) ? -f : f;
2543 return a < b ? a : b;
2547 return a > b ? a : b;
2576 return (f < 0) ? -f : f;
2581 return a < b ? a : b;
2586 return a > b ? a : b;
2598 return a < b ? a : b;
2603 return a > b ? a : b;
2617#ifndef SLANG_CUDA_STRUCTURED_BUFFER_NO_COUNT
2625#ifndef SLANG_CUDA_STRUCTURED_BUFFER_NO_COUNT
2631#ifndef SLANG_CUDA_STRUCTURED_BUFFER_NO_COUNT
2634 *outNumStructs = uint32_t(
count);
2635 *outStride = uint32_t(
sizeof(T));
2640#ifndef SLANG_CUDA_STRUCTURED_BUFFER_NO_COUNT
2650#ifndef SLANG_CUDA_STRUCTURED_BUFFER_NO_COUNT
2653 return this->
data[index];
2664 return data[index >> 2];
2669 const size_t dataIdx = index >> 2;
2675 const size_t dataIdx = index >> 2;
2681 const size_t dataIdx = index >> 2;
2684 template<
typename T>
2689 memcpy(&
data, ((
const char*)this->data) + index,
sizeof(T));
2692 template<
typename T>
2700 const uint32_t*
data;
2718 (
unsigned long long*)address,
2719 (
unsigned long long)compare,
2720 (
unsigned long long)val);
2725 return (
longlong)
atomicAdd((
unsigned long long*)address, (
unsigned long long)val);
2730__device__ __forceinline__
float atomicCAS(
float* address,
float compare,
float val)
2732 int* addr_as_int = (
int*)address;
2733 int old =
atomicCAS(addr_as_int, __float_as_int(compare), __float_as_int(val));
2734 return __int_as_float(old);
2753 asm volatile(
"red.relaxed.gpu.global.add.s32 [%0], %1;" : :
"l"(addr),
"r"(val) :
"memory");
2758 asm volatile(
"red.relaxed.gpu.global.add.u32 [%0], %1;" : :
"l"(addr),
"r"(val) :
"memory");
2763 asm volatile(
"red.relaxed.gpu.global.add.s64 [%0], %1;" : :
"l"(addr),
"l"(val) :
"memory");
2768 asm volatile(
"red.relaxed.gpu.global.add.u64 [%0], %1;" : :
"l"(addr),
"l"(val) :
"memory");
2773 asm volatile(
"red.relaxed.gpu.global.add.f32 [%0], %1;" : :
"l"(addr),
"f"(val) :
"memory");
2778 asm volatile(
"red.relaxed.gpu.global.add.f64 [%0], %1;" : :
"l"(addr),
"d"(val) :
"memory");
2781#if SLANG_CUDA_ENABLE_HALF
2784 unsigned short val_as_ushort = *
reinterpret_cast<unsigned short*
>(&val);
2785 asm volatile(
"red.relaxed.gpu.global.add.noftz.f16 [%0], %1;"
2787 :
"l"(addr),
"h"(val_as_ushort)
2793 unsigned int val_as_uint = *
reinterpret_cast<unsigned int*
>(&val);
2794 asm volatile(
"red.relaxed.gpu.global.add.noftz.f16x2 [%0], %1;"
2796 :
"l"(addr),
"r"(val_as_uint)
2801#if SLANG_CUDA_ENABLE_BF16
2803 __nv_bfloat16* addr,
2807 unsigned short val_as_ushort = *
reinterpret_cast<unsigned short*
>(&val);
2808 asm volatile(
"red.relaxed.gpu.global.add.noftz.bf16 [%0], %1;"
2810 :
"l"(addr),
"h"(val_as_ushort)
2815 __nv_bfloat162* addr,
2819 unsigned int val_as_uint = *
reinterpret_cast<unsigned int*
>(&val);
2820 asm volatile(
"red.relaxed.gpu.global.add.noftz.bf16x2 [%0], %1;"
2822 :
"l"(addr),
"r"(val_as_uint)
2830 asm volatile(
"red.relaxed.gpu.global.min.s32 [%0], %1;" : :
"l"(addr),
"r"(val) :
"memory");
2835 asm volatile(
"red.relaxed.gpu.global.min.u32 [%0], %1;" : :
"l"(addr),
"r"(val) :
"memory");
2840 asm volatile(
"red.relaxed.gpu.global.min.s64 [%0], %1;" : :
"l"(addr),
"l"(val) :
"memory");
2845 asm volatile(
"red.relaxed.gpu.global.min.u64 [%0], %1;" : :
"l"(addr),
"l"(val) :
"memory");
2856 asm volatile(
"red.relaxed.gpu.global.max.s32 [%0], %1;" : :
"l"(addr),
"r"(val) :
"memory");
2861 asm volatile(
"red.relaxed.gpu.global.max.u32 [%0], %1;" : :
"l"(addr),
"r"(val) :
"memory");
2866 asm volatile(
"red.relaxed.gpu.global.max.s64 [%0], %1;" : :
"l"(addr),
"l"(val) :
"memory");
2871 asm volatile(
"red.relaxed.gpu.global.max.u64 [%0], %1;" : :
"l"(addr),
"l"(val) :
"memory");
2882 asm volatile(
"red.relaxed.gpu.global.and.b32 [%0], %1;" : :
"l"(addr),
"r"(val) :
"memory");
2887 asm volatile(
"red.relaxed.gpu.global.and.b32 [%0], %1;" : :
"l"(addr),
"r"(val) :
"memory");
2892 asm volatile(
"red.relaxed.gpu.global.and.b64 [%0], %1;" : :
"l"(addr),
"l"(val) :
"memory");
2897 asm volatile(
"red.relaxed.gpu.global.and.b64 [%0], %1;" : :
"l"(addr),
"l"(val) :
"memory");
2903 asm volatile(
"red.relaxed.gpu.global.or.b32 [%0], %1;" : :
"l"(addr),
"r"(val) :
"memory");
2908 asm volatile(
"red.relaxed.gpu.global.or.b32 [%0], %1;" : :
"l"(addr),
"r"(val) :
"memory");
2913 asm volatile(
"red.relaxed.gpu.global.or.b64 [%0], %1;" : :
"l"(addr),
"l"(val) :
"memory");
2918 asm volatile(
"red.relaxed.gpu.global.or.b64 [%0], %1;" : :
"l"(addr),
"l"(val) :
"memory");
2924 asm volatile(
"red.relaxed.gpu.global.xor.b32 [%0], %1;" : :
"l"(addr),
"r"(val) :
"memory");
2929 asm volatile(
"red.relaxed.gpu.global.xor.b32 [%0], %1;" : :
"l"(addr),
"r"(val) :
"memory");
2934 asm volatile(
"red.relaxed.gpu.global.xor.b64 [%0], %1;" : :
"l"(addr),
"l"(val) :
"memory");
2939 asm volatile(
"red.relaxed.gpu.global.xor.b64 [%0], %1;" : :
"l"(addr),
"l"(val) :
"memory");
2949 asm volatile(
"red.relaxed.gpu.global.add.u32 [%0], 1;" : :
"l"(addr) :
"memory");
2954 asm volatile(
"red.relaxed.gpu.global.add.s32 [%0], 1;" : :
"l"(addr) :
"memory");
2959 asm volatile(
"red.relaxed.gpu.global.add.u32 [%0], -1;" : :
"l"(addr) :
"memory");
2964 asm volatile(
"red.relaxed.gpu.global.add.s32 [%0], -1;" : :
"l"(addr) :
"memory");
2979 return data[index >> 2];
2984 const size_t dataIdx = index >> 2;
2990 const size_t dataIdx = index >> 2;
2996 const size_t dataIdx = index >> 2;
2999 template<
typename T>
3004 memcpy(&
data, ((
const char*)this->data) + index,
sizeof(T));
3011 data[index >> 2] = v;
3016 const size_t dataIdx = index >> 2;
3017 data[dataIdx + 0] = v.x;
3018 data[dataIdx + 1] = v.y;
3023 const size_t dataIdx = index >> 2;
3024 data[dataIdx + 0] = v.x;
3025 data[dataIdx + 1] = v.y;
3026 data[dataIdx + 2] = v.z;
3031 const size_t dataIdx = index >> 2;
3032 data[dataIdx + 0] = v.x;
3033 data[dataIdx + 1] = v.y;
3034 data[dataIdx + 2] = v.z;
3035 data[dataIdx + 3] = v.w;
3037 template<
typename T>
3041 memcpy((
char*)
data + index, &value,
sizeof(T));
3045 template<
typename T>
3049 return (T*)(((
char*)
data) + index);
3051 template<
typename T>
3075#ifndef SLANG_USE_ASM_LANE_ID
3078 return ((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x) &
3082__forceinline__ __device__ uint32_t
_getLaneId()
3089 asm volatile(
"mov.u32 %0, %laneid;" :
"=r"(ret));
3128 return __ballot_sync(__activemask(),
true);
3141 return (mask & (mask - 1)) == 0;
3159 if ((mask & (mask + 1)) == 0)
3162 const int offset = 32 - __clz(mask);
3164 if ((offset & (offset - 1)) == 0)
3174 const WarpMask mask = __activemask();
3181 return (mask & 1) || ((__ffs(mask) - 1) ==
_getLaneId());
3188 __inline__ __device__
static T
doOp(T a, T b) {
return a | b; }
3195 __inline__ __device__
static T
doOp(T a, T b) {
return a & b; }
3202 __inline__ __device__
static T
doOp(T a, T b) {
return a ^ b; }
3203 __inline__ __device__
static T
doInverse(T a, T b) {
return a ^ b; }
3210 __inline__ __device__
static T
doOp(T a, T b) {
return a + b; }
3211 __inline__ __device__
static T
doInverse(T a, T b) {
return a - b; }
3218 __inline__ __device__
static T
doOp(T a, T b) {
return a * b; }
3222 __inline__ __device__
static T
doInverse(T a, T b) {
return a / b; }
3228 __inline__ __device__
static T
getInitial(T a,
bool exclusive =
false);
3229 __inline__ __device__
static T
doOp(T a, T b) {
return a > b ? a : b; }
3235 __inline__ __device__
static T
getInitial(T a,
bool exclusive =
false);
3236 __inline__ __device__
static T
doOp(T a, T b) {
return a < b ? a : b; }
3240#define SLANG_WAVE_MIN_SPEC(T, EXCL_VAL) \
3242 __inline__ __device__ T WaveOpMin<T>::getInitial(T a, bool exclusive) \
3244 return exclusive ? (EXCL_VAL) : a; \
3247#define SLANG_WAVE_MAX_SPEC(T, EXCL_VAL) \
3249 __inline__ __device__ T WaveOpMax<T>::getInitial(T a, bool exclusive) \
3251 return exclusive ? (EXCL_VAL) : a; \
3266#if SLANG_CUDA_ENABLE_HALF
3282#if SLANG_CUDA_ENABLE_HALF
3286#undef SLANG_WAVE_MIN_SPEC
3287#undef SLANG_WAVE_MAX_SPEC
3343#if SLANG_CUDA_ENABLE_HALF
3347 typedef __half Type;
3527#if SLANG_CUDA_ENABLE_HALF
3531 typedef __half Type;
3536 typedef __half
Type;
3541 typedef __half
Type;
3546template<
typename T,
int ROWS,
int COLS>
3553template<
typename INTF,
typename T>
3560 for (
int offset = offsetSize >> 1; offset > 0; offset >>= 1)
3562 val = INTF::doOp(val, __shfl_xor_sync(mask, val, offset));
3567 T result = INTF::getInitial(val);
3568 int remaining = mask;
3571 const int laneBit = remaining & -remaining;
3573 const int srcLane = __ffs(laneBit) - 1;
3575 result = INTF::doOp(result, __shfl_sync(mask, val, srcLane));
3576 remaining &= ~laneBit;
3585template<
typename INTF,
typename T,
size_t COUNT>
3592 for (
int offset = offsetSize >> 1; offset > 0; offset >>= 1)
3594 for (
size_t i = 0; i < COUNT; ++i)
3596 val[i] = INTF::doOp(val[i], __shfl_xor_sync(mask, val[i], offset));
3603 T originalVal[COUNT];
3604 for (
size_t i = 0; i < COUNT; ++i)
3608 val[i] = INTF::getInitial(v);
3611 int remaining = mask;
3614 const int laneBit = remaining & -remaining;
3616 const int srcLane = __ffs(laneBit) - 1;
3618 for (
size_t i = 0; i < COUNT; ++i)
3620 val[i] = INTF::doOp(val[i], __shfl_sync(mask, originalVal[i], srcLane));
3622 remaining &= ~laneBit;
3627template<
typename INTF,
typename T>
3637 return _waveReduceScalar<WaveOpOr<T>, T>(mask, val);
3643 return _waveReduceScalar<WaveOpAnd<T>, T>(mask, val);
3649 return _waveReduceScalar<WaveOpXor<T>, T>(mask, val);
3655 return _waveReduceScalar<WaveOpMul<T>, T>(mask, val);
3661 return _waveReduceScalar<WaveOpAdd<T>, T>(mask, val);
3667 return _waveReduceScalar<WaveOpMin<T>, T>(mask, val);
3673 return _waveReduceScalar<WaveOpMax<T>, T>(mask, val);
3677#if __CUDA_ARCH__ >= 800
3679__inline__ __device__
unsigned _waveOr<unsigned>(
WarpMask mask,
unsigned val)
3681 return __reduce_or_sync(mask, val);
3685__inline__ __device__
unsigned _waveAnd<unsigned>(
WarpMask mask,
unsigned val)
3687 return __reduce_and_sync(mask, val);
3691__inline__ __device__
unsigned _waveXor<unsigned>(
WarpMask mask,
unsigned val)
3693 return __reduce_xor_sync(mask, val);
3697__inline__ __device__
unsigned _waveSum<unsigned>(
WarpMask mask,
unsigned val)
3699 return __reduce_add_sync(mask, val);
3703__inline__ __device__
int _waveSum<int>(
WarpMask mask,
int val)
3705 return __reduce_add_sync(mask, val);
3709__inline__ __device__
unsigned _waveMin<unsigned>(
WarpMask mask,
unsigned val)
3711 return __reduce_min_sync(mask, val);
3715__inline__ __device__
int _waveMin<int>(
WarpMask mask,
int val)
3717 return __reduce_min_sync(mask, val);
3721__inline__ __device__
unsigned _waveMax<unsigned>(
WarpMask mask,
unsigned val)
3723 return __reduce_max_sync(mask, val);
3727__inline__ __device__
int _waveMax<int>(
WarpMask mask,
int val)
3729 return __reduce_max_sync(mask, val);
3739 _waveReduceMultiple<WaveOpOr<ElemType>>(mask, &val);
3747 _waveReduceMultiple<WaveOpAnd<ElemType>>(mask, &val);
3755 _waveReduceMultiple<WaveOpXor<ElemType>>(mask, &val);
3763 _waveReduceMultiple<WaveOpMul<ElemType>>(mask, &val);
3771 _waveReduceMultiple<WaveOpAdd<ElemType>>(mask, &val);
3779 _waveReduceMultiple<WaveOpMin<ElemType>>(mask, &val);
3787 _waveReduceMultiple<WaveOpMax<ElemType>>(mask, &val);
3796 __match_all_sync(mask, val, &pred);
3804 const size_t count =
sizeof(T) /
sizeof(ElemType);
3806 const ElemType* src = (
const ElemType*)&inVal;
3807 for (
size_t i = 0; i < count; ++i)
3809 __match_all_sync(mask, src[i], &pred);
3821 const int lowestLaneId = __ffs(mask) - 1;
3822 return __shfl_sync(mask, val, lowestLaneId);
3829 const size_t count =
sizeof(T) /
sizeof(ElemType);
3831 const ElemType* src = (
const ElemType*)&inVal;
3832 ElemType* dst = (ElemType*)&outVal;
3833 const int lowestLaneId = __ffs(mask) - 1;
3834 for (
size_t i = 0; i < count; ++i)
3836 dst[i] = __shfl_sync(mask, src[i], lowestLaneId);
3845 const size_t count =
sizeof(T) /
sizeof(ElemType);
3847 const ElemType* src = (
const ElemType*)&inVal;
3848 ElemType* dst = (ElemType*)&outVal;
3849 for (
size_t i = 0; i < count; ++i)
3851 dst[i] = __shfl_sync(mask, src[i], lane);
3860template<
typename INTF,
typename T>
3871 for (
int i = 1; i < offsetSize; i += i)
3873 const T readVal = __shfl_up_sync(mask, result, i, offsetSize);
3876 result = INTF::doOp(result, readVal);
3880 result = INTF::doInverse(result, val);
3884 result = INTF::getInitial(val);
3887 int remaining = mask;
3890 const int laneBit = remaining & -remaining;
3892 const int srcLane = __ffs(laneBit) - 1;
3894 const T readValue = __shfl_sync(mask, val, srcLane);
3896 if (srcLane < laneId)
3898 result = INTF::doOp(result, readValue);
3900 remaining &= ~laneBit;
3910template<
typename INTF,
typename T>
3916 T result = INTF::getInitial(val);
3923 for (
int i = 1; i < offsetSize; i += i)
3925 const T readVal = __shfl_up_sync(mask, val, i, offsetSize);
3928 result = INTF::doOp(result, readVal);
3929 val = INTF::doOp(val, readVal);
3937 int remaining = mask;
3940 const int laneBit = remaining & -remaining;
3942 const int srcLane = __ffs(laneBit) - 1;
3944 const T readValue = __shfl_sync(mask, val, srcLane);
3946 if (srcLane < laneId)
3948 result = INTF::doOp(result, readValue);
3950 remaining &= ~laneBit;
3958template<
typename INTF,
typename T,
size_t COUNT>
3961 for (
size_t j = 0; j < COUNT; ++j)
3968template<
typename INTF,
typename T,
size_t COUNT>
3971 for (
size_t j = 0; j < COUNT; ++j)
3973 inOut[j] = INTF::doInverse(inOut[j], val[j]);
3977template<
typename INTF,
typename T,
size_t COUNT>
3980 for (
size_t j = 0; j < COUNT; ++j)
3982 out[j] = INTF::getInitial(val[j]);
3986template<
typename INTF,
typename T,
size_t COUNT>
3992 T originalVal[COUNT];
3993 _waveOpCopy<INTF, T, COUNT>(originalVal, val);
3998 for (
int i = 1; i < offsetSize; i += i)
4004 for (
size_t j = 0; j < COUNT; ++j)
4006 const T readVal = __shfl_up_sync(mask, val[j], i, offsetSize);
4009 val[j] = INTF::doOp(val[j], readVal);
4014 _waveOpDoInverse<INTF, T, COUNT>(val, originalVal);
4018 _waveOpSetInitial<INTF, T, COUNT>(val, val);
4021 int remaining = mask;
4024 const int laneBit = remaining & -remaining;
4026 const int srcLane = __ffs(laneBit) - 1;
4028 for (
size_t j = 0; j < COUNT; ++j)
4031 const T readValue = __shfl_sync(mask, originalVal[j], srcLane);
4033 if (srcLane < laneId)
4035 val[j] = INTF::doOp(val[j], readValue);
4037 remaining &= ~laneBit;
4044template<
typename INTF,
typename T,
size_t COUNT>
4052 _waveOpCopy<INTF, T, COUNT>(work, val);
4053 _waveOpSetInitial<INTF, T, COUNT>(val, val);
4061 for (
int i = 1; i < offsetSize; i += i)
4063 for (
size_t j = 0; j < COUNT; ++j)
4065 const T readVal = __shfl_up_sync(mask, work[j], i, offsetSize);
4068 work[j] = INTF::doOp(work[j], readVal);
4069 val[j] = INTF::doOp(val[j], readVal);
4078 int remaining = mask;
4081 const int laneBit = remaining & -remaining;
4083 const int srcLane = __ffs(laneBit) - 1;
4085 for (
size_t j = 0; j < COUNT; ++j)
4088 const T readValue = __shfl_sync(mask, work[j], srcLane);
4090 if (srcLane < laneId)
4092 val[j] = INTF::doOp(val[j], readValue);
4095 remaining &= ~laneBit;
4104 return _wavePrefixScalar<WaveOpMul<T>, T>(mask, val);
4110 return _wavePrefixInvertableScalar<WaveOpAdd<T>, T>(mask, val);
4116 return _wavePrefixInvertableScalar<WaveOpXor<T>, T>(mask, val);
4122 return _wavePrefixScalar<WaveOpOr<T>, T>(mask, val);
4128 return _wavePrefixScalar<WaveOpAnd<T>, T>(mask, val);
4136 _wavePrefixInvertableMultiple<WaveOpMul<ElemType>, ElemType,
sizeof(T) /
sizeof(ElemType)>(
4146 _wavePrefixInvertableMultiple<WaveOpAdd<ElemType>, ElemType,
sizeof(T) /
sizeof(ElemType)>(
4156 _wavePrefixInvertableMultiple<WaveOpXor<ElemType>, ElemType,
sizeof(T) /
sizeof(ElemType)>(
4166 _wavePrefixMultiple<WaveOpOr<ElemType>, ElemType,
sizeof(T) /
sizeof(ElemType)>(
4176 _wavePrefixMultiple<WaveOpAnd<ElemType>, ElemType,
sizeof(T) /
sizeof(ElemType)>(
4185 return _wavePrefixScalar<WaveOpMin<T>, T>(mask, val);
4191 return _wavePrefixScalar<WaveOpMax<T>, T>(mask, val);
4198 _wavePrefixMultiple<WaveOpMin<ElemType>, ElemType,
sizeof(T) /
sizeof(ElemType)>(
4208 _wavePrefixMultiple<WaveOpMax<ElemType>, ElemType,
sizeof(T) /
sizeof(ElemType)>(
4258 return _wavePrefixScalar<WaveOpExclusiveMin<T>, T>(mask, val);
4264 return _wavePrefixScalar<WaveOpExclusiveMax<T>, T>(mask, val);
4271 _wavePrefixMultiple<WaveOpExclusiveMin<ElemType>, ElemType,
sizeof(T) /
sizeof(ElemType)>(
4281 _wavePrefixMultiple<WaveOpExclusiveMax<ElemType>, ElemType,
sizeof(T) /
sizeof(ElemType)>(
4291 return make_uint4(__match_all_sync(mask, val, &pred), 0, 0, 0);
4298 const size_t count =
sizeof(T) /
sizeof(ElemType);
4300 const ElemType* src = (
const ElemType*)&inVal;
4301 uint matchBits = 0xffffffff;
4302 for (
size_t i = 0; i < count && matchBits; ++i)
4304 matchBits = matchBits & __match_all_sync(mask, src[i], &pred);
4306 return make_uint4(matchBits, 0, 0, 0);
4323template<
typename TResult,
typename TInput>
4326 return *(TResult*)(&val);
4336struct UniformEntryPointParams;
4340#ifdef SLANG_CUDA_ENABLE_OPTIX
4350static __forceinline__ __device__
void* unpackOptiXRayPayloadPointer(uint32_t i0, uint32_t i1)
4352 const uint64_t uptr =
static_cast<uint64_t
>(i0) << 32 | i1;
4353 void* ptr =
reinterpret_cast<void*
>(uptr);
4357static __forceinline__ __device__
void packOptiXRayPayloadPointer(
4362 const uint64_t uptr =
reinterpret_cast<uint64_t
>(ptr);
4364 i1 = uptr & 0x00000000ffffffff;
4367static __forceinline__ __device__
void* getOptiXRayPayloadPtr()
4369 const uint32_t u0 = optixGetPayload_0();
4370 const uint32_t u1 = optixGetPayload_1();
4371 return unpackOptiXRayPayloadPointer(u0, u1);
4375static constexpr size_t kMaxOptiXPayloadRegisters = 32;
4378template<
typename T,
size_t N = (
sizeof(T) + 3) / 4>
4379struct PayloadRegisters
4381 uint32_t regs[N > 0 ? N : 1];
4383 __forceinline__ __device__
void pack(
const T& payload) { memcpy(regs, &payload,
sizeof(T)); }
4385 __forceinline__ __device__
void unpack(T& payload) { memcpy(&payload, regs,
sizeof(T)); }
4389template<
typename T,
size_t N = (
sizeof(T) + 3) / 4>
4390__forceinline__ __device__
void optixTraceWithRegs(
4397 uint32_t InstanceInclusionMask,
4399 uint32_t RayContributionToHitGroupIndex,
4400 uint32_t MultiplierForGeometryContributionToHitGroupIndex,
4401 uint32_t MissShaderIndex,
4402 PayloadRegisters<T, N>& pr)
4405 if constexpr (N == 0)
4408 AccelerationStructure,
4414 InstanceInclusionMask,
4416 RayContributionToHitGroupIndex,
4417 MultiplierForGeometryContributionToHitGroupIndex,
4420 else if constexpr (N == 1)
4423 AccelerationStructure,
4429 InstanceInclusionMask,
4431 RayContributionToHitGroupIndex,
4432 MultiplierForGeometryContributionToHitGroupIndex,
4436 else if constexpr (N == 2)
4439 AccelerationStructure,
4445 InstanceInclusionMask,
4447 RayContributionToHitGroupIndex,
4448 MultiplierForGeometryContributionToHitGroupIndex,
4453 else if constexpr (N == 3)
4456 AccelerationStructure,
4462 InstanceInclusionMask,
4464 RayContributionToHitGroupIndex,
4465 MultiplierForGeometryContributionToHitGroupIndex,
4471 else if constexpr (N == 4)
4474 AccelerationStructure,
4480 InstanceInclusionMask,
4482 RayContributionToHitGroupIndex,
4483 MultiplierForGeometryContributionToHitGroupIndex,
4490 else if constexpr (N == 5)
4493 AccelerationStructure,
4499 InstanceInclusionMask,
4501 RayContributionToHitGroupIndex,
4502 MultiplierForGeometryContributionToHitGroupIndex,
4510 else if constexpr (N == 6)
4513 AccelerationStructure,
4519 InstanceInclusionMask,
4521 RayContributionToHitGroupIndex,
4522 MultiplierForGeometryContributionToHitGroupIndex,
4531 else if constexpr (N == 7)
4534 AccelerationStructure,
4540 InstanceInclusionMask,
4542 RayContributionToHitGroupIndex,
4543 MultiplierForGeometryContributionToHitGroupIndex,
4553 else if constexpr (N == 8)
4556 AccelerationStructure,
4562 InstanceInclusionMask,
4564 RayContributionToHitGroupIndex,
4565 MultiplierForGeometryContributionToHitGroupIndex,
4576 else if constexpr (N <= 16)
4579 AccelerationStructure,
4585 InstanceInclusionMask,
4587 RayContributionToHitGroupIndex,
4588 MultiplierForGeometryContributionToHitGroupIndex,
4607 else if constexpr (N <= kMaxOptiXPayloadRegisters)
4610 AccelerationStructure,
4616 InstanceInclusionMask,
4618 RayContributionToHitGroupIndex,
4619 MultiplierForGeometryContributionToHitGroupIndex,
4657__forceinline__ __device__
void optixTrace(
4660 uint32_t InstanceInclusionMask,
4661 uint32_t RayContributionToHitGroupIndex,
4662 uint32_t MultiplierForGeometryContributionToHitGroupIndex,
4663 uint32_t MissShaderIndex,
4667 constexpr size_t numRegs = (
sizeof(T) + 3) / 4;
4669 if constexpr (numRegs <= kMaxOptiXPayloadRegisters)
4672 PayloadRegisters<T> pr;
4675 optixTraceWithRegs<T>(
4676 AccelerationStructure,
4682 InstanceInclusionMask,
4684 RayContributionToHitGroupIndex,
4685 MultiplierForGeometryContributionToHitGroupIndex,
4691 pr.unpack(*Payload);
4697 packOptiXRayPayloadPointer((
void*)Payload, r0, r1);
4699 AccelerationStructure,
4705 InstanceInclusionMask,
4707 RayContributionToHitGroupIndex,
4708 MultiplierForGeometryContributionToHitGroupIndex,
4718__forceinline__ __device__
void optixTrace(
4721 uint32_t InstanceInclusionMask,
4722 uint32_t RayContributionToHitGroupIndex,
4723 uint32_t MultiplierForGeometryContributionToHitGroupIndex,
4724 uint32_t MissShaderIndex,
4728 AccelerationStructure,
4734 InstanceInclusionMask,
4736 RayContributionToHitGroupIndex,
4737 MultiplierForGeometryContributionToHitGroupIndex,
4741#if (OPTIX_VERSION >= 90000)
4742__forceinline__ __device__
float4 optixGetSpherePositionAndRadius()
4745 optixGetSphereData(data);
4750#if (OPTIX_VERSION >= 90000)
4751__forceinline__ __device__
float4
4755 optixHitObjectGetSphereData(data);
4760#if (OPTIX_VERSION >= 90000)
4764 optixGetLinearCurveVertexData(data);
4765 return makeMatrix<float, 2, 4>(data[0], data[1]);
4769#if (OPTIX_VERSION >= 90000)
4774 optixHitObjectGetLinearCurveVertexData(data);
4775 return makeMatrix<float, 2, 4>(data[0], data[1]);
4779#if (OPTIX_VERSION >= 90000)
4780__forceinline__ __device__
bool optixIsSphereHit()
4782 return optixGetPrimitiveType() == OPTIX_PRIMITIVE_TYPE_SPHERE;
4786#if (OPTIX_VERSION >= 90000)
4789 return optixGetPrimitiveType(optixHitObjectGetHitKind()) == OPTIX_PRIMITIVE_TYPE_SPHERE;
4793#if (OPTIX_VERSION >= 90000)
4794__forceinline__ __device__
bool optixIsLSSHit()
4796 return optixGetPrimitiveType() == OPTIX_PRIMITIVE_TYPE_ROUND_LINEAR;
4800#if (OPTIX_VERSION >= 90000)
4803 return optixGetPrimitiveType(optixHitObjectGetHitKind()) == OPTIX_PRIMITIVE_TYPE_ROUND_LINEAR;
4808template<
typename T,
size_t N = (
sizeof(T) + 3) / 4>
4809__forceinline__ __device__
void optixTraverseWithRegs(
4816 uint32_t InstanceInclusionMask,
4818 uint32_t RayContributionToHitGroupIndex,
4819 uint32_t MultiplierForGeometryContributionToHitGroupIndex,
4820 uint32_t MissShaderIndex,
4821 PayloadRegisters<T, N>& pr)
4824 if constexpr (N == 0)
4827 AccelerationStructure,
4833 InstanceInclusionMask,
4835 RayContributionToHitGroupIndex,
4836 MultiplierForGeometryContributionToHitGroupIndex,
4839 else if constexpr (N == 1)
4842 AccelerationStructure,
4848 InstanceInclusionMask,
4850 RayContributionToHitGroupIndex,
4851 MultiplierForGeometryContributionToHitGroupIndex,
4855 else if constexpr (N == 2)
4858 AccelerationStructure,
4864 InstanceInclusionMask,
4866 RayContributionToHitGroupIndex,
4867 MultiplierForGeometryContributionToHitGroupIndex,
4872 else if constexpr (N == 3)
4875 AccelerationStructure,
4881 InstanceInclusionMask,
4883 RayContributionToHitGroupIndex,
4884 MultiplierForGeometryContributionToHitGroupIndex,
4890 else if constexpr (N == 4)
4893 AccelerationStructure,
4899 InstanceInclusionMask,
4901 RayContributionToHitGroupIndex,
4902 MultiplierForGeometryContributionToHitGroupIndex,
4909 else if constexpr (N == 5)
4912 AccelerationStructure,
4918 InstanceInclusionMask,
4920 RayContributionToHitGroupIndex,
4921 MultiplierForGeometryContributionToHitGroupIndex,
4929 else if constexpr (N == 6)
4932 AccelerationStructure,
4938 InstanceInclusionMask,
4940 RayContributionToHitGroupIndex,
4941 MultiplierForGeometryContributionToHitGroupIndex,
4950 else if constexpr (N == 7)
4953 AccelerationStructure,
4959 InstanceInclusionMask,
4961 RayContributionToHitGroupIndex,
4962 MultiplierForGeometryContributionToHitGroupIndex,
4972 else if constexpr (N == 8)
4975 AccelerationStructure,
4981 InstanceInclusionMask,
4983 RayContributionToHitGroupIndex,
4984 MultiplierForGeometryContributionToHitGroupIndex,
4995 else if constexpr (N <= 16)
4998 AccelerationStructure,
5004 InstanceInclusionMask,
5006 RayContributionToHitGroupIndex,
5007 MultiplierForGeometryContributionToHitGroupIndex,
5026 else if constexpr (N <= kMaxOptiXPayloadRegisters)
5029 AccelerationStructure,
5035 InstanceInclusionMask,
5037 RayContributionToHitGroupIndex,
5038 MultiplierForGeometryContributionToHitGroupIndex,
5076__forceinline__ __device__
void optixTraverse(
5079 uint32_t InstanceInclusionMask,
5080 uint32_t RayContributionToHitGroupIndex,
5081 uint32_t MultiplierForGeometryContributionToHitGroupIndex,
5082 uint32_t MissShaderIndex,
5087 constexpr size_t numRegs = (
sizeof(T) + 3) / 4;
5089 if constexpr (numRegs <= kMaxOptiXPayloadRegisters)
5092 PayloadRegisters<T> pr;
5095 optixTraverseWithRegs<T>(
5096 AccelerationStructure,
5102 InstanceInclusionMask,
5104 RayContributionToHitGroupIndex,
5105 MultiplierForGeometryContributionToHitGroupIndex,
5111 pr.unpack(*Payload);
5117 packOptiXRayPayloadPointer((
void*)Payload, r0, r1);
5119 AccelerationStructure,
5125 InstanceInclusionMask,
5127 RayContributionToHitGroupIndex,
5128 MultiplierForGeometryContributionToHitGroupIndex,
5136__forceinline__ __device__
void optixTraverse(
5139 uint32_t InstanceInclusionMask,
5140 uint32_t RayContributionToHitGroupIndex,
5141 uint32_t MultiplierForGeometryContributionToHitGroupIndex,
5142 uint32_t MissShaderIndex,
5148 constexpr size_t numRegs = (
sizeof(T) + 3) / 4;
5150 if constexpr (numRegs <= kMaxOptiXPayloadRegisters)
5153 PayloadRegisters<T> pr;
5156 optixTraverseWithRegs<T>(
5157 AccelerationStructure,
5163 InstanceInclusionMask,
5165 RayContributionToHitGroupIndex,
5166 MultiplierForGeometryContributionToHitGroupIndex,
5172 pr.unpack(*Payload);
5178 packOptiXRayPayloadPointer((
void*)Payload, r0, r1);
5180 AccelerationStructure,
5186 InstanceInclusionMask,
5188 RayContributionToHitGroupIndex,
5189 MultiplierForGeometryContributionToHitGroupIndex,
5199__forceinline__ __device__
void optixTraverse(
5202 uint32_t InstanceInclusionMask,
5203 uint32_t RayContributionToHitGroupIndex,
5204 uint32_t MultiplierForGeometryContributionToHitGroupIndex,
5205 uint32_t MissShaderIndex,
5210 AccelerationStructure,
5216 InstanceInclusionMask,
5218 RayContributionToHitGroupIndex,
5219 MultiplierForGeometryContributionToHitGroupIndex,
5223#if (OPTIX_VERSION >= 80100)
5226 return optixHitObjectIsHit();
5230#if (OPTIX_VERSION >= 80100)
5233 return optixHitObjectIsMiss();
5237#if (OPTIX_VERSION >= 80100)
5240 return optixHitObjectIsNop();
5244#if (OPTIX_VERSION >= 90000)
5245static __forceinline__ __device__
uint
5248 return optixHitObjectGetClusterId();
5252#if (OPTIX_VERSION >= 80100)
5253static __forceinline__ __device__
void optixMakeMissHitObject(
5254 uint MissShaderIndex,
5258 optixMakeMissHitObject(
5265#if (OPTIX_VERSION >= 90000)
5273#if (OPTIX_VERSION >= 80100)
5274static __forceinline__ __device__
void optixMakeMissHitObject(
5275 uint MissShaderIndex,
5280 optixMakeMissHitObject(
5287#if (OPTIX_VERSION >= 90000)
5295#if (OPTIX_VERSION >= 90000)
5297static __forceinline__ __device__
void optixMakeHitObject(
5301 uint PrimitiveIndex,
5303 uint RayContributionToHitGroupIndex,
5304 uint MultiplierForGeometryContributionToHitGroupIndex,
5309 OptixTraverseData data{};
5310 optixHitObjectGetTraverseData(&data);
5312 AccelerationStructure,
5317 OPTIX_RAY_FLAG_NONE,
5322#elif (OPTIX_VERSION >= 80100)
5324static __forceinline__ __device__
void optixMakeHitObject(
5328 uint PrimitiveIndex,
5330 uint RayContributionToHitGroupIndex,
5331 uint MultiplierForGeometryContributionToHitGroupIndex,
5338 AccelerationStructure,
5344 RayContributionToHitGroupIndex,
5345 MultiplierForGeometryContributionToHitGroupIndex,
5357#if (OPTIX_VERSION >= 90000)
5359static __forceinline__ __device__
void optixMakeHitObject(
5360 uint HitGroupRecordIndex,
5364 uint PrimitiveIndex,
5370 OptixTraverseData data{};
5371 optixHitObjectGetTraverseData(&data);
5373 AccelerationStructure,
5378 OPTIX_RAY_FLAG_NONE,
5383#elif (OPTIX_VERSION >= 80100)
5385static __forceinline__ __device__
void optixMakeHitObject(
5386 uint HitGroupRecordIndex,
5390 uint PrimitiveIndex,
5397 optixMakeHitObjectWithRecord(
5398 AccelerationStructure,
5404 HitGroupRecordIndex,
5416#if (OPTIX_VERSION >= 90000)
5418static __forceinline__ __device__
void optixMakeHitObject(
5422 uint PrimitiveIndex,
5424 uint RayContributionToHitGroupIndex,
5425 uint MultiplierForGeometryContributionToHitGroupIndex,
5431 OptixTraverseData data{};
5432 optixHitObjectGetTraverseData(&data);
5434 AccelerationStructure,
5439 OPTIX_RAY_FLAG_NONE,
5444#elif (OPTIX_VERSION >= 80100)
5446static __forceinline__ __device__
void optixMakeHitObject(
5450 uint PrimitiveIndex,
5452 uint RayContributionToHitGroupIndex,
5453 uint MultiplierForGeometryContributionToHitGroupIndex,
5461 AccelerationStructure,
5467 RayContributionToHitGroupIndex,
5468 MultiplierForGeometryContributionToHitGroupIndex,
5480#if (OPTIX_VERSION >= 90000)
5482static __forceinline__ __device__
void optixMakeHitObject(
5483 uint HitGroupRecordIndex,
5487 uint PrimitiveIndex,
5494 OptixTraverseData data{};
5495 optixHitObjectGetTraverseData(&data);
5497 AccelerationStructure,
5502 OPTIX_RAY_FLAG_NONE,
5507#elif (OPTIX_VERSION >= 80100)
5509static __forceinline__ __device__
void optixMakeHitObject(
5510 uint HitGroupRecordIndex,
5514 uint PrimitiveIndex,
5522 optixMakeHitObjectWithRecord(
5523 AccelerationStructure,
5529 HitGroupRecordIndex,
5541#if (OPTIX_VERSION >= 80100)
5544 optixMakeNopHitObject();
5548#if (OPTIX_VERSION >= 80100)
5550template<
typename T,
size_t N = (
sizeof(T) + 3) / 4>
5551__forceinline__ __device__
void optixInvokeWithRegs(PayloadRegisters<T, N>& pr)
5553 if constexpr (N == 0)
5557 else if constexpr (N == 1)
5559 optixInvoke(pr.regs[0]);
5561 else if constexpr (N == 2)
5563 optixInvoke(pr.regs[0], pr.regs[1]);
5565 else if constexpr (N == 3)
5567 optixInvoke(pr.regs[0], pr.regs[1], pr.regs[2]);
5569 else if constexpr (N == 4)
5571 optixInvoke(pr.regs[0], pr.regs[1], pr.regs[2], pr.regs[3]);
5573 else if constexpr (N == 5)
5575 optixInvoke(pr.regs[0], pr.regs[1], pr.regs[2], pr.regs[3], pr.regs[4]);
5577 else if constexpr (N == 6)
5579 optixInvoke(pr.regs[0], pr.regs[1], pr.regs[2], pr.regs[3], pr.regs[4], pr.regs[5]);
5581 else if constexpr (N == 7)
5592 else if constexpr (N == 8)
5604 else if constexpr (N <= 16)
5624 else if constexpr (N <= kMaxOptiXPayloadRegisters)
5663static __forceinline__ __device__
void optixInvoke(
5668 constexpr size_t numRegs = (
sizeof(T) + 3) / 4;
5670 if constexpr (numRegs <= kMaxOptiXPayloadRegisters)
5673 PayloadRegisters<T> pr;
5675 optixInvokeWithRegs<T>(pr);
5677 pr.unpack(*Payload);
5683 packOptiXRayPayloadPointer((
void*)Payload, r0, r1);
5684 optixInvoke(r0, r1);
5689static __forceinline__ __device__
void optixInvoke(
5698#if (OPTIX_VERSION >= 80100)
5702 optixHitObjectGetWorldRayOrigin(),
5703 optixHitObjectGetRayTmin(),
5704 optixHitObjectGetWorldRayDirection(),
5705 optixHitObjectGetRayTmax()};
5710#if (OPTIX_VERSION >= 80100)
5711static __forceinline__ __device__
uint
5714 return optixHitObjectGetInstanceIndex();
5718#if (OPTIX_VERSION >= 80100)
5721 return optixHitObjectGetInstanceId();
5725#if (OPTIX_VERSION >= 80000)
5728 return optixHitObjectGetRayTime();
5732#if (OPTIX_VERSION >= 80100)
5735 return optixHitObjectGetRayTmax();
5739#if (OPTIX_VERSION >= 80100)
5740static __forceinline__ __device__
uint
5743 return optixHitObjectGetSbtGASIndex();
5747#if (OPTIX_VERSION >= 80100)
5748static __forceinline__ __device__
uint
5751 return optixHitObjectGetPrimitiveIndex();
5755#if (OPTIX_VERSION >= 80100)
5759 constexpr size_t numInts = (
sizeof(T) +
sizeof(uint32_t) - 1) /
5761 static_assert(numInts <= 8,
"Attribute type is too large");
5764 uint32_t values[numInts == 0 ? 1 : numInts] = {0};
5767 if constexpr (numInts > 0)
5768 values[0] = optixHitObjectGetAttribute_0();
5769 if constexpr (numInts > 1)
5770 values[1] = optixHitObjectGetAttribute_1();
5771 if constexpr (numInts > 2)
5772 values[2] = optixHitObjectGetAttribute_2();
5773 if constexpr (numInts > 3)
5774 values[3] = optixHitObjectGetAttribute_3();
5775 if constexpr (numInts > 4)
5776 values[4] = optixHitObjectGetAttribute_4();
5777 if constexpr (numInts > 5)
5778 values[5] = optixHitObjectGetAttribute_5();
5779 if constexpr (numInts > 6)
5780 values[6] = optixHitObjectGetAttribute_6();
5781 if constexpr (numInts > 7)
5782 values[7] = optixHitObjectGetAttribute_7();
5786 memcpy(&result, values,
sizeof(T));
5791#if (OPTIX_VERSION >= 80100)
5792static __forceinline__ __device__
uint
5795 return optixHitObjectGetSbtRecordIndex();
5799#if (OPTIX_VERSION >= 90000)
5800static __forceinline__ __device__
void slangOptixHitObjectSetSbtRecordIndex(
5802 uint sbtRecordIndex)
5804 optixHitObjectSetSbtRecordIndex(sbtRecordIndex);
5813#if (OPTIX_VERSION >= 90000)
5818 optixHitObjectGetWorldToObjectTransformMatrix(m);
5820 return makeMatrix<float, 4, 3>(
5821 make_float3(m[0], m[4], m[8]),
5822 make_float3(m[1], m[5], m[9]),
5823 make_float3(m[2], m[6], m[10]),
5824 make_float3(m[3], m[7], m[11]));
5828#if (OPTIX_VERSION >= 90000)
5833 optixHitObjectGetObjectToWorldTransformMatrix(m);
5835 return makeMatrix<float, 4, 3>(
5836 make_float3(m[0], m[4], m[8]),
5837 make_float3(m[1], m[5], m[9]),
5838 make_float3(m[2], m[6], m[10]),
5839 make_float3(m[3], m[7], m[11]));
5848 const float4* m = optixGetInstanceTransformFromHandle(handle);
5850 return makeMatrix<float, 3, 4>(m[0], m[1], m[2]);
5853__device__ __forceinline__
Matrix<float, 3, 4> _slang_optixGetInstanceInverseTransformFromHandle(
5856 const float4* m = optixGetInstanceInverseTransformFromHandle(handle);
5858 return makeMatrix<float, 3, 4>(m[0], m[1], m[2]);
5866 optixGetObjectToWorldTransformMatrix(m);
5868 return makeMatrix<float, 3, 4>(
5869 make_float4(m[0], m[1], m[2], m[3]),
5870 make_float4(m[4], m[5], m[6], m[7]),
5871 make_float4(m[8], m[9], m[10], m[11]));
5877 optixGetWorldToObjectTransformMatrix(m);
5879 return makeMatrix<float, 3, 4>(
5880 make_float4(m[0], m[1], m[2], m[3]),
5881 make_float4(m[4], m[5], m[6], m[7]),
5882 make_float4(m[8], m[9], m[10], m[11]));
5885__device__ __forceinline__
Matrix<float, 4, 3> slangOptixGetObjectToWorldTransformMatrix4x3()
5888 optixGetObjectToWorldTransformMatrix(m);
5890 return makeMatrix<float, 4, 3>(
5891 make_float3(m[0], m[4], m[8]),
5892 make_float3(m[1], m[5], m[9]),
5893 make_float3(m[2], m[6], m[10]),
5894 make_float3(m[3], m[7], m[11]));
5897__device__ __forceinline__
Matrix<float, 4, 3> slangOptixGetWorldToObjectTransformMatrix4x3()
5900 optixGetWorldToObjectTransformMatrix(m);
5902 return makeMatrix<float, 4, 3>(
5903 make_float3(m[0], m[4], m[8]),
5904 make_float3(m[1], m[5], m[9]),
5905 make_float3(m[2], m[6], m[10]),
5906 make_float3(m[3], m[7], m[11]));
5926 template<
typename T>
5929 return reinterpret_cast<T*
>(
data);
5932 template<
typename T>
5935 uint64_t offset =
strides[0] * index;
5936 return reinterpret_cast<T*
>(
data + offset);
5939 template<
typename T>
5943 return reinterpret_cast<T*
>(
data + offset);
5946 template<
typename T>
5950 return reinterpret_cast<T*
>(
data + offset);
5953 template<
typename T>
5958 return reinterpret_cast<T*
>(
data + offset);
5961 template<
typename T,
unsigned int N>
5964 uint64_t offset = 0;
5965 for (
unsigned int i = 0; i < N; ++i)
5967 offset +=
strides[i] * index[i];
5969 return reinterpret_cast<T*
>(
data + offset);
5972 template<
typename T>
5975 return *
reinterpret_cast<T*
>(
data +
strides[0] * x);
5977 template<
typename T>
5978 __device__ T&
load(uint32_t x, uint32_t y)
5982 template<
typename T>
5987 template<
typename T>
5988 __device__ T&
load(uint32_t x, uint32_t y, uint32_t z)
5992 template<
typename T>
5995 return *
reinterpret_cast<T*
>(
5998 template<
typename T>
5999 __device__ T&
load(uint32_t x, uint32_t y, uint32_t z, uint32_t w)
6001 return *
reinterpret_cast<T*
>(
6004 template<
typename T>
6007 return *
reinterpret_cast<T*
>(
6011 template<
typename T>
6012 __device__ T&
load(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4)
6014 return *
reinterpret_cast<T*
>(
6020 template<
typename T,
unsigned int N>
6023 uint64_t offset = 0;
6024 for (
unsigned int i = 0; i < N; ++i)
6026 offset +=
strides[i] * index[i];
6028 return *
reinterpret_cast<T*
>(
data + offset);
6031 template<
typename T>
6032 __device__
void store(uint32_t x, T val)
6034 *
reinterpret_cast<T*
>(
data +
strides[0] * x) = val;
6036 template<
typename T>
6037 __device__
void store(uint32_t x, uint32_t y, T val)
6041 template<
typename T>
6046 template<
typename T>
6047 __device__
void store(uint32_t x, uint32_t y, uint32_t z, T val)
6051 template<
typename T>
6054 *
reinterpret_cast<T*
>(
6057 template<
typename T>
6058 __device__
void store(uint32_t x, uint32_t y, uint32_t z, uint32_t w, T val)
6060 *
reinterpret_cast<T*
>(
6063 template<
typename T>
6066 *
reinterpret_cast<T*
>(
6070 template<
typename T>
6071 __device__
void store(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4, T val)
6073 *
reinterpret_cast<T*
>(
6079 template<
typename T,
unsigned int N>
6082 uint64_t offset = 0;
6083 for (
unsigned int i = 0; i < N; ++i)
6085 offset +=
strides[i] * index[i];
6087 *
reinterpret_cast<T*
>(
data + offset) = val;
6104#define SLANG_TEX1DFETCH_INT_IMPL(T, dtype, c) \
6106 SLANG_FORCE_INLINE SLANG_CUDA_CALL T tex1Dfetch_int(CUtexObject texObj, int x, int mip) \
6109 [[maybe_unused]] T stub; \
6110 asm("tex.level.1d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5}], %6;" \
6111 : c(result), c(stub), c(stub), c(stub) \
6112 : "l"(texObj), "r"(x), "r"(mip)); \
6116 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##2 tex1Dfetch_int(CUtexObject texObj, int x, int mip) \
6118 T result_x, result_y; \
6119 [[maybe_unused]] T stub; \
6120 asm("tex.level.1d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5}], %6;" \
6121 : c(result_x), c(result_y), c(stub), c(stub) \
6122 : "l"(texObj), "r"(x), "r"(mip)); \
6123 return make_##T##2(result_x, result_y); \
6126 SLANG_FORCE_INLINE SLANG_CUDA_CALL T##4 tex1Dfetch_int(CUtexObject texObj, int x, int mip) \
6128 T result_x, result_y, result_z, result_w; \
6129 asm("tex.level.1d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5}], %6;" \
6130 : c(result_x), c(result_y), c(result_z), c(result_w) \
6131 : "l"(texObj), "r"(x), "r"(mip)); \
6132 return make_##T##4(result_x, result_y, result_z, result_w); \
6135SLANG_TEX1DFETCH_INT_IMPL(
float,
"f32",
"=f")
6136SLANG_TEX1DFETCH_INT_IMPL(
uint, "u32", "=r")
6137SLANG_TEX1DFETCH_INT_IMPL(
int, "s32", "=r")
6143#define SLANG_TEX2DFETCH_INT_IMPL(T, dtype, c) \
6145 SLANG_FORCE_INLINE SLANG_CUDA_CALL T tex2Dfetch_int(CUtexObject texObj, int x, int y, int mip) \
6148 [[maybe_unused]] T stub; \
6149 asm("tex.level.2d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6}], %7;" \
6150 : c(result), c(stub), c(stub), c(stub) \
6151 : "l"(texObj), "r"(x), "r"(y), "r"(mip)); \
6155 SLANG_FORCE_INLINE SLANG_CUDA_CALL \
6156 T##2 tex2Dfetch_int(CUtexObject texObj, int x, int y, int mip) \
6158 T result_x, result_y; \
6159 [[maybe_unused]] T stub; \
6160 asm("tex.level.2d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6}], %7;" \
6161 : c(result_x), c(result_y), c(stub), c(stub) \
6162 : "l"(texObj), "r"(x), "r"(y), "r"(mip)); \
6163 return make_##T##2(result_x, result_y); \
6166 SLANG_FORCE_INLINE SLANG_CUDA_CALL \
6167 T##4 tex2Dfetch_int(CUtexObject texObj, int x, int y, int mip) \
6169 T result_x, result_y, result_z, result_w; \
6170 asm("tex.level.2d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6}], %7;" \
6171 : c(result_x), c(result_y), c(result_z), c(result_w) \
6172 : "l"(texObj), "r"(x), "r"(y), "r"(mip)); \
6173 return make_##T##4(result_x, result_y, result_z, result_w); \
6185#define SLANG_TEX3DFETCH_INT_IMPL(T, dtype, c) \
6187 SLANG_FORCE_INLINE SLANG_CUDA_CALL T \
6188 tex3Dfetch_int(CUtexObject texObj, int x, int y, int z, int mip) \
6191 [[maybe_unused]] T stub; \
6192 asm("tex.level.3d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6, %7, %8}], %9;" \
6193 : c(result), c(stub), c(stub), c(stub) \
6194 : "l"(texObj), "r"(x), "r"(y), "r"(z), "r"(z) , "r"(mip)); \
6198 SLANG_FORCE_INLINE SLANG_CUDA_CALL \
6199 T##2 tex3Dfetch_int(CUtexObject texObj, int x, int y, int z, int mip) \
6201 T result_x, result_y; \
6202 [[maybe_unused]] T stub; \
6203 asm("tex.level.3d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6, %7, %8}], %9;" \
6204 : c(result_x), c(result_y), c(stub), c(stub) \
6205 : "l"(texObj), "r"(x), "r"(y), "r"(z), "r"(z) , "r"(mip)); \
6206 return make_##T##2(result_x, result_y); \
6209 SLANG_FORCE_INLINE SLANG_CUDA_CALL \
6210 T##4 tex3Dfetch_int(CUtexObject texObj, int x, int y, int z, int mip) \
6212 T result_x, result_y, result_z, result_w; \
6213 asm("tex.level.3d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6, %7, %8}], %9;" \
6214 : c(result_x), c(result_y), c(result_z), c(result_w) \
6215 : "l"(texObj), "r"(x), "r"(y), "r"(z), "r"(z) , "r"(mip)); \
6216 return make_##T##4(result_x, result_y, result_z, result_w); \
6227#define SLANG_TEX1DARRAYFETCH_INT_IMPL(T, dtype, c) \
6229 SLANG_FORCE_INLINE SLANG_CUDA_CALL T \
6230 tex1DArrayfetch_int(CUtexObject texObj, int x, int layer, int mip) \
6233 [[maybe_unused]] T stub; \
6234 asm("tex.level.a1d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6}], %7;" \
6235 : c(result), c(stub), c(stub), c(stub) \
6236 : "l"(texObj), "r"(layer), "r"(x), "r"(mip)); \
6240 SLANG_FORCE_INLINE SLANG_CUDA_CALL \
6241 T##2 tex1DArrayfetch_int(CUtexObject texObj, int x, int layer, int mip) \
6243 T result_x, result_y; \
6244 [[maybe_unused]] T stub; \
6245 asm("tex.level.a1d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6}], %7;" \
6246 : c(result_x), c(result_y), c(stub), c(stub) \
6247 : "l"(texObj), "r"(layer), "r"(x), "r"(mip)); \
6248 return make_##T##2(result_x, result_y); \
6251 SLANG_FORCE_INLINE SLANG_CUDA_CALL \
6252 T##4 tex1DArrayfetch_int(CUtexObject texObj, int x, int layer, int mip) \
6254 T result_x, result_y, result_z, result_w; \
6255 asm("tex.level.a1d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6}], %7;" \
6256 : c(result_x), c(result_y), c(result_z), c(result_w) \
6257 : "l"(texObj), "r"(layer), "r"(x), "r"(mip)); \
6258 return make_##T##4(result_x, result_y, result_z, result_w); \
6269#define SLANG_TEX2DARRAYFETCH_INT_IMPL(T, dtype, c) \
6271 SLANG_FORCE_INLINE SLANG_CUDA_CALL T \
6272 tex2DArrayfetch_int(CUtexObject texObj, int x, int y, int layer, int mip) \
6275 [[maybe_unused]] T stub; \
6276 asm("tex.level.a2d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6, %7, %8}], %9;" \
6277 : c(result), c(stub), c(stub), c(stub) \
6278 : "l"(texObj), "r"(layer), "r"(x), "r"(y), "r"(layer) , "r"(mip)); \
6282 SLANG_FORCE_INLINE SLANG_CUDA_CALL \
6283 T##2 tex2DArrayfetch_int(CUtexObject texObj, int x, int y, int layer, int mip) \
6285 T result_x, result_y; \
6286 [[maybe_unused]] T stub; \
6287 asm("tex.level.a2d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6, %7, %8}], %9;" \
6288 : c(result_x), c(result_y), c(stub), c(stub) \
6289 : "l"(texObj), "r"(layer), "r"(x), "r"(y), "r"(layer) , "r"(mip)); \
6290 return make_##T##2(result_x, result_y); \
6293 SLANG_FORCE_INLINE SLANG_CUDA_CALL \
6294 T##4 tex2DArrayfetch_int(CUtexObject texObj, int x, int y, int layer, int mip) \
6296 T result_x, result_y, result_z, result_w; \
6297 asm("tex.level.a2d.v4." dtype ".s32 {%0, %1, %2, %3}, [%4, {%5, %6, %7, %8}], %9;" \
6298 : c(result_x), c(result_y), c(result_z), c(result_w) \
6299 : "l"(texObj), "r"(layer), "r"(x), "r"(y), "r"(layer) , "r"(mip)); \
6300 return make_##T##4(result_x, result_y, result_z, result_w); \
6308#define SLANG_WARP_FULL_MASK 0xFFFFFFFF
6311#define SLANG_WAVE_ROTATE_IMPL(T) \
6312 __device__ __forceinline__ T##2 _slang_waveRotate(T##2 value, unsigned int delta) \
6314 return make_##T##2( \
6316 SLANG_WARP_FULL_MASK, \
6318 (_getLaneId() + delta) % SLANG_CUDA_WARP_SIZE), \
6320 SLANG_WARP_FULL_MASK, \
6322 (_getLaneId() + delta) % SLANG_CUDA_WARP_SIZE)); \
6324 __device__ __forceinline__ T##3 _slang_waveRotate(T##3 value, unsigned int delta) \
6326 return make_##T##3( \
6328 SLANG_WARP_FULL_MASK, \
6330 (_getLaneId() + delta) % SLANG_CUDA_WARP_SIZE), \
6332 SLANG_WARP_FULL_MASK, \
6334 (_getLaneId() + delta) % SLANG_CUDA_WARP_SIZE), \
6336 SLANG_WARP_FULL_MASK, \
6338 (_getLaneId() + delta) % SLANG_CUDA_WARP_SIZE)); \
6340 __device__ __forceinline__ T##4 _slang_waveRotate(T##4 value, unsigned int delta) \
6342 return make_##T##4( \
6344 SLANG_WARP_FULL_MASK, \
6346 (_getLaneId() + delta) % SLANG_CUDA_WARP_SIZE), \
6348 SLANG_WARP_FULL_MASK, \
6350 (_getLaneId() + delta) % SLANG_CUDA_WARP_SIZE), \
6352 SLANG_WARP_FULL_MASK, \
6354 (_getLaneId() + delta) % SLANG_CUDA_WARP_SIZE), \
6356 SLANG_WARP_FULL_MASK, \
6358 (_getLaneId() + delta) % SLANG_CUDA_WARP_SIZE)); \
6372#ifdef SLANG_CUDA_ENABLE_HALF
6379 int2 intValue = make_int2((
int)value.x, (
int)value.y);
6381 return make_bool2((
bool)result.x, (
bool)result.y);
6386 int3 intValue = make_int3((
int)value.x, (
int)value.y, (
int)value.z);
6388 return make_bool3((
bool)result.x, (
bool)result.y, (
bool)result.z);
6393 int4 intValue = make_int4((
int)value.x, (
int)value.y, (
int)value.z, (
int)value.w);
6395 return make_bool4((
bool)result.x, (
bool)result.y, (
bool)result.z, (
bool)result.w);
6398#undef SLANG_WAVE_ROTATE_IMPL
6404 bool v0 = __shfl_sync(0xFFFFFFFF, expr, (
_getLaneId() & 0xFFFFFFFC) | 0);
6405 bool v1 = __shfl_sync(0xFFFFFFFF, expr, (
_getLaneId() & 0xFFFFFFFC) | 1);
6406 bool v2 = __shfl_sync(0xFFFFFFFF, expr, (
_getLaneId() & 0xFFFFFFFC) | 2);
6407 bool v3 = __shfl_sync(0xFFFFFFFF, expr, (
_getLaneId() & 0xFFFFFFFC) | 3);
6408 return v0 || v1 || v2 || v3;
6414 bool v0 = __shfl_sync(0xFFFFFFFF, expr, (
_getLaneId() & 0xFFFFFFFC) | 0);
6415 bool v1 = __shfl_sync(0xFFFFFFFF, expr, (
_getLaneId() & 0xFFFFFFFC) | 1);
6416 bool v2 = __shfl_sync(0xFFFFFFFF, expr, (
_getLaneId() & 0xFFFFFFFC) | 2);
6417 bool v3 = __shfl_sync(0xFFFFFFFF, expr, (
_getLaneId() & 0xFFFFFFFC) | 3);
6418 return v0 && v1 && v2 && v3;
6423#define SLANG_WAVE_CLUSTERED_ROTATE_IMPL(T) \
6424 __device__ __forceinline__ T \
6425 _slang_waveClusteredRotate(T value, unsigned int delta, unsigned int clusterSize) \
6427 unsigned int laneId = _getLaneId(); \
6428 unsigned int clusterStart = (laneId / clusterSize) * clusterSize; \
6429 unsigned int targetLane = clusterStart + ((laneId - clusterStart + delta) % clusterSize); \
6430 return __shfl_sync(SLANG_WARP_FULL_MASK, value, targetLane); \
6432 __device__ __forceinline__ \
6433 T##2 _slang_waveClusteredRotate(T##2 value, unsigned int delta, unsigned int clusterSize) \
6435 unsigned int laneId = _getLaneId(); \
6436 unsigned int clusterStart = (laneId / clusterSize) * clusterSize; \
6437 unsigned int targetLane = clusterStart + ((laneId - clusterStart + delta) % clusterSize); \
6438 return make_##T##2( \
6439 (T)__shfl_sync(SLANG_WARP_FULL_MASK, value.x, targetLane), \
6440 (T)__shfl_sync(SLANG_WARP_FULL_MASK, value.y, targetLane)); \
6442 __device__ __forceinline__ \
6443 T##3 _slang_waveClusteredRotate(T##3 value, unsigned int delta, unsigned int clusterSize) \
6445 unsigned int laneId = _getLaneId(); \
6446 unsigned int clusterStart = (laneId / clusterSize) * clusterSize; \
6447 unsigned int targetLane = clusterStart + ((laneId - clusterStart + delta) % clusterSize); \
6448 return make_##T##3( \
6449 (T)__shfl_sync(SLANG_WARP_FULL_MASK, value.x, targetLane), \
6450 (T)__shfl_sync(SLANG_WARP_FULL_MASK, value.y, targetLane), \
6451 (T)__shfl_sync(SLANG_WARP_FULL_MASK, value.z, targetLane)); \
6453 __device__ __forceinline__ \
6454 T##4 _slang_waveClusteredRotate(T##4 value, unsigned int delta, unsigned int clusterSize) \
6456 unsigned int laneId = _getLaneId(); \
6457 unsigned int clusterStart = (laneId / clusterSize) * clusterSize; \
6458 unsigned int targetLane = clusterStart + ((laneId - clusterStart + delta) % clusterSize); \
6459 return make_##T##4( \
6460 (T)__shfl_sync(SLANG_WARP_FULL_MASK, value.x, targetLane), \
6461 (T)__shfl_sync(SLANG_WARP_FULL_MASK, value.y, targetLane), \
6462 (T)__shfl_sync(SLANG_WARP_FULL_MASK, value.z, targetLane), \
6463 (T)__shfl_sync(SLANG_WARP_FULL_MASK, value.w, targetLane)); \
6477#ifdef SLANG_CUDA_ENABLE_HALF
6485 unsigned int clusterSize)
6487 int intValue = (int)value;
6489 return (
bool)result;
6492__device__ __forceinline__ bool2
6495 int2 intValue = make_int2((
int)value.x, (
int)value.y);
6497 return make_bool2((
bool)result.x, (
bool)result.y);
6500__device__ __forceinline__ bool3
6503 int3 intValue = make_int3((
int)value.x, (
int)value.y, (
int)value.z);
6505 return make_bool3((
bool)result.x, (
bool)result.y, (
bool)result.z);
6508__device__ __forceinline__ bool4
6511 int4 intValue = make_int4((
int)value.x, (
int)value.y, (
int)value.z, (
int)value.w);
6513 return make_bool4((
bool)result.x, (
bool)result.y, (
bool)result.z, (
bool)result.w);
6516#undef SLANG_WAVE_CLUSTERED_ROTATE_IMPL
6519#ifdef SLANG_CUDA_ENABLE_OPTIX
6521#if (OPTIX_VERSION >= 90000)
6526struct OptixCoopVecTraits;
6531#if defined(OPTIX_VERSION) && OPTIX_VERSION > 90000
6532template<
typename T,
unsigned int N>
6533struct OptixCoopVecTraits<OptixCoopVec<T, N>>
6535 static constexpr unsigned int size = N;
6542 OptixCoopVecElemType inputInterpretation,
6543 OptixCoopVecElemType matrixInterpretation,
6544 OptixCoopVecMatrixLayout matrixLayout>
6545__forceinline__ __device__ VecTOut slangOptixCoopVecMatMul(
6546 const VecTIn& inputVector,
6548 unsigned matrixOffset,
6550 unsigned matrixStride)
6552 constexpr unsigned N = OptixCoopVecTraits<VecTOut>::size;
6553 constexpr unsigned K = OptixCoopVecTraits<VecTIn>::size;
6555 return optixCoopVecMatMul<
6558 inputInterpretation,
6563 matrixInterpretation>(inputVector, matrix, matrixOffset, matrixStride);
6570 OptixCoopVecElemType inputInterpretation,
6571 OptixCoopVecElemType matrixInterpretation,
6572 OptixCoopVecMatrixLayout matrixLayout,
6573 OptixCoopVecElemType biasInterpretation>
6574__forceinline__ __device__ VecTOut slangOptixCoopVecMatMul(
6575 const VecTIn& inputVector,
6577 unsigned matrixOffset,
6579 unsigned biasOffset,
6580 unsigned matrixStride)
6582 constexpr unsigned N = OptixCoopVecTraits<VecTOut>::size;
6583 constexpr unsigned K = OptixCoopVecTraits<VecTIn>::size;
6586 return optixCoopVecMatMul<
6589 inputInterpretation,
6594 matrixInterpretation,
6595 biasInterpretation>(inputVector, matrix, matrixOffset, bias, biasOffset, matrixStride);
6603 OptixCoopVecElemType inputInterpretation,
6604 OptixCoopVecElemType matrixInterpretation,
6605 OptixCoopVecMatrixLayout matrixLayout>
6606__forceinline__ __device__ VecTOut slangOptixCoopVecMatMul(
6607 const VecTIn& inputVector,
6609 unsigned matrixOffset,
6610 unsigned matrixStride)
6612 constexpr unsigned N = OptixCoopVecTraits<VecTOut>::size;
6613 constexpr unsigned K = OptixCoopVecTraits<VecTIn>::size;
6616 return optixCoopVecMatMul<
6619 inputInterpretation,
6624 matrixInterpretation>(inputVector, matrix, matrixOffset, matrixStride);
6633#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) || \
6634 (CUDA_VERSION >= 12050)
6640namespace Slang_CUDA_WMMA
6643template<
typename A,
typename B>
6646 static constexpr bool value =
false;
6649struct IsSameType<A, A>
6651 static constexpr bool value =
true;
6672template<
typename ElemT,
int M,
int N,
int K, MatrixUse use>
6673struct RegisterCount;
6675#if SLANG_CUDA_ENABLE_HALF
6678struct RegisterCount<
half, 16, 16, 16, MatrixUse::MatrixA>
6680 static constexpr int value = 4;
6683struct RegisterCount<
half, 16, 16, 16, MatrixUse::MatrixB>
6685 static constexpr int value = 4;
6688struct RegisterCount<
half, 16, 16, 16, MatrixUse::MatrixC>
6690 static constexpr int value = 4;
6693struct RegisterCount<
half, 16, 16, 16, MatrixUse::MatrixD>
6695 static constexpr int value = 4;
6699#if SLANG_CUDA_ENABLE_BF16
6704struct RegisterCount<__nv_bfloat16, 16, 16, 16, MatrixUse::MatrixA>
6706 static constexpr int value = 4;
6709struct RegisterCount<__nv_bfloat16, 16, 16, 16, MatrixUse::MatrixB>
6711 static constexpr int value = 4;
6716template<
int M,
int N,
int K>
6717struct RegisterCount<float, M, N, K, MatrixUse::MatrixC>
6719 static constexpr int value = 8;
6721template<
int M,
int N,
int K>
6722struct RegisterCount<float, M, N, K, MatrixUse::MatrixD>
6724 static constexpr int value = 8;
6728template<
int M,
int N,
int K>
6729struct RegisterCount<int32_t, M, N, K, MatrixUse::MatrixC>
6731 static constexpr int value = 8;
6733template<
int M,
int N,
int K>
6734struct RegisterCount<int32_t, M, N, K, MatrixUse::MatrixD>
6736 static constexpr int value = 8;
6741struct RegisterCount<unsigned char, 16, 16, 16, MatrixUse::MatrixA>
6743 static constexpr int value = 2;
6746struct RegisterCount<unsigned char, 16, 16, 16, MatrixUse::MatrixB>
6748 static constexpr int value = 2;
6753struct RegisterCount<char, 16, 16, 16, MatrixUse::MatrixA>
6755 static constexpr int value = 2;
6758struct RegisterCount<char, 16, 16, 16, MatrixUse::MatrixB>
6760 static constexpr int value = 2;
6763#if SLANG_CUDA_ENABLE_FP8
6766struct RegisterCount<__nv_fp8_e4m3, 16, 16, 16, MatrixUse::MatrixA>
6768 static constexpr int value = 2;
6771struct RegisterCount<__nv_fp8_e4m3, 16, 16, 16, MatrixUse::MatrixB>
6773 static constexpr int value = 2;
6776struct RegisterCount<__nv_fp8_e5m2, 16, 16, 16, MatrixUse::MatrixA>
6778 static constexpr int value = 2;
6781struct RegisterCount<__nv_fp8_e5m2, 16, 16, 16, MatrixUse::MatrixB>
6783 static constexpr int value = 2;
6903template<
typename ElemT, Layout layout,
int Row,
int Col, MatrixUse use>
6904struct MMALoadHelper;
6906template<
typename ElemT, Layout layout>
6907struct MMALoadHelper<ElemT, layout, 16, 16, MatrixUse::MatrixA>
6909 static __device__
inline void exec(
6911 const ElemT* buffer,
6917 if constexpr (
sizeof(ElemT) == 1)
6922 const uint8_t* ubuf =
reinterpret_cast<const uint8_t*
>(buffer);
6923 if constexpr (layout == Layout::RowMajor)
6927 regs[0] = *
reinterpret_cast<const uint32_t*
>(&ubuf[gid * stride + 4 * tid]);
6928 regs[1] = *
reinterpret_cast<const uint32_t*
>(&ubuf[(gid + 8) * stride + 4 * tid]);
6937 for (
int e = 0; e < 4; e++)
6939 unsigned col = 4 * tid + e;
6940 uint32_t b0 = (uint32_t)ubuf[col * stride + gid];
6941 uint32_t b1 = (uint32_t)ubuf[col * stride + (gid + 8)];
6942 r0 |= b0 << (e * 8);
6943 r1 |= b1 << (e * 8);
6950 unsigned row = laneid >> 1;
6951 unsigned side = laneid & 1;
6952 uint4 loaded_v = *
reinterpret_cast<const uint4*
>(&buffer[row * stride + side * 8]);
6953 uint32_t* loaded =
reinterpret_cast<uint32_t*
>(&loaded_v);
6955 const uint32_t mask = 0xFFFFFFFF;
6956 if constexpr (layout == Layout::RowMajor)
6960 for (
int k = 0; k < 4; k++)
6962 tmp = __shfl_sync(mask, loaded[k], gid * 2);
6965 tmp = __shfl_sync(mask, loaded[k], (gid + 8) * 2);
6968 tmp = __shfl_sync(mask, loaded[k], gid * 2 + 1);
6971 tmp = __shfl_sync(mask, loaded[k], (gid + 8) * 2 + 1);
6978 unsigned k = gid >> 1;
6979 unsigned half_sel = gid & 1;
6983 for (
int ki = 0; ki < 4; ki++)
6985 s[ki][0] = __shfl_sync(mask, loaded[ki], tid * 4);
6986 s[ki][1] = __shfl_sync(mask, loaded[ki], tid * 4 + 1);
6987 s[ki][2] = __shfl_sync(mask, loaded[ki], tid * 4 + 2);
6988 s[ki][3] = __shfl_sync(mask, loaded[ki], tid * 4 + 3);
6989 s[ki][4] = __shfl_sync(mask, loaded[ki], tid * 4 + 16);
6990 s[ki][5] = __shfl_sync(mask, loaded[ki], tid * 4 + 17);
6991 s[ki][6] = __shfl_sync(mask, loaded[ki], tid * 4 + 18);
6992 s[ki][7] = __shfl_sync(mask, loaded[ki], tid * 4 + 19);
6995 unsigned shift = half_sel * 16;
6996 uint16_t h0 = (uint16_t)(s[k][0] >> shift);
6997 uint16_t h1 = (uint16_t)(s[k][2] >> shift);
6998 regs[0] = (uint32_t)h0 | ((uint32_t)h1 << 16);
7000 h0 = (uint16_t)(s[k][1] >> shift);
7001 h1 = (uint16_t)(s[k][3] >> shift);
7002 regs[1] = (uint32_t)h0 | ((uint32_t)h1 << 16);
7004 h0 = (uint16_t)(s[k][4] >> shift);
7005 h1 = (uint16_t)(s[k][6] >> shift);
7006 regs[2] = (uint32_t)h0 | ((uint32_t)h1 << 16);
7008 h0 = (uint16_t)(s[k][5] >> shift);
7009 h1 = (uint16_t)(s[k][7] >> shift);
7010 regs[3] = (uint32_t)h0 | ((uint32_t)h1 << 16);
7015template<
typename ElemT, Layout layout>
7016struct MMALoadHelper<ElemT, layout, 16, 8, MatrixUse::MatrixB>
7018 static __device__
inline void exec(
7020 const ElemT* buffer,
7026 if constexpr (
sizeof(ElemT) == 1)
7031 const uint8_t* ubuf =
reinterpret_cast<const uint8_t*
>(buffer);
7032 if constexpr (layout == Layout::ColMajor)
7036 regs[0] = *
reinterpret_cast<const uint32_t*
>(&ubuf[gid * stride + 4 * tid]);
7044 for (
int e = 0; e < 4; e++)
7046 unsigned row = 4 * tid + e;
7047 r0 |= ((uint32_t)ubuf[row * stride + gid]) << (e * 8);
7054 if constexpr (layout == Layout::ColMajor)
7056 unsigned col = laneid >> 1;
7057 unsigned side = laneid & 1;
7058 loaded_v = *
reinterpret_cast<const uint4*
>(&buffer[col * stride + side * 8]);
7062 unsigned row = laneid & 15;
7063 loaded_v = *
reinterpret_cast<const uint4*
>(&buffer[row * stride]);
7065 uint32_t* loaded =
reinterpret_cast<uint32_t*
>(&loaded_v);
7067 const uint32_t mask = 0xFFFFFFFF;
7068 if constexpr (layout == Layout::ColMajor)
7072 for (
int k = 0; k < 4; k++)
7074 tmp = __shfl_sync(mask, loaded[k], gid * 2);
7077 tmp = __shfl_sync(mask, loaded[k], gid * 2 + 1);
7086 for (
int ki = 0; ki < 4; ki++)
7088 s[ki][0] = __shfl_sync(mask, loaded[ki], tid * 2);
7089 s[ki][1] = __shfl_sync(mask, loaded[ki], tid * 2 + 1);
7090 s[ki][2] = __shfl_sync(mask, loaded[ki], tid * 2 + 8);
7091 s[ki][3] = __shfl_sync(mask, loaded[ki], tid * 2 + 9);
7094 unsigned k = gid >> 1;
7095 unsigned shift = (gid & 1) * 16;
7097 uint16_t h0 = (uint16_t)(s[k][0] >> shift);
7098 uint16_t h1 = (uint16_t)(s[k][1] >> shift);
7099 regs[0] = (uint32_t)h0 | ((uint32_t)h1 << 16);
7101 h0 = (uint16_t)(s[k][2] >> shift);
7102 h1 = (uint16_t)(s[k][3] >> shift);
7103 regs[1] = (uint32_t)h0 | ((uint32_t)h1 << 16);
7108template<
typename ElemT, Layout layout>
7109struct MMALoadHelper<ElemT, layout, 16, 16, MatrixUse::MatrixB>
7111 static __device__
inline void exec(
7113 const ElemT* buffer,
7120 constexpr int regsPerSubTile = (
sizeof(ElemT) == 1) ? 1 : 2;
7121 MMALoadHelper<ElemT, layout, 16, 8, MatrixUse::MatrixB>::exec(
7128 if constexpr (layout == Layout::RowMajor)
7129 MMALoadHelper<ElemT, layout, 16, 8, MatrixUse::MatrixB>::exec(
7130 regs + regsPerSubTile,
7137 MMALoadHelper<ElemT, layout, 16, 8, MatrixUse::MatrixB>::exec(
7138 regs + regsPerSubTile,
7139 buffer + 8 * stride,
7147template<
typename ElemT, Layout layout>
7148struct MMALoadHelper<ElemT, layout, 16, 8, MatrixUse::MatrixC>
7150 static __device__
inline void exec(
7152 const ElemT* buffer,
7158 if constexpr (
sizeof(ElemT) == 4)
7160 const float* fbuf =
reinterpret_cast<const float*
>(buffer);
7162 if constexpr (layout == Layout::RowMajor)
7164 unsigned row = laneid >> 1;
7165 unsigned side = laneid & 1;
7166 loaded_v = *
reinterpret_cast<const uint4*
>(&fbuf[row * stride + side * 4]);
7170 unsigned col = laneid >> 2;
7171 unsigned chunk = laneid & 3;
7172 loaded_v = *
reinterpret_cast<const uint4*
>(&fbuf[col * stride + chunk * 4]);
7174 uint32_t* loaded =
reinterpret_cast<uint32_t*
>(&loaded_v);
7176 const uint32_t mask = 0xFFFFFFFF;
7177 if constexpr (layout == Layout::RowMajor)
7180 unsigned kb = (tid & 1) * 2;
7181 unsigned sb = (tid >> 1) * 2;
7183 for (
int k = 0; k < 4; k++)
7186 for (
int j = 0; j < 4; j++)
7188 unsigned srcLane = (j < 2) ? ((j == 0) ? gid * 2 : (gid + 8) * 2)
7189 : ((j == 2) ? gid * 2 + 1 : (gid + 8) * 2 + 1);
7190 tmp = __shfl_sync(mask, loaded[k], srcLane);
7191 if (k == kb && j == sb)
7193 if (k == kb + 1 && j == sb)
7195 if (k == kb && j == sb + 1)
7197 if (k == kb + 1 && j == sb + 1)
7205 unsigned k = gid & 3;
7207 for (
int ki = 0; ki < 4; ki++)
7209 tmp = __shfl_sync(mask, loaded[ki], tid * 8 + gid / 4);
7212 tmp = __shfl_sync(mask, loaded[ki], tid * 8 + 4 + gid / 4);
7215 tmp = __shfl_sync(mask, loaded[ki], tid * 8 + 2 + gid / 4);
7218 tmp = __shfl_sync(mask, loaded[ki], tid * 8 + 6 + gid / 4);
7227 if constexpr (layout == Layout::RowMajor)
7229 unsigned row = laneid & 15;
7230 loaded_v = *
reinterpret_cast<const uint4*
>(&buffer[row * stride]);
7234 unsigned col = (laneid & 15) >> 1;
7235 unsigned side = laneid & 1;
7236 loaded_v = *
reinterpret_cast<const uint4*
>(&buffer[col * stride + side * 8]);
7238 uint32_t* loaded =
reinterpret_cast<uint32_t*
>(&loaded_v);
7240 const uint32_t mask = 0xFFFFFFFF;
7241 if constexpr (layout == Layout::RowMajor)
7245 for (
int k = 0; k < 4; k++)
7247 tmp = __shfl_sync(mask, loaded[k], gid);
7250 tmp = __shfl_sync(mask, loaded[k], gid + 8);
7259 for (
int ki = 0; ki < 4; ki++)
7261 s[ki][0] = __shfl_sync(mask, loaded[ki], tid * 4);
7262 s[ki][1] = __shfl_sync(mask, loaded[ki], tid * 4 + 1);
7263 s[ki][2] = __shfl_sync(mask, loaded[ki], tid * 4 + 2);
7264 s[ki][3] = __shfl_sync(mask, loaded[ki], tid * 4 + 3);
7267 unsigned k = gid >> 1;
7268 unsigned shift = (gid & 1) * 16;
7270 uint16_t h0 = (uint16_t)(s[k][0] >> shift);
7271 uint16_t h1 = (uint16_t)(s[k][2] >> shift);
7272 regs[0] = (uint32_t)h0 | ((uint32_t)h1 << 16);
7274 h0 = (uint16_t)(s[k][1] >> shift);
7275 h1 = (uint16_t)(s[k][3] >> shift);
7276 regs[1] = (uint32_t)h0 | ((uint32_t)h1 << 16);
7282template<
typename ElemT, Layout layout>
7283struct MMALoadHelper<ElemT, layout, 16, 16, MatrixUse::MatrixC>
7285 static __device__
inline void exec(
7287 const ElemT* buffer,
7293 constexpr int regsPerSubTile = (
sizeof(ElemT) == 4) ? 4 : 2;
7294 MMALoadHelper<ElemT, layout, 16, 8, MatrixUse::MatrixC>::exec(
7301 if constexpr (layout == Layout::RowMajor)
7302 MMALoadHelper<ElemT, layout, 16, 8, MatrixUse::MatrixC>::exec(
7303 regs + regsPerSubTile,
7310 MMALoadHelper<ElemT, layout, 16, 8, MatrixUse::MatrixC>::exec(
7311 regs + regsPerSubTile,
7312 buffer + 8 * stride,
7320template<
typename ElemT, Layout layout,
int Row,
int Col>
7321struct MMALoadHelper<ElemT, layout,
Row, Col, MatrixUse::MatrixD>
7322 : MMALoadHelper<ElemT, layout, Row, Col, MatrixUse::MatrixC>
7326template<
typename ElemT, Layout layout,
int Row,
int Col, MatrixUse use>
7327__device__
inline void mmaLoad(uint32_t* regs,
const void* ptr,
int stride)
7329 const ElemT* buffer =
static_cast<const ElemT*
>(ptr);
7331 asm(
"mov.u32 %0, %%laneid;" :
"=r"(laneid));
7332 unsigned gid = laneid >> 2;
7333 unsigned tid = laneid & 3;
7334 MMALoadHelper<ElemT, layout, Row, Col, use>::exec(regs, buffer, stride, laneid, gid, tid);
7341template<
typename ElemT, Layout layout,
int Row,
int Col>
7342struct MMAStoreHelper;
7344template<
typename ElemT, Layout layout>
7345struct MMAStoreHelper<ElemT, layout, 16, 8>
7347 static __device__
inline void exec(
7349 const uint32_t* regs,
7353 if constexpr (
sizeof(ElemT) == 4)
7355 float* fbuf =
reinterpret_cast<float*
>(buffer);
7356 if constexpr (layout == Layout::RowMajor)
7358 unsigned write_row = laneid >> 1;
7359 unsigned write_side = laneid & 1;
7360 unsigned source_gid = (write_row < 8) ? write_row : (write_row - 8);
7361 unsigned src0 = source_gid * 4 + write_side * 2;
7362 unsigned src1 = src0 + 1;
7364 const uint32_t mask = 0xFFFFFFFF;
7365 uint32_t r0_s0 = __shfl_sync(mask, regs[0], src0);
7366 uint32_t r1_s0 = __shfl_sync(mask, regs[1], src0);
7367 uint32_t r0_s1 = __shfl_sync(mask, regs[0], src1);
7368 uint32_t r1_s1 = __shfl_sync(mask, regs[1], src1);
7369 uint32_t r2_s0 = __shfl_sync(mask, regs[2], src0);
7370 uint32_t r3_s0 = __shfl_sync(mask, regs[3], src0);
7371 uint32_t r2_s1 = __shfl_sync(mask, regs[2], src1);
7372 uint32_t r3_s1 = __shfl_sync(mask, regs[3], src1);
7375 uint32_t* out =
reinterpret_cast<uint32_t*
>(&out_v);
7391 *
reinterpret_cast<uint4*
>(&fbuf[write_row * stride + write_side * 4]) = out_v;
7395 unsigned write_col = laneid >> 2;
7396 unsigned write_chunk = laneid & 3;
7397 unsigned source_tid_val = write_col >> 1;
7398 unsigned gid_base = (write_chunk & 1) * 4;
7399 unsigned reg_idx = (write_chunk < 2) ? (write_col & 1) : (2 + (write_col & 1));
7401 const uint32_t mask = 0xFFFFFFFF;
7404 for (
int r = 0; r < 4; r++)
7406 s[r][0] = __shfl_sync(mask, regs[r], (gid_base + 0) * 4 + source_tid_val);
7407 s[r][1] = __shfl_sync(mask, regs[r], (gid_base + 1) * 4 + source_tid_val);
7408 s[r][2] = __shfl_sync(mask, regs[r], (gid_base + 2) * 4 + source_tid_val);
7409 s[r][3] = __shfl_sync(mask, regs[r], (gid_base + 3) * 4 + source_tid_val);
7413 uint32_t* out =
reinterpret_cast<uint32_t*
>(&out_v);
7414 out[0] = s[reg_idx][0];
7415 out[1] = s[reg_idx][1];
7416 out[2] = s[reg_idx][2];
7417 out[3] = s[reg_idx][3];
7419 *
reinterpret_cast<uint4*
>(&fbuf[write_col * stride + write_chunk * 4]) = out_v;
7424 if constexpr (layout == Layout::RowMajor)
7426 unsigned write_row = laneid & 15;
7428 const uint32_t mask = 0xFFFFFFFF;
7431 for (
int k = 0; k < 4; k++)
7433 s[k][0] = __shfl_sync(mask, regs[0], write_row * 4 + k);
7434 s[k][1] = __shfl_sync(mask, regs[1], write_row * 4 + k);
7438 uint32_t* out =
reinterpret_cast<uint32_t*
>(&out_v);
7454 *
reinterpret_cast<uint4*
>(&buffer[write_row * stride]) = out_v;
7458 unsigned write_col = (laneid & 15) >> 1;
7459 unsigned write_side = laneid & 1;
7460 unsigned source_tid_val = write_col >> 1;
7461 unsigned col_shift = (write_col & 1) * 16;
7463 const uint32_t mask = 0xFFFFFFFF;
7464 uint32_t from_r0[8], from_r1[8];
7466 for (
int k = 0; k < 4; k++)
7468 unsigned src_even = (2 * k) * 4 + source_tid_val;
7469 unsigned src_odd = (2 * k + 1) * 4 + source_tid_val;
7470 from_r0[2 * k] = __shfl_sync(mask, regs[0], src_even);
7471 from_r0[2 * k + 1] = __shfl_sync(mask, regs[0], src_odd);
7472 from_r1[2 * k] = __shfl_sync(mask, regs[1], src_even);
7473 from_r1[2 * k + 1] = __shfl_sync(mask, regs[1], src_odd);
7477 uint32_t* out =
reinterpret_cast<uint32_t*
>(&out_v);
7479 for (
int k = 0; k < 4; k++)
7481 uint32_t val_even = (write_side == 0) ? from_r0[2 * k] : from_r1[2 * k];
7482 uint32_t val_odd = (write_side == 0) ? from_r0[2 * k + 1] : from_r1[2 * k + 1];
7483 uint16_t h0 = (uint16_t)(val_even >> col_shift);
7484 uint16_t h1 = (uint16_t)(val_odd >> col_shift);
7485 out[k] = (uint32_t)h0 | ((uint32_t)h1 << 16);
7488 *
reinterpret_cast<uint4*
>(&buffer[write_col * stride + write_side * 8]) = out_v;
7494template<
typename ElemT, Layout layout>
7495struct MMAStoreHelper<ElemT, layout, 16, 16>
7497 static __device__
inline void exec(
7499 const uint32_t* regs,
7516 "CUDA `Store` on a 16x16 cooperative-matrix fragment with an 8-bit "
7517 "element type (s8 / u8 / e4m3 / e5m2) is not implemented: it would "
7518 "read past the fragment's register array. Use `.equals()` against "
7519 "a known-expected fragment to verify content (see int8-arith.slang) "
7520 "until MatrixA / MatrixB Store is rewritten.");
7521 constexpr int regsPerSubTile = (
sizeof(ElemT) == 4) ? 4 : 2;
7522 MMAStoreHelper<ElemT, layout, 16, 8>::exec(buffer, regs, stride, laneid);
7523 if constexpr (layout == Layout::RowMajor)
7524 MMAStoreHelper<ElemT, layout, 16, 8>::exec(
7526 regs + regsPerSubTile,
7530 MMAStoreHelper<ElemT, layout, 16, 8>::exec(
7531 buffer + 8 * stride,
7532 regs + regsPerSubTile,
7538template<
typename ElemT, Layout layout,
int Row,
int Col>
7539__device__
inline void mmaStore(
void* ptr,
const uint32_t* regs,
int stride)
7541 ElemT* buffer =
static_cast<ElemT*
>(ptr);
7543 asm(
"mov.u32 %0, %%laneid;" :
"=r"(laneid));
7544 MMAStoreHelper<ElemT, layout, Row, Col>::exec(buffer, regs, stride, laneid);
7562struct PackedFp16Traits;
7564#if SLANG_CUDA_ENABLE_HALF
7566struct PackedFp16Traits<
half>
7568 using PairType = __half2;
7569 static __device__ PairType broadcast(
half v) {
return __half2half2(v); }
7573#if SLANG_CUDA_ENABLE_BF16
7575struct PackedFp16Traits<__nv_bfloat16>
7577 using PairType = __nv_bfloat162;
7578 static __device__ PairType broadcast(__nv_bfloat16 v) {
return __bfloat162bfloat162(v); }
7585 static constexpr bool value =
false;
7587#if SLANG_CUDA_ENABLE_HALF
7589struct IsPackedFp16<
half>
7591 static constexpr bool value =
true;
7594#if SLANG_CUDA_ENABLE_BF16
7596struct IsPackedFp16<__nv_bfloat16>
7598 static constexpr bool value =
true;
7603inline unsigned __device__ Pack32Helper(T value);
7605#if SLANG_CUDA_ENABLE_HALF
7607inline unsigned __device__ Pack32Helper<half>(
half value)
7609 return __half_as_ushort(value) | (__half_as_ushort(value) << 16);
7613#if SLANG_CUDA_ENABLE_BF16
7615inline unsigned __device__ Pack32Helper<__nv_bfloat16>(__nv_bfloat16 value)
7617 unsigned short bits = __bfloat16_as_ushort(value);
7618 return (
unsigned)bits | ((unsigned)bits << 16);
7623inline unsigned __device__ Pack32Helper<float>(
float value)
7625 return __float_as_uint(value);
7629inline unsigned __device__ Pack32Helper<int>(
int value)
7631 return (
unsigned)
value;
7634inline unsigned __device__ Pack32Helper<char>(
char value)
7639 unsigned bits = (unsigned)(
unsigned char)
value;
7640 return (bits << 24) | (bits << 16) | (bits << 8) | bits;
7643inline unsigned __device__ Pack32Helper<unsigned char>(
unsigned char value)
7645 unsigned bits = (unsigned)value;
7646 return (bits << 24) | (bits << 16) | (bits << 8) | bits;
7649#if SLANG_CUDA_ENABLE_FP8
7651inline unsigned __device__ Pack32Helper<__nv_fp8_e4m3>(__nv_fp8_e4m3 value)
7654 unsigned bits = (unsigned)*
reinterpret_cast<const uint8_t*
>(&value);
7655 return (bits << 24) | (bits << 16) | (bits << 8) | bits;
7658inline unsigned __device__ Pack32Helper<__nv_fp8_e5m2>(__nv_fp8_e5m2 value)
7660 unsigned bits = (unsigned)*
reinterpret_cast<const uint8_t*
>(&value);
7661 return (bits << 24) | (bits << 16) | (bits << 8) | bits;
7672template<
typename T,
int M,
int N,
int K, MatrixUse R>
7675 __device__ WmmaFragment() {}
7676 __device__ WmmaFragment(T scalarValue) { fill(scalarValue); }
7678 typedef WmmaFragment<T, M, N, K, R> This;
7679 template<Layout layout>
7682 Store<layout>(buffer.
data, element, stride);
7685 template<Layout layout>
7688 return Load<layout>(buffer.
data, element, stride);
7693 void __device__ fill(T value)
7695 unsigned packed = Pack32Helper(value);
7696 constexpr int nregs = RegisterCount<T, M, N, K, R>::value;
7698 for (
int i = 0; i < nregs; i++)
7705 void __device__ clear()
7708 for (
int i = 0; i < RegsCount; i++)
7715 if constexpr (IsPackedFp16<T>::value)
7717 using PairT =
typename PackedFp16Traits<T>::PairType;
7718 PairT bv = PackedFp16Traits<T>::broadcast(b);
7720 for (
int i = 0; i < RegsCount; i++)
7722 PairT r = *
reinterpret_cast<const PairT*
>(®s[i]) * bv;
7723 memcpy(&result.regs[i], &r, 4);
7728 for (
int i = 0; i < GetLength(); i++)
7729 result.set(i,
get(i) * b);
7734 __device__ This
operator*(
const This& b)
7737 if constexpr (IsPackedFp16<T>::value)
7739 using PairT =
typename PackedFp16Traits<T>::PairType;
7741 for (
int i = 0; i < RegsCount; i++)
7743 PairT r = *
reinterpret_cast<const PairT*
>(®s[i]) *
7744 *
reinterpret_cast<const PairT*
>(&b.regs[i]);
7745 memcpy(&result.regs[i], &r, 4);
7750 for (
int i = 0; i < GetLength(); i++)
7751 result.set(i,
get(i) * b.get(i));
7756 __device__ This operator/(
const This& other)
7759 if constexpr (IsPackedFp16<T>::value)
7761 using PairT =
typename PackedFp16Traits<T>::PairType;
7763 for (
int i = 0; i < RegsCount; i++)
7765 PairT r = *
reinterpret_cast<const PairT*
>(®s[i]) /
7766 *
reinterpret_cast<const PairT*
>(&other.regs[i]);
7767 memcpy(&result.regs[i], &r, 4);
7772 for (
int i = 0; i < GetLength(); i++)
7773 result.set(i,
get(i) / other.get(i));
7778 __device__ This operator-(
const This& other)
7781 if constexpr (IsPackedFp16<T>::value)
7783 using PairT =
typename PackedFp16Traits<T>::PairType;
7785 for (
int i = 0; i < RegsCount; i++)
7787 PairT r = *
reinterpret_cast<const PairT*
>(®s[i]) -
7788 *
reinterpret_cast<const PairT*
>(&other.regs[i]);
7789 memcpy(&result.regs[i], &r, 4);
7794 for (
int i = 0; i < GetLength(); i++)
7795 result.set(i,
get(i) - other.get(i));
7800 __device__ This operator-()
7803 if constexpr (IsPackedFp16<T>::value)
7805 using PairT =
typename PackedFp16Traits<T>::PairType;
7807 for (
int i = 0; i < RegsCount; i++)
7809 PairT r = -*
reinterpret_cast<const PairT*
>(®s[i]);
7810 memcpy(&result.regs[i], &r, 4);
7815 for (
int i = 0; i < GetLength(); i++)
7816 result.set(i, -
get(i));
7821 __device__ This
operator+(
const This& other)
7824 if constexpr (IsPackedFp16<T>::value)
7826 using PairT =
typename PackedFp16Traits<T>::PairType;
7828 for (
int i = 0; i < RegsCount; i++)
7830 PairT r = *
reinterpret_cast<const PairT*
>(®s[i]) +
7831 *
reinterpret_cast<const PairT*
>(&other.regs[i]);
7832 memcpy(&result.regs[i], &r, 4);
7837 for (
int i = 0; i < GetLength(); i++)
7838 result.set(i,
get(i) + other.get(i));
7843 __device__ This operator%(
const This& other)
7846 if constexpr (IsPackedFp16<T>::value)
7851 for (
int i = 0; i < GetLength(); i++)
7853 float a =
static_cast<float>(
get(i));
7854 float b =
static_cast<float>(other.get(i));
7855 result.set(i, T(fmodf(a, b)));
7860 for (
int i = 0; i < GetLength(); i++)
7861 result.set(i,
get(i) % other.get(i));
7868 __device__
bool operator==(
const This& other)
const
7870 for (
int i = 0; i < GetLength(); i++)
7872 if (
get(i) != other.get(i))
7881 __device__
bool operator<(
const This& other)
const
7883 for (
int i = 0; i < GetLength(); i++)
7885 if (
get(i) < other.get(i))
7887 if (
get(i) > other.get(i))
7893 __device__
bool operator<=(
const This& other)
const
7895 for (
int i = 0; i < GetLength(); i++)
7897 if (
get(i) < other.get(i))
7899 if (
get(i) > other.get(i))
7905 template<
typename U, MatrixUse R2>
7906 __device__
void copyFrom(
const WmmaFragment<U, M, N, K, R2>& other)
7908 constexpr int OtherRegsCount = WmmaFragment<U, M, N, K, R2>::RegsCount;
7909 if constexpr (IsSameType<T, U>::value && RegsCount == OtherRegsCount)
7911 if constexpr (RegsCount == 2)
7912 *
reinterpret_cast<uint2*
>(regs) = *
reinterpret_cast<const uint2*
>(other.regs);
7913 else if constexpr (RegsCount == 4)
7914 *
reinterpret_cast<uint4*
>(regs) = *
reinterpret_cast<const uint4*
>(other.regs);
7915 else if constexpr (RegsCount == 8)
7917 *
reinterpret_cast<uint4*
>(regs) = *
reinterpret_cast<const uint4*
>(other.regs);
7918 *
reinterpret_cast<uint4*
>(regs + 4) =
7919 *
reinterpret_cast<const uint4*
>(other.regs + 4);
7924 for (
int i = 0; i < RegsCount; i++)
7925 regs[i] = other.regs[i];
7930 for (
int i = 0; i < GetLength(); i++)
7931 set(i,
static_cast<T
>(other.get(i)));
7941 __device__ T
get(
int index)
const
7943 if constexpr (
sizeof(T) == 4)
7947 memcpy(&v, ®s[index], 4);
7950 else if constexpr (
sizeof(T) == 2)
7954 int regIndex = index / 2;
7955 int elementOffset = index % 2;
7956 int bitOffset = elementOffset * 16;
7957 uint32_t extracted = (regs[regIndex] >> bitOffset) & 0xFFFF;
7958 uint16_t value16 =
static_cast<uint16_t
>(extracted);
7960 memcpy(&v, &value16, 2);
7963 else if constexpr (
sizeof(T) == 1)
7967 int regIndex = index / 4;
7968 int elementOffset = index % 4;
7969 int bitOffset = elementOffset * 8;
7970 uint32_t extracted = (regs[regIndex] >> bitOffset) & 0xFF;
7971 uint8_t value8 =
static_cast<uint8_t
>(extracted);
7972 return *
reinterpret_cast<const T*
>(&value8);
7977 __device__
void set(
int index, T value)
7979 if constexpr (
sizeof(T) == 4)
7982 memcpy(®s[index], &value, 4);
7984 else if constexpr (
sizeof(T) == 2)
7987 int regIndex = index / 2;
7988 int elementOffset = index % 2;
7989 int bitOffset = elementOffset * 16;
7990 uint32_t mask = 0xFFFF;
7992 memcpy(&value16, &value, 2);
7995 regs[regIndex] &= ~(mask << bitOffset);
7998 regs[regIndex] |= (
static_cast<uint32_t
>(value16) << bitOffset);
8000 else if constexpr (
sizeof(T) == 1)
8003 int regIndex = index / 4;
8004 int elementOffset = index % 4;
8005 int bitOffset = elementOffset * 8;
8006 uint32_t mask = 0xFF;
8007 uint8_t value8 = *
reinterpret_cast<const uint8_t*
>(&
value);
8010 regs[regIndex] &= ~(mask << bitOffset);
8013 regs[regIndex] |= (
static_cast<uint32_t
>(value8) << bitOffset);
8017 __device__
void FragmentWrite(
int regIndex,
unsigned value) { regs[regIndex] =
value; }
8018 __device__
unsigned FragmentRead(
int regIndex)
const {
return regs[regIndex]; }
8028 __device__
void ChangeMajor()
8030 if constexpr (RegsCount == 4 && (
R == MatrixUse::MatrixA ||
R == MatrixUse::MatrixB))
8032 uint32_t t0, t1, t2, t3;
8033 asm volatile(
"movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" :
"=r"(t0) :
"r"(regs[0]));
8034 asm volatile(
"movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" :
"=r"(t1) :
"r"(regs[1]));
8035 asm volatile(
"movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" :
"=r"(t2) :
"r"(regs[2]));
8036 asm volatile(
"movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" :
"=r"(t3) :
"r"(regs[3]));
8044 template<Layout layout>
8045 void __device__ Store(T* buffer,
uint element,
uint stride)
8047 (void)RegisterCount<T, M, N, K, R>::value;
8048 mmaStore<T, layout, M, N>(buffer + element, regs, stride);
8051 template<Layout layout,
typename U>
8052 void __device__ Store(U* buffer,
uint stride)
8054 (void)RegisterCount<T, M, N, K, R>::value;
8055 mmaStore<T, layout, M, N>(buffer, regs, stride *
sizeof(U) /
sizeof(T));
8058 template<Layout layout>
8059 static This __device__ Load(T* buffer,
uint element,
uint stride)
8061 WmmaFragment<T, M, N, K, R> fragment;
8062 (void)RegisterCount<T, M, N, K, R>::value;
8063 mmaLoad<T, layout, M, N, R>(fragment.regs, buffer + element, stride);
8067 template<Layout layout,
typename U>
8068 static This __device__ Load(U* buffer,
uint stride)
8070 WmmaFragment<T, M, N, K, R> fragment;
8071 (void)RegisterCount<T, M, N, K, R>::value;
8072 mmaLoad<T, layout, M, N, R>(fragment.regs, buffer, stride *
sizeof(U) /
sizeof(T));
8076 static constexpr __device__ uint32_t GetLength() {
return This::elements_per_thread; }
8077 static constexpr __device__
int GetPackedFragmentCount() {
return RegsCount; }
8079 using ElementType = T;
8080 static constexpr int m_M = M;
8081 static constexpr int m_N = N;
8082 static constexpr int m_K = K;
8085 static constexpr int RegsCount = RegisterCount<T, M, N, K, R>::value;
8086 unsigned regs[RegsCount] = {};
8088 static constexpr uint32_t elements_per_thread = RegsCount * (4 /
sizeof(T));
8108template<
typename InputT,
typename AccumT,
int M,
int N,
int K>
8109__device__
inline void mma(uint32_t* d,
const uint32_t* a,
const uint32_t* b,
const uint32_t* c);
8111#if SLANG_CUDA_ENABLE_HALF
8114__device__
inline void mma<half, float, 16, 8, 16>(
8120 asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
8121 "{%0, %1, %2, %3}, "
8122 "{%4, %5, %6, %7}, "
8124 "{%10, %11, %12, %13};\n"
8125 :
"=r"(d[0]),
"=r"(d[1]),
"=r"(d[2]),
"=r"(d[3])
8139__device__
inline void mma<half, half, 16, 8, 16>(
8146 "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "
8148 "{%2, %3, %4, %5}, "
8151 :
"=r"(d[0]),
"=r"(d[1])
8152 :
"r"(a[0]),
"r"(a[1]),
"r"(a[2]),
"r"(a[3]),
"r"(b[0]),
"r"(b[1]),
"r"(c[0]),
"r"(c[1]));
8157#if SLANG_CUDA_ENABLE_BF16
8161__device__
inline void mma<__nv_bfloat16, float, 16, 8, 16>(
8167 asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
8168 "{%0, %1, %2, %3}, "
8169 "{%4, %5, %6, %7}, "
8171 "{%10, %11, %12, %13};\n"
8172 :
"=r"(d[0]),
"=r"(d[1]),
"=r"(d[2]),
"=r"(d[3])
8195template<
typename CType,
typename DType,
int M,
int N,
int K>
8196struct Fp16MMAHelper;
8198#if SLANG_CUDA_ENABLE_HALF
8201struct Fp16MMAHelper<
half,
half, 16, 16, 16>
8203 __device__
static void eval(
8204 WmmaFragment<half, 16, 16, 16, MatrixC>& d,
8205 const WmmaFragment<half, 16, 16, 16, MatrixUse::MatrixA>& a,
8206 const WmmaFragment<half, 16, 16, 16, MatrixUse::MatrixB>& b,
8207 const WmmaFragment<half, 16, 16, 16, MatrixC>& c)
8209 mma<half, half, 16, 8, 16>(d.regs, a.regs, b.regs, c.regs);
8210 mma<half, half, 16, 8, 16>(d.regs + 2, a.regs, b.regs + 2, c.regs + 2);
8215struct Fp16MMAHelper<float, float, 16, 16, 16>
8217 __device__
static void eval(
8218 WmmaFragment<float, 16, 16, 16, MatrixC>& d,
8219 const WmmaFragment<half, 16, 16, 16, MatrixUse::MatrixA>& a,
8220 const WmmaFragment<half, 16, 16, 16, MatrixUse::MatrixB>& b,
8221 const WmmaFragment<float, 16, 16, 16, MatrixC>& c)
8223 mma<half, float, 16, 8, 16>(d.regs, a.regs, b.regs, c.regs);
8224 mma<half, float, 16, 8, 16>(d.regs + 4, a.regs, b.regs + 2, c.regs + 4);
8229struct Fp16MMAHelper<
half, float, 16, 16, 16>
8231 __device__
static void eval(
8232 WmmaFragment<float, 16, 16, 16, MatrixC>& d,
8233 const WmmaFragment<half, 16, 16, 16, MatrixUse::MatrixA>& a,
8234 const WmmaFragment<half, 16, 16, 16, MatrixUse::MatrixB>& b,
8235 const WmmaFragment<half, 16, 16, 16, MatrixC>& c)
8239 for (
int i = 0; i < 4; i++)
8241 half lo = __ushort_as_half((
unsigned short)(c.regs[i] & 0xFFFF));
8242 half hi = __ushort_as_half((
unsigned short)(c.regs[i] >> 16));
8243 fc[2 * i] = __float_as_uint(__half2float(lo));
8244 fc[2 * i + 1] = __float_as_uint(__half2float(hi));
8246 mma<half, float, 16, 8, 16>(d.regs, a.regs, b.regs, fc);
8247 mma<half, float, 16, 8, 16>(d.regs + 4, a.regs, b.regs + 2, fc + 4);
8252struct Fp16MMAHelper<float,
half, 16, 16, 16>
8254 __device__
static void eval(
8255 WmmaFragment<half, 16, 16, 16, MatrixC>& d,
8256 const WmmaFragment<half, 16, 16, 16, MatrixUse::MatrixA>& a,
8257 const WmmaFragment<half, 16, 16, 16, MatrixUse::MatrixB>& b,
8258 const WmmaFragment<float, 16, 16, 16, MatrixC>& c)
8261 mma<half, float, 16, 8, 16>(fd, a.regs, b.regs, c.regs);
8262 mma<half, float, 16, 8, 16>(fd + 4, a.regs, b.regs + 2, c.regs + 4);
8264 for (
int i = 0; i < 4; i++)
8266 half lo = __float2half(__uint_as_float(fd[2 * i]));
8267 half hi = __float2half(__uint_as_float(fd[2 * i + 1]));
8268 d.regs[i] = (uint32_t)__half_as_ushort(lo) | ((uint32_t)__half_as_ushort(hi) << 16);
8282template<
typename CType,
typename DType,
int M,
int N,
int K>
8283struct Bf16MMAHelper;
8285#if SLANG_CUDA_ENABLE_BF16
8288struct Bf16MMAHelper<float, float, 16, 16, 16>
8290 __device__
static void eval(
8291 WmmaFragment<float, 16, 16, 16, MatrixC>& d,
8292 const WmmaFragment<__nv_bfloat16, 16, 16, 16, MatrixUse::MatrixA>& a,
8293 const WmmaFragment<__nv_bfloat16, 16, 16, 16, MatrixUse::MatrixB>& b,
8294 const WmmaFragment<float, 16, 16, 16, MatrixC>& c)
8296 mma<__nv_bfloat16, float, 16, 8, 16>(d.regs, a.regs, b.regs, c.regs);
8297 mma<__nv_bfloat16, float, 16, 8, 16>(d.regs + 4, a.regs, b.regs + 2, c.regs + 4);
8318__device__
inline void mma<char, int32_t, 16, 8, 16>(
8324 asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 "
8325 "{%0, %1, %2, %3}, "
8328 "{%7, %8, %9, %10};\n"
8329 :
"=r"(d[0]),
"=r"(d[1]),
"=r"(d[2]),
"=r"(d[3])
8330 :
"r"(a[0]),
"r"(a[1]),
"r"(b[0]),
"r"(c[0]),
"r"(c[1]),
"r"(c[2]),
"r"(c[3]));
8335__device__
inline void mma<unsigned char, int32_t, 16, 8, 16>(
8341 asm volatile(
"mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 "
8342 "{%0, %1, %2, %3}, "
8345 "{%7, %8, %9, %10};\n"
8346 :
"=r"(d[0]),
"=r"(d[1]),
"=r"(d[2]),
"=r"(d[3])
8347 :
"r"(a[0]),
"r"(a[1]),
"r"(b[0]),
"r"(c[0]),
"r"(c[1]),
"r"(c[2]),
"r"(c[3]));
8353template<
typename InputT,
typename AccumT,
int M,
int N,
int K>
8354__device__
inline void mma_sat(
8361__device__
inline void mma_sat<char, int32_t, 16, 8, 16>(
8367 asm volatile(
"mma.sync.aligned.m16n8k16.row.col.satfinite.s32.s8.s8.s32 "
8368 "{%0, %1, %2, %3}, "
8371 "{%7, %8, %9, %10};\n"
8372 :
"=r"(d[0]),
"=r"(d[1]),
"=r"(d[2]),
"=r"(d[3])
8373 :
"r"(a[0]),
"r"(a[1]),
"r"(b[0]),
"r"(c[0]),
"r"(c[1]),
"r"(c[2]),
"r"(c[3]));
8377__device__
inline void mma_sat<unsigned char, int32_t, 16, 8, 16>(
8383 asm volatile(
"mma.sync.aligned.m16n8k16.row.col.satfinite.s32.u8.u8.s32 "
8384 "{%0, %1, %2, %3}, "
8387 "{%7, %8, %9, %10};\n"
8388 :
"=r"(d[0]),
"=r"(d[1]),
"=r"(d[2]),
"=r"(d[3])
8389 :
"r"(a[0]),
"r"(a[1]),
"r"(b[0]),
"r"(c[0]),
"r"(c[1]),
"r"(c[2]),
"r"(c[3]));
8401template<
typename AInputT,
typename CType,
typename DType,
int M,
int N,
int K,
bool Saturating>
8402struct Int8MMAHelper;
8404template<
typename AInputT,
bool Saturating>
8405struct Int8MMAHelper<AInputT, int32_t, int32_t, 16, 16, 16, Saturating>
8407 __device__
static void eval(
8408 WmmaFragment<int32_t, 16, 16, 16, MatrixC>& d,
8409 const WmmaFragment<AInputT, 16, 16, 16, MatrixUse::MatrixA>& a,
8410 const WmmaFragment<AInputT, 16, 16, 16, MatrixUse::MatrixB>& b,
8411 const WmmaFragment<int32_t, 16, 16, 16, MatrixC>& c)
8413 if constexpr (Saturating)
8415 mma_sat<AInputT, int32_t, 16, 8, 16>(d.regs, a.regs, b.regs, c.regs);
8416 mma_sat<AInputT, int32_t, 16, 8, 16>(d.regs + 4, a.regs, b.regs + 1, c.regs + 4);
8420 mma<AInputT, int32_t, 16, 8, 16>(d.regs, a.regs, b.regs, c.regs);
8421 mma<AInputT, int32_t, 16, 8, 16>(d.regs + 4, a.regs, b.regs + 1, c.regs + 4);
8438#if SLANG_CUDA_ENABLE_FP8
8441__device__
inline void mma<__nv_fp8_e4m3, float, 16, 8, 16>(
8447 asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 "
8448 "{%0, %1, %2, %3}, "
8451 "{%7, %8, %9, %10};\n"
8452 :
"=r"(d[0]),
"=r"(d[1]),
"=r"(d[2]),
"=r"(d[3])
8453 :
"r"(a[0]),
"r"(a[1]),
"r"(b[0]),
"r"(c[0]),
"r"(c[1]),
"r"(c[2]),
"r"(c[3]));
8457__device__
inline void mma<__nv_fp8_e5m2, float, 16, 8, 16>(
8463 asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.e5m2.e5m2.f32 "
8464 "{%0, %1, %2, %3}, "
8467 "{%7, %8, %9, %10};\n"
8468 :
"=r"(d[0]),
"=r"(d[1]),
"=r"(d[2]),
"=r"(d[3])
8469 :
"r"(a[0]),
"r"(a[1]),
"r"(b[0]),
"r"(c[0]),
"r"(c[1]),
"r"(c[2]),
"r"(c[3]));
8473__device__
inline void mma<__nv_fp8_e4m3, half, 16, 8, 16>(
8479 asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f16.e4m3.e4m3.f16 "
8484 :
"=r"(d[0]),
"=r"(d[1])
8485 :
"r"(a[0]),
"r"(a[1]),
"r"(b[0]),
"r"(c[0]),
"r"(c[1]));
8489__device__
inline void mma<__nv_fp8_e5m2, half, 16, 8, 16>(
8495 asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f16.e5m2.e5m2.f16 "
8500 :
"=r"(d[0]),
"=r"(d[1])
8501 :
"r"(a[0]),
"r"(a[1]),
"r"(b[0]),
"r"(c[0]),
"r"(c[1]));
8518template<
typename AInputT,
typename CType,
typename DType,
int M,
int N,
int K>
8521#if SLANG_CUDA_ENABLE_FP8
8523template<
typename AInputT>
8524struct Fp8MMAHelper<AInputT, float, float, 16, 16, 16>
8526 __device__
static void eval(
8527 WmmaFragment<float, 16, 16, 16, MatrixC>& d,
8528 const WmmaFragment<AInputT, 16, 16, 16, MatrixUse::MatrixA>& a,
8529 const WmmaFragment<AInputT, 16, 16, 16, MatrixUse::MatrixB>& b,
8530 const WmmaFragment<float, 16, 16, 16, MatrixC>& c)
8532 mma<AInputT, float, 16, 8, 16>(d.regs, a.regs, b.regs, c.regs);
8533 mma<AInputT, float, 16, 8, 16>(d.regs + 4, a.regs, b.regs + 1, c.regs + 4);
8537template<
typename AInputT>
8538struct Fp8MMAHelper<AInputT,
half,
half, 16, 16, 16>
8540 __device__
static void eval(
8541 WmmaFragment<half, 16, 16, 16, MatrixC>& d,
8542 const WmmaFragment<AInputT, 16, 16, 16, MatrixUse::MatrixA>& a,
8543 const WmmaFragment<AInputT, 16, 16, 16, MatrixUse::MatrixB>& b,
8544 const WmmaFragment<half, 16, 16, 16, MatrixC>& c)
8546 mma<AInputT, half, 16, 8, 16>(d.regs, a.regs, b.regs, c.regs);
8547 mma<AInputT, half, 16, 8, 16>(d.regs + 2, a.regs, b.regs + 1, c.regs + 2);
8570 bool saturatingAccumulation>
8571WmmaFragment<DType, M, N, K, MatrixC> __device__ coopMatMulAdd(
8572 WmmaFragment<AType, M, N, K, MatrixUse::MatrixA> matA,
8573 WmmaFragment<BType, M, N, K, MatrixUse::MatrixB> matB,
8574 WmmaFragment<CType, M, N, K, MatrixUse::MatrixC> matC)
8576 WmmaFragment<DType, M, N, K, MatrixC> matD;
8577 if constexpr (IsSameType<AType, char>::value || IsSameType<AType, unsigned char>::value)
8579 Int8MMAHelper<AType, CType, DType, M, N, K, saturatingAccumulation>::eval(
8585#if SLANG_CUDA_ENABLE_FP8
8587 IsSameType<AType, __nv_fp8_e4m3>::value || IsSameType<AType, __nv_fp8_e5m2>::value)
8589 Fp8MMAHelper<AType, CType, DType, M, N, K>::eval(matD, matA, matB, matC);
8592#if SLANG_CUDA_ENABLE_BF16
8593 else if constexpr (IsSameType<AType, __nv_bfloat16>::value)
8595 Bf16MMAHelper<CType, DType, M, N, K>::eval(matD, matA, matB, matC);
8600 Fp16MMAHelper<CType, DType, M, N, K>::eval(matD, matA, matB, matC);
bool operator==(const json_pointer< RefStringTypeLhs > &lhs, const json_pointer< RefStringTypeRhs > &rhs) noexcept
Definition json.hpp:14737
bool operator<(const json_pointer< RefStringTypeLhs > &lhs, const json_pointer< RefStringTypeRhs > &rhs) noexcept
Definition json.hpp:14787
@ value
the parser finished reading a JSON value
auto get(const nlohmann::detail::iteration_proxy_value< IteratorType > &i) -> decltype(i.key())
Definition json.hpp:5342
Direction
Definition Direction.h:18
constexpr FloatOptional operator+(FloatOptional lhs, FloatOptional rhs)
Definition FloatOptional.h:56
uint8_t Type
Definition slang-gfx.h:1359
#define SLANG_FORCE_INLINE
Definition slang-cpp-prelude.h:286
SLANG_FORCE_INLINE int16_t F16_asint(half h)
Definition slang-cpp-scalar-intrinsics.h:732
SLANG_FORCE_INLINE half F16_exp(half f)
Definition slang-cpp-scalar-intrinsics.h:834
SLANG_FORCE_INLINE half F16_acosh(half f)
Definition slang-cpp-scalar-intrinsics.h:804
SLANG_FORCE_INLINE half F16_fmod(half a, half b)
Definition slang-cpp-scalar-intrinsics.h:895
SLANG_FORCE_INLINE half F16_tan(half f)
Definition slang-cpp-scalar-intrinsics.h:764
SLANG_FORCE_INLINE half F16_tanh(half f)
Definition slang-cpp-scalar-intrinsics.h:794
SLANG_FORCE_INLINE half F16_atan2(half a, half b)
Definition slang-cpp-scalar-intrinsics.h:905
SLANG_FORCE_INLINE bool F16_isnan(half f)
Definition slang-cpp-scalar-intrinsics.h:854
SLANG_FORCE_INLINE half F16_frac(half h)
Definition slang-cpp-scalar-intrinsics.h:951
SLANG_FORCE_INLINE bool F16_isinf(half f)
Definition slang-cpp-scalar-intrinsics.h:866
SLANG_FORCE_INLINE half F16_frexp(half x, int *e)
Definition slang-cpp-scalar-intrinsics.h:910
SLANG_FORCE_INLINE bool F16_isfinite(half f)
Definition slang-cpp-scalar-intrinsics.h:860
SLANG_FORCE_INLINE half F16_pow(half a, half b)
Definition slang-cpp-scalar-intrinsics.h:890
SLANG_FORCE_INLINE half F16_cos(half f)
Definition slang-cpp-scalar-intrinsics.h:759
SLANG_FORCE_INLINE half F16_sinh(half f)
Definition slang-cpp-scalar-intrinsics.h:784
SLANG_FORCE_INLINE half F16_asin(half f)
Definition slang-cpp-scalar-intrinsics.h:769
SLANG_FORCE_INLINE half F16_modf(half x, half *ip)
Definition slang-cpp-scalar-intrinsics.h:915
SLANG_FORCE_INLINE half F16_sin(half f)
Definition slang-cpp-scalar-intrinsics.h:754
SLANG_FORCE_INLINE half F16_exp2(half f)
Definition slang-cpp-scalar-intrinsics.h:829
SLANG_FORCE_INLINE half F16_trunc(half f)
Definition slang-cpp-scalar-intrinsics.h:844
SLANG_FORCE_INLINE half F16_log2(half f)
Definition slang-cpp-scalar-intrinsics.h:814
SLANG_FORCE_INLINE int F16_sign(half f)
Definition slang-cpp-scalar-intrinsics.h:943
SLANG_FORCE_INLINE half F16_round(half f)
Definition slang-cpp-scalar-intrinsics.h:749
SLANG_FORCE_INLINE half F16_max(half a, half b)
Definition slang-cpp-scalar-intrinsics.h:881
SLANG_FORCE_INLINE half F16_log(half f)
Definition slang-cpp-scalar-intrinsics.h:819
SLANG_FORCE_INLINE half F16_acos(half f)
Definition slang-cpp-scalar-intrinsics.h:774
SLANG_FORCE_INLINE half F16_atanh(half f)
Definition slang-cpp-scalar-intrinsics.h:809
SLANG_FORCE_INLINE half F16_sqrt(half f)
Definition slang-cpp-scalar-intrinsics.h:849
SLANG_FORCE_INLINE half F16_cosh(half f)
Definition slang-cpp-scalar-intrinsics.h:789
SLANG_FORCE_INLINE half F16_asinh(half f)
Definition slang-cpp-scalar-intrinsics.h:799
SLANG_FORCE_INLINE half F16_abs(half f)
Definition slang-cpp-scalar-intrinsics.h:839
SLANG_FORCE_INLINE uint16_t F16_asuint(half h)
Definition slang-cpp-scalar-intrinsics.h:725
SLANG_FORCE_INLINE half F16_rsqrt(half f)
Definition slang-cpp-scalar-intrinsics.h:938
SLANG_FORCE_INLINE half F16_fma(half a, half b, half c)
Definition slang-cpp-scalar-intrinsics.h:923
SLANG_FORCE_INLINE half F16_min(half a, half b)
Definition slang-cpp-scalar-intrinsics.h:872
SLANG_FORCE_INLINE half F16_atan(half f)
Definition slang-cpp-scalar-intrinsics.h:779
SLANG_FORCE_INLINE half F16_floor(half f)
Definition slang-cpp-scalar-intrinsics.h:744
SLANG_FORCE_INLINE half F16_log10(half f)
Definition slang-cpp-scalar-intrinsics.h:824
SLANG_FORCE_INLINE half F16_remainder(half a, half b)
Definition slang-cpp-scalar-intrinsics.h:900
SLANG_FORCE_INLINE half F16_ceil(half f)
Definition slang-cpp-scalar-intrinsics.h:739
uint32_t uint
Definition slang-cpp-types-core.h:272
SLANG_FORCE_INLINE T _slang_vector_get_element(Vector< T, N > x, int index)
Definition slang-cpp-types-core.h:241
SLANG_FORCE_INLINE const T * _slang_vector_get_element_ptr(const Vector< T, N > *x, int index)
Definition slang-cpp-types-core.h:247
SLANG_FORCE_INLINE SLANG_CUDA_CALL uintptr_t UPTR_max(uintptr_t a, uintptr_t b)
Definition slang-cuda-prelude.h:2601
SLANG_FORCE_INLINE SLANG_CUDA_CALL Vector< T, n > _slang_vector_reshape(const Vector< OtherT, m > other)
Definition slang-cuda-prelude.h:914
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_fma(double a, double b, double c)
Definition slang-cuda-prelude.h:2341
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_atan2(float a, float b)
Definition slang-cuda-prelude.h:2140
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_remainder(float a, float b)
Definition slang-cuda-prelude.h:2136
#define SLANG_WAVE_ROTATE_IMPL(T)
Definition slang-cuda-prelude.h:6311
SLANG_FORCE_INLINE SLANG_CUDA_CALL intptr_t IPTR_max(intptr_t a, intptr_t b)
Definition slang-cuda-prelude.h:2584
__device__ T _waveOpCopy(T *dst, const T *src)
Definition slang-cuda-prelude.h:3959
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_abs(double f)
Definition slang-cuda-prelude.h:2250
__forceinline__ __device__ WarpMask _getMultiPrefixMask(int mask)
Definition slang-cuda-prelude.h:3132
__device__ T _wavePrefixInvertableScalar(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3861
#define SLANG_BOUND_CHECK_FIXED_ARRAY(index, count)
Definition slang-cuda-prelude.h:117
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_frexp(float x, int *e)
Definition slang-cuda-prelude.h:2145
__inline__ __device__ uint4 _waveMatchScalar(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4288
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_fmod(float a, float b)
Definition slang-cuda-prelude.h:2132
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_tanh(double f)
Definition slang-cuda-prelude.h:2226
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_log10(float f)
Definition slang-cuda-prelude.h:2238
#define SLANG_VECTOR_GET_ELEMENT(T)
Definition slang-cuda-prelude.h:415
SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf1Dwrite_convert(T v, cudaSurfaceObject_t surfObj, int x, cudaSurfaceBoundaryMode boundaryMode)
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U64_firstbitlow(uint64_t v)
Definition slang-cuda-prelude.h:2515
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool F64_isfinite(double f)
Definition slang-cuda-prelude.h:2279
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_acos(float f)
Definition slang-cuda-prelude.h:2029
__inline__ __device__ T _wavePrefixOr(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4120
__inline__ __device__ uint getAt(dim3 a, int b)
Definition slang-cuda-prelude.h:4309
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t I16_countbits(int16_t v)
Definition slang-cuda-prelude.h:2371
__device__ __forceinline__ longlong atomicCAS(longlong *address, longlong compare, longlong val)
Definition slang-cuda-prelude.h:2715
__inline__ __device__ T _wavePrefixMax(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4189
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U32_firstbitlow(uint32_t v)
Definition slang-cuda-prelude.h:2417
__inline__ __device__ T _waveAndMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3744
__forceinline__ __device__ WarpMask _getActiveMask()
Definition slang-cuda-prelude.h:3126
__inline__ __device__ T _waveMin(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3665
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_round(float f)
Definition slang-cuda-prelude.h:2005
#define SLANG_CUDA_WARP_MASK
Definition slang-cuda-prelude.h:63
__inline__ __device__ T _wavePrefixMaxMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4205
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_max(float a, float b)
Definition slang-cuda-prelude.h:2124
#define SLANG_PRELUDE_ASSERT(x)
Definition slang-cuda-prelude.h:57
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool3 make_bool3(bool x, bool y, bool z)
Definition slang-cuda-prelude.h:740
SLANG_FORCE_INLINE SLANG_CUDA_CALL int64_t U64_abs(uint64_t f)
Definition slang-cuda-prelude.h:2496
__inline__ __device__ T _wavePrefixExclusiveMaxMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4278
__inline__ __device__ T _wavePrefixOrMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4163
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool2 make_bool2(bool x, bool y)
Definition slang-cuda-prelude.h:736
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U32_max(uint32_t a, uint32_t b)
Definition slang-cuda-prelude.h:2389
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t I64_firstbitlow(int64_t v)
Definition slang-cuda-prelude.h:2555
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t I32_asuint(int32_t x)
Definition slang-cuda-prelude.h:2463
SLANG_FORCE_INLINE SLANG_CUDA_CALL T tex1DArrayfetch_int(CUtexObject texObj, int x, int layer, int mip)
__inline__ __device__ T _waveOrMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3736
SLANG_FORCE_INLINE SLANG_CUDA_CALL double U32_asdouble(uint32_t low, uint32_t hi)
Definition slang-cuda-prelude.h:2405
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_remainder(double a, double b)
Definition slang-cuda-prelude.h:2305
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_sqrt(float f)
Definition slang-cuda-prelude.h:2089
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool4 make_bool4(bool x, bool y, bool z, bool w)
Definition slang-cuda-prelude.h:744
unsigned long long CUsurfObject
Definition slang-cuda-prelude.h:185
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_floor(float f)
Definition slang-cuda-prelude.h:2001
__device__ __forceinline__ bool _slang_quadAll(bool expr)
Definition slang-cuda-prelude.h:6411
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_tanh(float f)
Definition slang-cuda-prelude.h:2045
__device__ __forceinline__ longlong atomicAdd(longlong *address, longlong val)
Definition slang-cuda-prelude.h:2723
#define SLANG_CUDA_CALL
Definition slang-cuda-prelude.h:70
static const int kSlangTorchTensorMaxDim
Definition slang-cuda-prelude.h:5914
#define SLANG_WAVE_MAX_SPEC(T, EXCL_VAL)
Definition slang-cuda-prelude.h:3247
#define SLANG_CUDA_VECTOR_FLOAT_OPS(T)
Definition slang-cuda-prelude.h:665
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_atanh(float f)
Definition slang-cuda-prelude.h:2057
size_t NonUniformResourceIndex
Definition slang-cuda-prelude.h:197
__inline__ __device__ T _wavePrefixMin(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4183
SLANG_FORCE_INLINE SLANG_CUDA_CALL uintptr_t UPTR_abs(uintptr_t f)
Definition slang-cuda-prelude.h:2591
__inline__ __device__ T _waveSumMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3768
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_tan(double f)
Definition slang-cuda-prelude.h:2202
__inline__ __device__ T _waveProduct(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3653
int64_t longlong
Definition slang-cuda-prelude.h:301
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_cosh(double f)
Definition slang-cuda-prelude.h:2222
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint64_t U64_reversebits(uint64_t v)
Definition slang-cuda-prelude.h:2529
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_log(double f)
Definition slang-cuda-prelude.h:2234
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool F64_isnan(double f)
Definition slang-cuda-prelude.h:2275
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_min(float a, float b)
Definition slang-cuda-prelude.h:2120
#define GET_VECTOR_TYPE_IMPL_N(T)
Definition slang-cuda-prelude.h:882
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_exp2(float f)
Definition slang-cuda-prelude.h:2073
SLANG_FORCE_INLINE SLANG_CUDA_CALL T tex1Dfetch_int(CUtexObject texObj, int x, int mip)
Definition slang-cuda-prelude.h:6096
__device__ __forceinline__ longlong atomicExch(longlong *address, longlong val)
Definition slang-cuda-prelude.h:2710
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_min(double a, double b)
Definition slang-cuda-prelude.h:2289
__inline__ __device__ T _waveOr(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3635
SLANG_FORCE_INLINE SLANG_CUDA_CALL intptr_t IPTR_min(intptr_t a, intptr_t b)
Definition slang-cuda-prelude.h:2579
SLANG_FORCE_INLINE SLANG_CUDA_CALL void F64_sincos(double f, double *s, double *c)
Definition slang-cuda-prelude.h:2198
SLANG_FORCE_INLINE SLANG_CUDA_CALL int F32_sign(float f)
Definition slang-cuda-prelude.h:2097
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U16_countbits(uint16_t v)
Definition slang-cuda-prelude.h:2363
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_cos(double f)
Definition slang-cuda-prelude.h:2194
__inline__ __device__ uint4 _waveMatchMultiple(WarpMask mask, const T &inVal)
Definition slang-cuda-prelude.h:4295
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U64_countbits(uint64_t v)
Definition slang-cuda-prelude.h:2510
SLANG_FORCE_INLINE SLANG_CUDA_CALL intptr_t IPTR_abs(intptr_t f)
Definition slang-cuda-prelude.h:2574
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U32_firstbithigh(uint32_t v)
Definition slang-cuda-prelude.h:2424
__inline__ __device__ T _wavePrefixMinMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4195
__device__ __forceinline__ void __slang_atomic_reduce_dec(uint32_t *addr, int order)
Definition slang-cuda-prelude.h:2957
#define SLANG_MATRIX_INT_NEG_OP(T)
Definition slang-cuda-prelude.h:1267
__inline__ __device__ T _waveAnd(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3641
SLANG_FORCE_INLINE SLANG_CUDA_CALL float _slang_fmod(float x, float y)
Definition slang-cuda-prelude.h:335
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t F32_asuint(float f)
Definition slang-cuda-prelude.h:2155
#define SLANG_TEX2DFETCH_INT_IMPL(T, dtype, c)
Definition slang-cuda-prelude.h:6143
#define SLANG_WAVE_MIN_SPEC(T, EXCL_VAL)
Definition slang-cuda-prelude.h:3240
__inline__ __device__ TResult slang_bit_cast(TInput val)
Definition slang-cuda-prelude.h:4324
__device__ T _waveOpSetInitial(T *out, const T *val)
Definition slang-cuda-prelude.h:3978
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_sinh(float f)
Definition slang-cuda-prelude.h:2037
unsigned char uchar
Definition slang-cuda-prelude.h:307
__inline__ __device__ T _wavePrefixInclusiveMinMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4243
#define SLANG_FORCE_INLINE
Definition slang-cuda-prelude.h:68
__device__ __forceinline__ bool2 _slang_waveRotate(bool2 value, unsigned int delta)
Definition slang-cuda-prelude.h:6377
#define SLANG_INT_MATRIX_OPS(T)
Definition slang-cuda-prelude.h:1235
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_log2(double f)
Definition slang-cuda-prelude.h:2230
#define SLANG_CUDA_WARP_BITMASK
Definition slang-cuda-prelude.h:65
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_log(float f)
Definition slang-cuda-prelude.h:2065
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U32_min(uint32_t a, uint32_t b)
Definition slang-cuda-prelude.h:2385
SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf3Dwrite_convert(T v, cudaSurfaceObject_t surfObj, int x, int y, int z, cudaSurfaceBoundaryMode boundaryMode)
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_tan(float f)
Definition slang-cuda-prelude.h:2021
__inline__ __device__ T _waveReadFirstMultiple(WarpMask mask, T inVal)
Definition slang-cuda-prelude.h:3826
__device__ void _waveReduceMultiple(WarpMask mask, T *val)
Definition slang-cuda-prelude.h:3586
SLANG_FORCE_INLINE SLANG_CUDA_CALL float make_float(T val)
Definition slang-cuda-prelude.h:330
__inline__ __device__ T _wavePrefixInclusiveMin(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4231
struct __align__(1) bool1
Definition slang-cuda-prelude.h:207
SLANG_FORCE_INLINE SLANG_CUDA_CALL T tex2Dfetch_int(CUtexObject texObj, int x, int y, int mip)
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_asin(float f)
Definition slang-cuda-prelude.h:2025
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U32_countbits(uint32_t v)
Definition slang-cuda-prelude.h:2412
__inline__ __device__ uint3 operator*(uint3 a, dim3 b)
Definition slang-cuda-prelude.h:4314
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool F32_isfinite(float f)
Definition slang-cuda-prelude.h:2110
SLANG_FORCE_INLINE SLANG_CUDA_CALL int32_t I32_max(int32_t a, int32_t b)
Definition slang-cuda-prelude.h:2452
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t I32_firstbitlow(int32_t v)
Definition slang-cuda-prelude.h:2479
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_trunc(double f)
Definition slang-cuda-prelude.h:2254
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool F32_isinf(float f)
Definition slang-cuda-prelude.h:2114
SLANG_FORCE_INLINE SLANG_CUDA_CALL uintptr_t UPTR_min(uintptr_t a, uintptr_t b)
Definition slang-cuda-prelude.h:2596
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool1 make_bool1(bool x)
Definition slang-cuda-prelude.h:732
#define SLANG_SELECT_T(T)
Definition slang-cuda-prelude.h:1342
#define SLANG_TEX3DFETCH_INT_IMPL(T, dtype, c)
Definition slang-cuda-prelude.h:6185
__forceinline__ __device__ WarpMask _getLaneLtMask()
Definition slang-cuda-prelude.h:3118
SLANG_FORCE_INLINE SLANG_CUDA_CALL void F32_sincos(float f, float *s, float *c)
Definition slang-cuda-prelude.h:2017
SLANG_FORCE_INLINE SLANG_CUDA_CALL int F64_sign(double f)
Definition slang-cuda-prelude.h:2266
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_pow(float a, float b)
Definition slang-cuda-prelude.h:2128
__inline__ __device__ int _waveCalcPow2Offset(WarpMask mask)
Definition slang-cuda-prelude.h:3151
__device__ __forceinline__ void __slang_atomic_reduce_min(int32_t *addr, int32_t val, int order)
Definition slang-cuda-prelude.h:2828
SLANG_FORCE_INLINE SLANG_CUDA_CALL int32_t I32_min(int32_t a, int32_t b)
Definition slang-cuda-prelude.h:2448
#define SLANG_FLOAT_MATRIX_OPS(T)
Definition slang-cuda-prelude.h:1248
SamplerStateUnused * SamplerState
Definition slang-cuda-prelude.h:192
__device__ __forceinline__ void __slang_atomic_reduce_or(int32_t *addr, int32_t val, int order)
Definition slang-cuda-prelude.h:2901
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_abs(float f)
Definition slang-cuda-prelude.h:2081
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_modf(double x, double *ip)
Definition slang-cuda-prelude.h:2319
int WarpMask
Definition slang-cuda-prelude.h:3094
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_log2(float f)
Definition slang-cuda-prelude.h:2061
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_max(double a, double b)
Definition slang-cuda-prelude.h:2293
#define SLANG_CUDA_VECTOR_ATOMIC_BINARY_IMPL(Fn, T, N)
Definition slang-cuda-prelude.h:838
unsigned long long ulonglong
Definition slang-cuda-prelude.h:303
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool F64_isinf(double f)
Definition slang-cuda-prelude.h:2283
__device__ T _waveOpDoInverse(T *inOut, const T *val)
Definition slang-cuda-prelude.h:3969
#define SLANG_CUDA_WARP_SIZE
Definition slang-cuda-prelude.h:60
SLANG_FORCE_INLINE SLANG_CUDA_CALL double I32_asdouble(int32_t low, int32_t hi)
Definition slang-cuda-prelude.h:2467
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_round(double f)
Definition slang-cuda-prelude.h:2186
unsigned int uint
Definition slang-cuda-prelude.h:309
__inline__ __device__ bool _waveIsFirstLane()
Definition slang-cuda-prelude.h:3172
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_exp(float f)
Definition slang-cuda-prelude.h:2077
SLANG_FORCE_INLINE SLANG_CUDA_CALL int32_t I32_abs(int32_t f)
Definition slang-cuda-prelude.h:2442
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U32_reversebits(uint32_t v)
Definition slang-cuda-prelude.h:2434
__device__ __forceinline__ void __slang_atomic_reduce_and(int32_t *addr, int32_t val, int order)
Definition slang-cuda-prelude.h:2880
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_cos(float f)
Definition slang-cuda-prelude.h:2013
__inline__ __device__ T _wavePrefixExclusiveMinMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4268
#define SLANG_CUDA_VECTOR_INT_OPS(T)
Definition slang-cuda-prelude.h:571
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_atan(double f)
Definition slang-cuda-prelude.h:2214
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool __ldg(const bool *ptr)
Definition slang-cuda-prelude.h:263
SLANG_FORCE_INLINE SLANG_CUDA_CALL T _slang_select(bool condition, T v0, T v1)
Definition slang-cuda-prelude.h:1358
SLANG_FORCE_INLINE SLANG_CUDA_CALL int64_t U64_min(uint64_t a, uint64_t b)
Definition slang-cuda-prelude.h:2501
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_atan(float f)
Definition slang-cuda-prelude.h:2033
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t I64_firstbithigh(int64_t v)
Definition slang-cuda-prelude.h:2560
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t I32_firstbithigh(int32_t v)
Definition slang-cuda-prelude.h:2484
__device__ T _wavePrefixMultiple(WarpMask mask, T *val)
Definition slang-cuda-prelude.h:4045
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_rsqrt(double f)
Definition slang-cuda-prelude.h:2262
SLANG_FORCE_INLINE SLANG_CUDA_CALL float U32_asfloat(uint32_t x)
Definition slang-cuda-prelude.h:2394
__inline__ __device__ T _wavePrefixSum(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4108
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_pow(double a, double b)
Definition slang-cuda-prelude.h:2297
SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf1DLayeredwrite_convert(T v, cudaSurfaceObject_t surfObj, int x, int layer, cudaSurfaceBoundaryMode boundaryMode)
Definition slang-cuda-prelude.h:1634
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t I8_countbits(int8_t v)
Definition slang-cuda-prelude.h:2356
#define SLANG_SURF3DWRITE_CONVERT_IMPL(T, c)
Definition slang-cuda-prelude.h:1739
__inline__ __device__ T _wavePrefixProductMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4133
unsigned long long OptixTraversableHandle
Definition slang-cuda-prelude.h:5912
__inline__ __device__ T _wavePrefixExclusiveMax(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4262
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_ceil(float f)
Definition slang-cuda-prelude.h:1997
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t I64_countbits(int64_t v)
Definition slang-cuda-prelude.h:2550
#define SLANG_BOUND_CHECK_BYTE_ADDRESS(index, elemSize, sizeInBytes)
Definition slang-cuda-prelude.h:111
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_sinh(double f)
Definition slang-cuda-prelude.h:2218
__inline__ __device__ T _waveSum(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3659
SLANG_FORCE_INLINE SLANG_CUDA_CALL int64_t I64_min(int64_t a, int64_t b)
Definition slang-cuda-prelude.h:2541
__forceinline__ __device__ uint32_t _getLaneId()
Definition slang-cuda-prelude.h:3076
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_trunc(float f)
Definition slang-cuda-prelude.h:2085
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_asinh(float f)
Definition slang-cuda-prelude.h:2049
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_sin(double f)
Definition slang-cuda-prelude.h:2190
__device__ T _wavePrefixInvertableMultiple(WarpMask mask, T *val)
Definition slang-cuda-prelude.h:3987
SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf2DLayeredwrite_convert(T v, cudaSurfaceObject_t surfObj, int x, int y, int layer, cudaSurfaceBoundaryMode boundaryMode)
Definition slang-cuda-prelude.h:1715
unsigned short ushort
Definition slang-cuda-prelude.h:308
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_sqrt(double f)
Definition slang-cuda-prelude.h:2258
#define SLANG_SURF2DWRITE_CONVERT_IMPL(T, c)
Definition slang-cuda-prelude.h:1656
SLANG_FORCE_INLINE SLANG_CUDA_CALL void F64_asint(double d, int32_t *low, int32_t *hi)
Definition slang-cuda-prelude.h:2332
__inline__ __device__ T _waveMax(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3671
SLANG_FORCE_INLINE SLANG_CUDA_CALL int64_t U64_max(uint64_t a, uint64_t b)
Definition slang-cuda-prelude.h:2505
__inline__ __device__ bool _waveAllEqual(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3793
__inline__ __device__ T _waveShuffleMultiple(WarpMask mask, T inVal, int lane)
Definition slang-cuda-prelude.h:3842
SLANG_FORCE_INLINE SLANG_CUDA_CALL Matrix< T, ROWS, COLS > makeMatrix(T scalar)
Definition slang-cuda-prelude.h:944
__device__ __forceinline__ void __slang_atomic_reduce_max(int32_t *addr, int32_t val, int order)
Definition slang-cuda-prelude.h:2854
SLANG_FORCE_INLINE SLANG_CUDA_CALL int32_t F32_asint(float f)
Definition slang-cuda-prelude.h:2161
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_frac(double f)
Definition slang-cuda-prelude.h:2270
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_acosh(float f)
Definition slang-cuda-prelude.h:2053
__inline__ __device__ T _waveProductMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3760
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_ceil(double f)
Definition slang-cuda-prelude.h:2178
SLANG_FORCE_INLINE SLANG_CUDA_CALL int64_t I64_max(int64_t a, int64_t b)
Definition slang-cuda-prelude.h:2545
#define SLANG_INFINITY
Definition slang-cuda-prelude.h:53
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_asin(double f)
Definition slang-cuda-prelude.h:2206
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_exp2(double f)
Definition slang-cuda-prelude.h:2242
__device__ __forceinline__ bool _slang_quadAny(bool expr)
Definition slang-cuda-prelude.h:6401
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U64_firstbithigh(uint64_t v)
Definition slang-cuda-prelude.h:2522
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_atan2(double a, double b)
Definition slang-cuda-prelude.h:2309
SLANG_FORCE_INLINE SLANG_CUDA_CALL void F64_asuint(double d, uint32_t *low, uint32_t *hi)
Definition slang-cuda-prelude.h:2324
SLANG_FORCE_INLINE SLANG_CUDA_CALL T tex3Dfetch_int(CUtexObject texObj, int x, int y, int z, int mip)
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_rsqrt(float f)
Definition slang-cuda-prelude.h:2093
#define SLANG_SURF1DWRITE_CONVERT_IMPL(T, c)
Definition slang-cuda-prelude.h:1582
SLANG_FORCE_INLINE SLANG_CUDA_CALL float I32_asfloat(int32_t x)
Definition slang-cuda-prelude.h:2457
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_exp(double f)
Definition slang-cuda-prelude.h:2246
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U32_asint(int32_t x)
Definition slang-cuda-prelude.h:2400
__inline__ __device__ T _waveMaxMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3784
__inline__ __device__ T _waveXor(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3647
__device__ T _waveReduceScalar(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3554
__inline__ __device__ T _wavePrefixInclusiveMax(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4237
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_floor(double f)
Definition slang-cuda-prelude.h:2182
#define SLANG_CUDA_FLOAT_VECTOR_MOD(T)
Definition slang-cuda-prelude.h:682
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_fmod(double a, double b)
Definition slang-cuda-prelude.h:2301
SLANG_FORCE_INLINE SLANG_CUDA_CALL void surf2Dwrite_convert(T v, cudaSurfaceObject_t surfObj, int x, int y, cudaSurfaceBoundaryMode boundaryMode)
unsigned long long CUtexObject
Definition slang-cuda-prelude.h:184
__inline__ __device__ T _wavePrefixXor(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4114
__inline__ __device__ T _wavePrefixAnd(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4126
#define SLANG_FLOAT_MATRIX_MOD(T)
Definition slang-cuda-prelude.h:1287
#define SLANG_TEX1DARRAYFETCH_INT_IMPL(T, dtype, c)
Definition slang-cuda-prelude.h:6227
__device__ T _wavePrefixScalar(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3911
__inline__ __device__ T _wavePrefixXorMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4153
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_fma(float a, float b, float c)
Definition slang-cuda-prelude.h:2169
__inline__ __device__ T _waveReadFirst(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3819
__inline__ __device__ T _wavePrefixSumMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4143
#define SLANG_WAVE_CLUSTERED_ROTATE_IMPL(T)
Definition slang-cuda-prelude.h:6423
__device__ __forceinline__ void __slang_atomic_reduce_xor(int32_t *addr, int32_t val, int order)
Definition slang-cuda-prelude.h:2922
SLANG_FORCE_INLINE SLANG_CUDA_CALL int64_t I64_abs(int64_t f)
Definition slang-cuda-prelude.h:2536
__inline__ __device__ bool _waveAllEqualMultiple(WarpMask mask, T inVal)
Definition slang-cuda-prelude.h:3801
#define SLANG_MAKE_VECTOR_FROM_SCALAR(T)
Definition slang-cuda-prelude.h:780
__device__ __forceinline__ void __slang_atomic_reduce_inc(uint32_t *addr, int order)
Definition slang-cuda-prelude.h:2947
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U8_countbits(uint8_t v)
Definition slang-cuda-prelude.h:2348
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_log10(float f)
Definition slang-cuda-prelude.h:2069
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_frexp(double x, int *e)
Definition slang-cuda-prelude.h:2314
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t U32_abs(uint32_t f)
Definition slang-cuda-prelude.h:2379
#define SLANG_TEX2DARRAYFETCH_INT_IMPL(T, dtype, c)
Definition slang-cuda-prelude.h:6269
#define SLANG_BOUND_CHECK(index, count)
Definition slang-cuda-prelude.h:106
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_frac(float f)
Definition slang-cuda-prelude.h:2101
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_cosh(float f)
Definition slang-cuda-prelude.h:2041
__inline__ __device__ bool _waveIsSingleLane(WarpMask mask)
Definition slang-cuda-prelude.h:3139
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_sin(float f)
Definition slang-cuda-prelude.h:2009
SLANG_FORCE_INLINE SLANG_CUDA_CALL float F32_modf(float x, float *ip)
Definition slang-cuda-prelude.h:2150
__inline__ __device__ T _wavePrefixInclusiveMaxMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4249
#define SLANG_VECTOR_GET_ELEMENT_PTR(T)
Definition slang-cuda-prelude.h:444
SLANG_FORCE_INLINE SLANG_CUDA_CALL int64_t I64_reversebits(int64_t v)
Definition slang-cuda-prelude.h:2567
SLANG_FORCE_INLINE SLANG_CUDA_CALL int32_t I32_reversebits(int32_t v)
Definition slang-cuda-prelude.h:2489
SLANG_FORCE_INLINE SLANG_CUDA_CALL double F64_acos(double f)
Definition slang-cuda-prelude.h:2210
SLANG_FORCE_INLINE SLANG_CUDA_CALL bool F32_isnan(float f)
Definition slang-cuda-prelude.h:2106
__inline__ __device__ T _wavePrefixExclusiveMin(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4256
__inline__ __device__ T _wavePrefixProduct(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4102
SLANG_FORCE_INLINE SLANG_CUDA_CALL uint32_t I32_countbits(int32_t v)
Definition slang-cuda-prelude.h:2474
__inline__ __device__ T _waveXorMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3752
__inline__ __device__ T _wavePrefixAndMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:4173
__device__ __forceinline__ void __slang_atomic_reduce_add(int32_t *addr, int32_t val, int order)
Definition slang-cuda-prelude.h:2751
__inline__ __device__ T _waveMinMultiple(WarpMask mask, T val)
Definition slang-cuda-prelude.h:3776
__device__ __forceinline__ bool _slang_waveClusteredRotate(bool value, unsigned int delta, unsigned int clusterSize)
Definition slang-cuda-prelude.h:6482
SLANG_FORCE_INLINE SLANG_CUDA_CALL T tex2DArrayfetch_int(CUtexObject texObj, int x, int y, int layer, int mip)
__INTPTR_TYPE__ intptr_t
Definition slang-llvm.h:146
__UINTPTR_TYPE__ uintptr_t
Definition slang-llvm.h:153
__PTRDIFF_TYPE__ ptrdiff_t
Definition slang-llvm.h:29
static const int kSlangTorchTensorMaxDim
Definition slang-torch-prelude.h:70
Definition slang-cpp-types-core.h:82
SLANG_CUDA_CALL T & operator[](size_t index)
Definition slang-cuda-prelude.h:173
SLANG_CUDA_CALL const T & operator[](size_t index) const
Definition slang-cuda-prelude.h:168
T * data
Definition slang-cpp-types-core.h:94
size_t count
Definition slang-cpp-types-core.h:95
Definition slang-cpp-types.h:123
SLANG_CUDA_CALL StructuredBuffer< T > asStructuredBuffer() const
Definition slang-cuda-prelude.h:2693
SLANG_CUDA_CALL uint4 Load4(size_t index) const
Definition slang-cuda-prelude.h:2678
SLANG_CUDA_CALL uint32_t Load(size_t index) const
Definition slang-cuda-prelude.h:2661
SLANG_CUDA_CALL T Load(size_t index) const
Definition slang-cuda-prelude.h:2685
SLANG_CUDA_CALL uint2 Load2(size_t index) const
Definition slang-cuda-prelude.h:2666
SLANG_CUDA_CALL uint3 Load3(size_t index) const
Definition slang-cuda-prelude.h:2672
const uint32_t * data
Definition slang-cpp-types.h:155
SLANG_CUDA_CALL void GetDimensions(uint32_t *outDim) const
Definition slang-cuda-prelude.h:2660
size_t sizeInBytes
Definition slang-cpp-types.h:156
T Type
Definition slang-cuda-prelude.h:3549
char Type
Definition slang-cuda-prelude.h:3440
char Type
Definition slang-cuda-prelude.h:3445
char Type
Definition slang-cuda-prelude.h:3450
char Type
Definition slang-cuda-prelude.h:3326
double Type
Definition slang-cuda-prelude.h:3418
double Type
Definition slang-cuda-prelude.h:3423
double Type
Definition slang-cuda-prelude.h:3428
double Type
Definition slang-cuda-prelude.h:3433
double Type
Definition slang-cuda-prelude.h:3311
float Type
Definition slang-cuda-prelude.h:3397
float Type
Definition slang-cuda-prelude.h:3402
float Type
Definition slang-cuda-prelude.h:3407
float Type
Definition slang-cuda-prelude.h:3412
float Type
Definition slang-cuda-prelude.h:3306
int Type
Definition slang-cuda-prelude.h:3355
int Type
Definition slang-cuda-prelude.h:3360
int Type
Definition slang-cuda-prelude.h:3365
int Type
Definition slang-cuda-prelude.h:3370
int64_t Type
Definition slang-cuda-prelude.h:3321
int Type
Definition slang-cuda-prelude.h:3296
int64_t Type
Definition slang-cuda-prelude.h:3500
int64_t Type
Definition slang-cuda-prelude.h:3505
int64_t Type
Definition slang-cuda-prelude.h:3510
short Type
Definition slang-cuda-prelude.h:3470
short Type
Definition slang-cuda-prelude.h:3475
short Type
Definition slang-cuda-prelude.h:3480
short Type
Definition slang-cuda-prelude.h:3336
uchar Type
Definition slang-cuda-prelude.h:3455
uchar Type
Definition slang-cuda-prelude.h:3460
uchar Type
Definition slang-cuda-prelude.h:3465
uchar Type
Definition slang-cuda-prelude.h:3331
uint Type
Definition slang-cuda-prelude.h:3376
uint Type
Definition slang-cuda-prelude.h:3381
uint Type
Definition slang-cuda-prelude.h:3386
uint Type
Definition slang-cuda-prelude.h:3391
uint64_t Type
Definition slang-cuda-prelude.h:3316
uint Type
Definition slang-cuda-prelude.h:3301
uint64_t Type
Definition slang-cuda-prelude.h:3515
uint64_t Type
Definition slang-cuda-prelude.h:3520
uint64_t Type
Definition slang-cuda-prelude.h:3525
ushort Type
Definition slang-cuda-prelude.h:3485
ushort Type
Definition slang-cuda-prelude.h:3490
ushort Type
Definition slang-cuda-prelude.h:3495
ushort Type
Definition slang-cuda-prelude.h:3341
Definition slang-cuda-prelude.h:3290
Definition slang-cpp-types-core.h:63
SLANG_CUDA_CALL const T & operator[](size_t index) const
Definition slang-cuda-prelude.h:149
T m_data[SIZE]
Definition slang-cpp-types-core.h:75
SLANG_CUDA_CALL T & operator[](size_t index)
Definition slang-cuda-prelude.h:154
Definition slang-cuda-prelude.h:869
Definition slang-cpp-types-core.h:401
SLANG_FORCE_INLINE SLANG_CUDA_CALL const Vector< T, COLS > & operator[](size_t index) const
Definition slang-cuda-prelude.h:936
SLANG_FORCE_INLINE SLANG_CUDA_CALL Vector< T, COLS > & operator[](size_t index)
Definition slang-cuda-prelude.h:931
Vector< T, COLS > rows[ROWS]
Definition slang-cpp-types-core.h:402
Definition slang-cpp-types.h:163
SLANG_CUDA_CALL uint32_t Load(size_t index) const
Definition slang-cuda-prelude.h:2976
SLANG_CUDA_CALL RWStructuredBuffer< T > asStructuredBuffer() const
Definition slang-cuda-prelude.h:3052
SLANG_CUDA_CALL void Store2(size_t index, uint2 v) const
Definition slang-cuda-prelude.h:3013
SLANG_CUDA_CALL void Store(size_t index, uint32_t v) const
Definition slang-cuda-prelude.h:3008
SLANG_CUDA_CALL void GetDimensions(uint32_t *outDim) const
Definition slang-cuda-prelude.h:2974
SLANG_CUDA_CALL uint3 Load3(size_t index) const
Definition slang-cuda-prelude.h:2987
SLANG_CUDA_CALL uint4 Load4(size_t index) const
Definition slang-cuda-prelude.h:2993
SLANG_CUDA_CALL void Store3(size_t index, uint3 v) const
Definition slang-cuda-prelude.h:3020
SLANG_CUDA_CALL uint2 Load2(size_t index) const
Definition slang-cuda-prelude.h:2981
SLANG_CUDA_CALL T * _getPtrAt(size_t index)
Can be used in the core module to gain access
Definition slang-cuda-prelude.h:3046
SLANG_CUDA_CALL T Load(size_t index) const
Definition slang-cuda-prelude.h:3000
size_t sizeInBytes
Definition slang-cpp-types.h:233
SLANG_CUDA_CALL void Store4(size_t index, uint4 v) const
Definition slang-cuda-prelude.h:3028
SLANG_CUDA_CALL void Store(size_t index, T const &value) const
Definition slang-cuda-prelude.h:3038
uint32_t * data
Definition slang-cpp-types.h:232
Definition slang-cpp-types.h:38
SLANG_CUDA_CALL T & operator[](size_t index) const
Definition slang-cuda-prelude.h:2648
T * data
Definition slang-cpp-types.h:55
size_t count
Definition slang-cpp-types.h:56
Definition slang-cpp-types.h:61
size_t count
Definition slang-cpp-types.h:79
SLANG_CUDA_CALL T & Load(size_t index) const
Definition slang-cuda-prelude.h:2623
T * data
Definition slang-cpp-types.h:78
SLANG_CUDA_CALL void GetDimensions(uint32_t *outNumStructs, uint32_t *outStride) const
Definition slang-cuda-prelude.h:2632
SLANG_CUDA_CALL T & operator[](size_t index) const
Definition slang-cuda-prelude.h:2615
Definition slang-cuda-prelude.h:5920
__device__ T * data_ptr_at(uint2 index)
Definition slang-cuda-prelude.h:5940
__device__ void store(uint32_t x, T val)
Definition slang-cuda-prelude.h:6032
__device__ void store(uint32_t x, uint32_t y, uint32_t z, uint32_t w, T val)
Definition slang-cuda-prelude.h:6058
__device__ void store(uint2 index, T val)
Definition slang-cuda-prelude.h:6042
__device__ T & load(uint32_t x, uint32_t y, uint32_t z, uint32_t w)
Definition slang-cuda-prelude.h:5999
__device__ T & load(uint4 index)
Definition slang-cuda-prelude.h:6005
__device__ T * data_ptr_at(uint32_t index)
Definition slang-cuda-prelude.h:5933
__device__ void store(uint32_t x, uint32_t y, T val)
Definition slang-cuda-prelude.h:6037
uint32_t strides[kSlangTorchTensorMaxDim]
Definition slang-cuda-prelude.h:5922
__device__ T & load(uint32_t x, uint32_t y, uint32_t z)
Definition slang-cuda-prelude.h:5988
__device__ T & load(uint32_t x, uint32_t y)
Definition slang-cuda-prelude.h:5978
uint32_t sizes[kSlangTorchTensorMaxDim]
Definition slang-cuda-prelude.h:5923
__device__ T * data_ptr()
Definition slang-cuda-prelude.h:5927
__device__ T * data_ptr_at(uint4 index)
Definition slang-cuda-prelude.h:5954
__device__ void store(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4, T val)
Definition slang-cuda-prelude.h:6071
__device__ void store(uint4 index, T val)
Definition slang-cuda-prelude.h:6064
uint8_t * data
Definition slang-cuda-prelude.h:5921
__device__ void store(uint3 index, T val)
Definition slang-cuda-prelude.h:6052
__device__ T & load(uint index[N])
Definition slang-cuda-prelude.h:6021
__device__ T & load(uint32_t x)
Definition slang-cuda-prelude.h:5973
__device__ T * data_ptr_at(uint index[N])
Definition slang-cuda-prelude.h:5962
__device__ T & load(uint2 index)
Definition slang-cuda-prelude.h:5983
__device__ T * data_ptr_at(uint3 index)
Definition slang-cuda-prelude.h:5947
__device__ T & load(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3, uint32_t i4)
Definition slang-cuda-prelude.h:6012
__device__ void store(uint index[N], T val)
Definition slang-cuda-prelude.h:6080
__device__ void store(uint32_t x, uint32_t y, uint32_t z, T val)
Definition slang-cuda-prelude.h:6047
uint32_t dimensionCount
Definition slang-cuda-prelude.h:5924
__device__ T & load(uint3 index)
Definition slang-cuda-prelude.h:5993
Definition slang-cpp-types-core.h:57
size_t typeSize
Definition slang-cpp-types-core.h:58
Definition slang-cpp-types-core.h:103
Definition slang-cuda-prelude.h:3208
__inline__ static __device__ T getInitial(T a)
Definition slang-cuda-prelude.h:3209
__inline__ static __device__ T doOp(T a, T b)
Definition slang-cuda-prelude.h:3210
__inline__ static __device__ T doInverse(T a, T b)
Definition slang-cuda-prelude.h:3211
Definition slang-cuda-prelude.h:3193
__inline__ static __device__ T getInitial(T a)
Definition slang-cuda-prelude.h:3194
__inline__ static __device__ T doOp(T a, T b)
Definition slang-cuda-prelude.h:3195
Definition slang-cuda-prelude.h:4224
__inline__ static __device__ T getInitial(T a)
Definition slang-cuda-prelude.h:4225
__inline__ static __device__ T doOp(T a, T b)
Definition slang-cuda-prelude.h:4226
Definition slang-cuda-prelude.h:4217
__inline__ static __device__ T doOp(T a, T b)
Definition slang-cuda-prelude.h:4219
__inline__ static __device__ T getInitial(T a)
Definition slang-cuda-prelude.h:4218
Definition slang-cuda-prelude.h:3227
__inline__ static __device__ T doOp(T a, T b)
Definition slang-cuda-prelude.h:3229
__inline__ static __device__ T getInitial(T a, bool exclusive=false)
Definition slang-cuda-prelude.h:3234
__inline__ static __device__ T doOp(T a, T b)
Definition slang-cuda-prelude.h:3236
__inline__ static __device__ T getInitial(T a, bool exclusive=false)
Definition slang-cuda-prelude.h:3216
__inline__ static __device__ T doInverse(T a, T b)
Definition slang-cuda-prelude.h:3222
__inline__ static __device__ T getInitial(T a)
Definition slang-cuda-prelude.h:3217
__inline__ static __device__ T doOp(T a, T b)
Definition slang-cuda-prelude.h:3218
Definition slang-cuda-prelude.h:3186
__inline__ static __device__ T doOp(T a, T b)
Definition slang-cuda-prelude.h:3188
__inline__ static __device__ T getInitial(T a)
Definition slang-cuda-prelude.h:3187
Definition slang-cuda-prelude.h:3200
__inline__ static __device__ T doOp(T a, T b)
Definition slang-cuda-prelude.h:3202
__inline__ static __device__ T doInverse(T a, T b)
Definition slang-cuda-prelude.h:3203
__inline__ static __device__ T getInitial(T a)
Definition slang-cuda-prelude.h:3201
Definition slang-cpp-scalar-intrinsics.h:671
Definition slang-cpp-scalar-intrinsics.h:24
uint32_t u
Definition slang-cpp-scalar-intrinsics.h:25
int32_t i
Definition slang-cpp-scalar-intrinsics.h:26
float f
Definition slang-cpp-scalar-intrinsics.h:27
Definition slang-cpp-scalar-intrinsics.h:31
int64_t i
Definition slang-cpp-scalar-intrinsics.h:33
double d
Definition slang-cpp-scalar-intrinsics.h:34
uint64_t u
Definition slang-cpp-scalar-intrinsics.h:32