Traktor/myenv/Lib/site-packages/torch/include/ATen/cuda/CUDAUtils.h
2024-05-26 05:12:46 +02:00

21 lines
416 B
C++

#pragma once
#include <ATen/cuda/CUDAContext.h>
namespace at::cuda {
// Check if every tensor in a list of tensors matches the current
// device.
inline bool check_device(ArrayRef<Tensor> ts) {
if (ts.empty()) {
return true;
}
Device curDevice = Device(kCUDA, current_device());
for (const Tensor& t : ts) {
if (t.device() != curDevice) return false;
}
return true;
}
} // namespace at::cuda