/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // Generally useful utility functions that are common to (not specific to any // given part of) the XLA code base. #ifndef XLA_UTIL_H_ #define XLA_UTIL_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "absl/algorithm/container.h" #include "absl/base/macros.h" #include "absl/base/thread_annotations.h" #include "absl/container/inlined_vector.h" #include "absl/memory/memory.h" #include "absl/numeric/bits.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/types/span.h" #include "Eigen/Core" // from @eigen_archive #include "xla/status.h" #include "xla/status_macros.h" #include "xla/xla_data.pb.h" #include "tsl/lib/math/math_util.h" #include "tsl/platform/bfloat16.h" #include "tsl/platform/casts.h" #include "tsl/platform/errors.h" // IWYU pragma: keep #include "tsl/platform/logging.h" #include "tsl/platform/ml_dtypes.h" namespace xla { // Converts the unsigned integer n into a mixed-radix representation with the // given bounds (radices). More precisely, if there are K radices, then the // returned vector digits has K entries and satisfies // // 0 <= digits[i] < bounds[i], for i = 0, ..., K - 1 // // and FromMixedRadix(digits) == n. The mixed radix representation is unique // modulo the product of the entries of bounds. std::vector ToMixedRadix(int64_t n, absl::Span bounds); // Logs the provided status message with a backtrace. // // For use by Status-factories, logs a backtrace at the point where the status // is created, such that we can use --vmodule=util=1 to see all status // creation backtraces. Status WithLogBacktrace(const Status& status); // Ranks greater than 6 are very rare, so use InlinedVector to store // the bounds and indices. And for the rare cases of ranks greater than 6, // the InlinedVector will just behave like an std::vector<> and allocate the // memory to store its values. inline constexpr int InlineRank() { return 6; } using DimensionVector = absl::InlinedVector; using DimLevelTypeVector = absl::InlinedVector; // RAII timer that logs with a given label the wall clock time duration in human // readable form. This differs from base's ElapsedTimer primarily in that it // spits out the human-readable duration form. // // Keeps track of global maximum and cumulative times across all invocations. // // By default, the timing traces are only printed at VLOG(1) and above: // // XLA_SCOPED_LOGGING_TIMER("fooing bar"); // nop if !VLOG_IS_ON(1). // // but you can control this via: // // XLA_SCOPED_LOGGING_TIMER_LEVEL("fooing bar", 2); // nop if !VLOG_IS_ON(2) // #define XLA_SCOPED_LOGGING_TIMER(label) \ XLA_SCOPED_LOGGING_TIMER_HELPER(label, 1, __COUNTER__, /*condition=*/true) #define XLA_SCOPED_LOGGING_TIMER_LEVEL(label, level) \ XLA_SCOPED_LOGGING_TIMER_HELPER(label, level, __COUNTER__, /*condition=*/true) // The timer trace is only printed if the condition is true. #define XLA_SCOPED_LOGGING_TIMER_IF(label, condition) \ XLA_SCOPED_LOGGING_TIMER_HELPER(label, 1, __COUNTER__, (condition)) // Helper for implementing macros above. Do not use directly. // // Forces the evaluation of "counter", which we expect is equal to __COUNTER__. #define XLA_SCOPED_LOGGING_TIMER_HELPER(label, level, counter, condition) \ XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter, (condition)) // Helper for macros above. Don't use directly. #define XLA_SCOPED_LOGGING_TIMER_HELPER2(label, level, counter, condition) \ static ::xla::TimerStats XLA_TimerStats##counter; \ ::xla::ScopedLoggingTimer XLA_ScopedLoggingTimerInstance##counter( \ label, /*enabled=*/VLOG_IS_ON(level) && (condition), __FILE__, __LINE__, \ &XLA_TimerStats##counter); struct TimerStats { absl::Mutex stats_mutex; double cumulative_secs ABSL_GUARDED_BY(stats_mutex) = 0; double max_secs ABSL_GUARDED_BY(stats_mutex) = 0; uint64_t times_called ABSL_GUARDED_BY(stats_mutex) = 0; }; // RAII timer for XLA_SCOPED_LOGGING_TIMER and XLA_SCOPED_LOGGING_TIMER_LEVEL // macros above. Recommended usage is via the macros so you don't have to give // the timer a name or worry about calling VLOG_IS_ON yourself. class ScopedLoggingTimer { public: // label: Label to display for logging. // enabled: Whether this timer should do anything at all. // file: Filename to display in logging. // line: Line number to display in logging. // `timer_stats`: unowned non-null pointer which is used to populate the // global timer statistics. ScopedLoggingTimer(absl::string_view label, bool enabled, const char* file, int line, TimerStats* timer_stats); // Stop the timer and log the tracked time. Timer is disabled after this // function is called. void StopAndLog(); ~ScopedLoggingTimer(); private: const std::string label_; const char* const file_; const int line_; TimerStats* const timer_stats_; uint64_t start_micros_; bool enabled_; }; // Turns an immutable slice of type T into an immutable slice of bytes with the // same byte size. template absl::Span CastToByteSlice(absl::Span slice) { return absl::Span( reinterpret_cast(slice.data()), slice.size() * sizeof(T)); } // Casts a byte slice to a non-byte type T, checking that the original slice // length is a multiple of sizeof(T). template absl::Span CastByteSlice(absl::Span slice) { CHECK_EQ(0, slice.size() % sizeof(T)); return absl::Span(reinterpret_cast(slice.data()), slice.size() / sizeof(T)); } // Compares two containers for equality. Returns true iff the two containers // have the same size and all their elements compare equal using their // operator==. Like std::equal, but forces size equality. template bool ContainersEqual(const Container1T& c1, std::initializer_list il) { absl::Span c2{il}; return absl::c_equal(c1, c2); } #if defined(__cpp_lib_to_underlying) && __cpp_lib_to_underlying >= 202102L using to_underlying = std::to_underlying; #else // Helper function which implements C++23's std::to_underlying. template constexpr std::underlying_type_t to_underlying(T value) noexcept { return static_cast>(value); } #endif // Performs a copy of count values from src to dest, using different strides for // source and destination. The source starting index is src_base, while the // destination one is dest_base. template void StridedCopy(D* dest, int64_t dest_stride, const S* src, int64_t src_stride, int64_t count) { const S* src_end = src + count * src_stride; DCHECK_LT(src, src_end); for (; src < src_end; dest += dest_stride, src += src_stride) { *dest = static_cast(*src); } } // Adds some context information to the error message in a // Status. This is useful as Statuses are // propagated upwards. Status AddStatus(Status prior, absl::string_view context); Status AppendStatus(Status prior, absl::string_view context); // This macro defines the arguments to be used as an error // message to be passed to absl::StrFormat, and returns a status in the // canonical error space. #define DEFINE_XLA_ERROR_WITH_STRFORMAT_WITH_BACKTRACE(error_type) \ template \ Status error_type(const absl::FormatSpec& format, \ const Args&... args) { \ return WithLogBacktrace( \ tsl::errors::error_type(absl::StrFormat(format, args...))); \ } DEFINE_XLA_ERROR_WITH_STRFORMAT_WITH_BACKTRACE(InvalidArgument); DEFINE_XLA_ERROR_WITH_STRFORMAT_WITH_BACKTRACE(Unimplemented); DEFINE_XLA_ERROR_WITH_STRFORMAT_WITH_BACKTRACE(Internal); DEFINE_XLA_ERROR_WITH_STRFORMAT_WITH_BACKTRACE(FailedPrecondition); DEFINE_XLA_ERROR_WITH_STRFORMAT_WITH_BACKTRACE(Cancelled); DEFINE_XLA_ERROR_WITH_STRFORMAT_WITH_BACKTRACE(ResourceExhausted); DEFINE_XLA_ERROR_WITH_STRFORMAT_WITH_BACKTRACE(NotFound); DEFINE_XLA_ERROR_WITH_STRFORMAT_WITH_BACKTRACE(Unavailable); DEFINE_XLA_ERROR_WITH_STRFORMAT_WITH_BACKTRACE(Unknown); #undef DEFINE_XLA_ERROR_WITH_STRFORMAT_WITH_BACKTRACE // The following three macros define a common set of code for creating // absl::Status errors with the given error_type, with the addition of adding // absl::SourceLocation if it's available (PLATFORM_GOOGLE). They're a // complicated by the need to use #ifdefs within the code. This would be the // equivalent code for ResourceExhausted if a #define macro could have embedded // #ifdef directives: // // template // struct ResourceExhaustedStrCat { // Status status; // #if defined(PLATFORM_GOOGLE) // // NOLINTNEXTLINE(google-explicit-constructor) // ResourceExhaustedStrCat(Args&&... concat, absl::SourceLocation loc = // absl::SourceLocation::current()) // : status(WithLogBacktrace( // tsl::errors::ResourceExhausted(std::forward(concat)...) // .WithSourceLocation(loc))) {} // #else // ResourceExhaustedStrCat(Args&&... concat) // : status(WithLogBacktrace( // tsl::errors::ResourceExhausted(std::forward(concat)...))) // {} // #endif // // // NOLINTNEXTLINE(google-explicit-constructor) // operator Status() const { return status; } // }; // #define XLA_ERROR_WITH_STRCAT_AND_BACKTRACE_PREFIX(error_type) \ template \ struct error_type##StrCat { \ Status status; \ /* NOLINTNEXTLINE(google-explicit-constructor) */ #define XLA_ERROR_WITH_STRCAT_AND_BACKTRACE_SUFFIX(error_type) \ /* NOLINTNEXTLINE(google-explicit-constructor) */ \ operator Status() const { return status; } \ } \ ; \ /*Deduction guide to make variadic arguments play nice with default */ \ /* absl::SourceLocation argument. */ \ template \ error_type##StrCat(Args&&...)->error_type##StrCat; #if defined(PLATFORM_GOOGLE) #define XLA_ERROR_WITH_STRCAT_AND_BACKTRACE(error_type) \ XLA_ERROR_WITH_STRCAT_AND_BACKTRACE_PREFIX(error_type) \ error_type##StrCat(Args&&... concat, absl::SourceLocation loc = \ absl::SourceLocation::current()) \ : status(WithLogBacktrace( \ tsl::errors::error_type(std::forward(concat)...) \ .WithSourceLocation(loc))) {} \ XLA_ERROR_WITH_STRCAT_AND_BACKTRACE_SUFFIX(error_type) #else #define XLA_ERROR_WITH_STRCAT_AND_BACKTRACE(error_type) \ XLA_ERROR_WITH_STRCAT_AND_BACKTRACE_PREFIX(error_type) \ error_type##StrCat(Args&&... concat) \ : status(WithLogBacktrace( \ tsl::errors::error_type(std::forward(concat)...))) {} \ XLA_ERROR_WITH_STRCAT_AND_BACKTRACE_SUFFIX(error_type) #endif XLA_ERROR_WITH_STRCAT_AND_BACKTRACE(ResourceExhausted); XLA_ERROR_WITH_STRCAT_AND_BACKTRACE(InvalidArgument); XLA_ERROR_WITH_STRCAT_AND_BACKTRACE(Unimplemented); XLA_ERROR_WITH_STRCAT_AND_BACKTRACE(Internal); #undef XLA_ERROR_WITH_STRCAT_AND_BACKTRACE #undef XLA_ERROR_WITH_STRCAT_AND_BACKTRACE_PREFIX #undef XLA_ERROR_WITH_STRCAT_AND_BACKTRACE_SUFFIX // Splits the lines of the original, replaces leading whitespace with the prefix // given by "indentation", and returns the string joined by newlines again. As a // side effect, any additional trailing whitespace is removed. // // Note: even different amounts of leading whitespace on different lines will be // uniformly replaced with "indentation". std::string Reindent(absl::string_view original, absl::string_view indentation); template int64_t PositionInContainer(const Container& container, int64_t value) { return std::distance(container.begin(), absl::c_find(container, value)); } // Formats the container as a comma-separated string. StrAppend must support // appending the elements of the container. Prefix is prepended and suffix is // appended to the returned string. template std::string CommaSeparatedString(const Container& c, const char* prefix = "", const char* suffix = "") { // Not using Join() since the implementation here is simple anyway and this // avoids copying the string to append prefix. std::string comma_separated = prefix; const char* separator = ""; for (const auto& entry : c) { absl::StrAppend(&comma_separated, separator, entry); separator = ", "; } comma_separated += suffix; return comma_separated; } // Overload needed to allow the container to be an initializer list. The default // type for T makes an empty initializer list work as well. template std::string CommaSeparatedString(const std::initializer_list& c, const char* prefix = "", const char* suffix = "") { return CommaSeparatedString>(c, prefix, suffix); } // Formats the container in the mathematical notation for a vector, e.g. (1, 3, // 7). StrAppend must support appending the elements of c. template std::string VectorString(const Container& c) { return CommaSeparatedString(c, "(", ")"); } // Overload needed to allow the container to be an initializer list. The default // type for T makes an empty initializer list work as well. template std::string VectorString(const std::initializer_list& c) { return VectorString>(c); } // Returns a string which can losslessly round trip to a float8 E5M2. std::string RoundTripFpToString(tsl::float8_e5m2 value); // Returns a string which can losslessly round trip to a float8 E4M3. std::string RoundTripFpToString(tsl::float8_e4m3fn value); // Returns a string which can losslessly round trip to a float8 E4M3B11. std::string RoundTripFpToString(tsl::float8_e4m3b11 value); // Returns a string which can losslessly round trip to a float8 E5M2FNUZ. std::string RoundTripFpToString(tsl::float8_e5m2fnuz value); // Returns a string which can losslessly round trip to a float8 E4M3FNUZ. std::string RoundTripFpToString(tsl::float8_e4m3fnuz value); // Returns a string which can losslessly round trip to a bfloat. std::string RoundTripFpToString(tsl::bfloat16 value); // Returns a string which can losslessly round trip to a fp16. std::string RoundTripFpToString(Eigen::half value); // Returns a string which can losslessly round trip to a float. std::string RoundTripFpToString(float value); // Returns a string which can losslessly round trip to a double. std::string RoundTripFpToString(double value); // Returns a PaddingConfig object that represents no padding for the given rank. PaddingConfig MakeNoPaddingConfig(int64_t rank); // Returns a PaddingConfig object where 'padding' contains // (low edge padding, high edge padding) pairs for each dimension. PaddingConfig MakeEdgePaddingConfig( absl::Span> padding); // Returns true if the padding configuration has at least one dimension with // non-zero interior padding. bool HasInteriorPadding(const PaddingConfig& config); // Imports the templated FloorOfRatio math function from the TensorFlow // namespace, as it is very commonly used. template constexpr T FloorOfRatio(T dividend, T divisor) { return tsl::MathUtil::FloorOfRatio(dividend, divisor); } // Imports the templated CeilOfRatio math function from the TensorFlow // namespace, as it is very commonly used. template constexpr T CeilOfRatio(T dividend, T divisor) { return tsl::MathUtil::CeilOfRatio(dividend, divisor); } // Rounds the value up to a multiple of the divisor by first calling CeilOfRatio // then multiplying by the divisor. For example: RoundUpTo(13, 8) => 16 template constexpr T RoundUpTo(T value, T divisor) { return CeilOfRatio(value, divisor) * divisor; } // Rounds the value down to a multiple of the divisor by first calling // FloorOfRatio then multiplying by the divisor. For example: // RoundDownTo(13, 8) => 8 template constexpr T RoundDownTo(T value, T divisor) { return FloorOfRatio(value, divisor) * divisor; } template struct DivMod { T quotient; T modulo; }; // Divide `dividend` by `divisor` such that the quotient is rounded towards // negative infinity. The remainder will have the same sign as `divisor`. template constexpr DivMod FloorDivMod(T dividend, T divisor) { DivMod div_mod; div_mod.quotient = FloorOfRatio(dividend, divisor); div_mod.modulo = dividend - div_mod.quotient * divisor; return div_mod; } // Given a number of flops executed in an amount of time, produces a string that // represents the throughput; // e.g. HumanReadableNumFlops(1e9, 1e9) => 1.00GFLOP/s. std::string HumanReadableNumFlops(double flops, double nanoseconds); // Given a number of transcendental ops executed in an amount of time, produces // a string that represents the throughput; // e.g. HumanReadableNumTranscendentalOps(1e9, 1e9) => 1.00GTROP/s. std::string HumanReadableNumTranscendentalOps(double trops, double nanoseconds); // Split the text into multiple lines and log each line with the given // severity, filename, and line number. void LogLines(int sev, absl::string_view text, const char* fname, int lineno); // Returns a mask with "width" number of least significant bits set. template constexpr inline T LsbMask(int width) { static_assert(std::is_unsigned::value, "T should be an unsigned integer type"); ABSL_ASSERT(width >= 0); ABSL_ASSERT(width <= std::numeric_limits::digits); return width == 0 ? 0 : static_cast(-1) >> (std::numeric_limits::digits - width); } // Return floor(log2(n)) for positive integer n. Returns -1 iff n == 0. template constexpr inline int Log2Floor(T x) { static_assert(std::is_unsigned::value, "T should be an unsigned integer type"); return absl::bit_width(x) - 1; } // Return ceiling(log2(n)) for positive integer n. Returns -1 iff n == 0. template constexpr inline int Log2Ceiling(T x) { static_assert(std::is_unsigned::value, "T should be an unsigned integer type"); return x == 0 ? -1 : absl::bit_width(x - 1); } // Return the number of sign bits (i.e. the number of leading ones for negative // numbers and the number of leading zeros for non-negative numbers). template constexpr inline int CountLeadingSignBits(T x) { static_assert(std::is_signed::value, "T should be a signed integer type"); using UnsignedType = std::make_unsigned_t; return x < T{0} ? absl::countl_one(x) : absl::countl_zero(x); } // Returns `value` with the low `width` bits set and the remaining bits set to // zero. template constexpr inline T KeepLowerBits(T value, int width) { return value & LsbMask(width); } // Returns `base` multiplied by itself `exponent` number of times. // // Note: returns 1 when `exponent` is zero. // Precondition: `exponent` is non-negative for integral `T`. template constexpr T IPow(T base, ExpType exponent) { static_assert(std::numeric_limits::is_integer); if constexpr (std::numeric_limits::is_integer) { // A negative `exponent` is indicative of a logic bug for integral `base`. // We disallow it for floating-point types for symmetry. ABSL_ASSERT(exponent >= 0); } const bool take_reciprocal = exponent < 0; // We use the right-to-left binary exponentiation algorithm. T result(1); for (;;) { if ((exponent & 1) != 0) { result *= base; } exponent /= 2; if (exponent == 0) { break; } base *= base; } if constexpr (std::numeric_limits::is_signed) { if (take_reciprocal) { return T(1) / result; } } return result; } // UnsignedIntegerTypeForSize gets an unsigned integer with the given size in // bytes. template struct UnsignedIntegerTypeForSize; template <> struct UnsignedIntegerTypeForSize<1> { using type = uint8_t; }; template <> struct UnsignedIntegerTypeForSize<2> { using type = uint16_t; }; template <> struct UnsignedIntegerTypeForSize<4> { using type = uint32_t; }; template <> struct UnsignedIntegerTypeForSize<8> { using type = uint64_t; }; template using UnsignedIntegerTypeForSizeType = typename UnsignedIntegerTypeForSize::type; template using SignedIntegerTypeForSizeType = std::make_signed_t>; template auto SignAndMagnitude(T x) { using BitType = UnsignedIntegerTypeForSizeType; BitType x_abs_bits = Eigen::numext::bit_cast(Eigen::numext::abs(x)); const BitType x_bits = Eigen::numext::bit_cast(x); const BitType x_sign = x_bits ^ x_abs_bits; if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { // f8e4m3b11, f8e4m3fnuz, and f8e5m2fnuz don't support -0, adjust negative // numbers to fill in the gap. if (x_sign) { x_abs_bits -= 1; } } return std::make_pair(x_sign, x_abs_bits); } template auto SignAndMagnitudeToTwosComplement(T sign, T magnitude) { static_assert(!std::numeric_limits::is_signed); using SignedType = std::make_signed_t; return static_cast(magnitude) ^ (static_cast(sign) < 0 ? SignedType{-1} : SignedType{0}); } // Returns the signed magnitude of T. template auto ToSignMagnitude(T input) { auto [sign, magnitude] = SignAndMagnitude(input); return SignAndMagnitudeToTwosComplement(sign, magnitude); } template constexpr int NanPayloadBits() { // Floating point types with signaling NaNs have payloads. if constexpr (!std::numeric_limits::has_signaling_NaN) { return 0; } return std::numeric_limits::digits - 1; } template constexpr uint64_t QuietNanWithoutPayload() { constexpr int bits = NanPayloadBits(); if constexpr (bits > 0) { return uint64_t{1} << (bits - 1); } return 0; } template constexpr uint64_t NanPayloadBitMask() { constexpr int bits = NanPayloadBits(); if constexpr (bits > 0) { return LsbMask(bits); } return 0; } template T NanWithSignAndPayload(bool sign, uint64_t nan_payload) { static_assert(NanPayloadBits() > 0); using RepT = UnsignedIntegerTypeForSizeType; // Clear the sign bit. T val = Eigen::numext::abs(std::numeric_limits::quiet_NaN()); // Conditionally set the sign bit. if (sign) { val = -val; } auto rep = absl::bit_cast(val); rep |= uint64_t{sign} << (std::numeric_limits::digits - 1); constexpr int kPayloadBits = NanPayloadBits(); if (kPayloadBits > 0) { // Clear rep's NaN payload. rep &= ~NanPayloadBitMask(); CHECK_NE(nan_payload, 0); rep |= nan_payload; } return absl::bit_cast(rep); } // Utility for performing a down_cast<> on a std::unique_ptr<>. template std::unique_ptr unique_ptr_down_cast(std::unique_ptr ptr) { return absl::WrapUnique(tensorflow::down_cast(ptr.release())); } int64_t Product(absl::Span xs); // Returns the start indices of consecutive non-overlapping subsequences of `a` // and `b` with the same product, i.e. `(i, j)` so // • a = {a[0 = i_0], ..., a[i_1 - 1], a[i_1], ... , a[i_2 - 1], ...} // • b = {b[0 = j_0], ..., b[j_1 - 1], b[j_1], ... , b[j_2 - 1], ...} // • ∀ k . 0 <= k < CommonFactors(a, b).size - 1 => // a[i_k] × a[i_k + 1] × ... × a[i_(k+1) - 1] = // b[j_k] × b[j_k + 1] × ... × b[j_(k+1) - 1] // where `CommonFactors(a, b)[CommonFactors(a, b).size - 1] = (a.size, b.size)` // // If input and output are the same, return {(0, 0), {1, 1}, ... {a.size, // b.size}}, otherwise if the given shapes have non-zero size, returns the // bounds of the shortest possible such subsequences; else, returns `{(0, 0), // (a.size, b.size)}`. absl::InlinedVector, 8> CommonFactors( absl::Span a, absl::Span b); struct ConvertedDimensionNumbers { DimensionVector transformed_from_dimensions; DimensionVector untransformed_from_dimensions; DimensionVector to_dimensions; DimensionVector split_from_dimensions; DimensionVector split_from_sizes; DimensionVector split_to_dimensions; }; // Convert and unsorted list of dimensions from one shapes dimension sizes to // another shapes dimensions sizes. ConvertedDimensionNumbers ConvertDimensionNumbers( absl::Span from_dimensions, absl::Span from_sizes, absl::Span to_sizes); // Removes illegal characters from filenames. std::string SanitizeFileName(std::string file_name); // Check that a sequence of distinct numbers can form a continuous interval. bool DistinctNumbersAreConsecutiveIfSorted(absl::Span); template int64_t FindIndex(const C& c, Value&& value) { auto it = absl::c_find(c, std::forward(value)); return std::distance(c.begin(), it); } template void InsertAt(C* c, int64_t index, Value&& value) { c->insert(c->begin() + index, std::forward(value)); } template void EraseAt(C* c, int64_t index) { c->erase(c->begin() + index); } template std::vector SpanToVector(absl::Span slice) { return std::vector(slice.begin(), slice.end()); } template std::vector InlinedVectorToVector( const absl::InlinedVector& inlined_vector) { return std::vector(inlined_vector.begin(), inlined_vector.end()); } // Returns true if `x` fits in 32-bits. template bool IsInt32(T x) { // Following conversion rules: "the value is unchanged if it can be // represented in the destination type (and bit-field width); otherwise, the // value is implementation-defined." return static_cast(x) == x; } template Status EraseElementFromVector(std::vector* container, const T& value) { // absl::c_find returns a const_iterator which does not seem to work on // gcc 4.8.4, and this breaks the ubuntu/xla_gpu build bot. auto it = std::find(container->begin(), container->end(), value); TF_RET_CHECK(it != container->end()); container->erase(it); return OkStatus(); } // Takes a sequence of unpacked int4 values, such that every byte stores one // int4 value in the low-order four bits, and packs them so every byte stores // two int4 values. 'input' should have num_elements bytes; 'output' should have // (num_elements+1)/2 bytes. The high-order four bits of each byte in 'input' // are ignored. void PackInt4(absl::Span input, absl::Span output); // Takes a sequence of packed int4 values, such that every byte stores two // int4 values, and unpacks them so every byte stores one int4 value in the // low-order four bits. 'input' should have (num_elements+1)/2 bytes; 'output' // should have num_elements bytes. The high-order 4-bits in each output are // zero. void UnpackInt4(absl::Span input, absl::Span output); class HloInstruction; class HloModule; // A predicate over HLO instruction. using HloPredicate = std::function; using HloModulePredicate = std::function; inline bool HloPredicateTrue(const HloInstruction*) { return true; } inline bool HloPredicateFalse(const HloInstruction*) { return false; } using Vector2 = std::array; using Vector3 = std::array; } // namespace xla #define XLA_LOG_LINES(SEV, STRING) \ ::xla::LogLines(SEV, STRING, __FILE__, __LINE__) #define XLA_VLOG_LINES(LEVEL, STRING) \ do { \ if (VLOG_IS_ON(LEVEL)) XLA_LOG_LINES(::tsl::INFO, STRING); \ } while (false); // Utility macro that performs the equivalent of what one would expect // LOG_LINES(FATAL, X) to do but can be used at the end of a function that // returns a value without getting a compiler warning that no value is returned. #define XLA_FATAL_LOG(X) \ XLA_LOG_LINES(::tsl::ERROR, X); \ LOG(FATAL) << "Aborting in " << __FUNCTION__ << " due to previous errors."; #endif // XLA_UTIL_H_