Projekt_AI-Automatyczny_saper/venv/Lib/site-packages/caffe2/python/operator_test/given_tensor_fill_op_test.py
2021-06-01 17:38:31 +02:00

48 lines
1.5 KiB
Python

from caffe2.python import core
from hypothesis import given
import hypothesis.strategies as st
import caffe2.python.hypothesis_test_util as hu
import numpy as np
import unittest
class TestGivenTensorFillOps(hu.HypothesisTestCase):
@given(X=hu.tensor(min_dim=1, max_dim=4, dtype=np.int32),
t=st.sampled_from([
(core.DataType.BOOL, np.bool_, "GivenTensorFill"),
(core.DataType.INT32, np.int32, "GivenTensorFill"),
(core.DataType.FLOAT, np.float32, "GivenTensorFill"),
(core.DataType.INT16, np.int16, "GivenTensorInt16Fill"),
(core.DataType.INT32, np.int32, "GivenTensorIntFill"),
(core.DataType.INT64, np.int64, "GivenTensorInt64Fill"),
(core.DataType.BOOL, np.bool_, "GivenTensorBoolFill"),
(core.DataType.DOUBLE, np.double, "GivenTensorDoubleFill"),
(core.DataType.INT32, np.double, "GivenTensorDoubleFill"),
]),
**hu.gcs)
def test_given_tensor_fill(self, X, t, gc, dc):
X = X.astype(t[1])
print('X: ', str(X))
op = core.CreateOperator(
t[2], [], ["Y"],
shape=X.shape,
dtype=t[0],
values=X.reshape((1, X.size)),
)
def constant_fill(*args, **kw):
return [X]
self.assertReferenceChecks(gc, op, [], constant_fill)
self.assertDeviceChecks(dc, op, [], [0])
if __name__ == "__main__":
unittest.main()