slot 0.0.1
A real time UI render framework
载入中...
搜索中...
未找到
slang-torch-prelude.h
浏览该文件的文档.
1// Prelude for PyTorch cpp binding.
2
3// clang-format off
4#include <torch/extension.h>
5// clang-format on
6
7#include <ATen/cuda/CUDAContext.h>
8#include <ATen/cuda/CUDAUtils.h>
9#include <stdexcept>
10#include <string>
11#include <vector>
12
13#ifdef SLANG_LLVM
14#include "slang-llvm.h"
15#else // SLANG_LLVM
16#if SLANG_GCC_FAMILY && __GNUC__ < 6
17#include <cmath>
18#define SLANG_PRELUDE_STD std::
19#else
20#include <math.h>
21#define SLANG_PRELUDE_STD
22#endif
23
24#include <assert.h>
25#include <stdint.h>
26#include <stdlib.h>
27#include <string.h>
28#endif // SLANG_LLVM
29
30#include "../source/core/slang-string.h"
31
32#if defined(_MSC_VER)
33#define SLANG_PRELUDE_SHARED_LIB_EXPORT __declspec(dllexport)
34#else
35#define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__((__visibility__("default")))
36// # define SLANG_PRELUDE_SHARED_LIB_EXPORT __attribute__ ((dllexport))
37// __attribute__((__visibility__("default")))
38#endif
39
40#ifdef __cplusplus
41#define SLANG_PRELUDE_EXTERN_C extern "C"
42#define SLANG_PRELUDE_EXTERN_C_START \
43 extern "C" \
44 {
45#define SLANG_PRELUDE_EXTERN_C_END }
46#else
47#define SLANG_PRELUDE_EXTERN_C
48#define SLANG_PRELUDE_EXTERN_C_START
49#define SLANG_PRELUDE_EXTERN_C_END
50#endif
51
52#define SLANG_PRELUDE_NAMESPACE
53
54#ifndef SLANG_NO_THROW
55#define SLANG_NO_THROW
56#endif
57#ifndef SLANG_STDCALL
58#define SLANG_STDCALL
59#endif
60#ifndef SLANG_MCALL
61#define SLANG_MCALL SLANG_STDCALL
62#endif
63#ifndef SLANG_FORCE_INLINE
64#define SLANG_FORCE_INLINE inline
65#endif
68
69
70static const int kSlangTorchTensorMaxDim = 5;
71
72// NOTE: If you change this struct's layout, also update the hard-coded size/alignment
73// in _createTypeLayout() in slang-type-layout.cpp.
74struct TensorView
75{
76 uint8_t* data;
79 uint32_t dimensionCount;
80};
81
82
84 torch::Tensor val,
85 const char* name,
86 torch::ScalarType targetScalarType,
87 bool requireContiguous)
88{
89 // We're currently not trying to implicitly cast or transfer to device for two reasons:
90 // 1. There appears to be a bug with .to() where successive calls after the first one fail.
91 // 2. Silent casts like this can cause large memory allocations & unexpected overheads.
92 // It's better to be explicit.
93
94 // Expect tensors to be on CUDA device
95 if (!val.device().is_cuda())
96 throw std::runtime_error(
97 std::string(name).append(": tensor is not on CUDA device.").c_str());
98
99 // Expect tensors to be the right type.
100 if (val.dtype() != targetScalarType)
101 throw std::runtime_error(
102 std::string(name).append(": tensor is not of the expected type.").c_str());
103
104 // Check that the tensor is contiguous
105 if (requireContiguous && !val.is_contiguous())
106 throw std::runtime_error(std::string(name).append(": tensor is not contiguous.").c_str());
107
108 TensorView res = {};
109 res.dimensionCount = val.dim();
110 res.data = nullptr;
111 size_t elementSize = 4;
112
113 switch (val.scalar_type())
114 {
115 case torch::kInt8:
116 case torch::kUInt8:
117 elementSize = 1;
118 res.data = (uint8_t*)val.data_ptr<uint8_t>();
119 break;
120 case torch::kBFloat16:
121 elementSize = 2;
122 res.data = (uint8_t*)val.data_ptr<torch::BFloat16>();
123 break;
124 case torch::kFloat16:
125 elementSize = 2;
126 res.data = (uint8_t*)val.data_ptr<at::Half>();
127 break;
128 case torch::kInt16:
129 elementSize = 2;
130 res.data = (uint8_t*)val.data_ptr<int16_t>();
131 break;
132 case torch::kFloat32:
133 elementSize = 4;
134 res.data = (uint8_t*)val.data_ptr<float>();
135 break;
136 case torch::kInt32:
137 elementSize = 4;
138 res.data = (uint8_t*)val.data_ptr<int32_t>();
139 break;
140 case torch::kFloat64:
141 elementSize = 8;
142 res.data = (uint8_t*)val.data_ptr<double>();
143 break;
144 case torch::kInt64:
145 elementSize = 8;
146 res.data = (uint8_t*)val.data_ptr<int64_t>();
147 break;
148 case torch::kBool:
149 elementSize = 1;
150 res.data = (uint8_t*)val.data_ptr<bool>();
151 break;
152 }
153
154 if (val.dim() > kSlangTorchTensorMaxDim)
155 throw std::runtime_error(std::string(name)
156 .append(": number of dimensions exceeds limit (")
157 .append(std::to_string(kSlangTorchTensorMaxDim))
158 .append(")")
159 .c_str());
160
161 // A tensor can have zero elements even if some dimensions are non-zero
162 // (e.g. shape (10, 0)). Emptiness must be based on numel().
163 bool isEmpty = (val.numel() == 0);
164 for (int i = 0; i < val.dim(); ++i)
165 {
166 res.sizes[i] = val.size(i);
167 res.strides[i] = val.stride(i) * elementSize;
168 if (!isEmpty && res.strides[i] == 0)
169 throw std::runtime_error(
170 std::string(name)
171 .append(": tensors with broadcasted dimensions are not supported (use "
172 "tensor.contiguous() to make tensor whole)")
173 .c_str());
174 }
175
176 if (!res.data && !isEmpty)
177 throw std::runtime_error(std::string(name).append(": data pointer is invalid.").c_str());
178
179 return res;
180}
181
182#define SLANG_PRELUDE_EXPORT
TensorView make_tensor_view(torch::Tensor val, const char *name, torch::ScalarType targetScalarType, bool requireContiguous)
Definition slang-torch-prelude.h:83
static const int kSlangTorchTensorMaxDim
Definition slang-torch-prelude.h:70
Definition slang-cuda-prelude.h:5920
uint32_t strides[kSlangTorchTensorMaxDim]
Definition slang-cuda-prelude.h:5922
uint32_t sizes[kSlangTorchTensorMaxDim]
Definition slang-cuda-prelude.h:5923
uint8_t * data
Definition slang-cuda-prelude.h:5921
uint32_t dimensionCount
Definition slang-cuda-prelude.h:5924