86 torch::ScalarType targetScalarType,
87 bool requireContiguous)
95 if (!val.device().is_cuda())
96 throw std::runtime_error(
97 std::string(name).append(
": tensor is not on CUDA device.").c_str());
100 if (val.dtype() != targetScalarType)
101 throw std::runtime_error(
102 std::string(name).append(
": tensor is not of the expected type.").c_str());
105 if (requireContiguous && !val.is_contiguous())
106 throw std::runtime_error(std::string(name).append(
": tensor is not contiguous.").c_str());
111 size_t elementSize = 4;
113 switch (val.scalar_type())
118 res.
data = (uint8_t*)val.data_ptr<uint8_t>();
120 case torch::kBFloat16:
122 res.
data = (uint8_t*)val.data_ptr<torch::BFloat16>();
124 case torch::kFloat16:
126 res.
data = (uint8_t*)val.data_ptr<at::Half>();
130 res.
data = (uint8_t*)val.data_ptr<int16_t>();
132 case torch::kFloat32:
134 res.
data = (uint8_t*)val.data_ptr<
float>();
138 res.
data = (uint8_t*)val.data_ptr<int32_t>();
140 case torch::kFloat64:
142 res.
data = (uint8_t*)val.data_ptr<
double>();
146 res.
data = (uint8_t*)val.data_ptr<int64_t>();
150 res.
data = (uint8_t*)val.data_ptr<
bool>();
155 throw std::runtime_error(std::string(name)
156 .append(
": number of dimensions exceeds limit (")
163 bool isEmpty = (val.numel() == 0);
164 for (
int i = 0; i < val.dim(); ++i)
166 res.
sizes[i] = val.size(i);
167 res.
strides[i] = val.stride(i) * elementSize;
168 if (!isEmpty && res.
strides[i] == 0)
169 throw std::runtime_error(
171 .append(
": tensors with broadcasted dimensions are not supported (use "
172 "tensor.contiguous() to make tensor whole)")
176 if (!res.
data && !isEmpty)
177 throw std::runtime_error(std::string(name).append(
": data pointer is invalid.").c_str());
TensorView make_tensor_view(torch::Tensor val, const char *name, torch::ScalarType targetScalarType, bool requireContiguous)
Definition slang-torch-prelude.h:83