forked from openkylin/astroid
212 lines
8.0 KiB
Python
212 lines
8.0 KiB
Python
# Copyright (c) 2019-2021 hippo91 <guillaume.peillex@gmail.com>
|
|
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
|
|
# Copyright (c) 2020 Claudiu Popa <pcmanticore@gmail.com>
|
|
# Copyright (c) 2021 Pierre Sassoulas <pierre.sassoulas@gmail.com>
|
|
# Copyright (c) 2021 Daniël van Noord <13665637+DanielNoord@users.noreply.github.com>
|
|
# Copyright (c) 2021 Marc Mueller <30130371+cdce8p@users.noreply.github.com>
|
|
|
|
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
|
|
# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
|
|
import unittest
|
|
|
|
try:
|
|
import numpy # pylint: disable=unused-import
|
|
|
|
HAS_NUMPY = True
|
|
except ImportError:
|
|
HAS_NUMPY = False
|
|
|
|
from astroid import builder
|
|
|
|
|
|
@unittest.skipUnless(HAS_NUMPY, "This test requires the numpy library.")
|
|
class BrainNumpyCoreMultiarrayTest(unittest.TestCase):
|
|
"""
|
|
Test the numpy core multiarray brain module
|
|
"""
|
|
|
|
numpy_functions_returning_array = (
|
|
("array", "[1, 2]"),
|
|
("bincount", "[1, 2]"),
|
|
("busday_count", "('2011-01', '2011-02')"),
|
|
("busday_offset", "'2012-03', -1, roll='forward'"),
|
|
("concatenate", "([1, 2], [1, 2])"),
|
|
("datetime_as_string", "['2012-02', '2012-03']"),
|
|
("dot", "[1, 2]", "[1, 2]"),
|
|
("empty_like", "[1, 2]"),
|
|
("inner", "[1, 2]", "[1, 2]"),
|
|
("is_busday", "['2011-07-01', '2011-07-02', '2011-07-18']"),
|
|
("lexsort", "(('toto', 'tutu'), ('riri', 'fifi'))"),
|
|
("packbits", "np.array([1, 2])"),
|
|
("unpackbits", "np.array([[1], [2], [3]], dtype=np.uint8)"),
|
|
("vdot", "[1, 2]", "[1, 2]"),
|
|
("where", "[True, False]", "[1, 2]", "[2, 1]"),
|
|
("empty", "[1, 2]"),
|
|
("zeros", "[1, 2]"),
|
|
)
|
|
|
|
numpy_functions_returning_bool = (
|
|
("can_cast", "np.int32, np.int64"),
|
|
("may_share_memory", "np.array([1, 2])", "np.array([3, 4])"),
|
|
("shares_memory", "np.array([1, 2])", "np.array([3, 4])"),
|
|
)
|
|
|
|
numpy_functions_returning_dtype = (
|
|
# ("min_scalar_type", "10"), # Not yet tested as it returns np.dtype
|
|
# ("result_type", "'i4'", "'c8'"), # Not yet tested as it returns np.dtype
|
|
)
|
|
|
|
numpy_functions_returning_none = (("copyto", "([1, 2], [1, 3])"),)
|
|
|
|
numpy_functions_returning_tuple = (
|
|
(
|
|
"unravel_index",
|
|
"[22, 33, 44]",
|
|
"(6, 7)",
|
|
), # Not yet tested as is returns a tuple
|
|
)
|
|
|
|
def _inferred_numpy_func_call(self, func_name, *func_args):
|
|
node = builder.extract_node(
|
|
f"""
|
|
import numpy as np
|
|
func = np.{func_name:s}
|
|
func({','.join(func_args):s})
|
|
"""
|
|
)
|
|
return node.infer()
|
|
|
|
def _inferred_numpy_no_alias_func_call(self, func_name, *func_args):
|
|
node = builder.extract_node(
|
|
f"""
|
|
import numpy
|
|
func = numpy.{func_name:s}
|
|
func({','.join(func_args):s})
|
|
"""
|
|
)
|
|
return node.infer()
|
|
|
|
def test_numpy_function_calls_inferred_as_ndarray(self):
|
|
"""
|
|
Test that calls to numpy functions are inferred as numpy.ndarray
|
|
"""
|
|
for infer_wrapper in (
|
|
self._inferred_numpy_func_call,
|
|
self._inferred_numpy_no_alias_func_call,
|
|
):
|
|
for func_ in self.numpy_functions_returning_array:
|
|
with self.subTest(typ=func_):
|
|
inferred_values = list(infer_wrapper(*func_))
|
|
self.assertTrue(
|
|
len(inferred_values) == 1,
|
|
msg="Too much inferred values ({}) for {:s}".format(
|
|
inferred_values, func_[0]
|
|
),
|
|
)
|
|
self.assertTrue(
|
|
inferred_values[-1].pytype() == ".ndarray",
|
|
msg="Illicit type for {:s} ({})".format(
|
|
func_[0], inferred_values[-1].pytype()
|
|
),
|
|
)
|
|
|
|
def test_numpy_function_calls_inferred_as_bool(self):
|
|
"""
|
|
Test that calls to numpy functions are inferred as bool
|
|
"""
|
|
for infer_wrapper in (
|
|
self._inferred_numpy_func_call,
|
|
self._inferred_numpy_no_alias_func_call,
|
|
):
|
|
for func_ in self.numpy_functions_returning_bool:
|
|
with self.subTest(typ=func_):
|
|
inferred_values = list(infer_wrapper(*func_))
|
|
self.assertTrue(
|
|
len(inferred_values) == 1,
|
|
msg="Too much inferred values ({}) for {:s}".format(
|
|
inferred_values, func_[0]
|
|
),
|
|
)
|
|
self.assertTrue(
|
|
inferred_values[-1].pytype() == "builtins.bool",
|
|
msg="Illicit type for {:s} ({})".format(
|
|
func_[0], inferred_values[-1].pytype()
|
|
),
|
|
)
|
|
|
|
def test_numpy_function_calls_inferred_as_dtype(self):
|
|
"""
|
|
Test that calls to numpy functions are inferred as numpy.dtype
|
|
"""
|
|
for infer_wrapper in (
|
|
self._inferred_numpy_func_call,
|
|
self._inferred_numpy_no_alias_func_call,
|
|
):
|
|
for func_ in self.numpy_functions_returning_dtype:
|
|
with self.subTest(typ=func_):
|
|
inferred_values = list(infer_wrapper(*func_))
|
|
self.assertTrue(
|
|
len(inferred_values) == 1,
|
|
msg="Too much inferred values ({}) for {:s}".format(
|
|
inferred_values, func_[0]
|
|
),
|
|
)
|
|
self.assertTrue(
|
|
inferred_values[-1].pytype() == "numpy.dtype",
|
|
msg="Illicit type for {:s} ({})".format(
|
|
func_[0], inferred_values[-1].pytype()
|
|
),
|
|
)
|
|
|
|
def test_numpy_function_calls_inferred_as_none(self):
|
|
"""
|
|
Test that calls to numpy functions are inferred as None
|
|
"""
|
|
for infer_wrapper in (
|
|
self._inferred_numpy_func_call,
|
|
self._inferred_numpy_no_alias_func_call,
|
|
):
|
|
for func_ in self.numpy_functions_returning_none:
|
|
with self.subTest(typ=func_):
|
|
inferred_values = list(infer_wrapper(*func_))
|
|
self.assertTrue(
|
|
len(inferred_values) == 1,
|
|
msg="Too much inferred values ({}) for {:s}".format(
|
|
inferred_values, func_[0]
|
|
),
|
|
)
|
|
self.assertTrue(
|
|
inferred_values[-1].pytype() == "builtins.NoneType",
|
|
msg="Illicit type for {:s} ({})".format(
|
|
func_[0], inferred_values[-1].pytype()
|
|
),
|
|
)
|
|
|
|
def test_numpy_function_calls_inferred_as_tuple(self):
|
|
"""
|
|
Test that calls to numpy functions are inferred as tuple
|
|
"""
|
|
for infer_wrapper in (
|
|
self._inferred_numpy_func_call,
|
|
self._inferred_numpy_no_alias_func_call,
|
|
):
|
|
for func_ in self.numpy_functions_returning_tuple:
|
|
with self.subTest(typ=func_):
|
|
inferred_values = list(infer_wrapper(*func_))
|
|
self.assertTrue(
|
|
len(inferred_values) == 1,
|
|
msg="Too much inferred values ({}) for {:s}".format(
|
|
inferred_values, func_[0]
|
|
),
|
|
)
|
|
self.assertTrue(
|
|
inferred_values[-1].pytype() == "builtins.tuple",
|
|
msg="Illicit type for {:s} ({})".format(
|
|
func_[0], inferred_values[-1].pytype()
|
|
),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|