forked from openkylin/astroid
189 lines
5.3 KiB
Python
189 lines
5.3 KiB
Python
# Copyright (c) 2017-2021 hippo91 <guillaume.peillex@gmail.com>
|
|
# Copyright (c) 2017-2018, 2020 Claudiu Popa <pcmanticore@gmail.com>
|
|
# Copyright (c) 2018 Bryce Guinta <bryce.paul.guinta@gmail.com>
|
|
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
|
|
# Copyright (c) 2021 Pierre Sassoulas <pierre.sassoulas@gmail.com>
|
|
# Copyright (c) 2021 Daniël van Noord <13665637+DanielNoord@users.noreply.github.com>
|
|
# Copyright (c) 2021 Marc Mueller <30130371+cdce8p@users.noreply.github.com>
|
|
|
|
# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
|
|
# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE
|
|
import unittest
|
|
|
|
try:
|
|
import numpy # pylint: disable=unused-import
|
|
|
|
HAS_NUMPY = True
|
|
except ImportError:
|
|
HAS_NUMPY = False
|
|
|
|
from astroid import builder, nodes
|
|
from astroid.brain.brain_numpy_utils import (
|
|
NUMPY_VERSION_TYPE_HINTS_SUPPORT,
|
|
numpy_supports_type_hints,
|
|
)
|
|
|
|
|
|
@unittest.skipUnless(HAS_NUMPY, "This test requires the numpy library.")
|
|
class NumpyBrainNdarrayTest(unittest.TestCase):
|
|
"""
|
|
Test that calls to numpy functions returning arrays are correctly inferred
|
|
"""
|
|
|
|
ndarray_returning_ndarray_methods = (
|
|
"__abs__",
|
|
"__add__",
|
|
"__and__",
|
|
"__array__",
|
|
"__array_wrap__",
|
|
"__copy__",
|
|
"__deepcopy__",
|
|
"__eq__",
|
|
"__floordiv__",
|
|
"__ge__",
|
|
"__gt__",
|
|
"__iadd__",
|
|
"__iand__",
|
|
"__ifloordiv__",
|
|
"__ilshift__",
|
|
"__imod__",
|
|
"__imul__",
|
|
"__invert__",
|
|
"__ior__",
|
|
"__ipow__",
|
|
"__irshift__",
|
|
"__isub__",
|
|
"__itruediv__",
|
|
"__ixor__",
|
|
"__le__",
|
|
"__lshift__",
|
|
"__lt__",
|
|
"__matmul__",
|
|
"__mod__",
|
|
"__mul__",
|
|
"__ne__",
|
|
"__neg__",
|
|
"__or__",
|
|
"__pos__",
|
|
"__pow__",
|
|
"__rshift__",
|
|
"__sub__",
|
|
"__truediv__",
|
|
"__xor__",
|
|
"all",
|
|
"any",
|
|
"argmax",
|
|
"argmin",
|
|
"argpartition",
|
|
"argsort",
|
|
"astype",
|
|
"byteswap",
|
|
"choose",
|
|
"clip",
|
|
"compress",
|
|
"conj",
|
|
"conjugate",
|
|
"copy",
|
|
"cumprod",
|
|
"cumsum",
|
|
"diagonal",
|
|
"dot",
|
|
"flatten",
|
|
"getfield",
|
|
"max",
|
|
"mean",
|
|
"min",
|
|
"newbyteorder",
|
|
"prod",
|
|
"ptp",
|
|
"ravel",
|
|
"repeat",
|
|
"reshape",
|
|
"round",
|
|
"searchsorted",
|
|
"squeeze",
|
|
"std",
|
|
"sum",
|
|
"swapaxes",
|
|
"take",
|
|
"trace",
|
|
"transpose",
|
|
"var",
|
|
"view",
|
|
)
|
|
|
|
def _inferred_ndarray_method_call(self, func_name):
|
|
node = builder.extract_node(
|
|
f"""
|
|
import numpy as np
|
|
test_array = np.ndarray((2, 2))
|
|
test_array.{func_name:s}()
|
|
"""
|
|
)
|
|
return node.infer()
|
|
|
|
def _inferred_ndarray_attribute(self, attr_name):
|
|
node = builder.extract_node(
|
|
f"""
|
|
import numpy as np
|
|
test_array = np.ndarray((2, 2))
|
|
test_array.{attr_name:s}
|
|
"""
|
|
)
|
|
return node.infer()
|
|
|
|
def test_numpy_function_calls_inferred_as_ndarray(self):
|
|
"""
|
|
Test that some calls to numpy functions are inferred as numpy.ndarray
|
|
"""
|
|
licit_array_types = ".ndarray"
|
|
for func_ in self.ndarray_returning_ndarray_methods:
|
|
with self.subTest(typ=func_):
|
|
inferred_values = list(self._inferred_ndarray_method_call(func_))
|
|
self.assertTrue(
|
|
len(inferred_values) == 1,
|
|
msg=f"Too much inferred value for {func_:s}",
|
|
)
|
|
self.assertTrue(
|
|
inferred_values[-1].pytype() in licit_array_types,
|
|
msg=f"Illicit type for {func_:s} ({inferred_values[-1].pytype()})",
|
|
)
|
|
|
|
def test_numpy_ndarray_attribute_inferred_as_ndarray(self):
|
|
"""
|
|
Test that some numpy ndarray attributes are inferred as numpy.ndarray
|
|
"""
|
|
licit_array_types = ".ndarray"
|
|
for attr_ in ("real", "imag", "shape", "T"):
|
|
with self.subTest(typ=attr_):
|
|
inferred_values = list(self._inferred_ndarray_attribute(attr_))
|
|
self.assertTrue(
|
|
len(inferred_values) == 1,
|
|
msg=f"Too much inferred value for {attr_:s}",
|
|
)
|
|
self.assertTrue(
|
|
inferred_values[-1].pytype() in licit_array_types,
|
|
msg=f"Illicit type for {attr_:s} ({inferred_values[-1].pytype()})",
|
|
)
|
|
|
|
@unittest.skipUnless(
|
|
HAS_NUMPY and numpy_supports_type_hints(),
|
|
f"This test requires the numpy library with a version above {NUMPY_VERSION_TYPE_HINTS_SUPPORT}",
|
|
)
|
|
def test_numpy_ndarray_class_support_type_indexing(self):
|
|
"""
|
|
Test that numpy ndarray class can be subscripted (type hints)
|
|
"""
|
|
src = """
|
|
import numpy as np
|
|
np.ndarray[int]
|
|
"""
|
|
node = builder.extract_node(src)
|
|
cls_node = node.inferred()[0]
|
|
self.assertIsInstance(cls_node, nodes.ClassDef)
|
|
self.assertEqual(cls_node.name, "ndarray")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|