105 lines
4.1 KiB
C++
105 lines
4.1 KiB
C++
/* Copyright 2018 The OpenXLA Authors.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
// Library for comparing literals without taking a dependency on testing
|
|
// libraries.
|
|
|
|
#ifndef XLA_LITERAL_COMPARISON_H_
|
|
#define XLA_LITERAL_COMPARISON_H_
|
|
|
|
#include <cstdint>
|
|
#include <functional>
|
|
#include <optional>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "xla/error_spec.h"
|
|
#include "xla/literal.h"
|
|
#include "xla/shape.h"
|
|
#include "xla/shape_util.h"
|
|
#include "xla/status.h"
|
|
|
|
namespace xla {
|
|
namespace literal_comparison {
|
|
|
|
// Returns ok if the given shapes have the same rank, dimension sizes, and
|
|
// primitive types.
|
|
Status EqualShapes(const Shape& expected, const Shape& actual);
|
|
|
|
// Returns ok if the given literals share identical dynamic shapes and
|
|
// dimension sizes.
|
|
Status EqualDynamicShapesAndDimensions(const LiteralSlice& expected,
|
|
const LiteralSlice& actual);
|
|
|
|
// Returns ok if the expected and actual literals are (bitwise) equal for all
|
|
// elements in the literal. Also, asserts that the rank, dimensions sizes, and
|
|
// primitive type are equal.
|
|
Status Equal(const LiteralSlice& expected, const LiteralSlice& actual);
|
|
|
|
// Structure that contains the distribution of absolute and relative errors,
|
|
// bucketized into five buckets: [0.0001, 0.001, 0.01, 0.1, 1].
|
|
// Useful to understand the distribution of errors and set the permissible
|
|
// error bounds in an ErrorSpec.
|
|
struct ErrorBuckets {
|
|
explicit ErrorBuckets(const std::vector<int64_t>& absolute_error_buckets = {},
|
|
const std::vector<int64_t>& rel_error_buckets = {})
|
|
: abs_error_buckets(absolute_error_buckets),
|
|
rel_error_buckets(rel_error_buckets) {}
|
|
|
|
const std::vector<int64_t> abs_error_buckets;
|
|
const std::vector<int64_t> rel_error_buckets;
|
|
};
|
|
|
|
using MiscompareCallback = std::function<void(
|
|
const LiteralSlice& expected, const LiteralSlice& actual,
|
|
const LiteralSlice& mismatches, const ShapeIndex& shape_index,
|
|
const ErrorBuckets& error_buckets)>;
|
|
|
|
// Inspects whether the expected and actual literals are within the given error
|
|
// bound for all elements. Also, inspects whether the rank, dimensions sizes,
|
|
// and dimension bounds are equivalent.
|
|
//
|
|
// Tuples are matched recursively.
|
|
//
|
|
// When comparing tensors of non-floating-point type, this inspects for exact
|
|
// equality, ignoring the ErrorSpec.
|
|
//
|
|
// If the shape of the literals is neither a complex/floating-point tensor nor a
|
|
// tuple which contains a complex/floating-point tensor, Near() is equivalent to
|
|
// Equal(). We don't raise an error in this case, because we want to allow
|
|
// callers to call Near() even if they have no preconceptions about the shapes
|
|
// being compared.
|
|
//
|
|
// If detailed_message is true, then the error message in the assertion result
|
|
// will contain a more detailed breakdown of mismatches. By default, we display
|
|
// a detailed message only for "large" inputs.
|
|
//
|
|
// If miscompare_callback is nullptr, Near will return an error on the first
|
|
// detected mismatch.
|
|
Status Near(const LiteralSlice& expected, const LiteralSlice& actual,
|
|
const ErrorSpec& error, std::optional<bool> detailed_message,
|
|
const MiscompareCallback& miscompare_callback);
|
|
|
|
// Calling ToString on a literal with over 100 million elements takes around
|
|
// 3 minutes. The utility of printing a literal with >1000 elements is
|
|
// questionable, especially when writing the Literal proto to disk is orders
|
|
// of magnitude faster.
|
|
std::string ToStringTruncated(const LiteralSlice& literal);
|
|
|
|
} // namespace literal_comparison
|
|
} // namespace xla
|
|
|
|
#endif // XLA_LITERAL_COMPARISON_H_
|