# Copyright (c) 2006-2007, 2009-2014 LOGILAB S.A. (Paris, FRANCE) # Copyright (c) 2012 FELD Boris # Copyright (c) 2013-2021 Claudiu Popa # Copyright (c) 2014 Google, Inc. # Copyright (c) 2014 Eevee (Alex Munroe) # Copyright (c) 2015-2016 Ceridwen # Copyright (c) 2015 Florian Bruhin # Copyright (c) 2016 Jakub Wilk # Copyright (c) 2017 rr- # Copyright (c) 2017 Derek Gustafson # Copyright (c) 2018 Serhiy Storchaka # Copyright (c) 2018 brendanator # Copyright (c) 2018 Bryce Guinta # Copyright (c) 2018 Anthony Sottile # Copyright (c) 2019-2021 Ashley Whetter # Copyright (c) 2019 Alex Hall # Copyright (c) 2019 Hugo van Kemenade # Copyright (c) 2020 David Gilman # Copyright (c) 2021 Pierre Sassoulas # Copyright (c) 2021 Tushar Sadhwani <86737547+tushar-deepsource@users.noreply.github.com> # Copyright (c) 2021 Nick Drozd # Copyright (c) 2021 Marc Mueller <30130371+cdce8p@users.noreply.github.com> # Copyright (c) 2021 Daniël van Noord <13665637+DanielNoord@users.noreply.github.com> # Copyright (c) 2021 René Fritze <47802+renefritze@users.noreply.github.com> # Copyright (c) 2021 Federico Bond # Copyright (c) 2021 hippo91 # Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html # For details: https://github.com/PyCQA/astroid/blob/main/LICENSE """tests for specific behaviour of astroid nodes """ import copy import os import platform import sys import textwrap import unittest from typing import Any, Optional import pytest import astroid from astroid import ( Uninferable, bases, builder, nodes, parse, test_utils, transforms, util, ) from astroid.const import PY38_PLUS, PY310_PLUS, Context from astroid.context import InferenceContext from astroid.exceptions import ( AstroidBuildingError, AstroidSyntaxError, AttributeInferenceError, ParentMissingError, StatementMissing, ) from astroid.nodes.node_classes import ( AssignAttr, AssignName, Attribute, Call, ImportFrom, Tuple, ) from astroid.nodes.scoped_nodes import ClassDef, FunctionDef, GeneratorExp, Module from . import resources abuilder = builder.AstroidBuilder() try: import typed_ast # pylint: disable=unused-import HAS_TYPED_AST = True except ImportError: # typed_ast merged in `ast` in Python 3.8 HAS_TYPED_AST = PY38_PLUS class AsStringTest(resources.SysPathSetup, unittest.TestCase): def test_tuple_as_string(self) -> None: def build(string: str) -> Tuple: return abuilder.string_build(string).body[0].value self.assertEqual(build("1,").as_string(), "(1, )") self.assertEqual(build("1, 2, 3").as_string(), "(1, 2, 3)") self.assertEqual(build("(1, )").as_string(), "(1, )") self.assertEqual(build("1, 2, 3").as_string(), "(1, 2, 3)") def test_func_signature_issue_185(self) -> None: code = textwrap.dedent( """ def test(a, b, c=42, *, x=42, **kwargs): print(a, b, c, args) """ ) node = parse(code) self.assertEqual(node.as_string().strip(), code.strip()) def test_as_string_for_list_containing_uninferable(self) -> None: node = builder.extract_node( """ def foo(): bar = [arg] * 1 """ ) binop = node.body[0].value inferred = next(binop.infer()) self.assertEqual(inferred.as_string(), "[Uninferable]") self.assertEqual(binop.as_string(), "[arg] * 1") def test_frozenset_as_string(self) -> None: ast_nodes = builder.extract_node( """ frozenset((1, 2, 3)) #@ frozenset({1, 2, 3}) #@ frozenset([1, 2, 3,]) #@ frozenset(None) #@ frozenset(1) #@ """ ) ast_nodes = [next(node.infer()) for node in ast_nodes] assert isinstance(ast_nodes, list) self.assertEqual(ast_nodes[0].as_string(), "frozenset((1, 2, 3))") self.assertEqual(ast_nodes[1].as_string(), "frozenset({1, 2, 3})") self.assertEqual(ast_nodes[2].as_string(), "frozenset([1, 2, 3])") self.assertNotEqual(ast_nodes[3].as_string(), "frozenset(None)") self.assertNotEqual(ast_nodes[4].as_string(), "frozenset(1)") def test_varargs_kwargs_as_string(self) -> None: ast = abuilder.string_build("raise_string(*args, **kwargs)").body[0] self.assertEqual(ast.as_string(), "raise_string(*args, **kwargs)") def test_module_as_string(self) -> None: """check as_string on a whole module prepared to be returned identically""" module = resources.build_file("data/module.py", "data.module") with open(resources.find("data/module.py"), encoding="utf-8") as fobj: self.assertMultiLineEqual(module.as_string(), fobj.read()) def test_module2_as_string(self) -> None: """check as_string on a whole module prepared to be returned identically""" module2 = resources.build_file("data/module2.py", "data.module2") with open(resources.find("data/module2.py"), encoding="utf-8") as fobj: self.assertMultiLineEqual(module2.as_string(), fobj.read()) def test_as_string(self) -> None: """check as_string for python syntax >= 2.7""" code = """one_two = {1, 2} b = {v: k for (k, v) in enumerate('string')} cdd = {k for k in b}\n\n""" ast = abuilder.string_build(code) self.assertMultiLineEqual(ast.as_string(), code) def test_3k_as_string(self) -> None: """check as_string for python 3k syntax""" code = """print() def function(var): nonlocal counter try: hello except NameError as nexc: (*hell, o) = b'hello' raise AttributeError from nexc \n""" ast = abuilder.string_build(code) self.assertEqual(ast.as_string(), code) def test_3k_annotations_and_metaclass(self) -> None: code = ''' def function(var: int): nonlocal counter class Language(metaclass=Natural): """natural language""" ''' code_annotations = textwrap.dedent(code) expected = '''\ def function(var: int): nonlocal counter class Language(metaclass=Natural): """natural language"""''' ast = abuilder.string_build(code_annotations) self.assertEqual(ast.as_string().strip(), expected) def test_ellipsis(self) -> None: ast = abuilder.string_build("a[...]").body[0] self.assertEqual(ast.as_string(), "a[...]") def test_slices(self) -> None: for code in ( "a[0]", "a[1:3]", "a[:-1:step]", "a[:, newaxis]", "a[newaxis, :]", "del L[::2]", "del A[1]", "del Br[:]", ): ast = abuilder.string_build(code).body[0] self.assertEqual(ast.as_string(), code) def test_slice_and_subscripts(self) -> None: code = """a[:1] = bord[2:] a[:1] = bord[2:] del bree[3:d] bord[2:] del av[d::f], a[df:] a[:1] = bord[2:] del SRC[::1, newaxis, 1:] tous[vals] = 1010 del thousand[key] del a[::2], a[:-1:step] del Fee.form[left:] aout.vals = miles.of_stuff del (ccok, (name.thing, foo.attrib.value)), Fee.form[left:] if all[1] == bord[0:]: pass\n\n""" ast = abuilder.string_build(code) self.assertEqual(ast.as_string(), code) def test_int_attribute(self) -> None: code = """ x = (-3).real y = (3).imag """ ast = abuilder.string_build(code) self.assertEqual(ast.as_string().strip(), code.strip()) def test_operator_precedence(self) -> None: with open(resources.find("data/operator_precedence.py"), encoding="utf-8") as f: for code in f: self.check_as_string_ast_equality(code) @staticmethod def check_as_string_ast_equality(code: str) -> None: """ Check that as_string produces source code with exactly the same semantics as the source it was originally parsed from """ pre = builder.parse(code) post = builder.parse(pre.as_string()) pre_repr = pre.repr_tree() post_repr = post.repr_tree() assert pre_repr == post_repr assert pre.as_string().strip() == code.strip() def test_class_def(self) -> None: code = """ import abc from typing import Tuple class A: pass class B(metaclass=A, x=1): pass class C(B): pass class D(metaclass=abc.ABCMeta): pass def func(param: Tuple): pass """ ast = abuilder.string_build(code) self.assertEqual(ast.as_string().strip(), code.strip()) # This test is disabled on PyPy because we cannot get a release that has proper # support for f-strings (we need 7.2 at least) @pytest.mark.skipif( platform.python_implementation() == "PyPy", reason="Needs f-string support.", ) def test_f_strings(self): code = r''' a = f"{'a'}" b = f'{{b}}' c = f""" "{'c'}" """ d = f'{d!r} {d!s} {d!a}' e = f'{e:.3}' f = f'{f:{x}.{y}}' n = f'\n' everything = f""" " \' \r \t \\ {{ }} {'x' + x!r:a} {["'"]!s:{a}}""" ''' ast = abuilder.string_build(code) self.assertEqual(ast.as_string().strip(), code.strip()) @staticmethod def test_as_string_unknown() -> None: assert nodes.Unknown().as_string() == "Unknown.Unknown()" assert nodes.Unknown(lineno=1, col_offset=0).as_string() == "Unknown.Unknown()" class _NodeTest(unittest.TestCase): """test transformation of If Node""" CODE = "" @property def astroid(self) -> Module: try: return self.__class__.__dict__["CODE_Astroid"] except KeyError: module = builder.parse(self.CODE) self.__class__.CODE_Astroid = module return module class IfNodeTest(_NodeTest): """test transformation of If Node""" CODE = """ if 0: print() if True: print() else: pass if "": print() elif []: raise if 1: print() elif True: print() elif func(): pass else: raise """ def test_if_elif_else_node(self) -> None: """test transformation for If node""" self.assertEqual(len(self.astroid.body), 4) for stmt in self.astroid.body: self.assertIsInstance(stmt, nodes.If) self.assertFalse(self.astroid.body[0].orelse) # simple If self.assertIsInstance(self.astroid.body[1].orelse[0], nodes.Pass) # If / else self.assertIsInstance(self.astroid.body[2].orelse[0], nodes.If) # If / elif self.assertIsInstance(self.astroid.body[3].orelse[0].orelse[0], nodes.If) def test_block_range(self) -> None: # XXX ensure expected values self.assertEqual(self.astroid.block_range(1), (0, 22)) self.assertEqual(self.astroid.block_range(10), (0, 22)) # XXX (10, 22) ? self.assertEqual(self.astroid.body[1].block_range(5), (5, 6)) self.assertEqual(self.astroid.body[1].block_range(6), (6, 6)) self.assertEqual(self.astroid.body[1].orelse[0].block_range(7), (7, 8)) self.assertEqual(self.astroid.body[1].orelse[0].block_range(8), (8, 8)) @staticmethod @pytest.mark.filterwarnings("ignore:.*is_sys_guard:DeprecationWarning") def test_if_sys_guard() -> None: code = builder.extract_node( """ import sys if sys.version_info > (3, 8): #@ pass if sys.version_info[:2] > (3, 8): #@ pass if sys.some_other_function > (3, 8): #@ pass """ ) assert isinstance(code, list) and len(code) == 3 assert isinstance(code[0], nodes.If) assert code[0].is_sys_guard() is True assert isinstance(code[1], nodes.If) assert code[1].is_sys_guard() is True assert isinstance(code[2], nodes.If) assert code[2].is_sys_guard() is False @staticmethod @pytest.mark.filterwarnings("ignore:.*is_typing_guard:DeprecationWarning") def test_if_typing_guard() -> None: code = builder.extract_node( """ import typing import typing as t from typing import TYPE_CHECKING if typing.TYPE_CHECKING: #@ pass if t.TYPE_CHECKING: #@ pass if TYPE_CHECKING: #@ pass if typing.SOME_OTHER_CONST: #@ pass """ ) assert isinstance(code, list) and len(code) == 4 assert isinstance(code[0], nodes.If) assert code[0].is_typing_guard() is True assert isinstance(code[1], nodes.If) assert code[1].is_typing_guard() is True assert isinstance(code[2], nodes.If) assert code[2].is_typing_guard() is True assert isinstance(code[3], nodes.If) assert code[3].is_typing_guard() is False class TryExceptNodeTest(_NodeTest): CODE = """ try: print ('pouet') except IOError: pass except UnicodeError: print() else: print() """ def test_block_range(self) -> None: # XXX ensure expected values self.assertEqual(self.astroid.body[0].block_range(1), (1, 8)) self.assertEqual(self.astroid.body[0].block_range(2), (2, 2)) self.assertEqual(self.astroid.body[0].block_range(3), (3, 8)) self.assertEqual(self.astroid.body[0].block_range(4), (4, 4)) self.assertEqual(self.astroid.body[0].block_range(5), (5, 5)) self.assertEqual(self.astroid.body[0].block_range(6), (6, 6)) self.assertEqual(self.astroid.body[0].block_range(7), (7, 7)) self.assertEqual(self.astroid.body[0].block_range(8), (8, 8)) class TryFinallyNodeTest(_NodeTest): CODE = """ try: print ('pouet') finally: print ('pouet') """ def test_block_range(self) -> None: # XXX ensure expected values self.assertEqual(self.astroid.body[0].block_range(1), (1, 4)) self.assertEqual(self.astroid.body[0].block_range(2), (2, 2)) self.assertEqual(self.astroid.body[0].block_range(3), (3, 4)) self.assertEqual(self.astroid.body[0].block_range(4), (4, 4)) class TryExceptFinallyNodeTest(_NodeTest): CODE = """ try: print('pouet') except Exception: print ('oops') finally: print ('pouet') """ def test_block_range(self) -> None: # XXX ensure expected values self.assertEqual(self.astroid.body[0].block_range(1), (1, 6)) self.assertEqual(self.astroid.body[0].block_range(2), (2, 2)) self.assertEqual(self.astroid.body[0].block_range(3), (3, 4)) self.assertEqual(self.astroid.body[0].block_range(4), (4, 4)) self.assertEqual(self.astroid.body[0].block_range(5), (5, 5)) self.assertEqual(self.astroid.body[0].block_range(6), (6, 6)) class ImportNodeTest(resources.SysPathSetup, unittest.TestCase): def setUp(self) -> None: super().setUp() self.module = resources.build_file("data/module.py", "data.module") self.module2 = resources.build_file("data/module2.py", "data.module2") def test_import_self_resolve(self) -> None: myos = next(self.module2.igetattr("myos")) self.assertTrue(isinstance(myos, nodes.Module), myos) self.assertEqual(myos.name, "os") self.assertEqual(myos.qname(), "os") self.assertEqual(myos.pytype(), "builtins.module") def test_from_self_resolve(self) -> None: namenode = next(self.module.igetattr("NameNode")) self.assertTrue(isinstance(namenode, nodes.ClassDef), namenode) self.assertEqual(namenode.root().name, "astroid.nodes.node_classes") self.assertEqual(namenode.qname(), "astroid.nodes.node_classes.Name") self.assertEqual(namenode.pytype(), "builtins.type") abspath = next(self.module2.igetattr("abspath")) self.assertTrue(isinstance(abspath, nodes.FunctionDef), abspath) self.assertEqual(abspath.root().name, "os.path") self.assertEqual(abspath.pytype(), "builtins.function") if sys.platform != "win32": # Not sure what is causing this check to fail on Windows. # For some reason the abspath() inference returns a different # path than expected: # AssertionError: 'os.path._abspath_fallback' != 'os.path.abspath' self.assertEqual(abspath.qname(), "os.path.abspath") def test_real_name(self) -> None: from_ = self.module["NameNode"] self.assertEqual(from_.real_name("NameNode"), "Name") imp_ = self.module["os"] self.assertEqual(imp_.real_name("os"), "os") self.assertRaises(AttributeInferenceError, imp_.real_name, "os.path") imp_ = self.module["NameNode"] self.assertEqual(imp_.real_name("NameNode"), "Name") self.assertRaises(AttributeInferenceError, imp_.real_name, "Name") imp_ = self.module2["YO"] self.assertEqual(imp_.real_name("YO"), "YO") self.assertRaises(AttributeInferenceError, imp_.real_name, "data") def test_as_string(self) -> None: ast = self.module["modutils"] self.assertEqual(ast.as_string(), "from astroid import modutils") ast = self.module["NameNode"] self.assertEqual( ast.as_string(), "from astroid.nodes.node_classes import Name as NameNode" ) ast = self.module["os"] self.assertEqual(ast.as_string(), "import os.path") code = """from . import here from .. import door from .store import bread from ..cave import wine\n\n""" ast = abuilder.string_build(code) self.assertMultiLineEqual(ast.as_string(), code) def test_bad_import_inference(self) -> None: # Explication of bug """When we import PickleError from nonexistent, a call to the infer method of this From node will be made by unpack_infer. inference.infer_from will try to import this module, which will fail and raise a InferenceException (by mixins.do_import_module). The infer_name will catch this exception and yield and Uninferable instead. """ code = """ try: from pickle import PickleError except ImportError: from nonexistent import PickleError try: pass except PickleError: pass """ module = builder.parse(code) handler_type = module.body[1].handlers[0].type excs = list(nodes.unpack_infer(handler_type)) # The number of returned object can differ on Python 2 # and Python 3. In one version, an additional item will # be returned, from the _pickle module, which is not # present in the other version. self.assertIsInstance(excs[0], nodes.ClassDef) self.assertEqual(excs[0].name, "PickleError") self.assertIs(excs[-1], util.Uninferable) def test_absolute_import(self) -> None: module = resources.build_file("data/absimport.py") ctx = InferenceContext() # will fail if absolute import failed ctx.lookupname = "message" next(module["message"].infer(ctx)) ctx.lookupname = "email" m = next(module["email"].infer(ctx)) self.assertFalse(m.file.startswith(os.path.join("data", "email.py"))) def test_more_absolute_import(self) -> None: module = resources.build_file("data/module1abs/__init__.py", "data.module1abs") self.assertIn("sys", module.locals) _pickle_names = ("dump",) # "dumps", "load", "loads") def test_conditional(self) -> None: module = resources.build_file("data/conditional_import/__init__.py") ctx = InferenceContext() for name in self._pickle_names: ctx.lookupname = name some = list(module[name].infer(ctx)) assert Uninferable not in some, name def test_conditional_import(self) -> None: module = resources.build_file("data/conditional.py") ctx = InferenceContext() for name in self._pickle_names: ctx.lookupname = name some = list(module[name].infer(ctx)) assert Uninferable not in some, name class CmpNodeTest(unittest.TestCase): def test_as_string(self) -> None: ast = abuilder.string_build("a == 2").body[0] self.assertEqual(ast.as_string(), "a == 2") class ConstNodeTest(unittest.TestCase): def _test(self, value: Any) -> None: node = nodes.const_factory(value) self.assertIsInstance(node._proxied, nodes.ClassDef) self.assertEqual(node._proxied.name, value.__class__.__name__) self.assertIs(node.value, value) self.assertTrue(node._proxied.parent) self.assertEqual(node._proxied.root().name, value.__class__.__module__) with self.assertRaises(AttributeError): with pytest.warns(DeprecationWarning) as records: node.statement() assert len(records) == 1 with self.assertRaises(StatementMissing): node.statement(future=True) with self.assertRaises(AttributeError): with pytest.warns(DeprecationWarning) as records: node.frame() assert len(records) == 1 with self.assertRaises(ParentMissingError): node.frame(future=True) def test_none(self) -> None: self._test(None) def test_bool(self) -> None: self._test(True) def test_int(self) -> None: self._test(1) def test_float(self) -> None: self._test(1.0) def test_complex(self) -> None: self._test(1.0j) def test_str(self) -> None: self._test("a") def test_unicode(self) -> None: self._test("a") @pytest.mark.skipif( not PY38_PLUS, reason="kind attribute for ast.Constant was added in 3.8" ) def test_str_kind(self): node = builder.extract_node( """ const = u"foo" """ ) assert isinstance(node.value, nodes.Const) assert node.value.value == "foo" assert node.value.kind, "u" def test_copy(self) -> None: """ Make sure copying a Const object doesn't result in infinite recursion """ const = copy.copy(nodes.Const(1)) assert const.value == 1 class NameNodeTest(unittest.TestCase): def test_assign_to_true(self) -> None: """Test that True and False assignments don't crash""" code = """ True = False def hello(False): pass del True """ with self.assertRaises(AstroidBuildingError): builder.parse(code) @pytest.mark.skipif(not PY38_PLUS, reason="needs assignment expressions") class TestNamedExprNode: """Tests for the NamedExpr node""" @staticmethod def test_frame() -> None: """Test if the frame of NamedExpr is correctly set for certain types of parent nodes. """ module = builder.parse( """ def func(var_1): pass def func_two(var_2, var_2 = (named_expr_1 := "walrus")): pass class MyBaseClass: pass class MyInheritedClass(MyBaseClass, var_3=(named_expr_2 := "walrus")): pass VAR = lambda y = (named_expr_3 := "walrus"): print(y) def func_with_lambda( var_5 = ( named_expr_4 := lambda y = (named_expr_5 := "walrus"): y ) ): pass COMPREHENSION = [y for i in (1, 2) if (y := i ** 2)] """ ) function = module.body[0] assert function.args.frame() == function assert function.args.frame(future=True) == function function_two = module.body[1] assert function_two.args.args[0].frame() == function_two assert function_two.args.args[0].frame(future=True) == function_two assert function_two.args.args[1].frame() == function_two assert function_two.args.args[1].frame(future=True) == function_two assert function_two.args.defaults[0].frame() == module assert function_two.args.defaults[0].frame(future=True) == module inherited_class = module.body[3] assert inherited_class.keywords[0].frame() == inherited_class assert inherited_class.keywords[0].frame(future=True) == inherited_class assert inherited_class.keywords[0].value.frame() == module assert inherited_class.keywords[0].value.frame(future=True) == module lambda_assignment = module.body[4].value assert lambda_assignment.args.args[0].frame() == lambda_assignment assert lambda_assignment.args.args[0].frame(future=True) == lambda_assignment assert lambda_assignment.args.defaults[0].frame() == module assert lambda_assignment.args.defaults[0].frame(future=True) == module lambda_named_expr = module.body[5].args.defaults[0] assert lambda_named_expr.value.args.defaults[0].frame() == module assert lambda_named_expr.value.args.defaults[0].frame(future=True) == module comprehension = module.body[6].value assert comprehension.generators[0].ifs[0].frame() == module assert comprehension.generators[0].ifs[0].frame(future=True) == module @staticmethod def test_scope() -> None: """Test if the scope of NamedExpr is correctly set for certain types of parent nodes. """ module = builder.parse( """ def func(var_1): pass def func_two(var_2, var_2 = (named_expr_1 := "walrus")): pass class MyBaseClass: pass class MyInheritedClass(MyBaseClass, var_3=(named_expr_2 := "walrus")): pass VAR = lambda y = (named_expr_3 := "walrus"): print(y) def func_with_lambda( var_5 = ( named_expr_4 := lambda y = (named_expr_5 := "walrus"): y ) ): pass COMPREHENSION = [y for i in (1, 2) if (y := i ** 2)] """ ) function = module.body[0] assert function.args.scope() == function function_two = module.body[1] assert function_two.args.args[0].scope() == function_two assert function_two.args.args[1].scope() == function_two assert function_two.args.defaults[0].scope() == module inherited_class = module.body[3] assert inherited_class.keywords[0].scope() == inherited_class assert inherited_class.keywords[0].value.scope() == module lambda_assignment = module.body[4].value assert lambda_assignment.args.args[0].scope() == lambda_assignment assert lambda_assignment.args.defaults[0].scope() lambda_named_expr = module.body[5].args.defaults[0] assert lambda_named_expr.value.args.defaults[0].scope() == module comprehension = module.body[6].value assert comprehension.generators[0].ifs[0].scope() == module class AnnAssignNodeTest(unittest.TestCase): def test_primitive(self) -> None: code = textwrap.dedent( """ test: int = 5 """ ) assign = builder.extract_node(code) self.assertIsInstance(assign, nodes.AnnAssign) self.assertEqual(assign.target.name, "test") self.assertEqual(assign.annotation.name, "int") self.assertEqual(assign.value.value, 5) self.assertEqual(assign.simple, 1) def test_primitive_without_initial_value(self) -> None: code = textwrap.dedent( """ test: str """ ) assign = builder.extract_node(code) self.assertIsInstance(assign, nodes.AnnAssign) self.assertEqual(assign.target.name, "test") self.assertEqual(assign.annotation.name, "str") self.assertEqual(assign.value, None) def test_complex(self) -> None: code = textwrap.dedent( """ test: Dict[List[str]] = {} """ ) assign = builder.extract_node(code) self.assertIsInstance(assign, nodes.AnnAssign) self.assertEqual(assign.target.name, "test") self.assertIsInstance(assign.annotation, astroid.Subscript) self.assertIsInstance(assign.value, astroid.Dict) def test_as_string(self) -> None: code = textwrap.dedent( """ print() test: int = 5 test2: str test3: List[Dict[str, str]] = [] """ ) ast = abuilder.string_build(code) self.assertEqual(ast.as_string().strip(), code.strip()) class ArgumentsNodeTC(unittest.TestCase): @pytest.mark.skip( "FIXME http://bugs.python.org/issue10445 (no line number on function args)" ) def test_linenumbering(self) -> None: ast = builder.parse( """ def func(a, b): pass x = lambda x: None """ ) self.assertEqual(ast["func"].args.fromlineno, 2) self.assertFalse(ast["func"].args.is_statement) xlambda = next(ast["x"].infer()) self.assertEqual(xlambda.args.fromlineno, 4) self.assertEqual(xlambda.args.tolineno, 4) self.assertFalse(xlambda.args.is_statement) def test_kwoargs(self) -> None: ast = builder.parse( """ def func(*, x): pass """ ) args = ast["func"].args self.assertTrue(args.is_argument("x")) @test_utils.require_version(minver="3.8") def test_positional_only(self): ast = builder.parse( """ def func(x, /, y): pass """ ) args = ast["func"].args self.assertTrue(args.is_argument("x")) self.assertTrue(args.is_argument("y")) index, node = args.find_argname("x") self.assertEqual(index, 0) self.assertIsNotNone(node) class UnboundMethodNodeTest(unittest.TestCase): def test_no_super_getattr(self) -> None: # This is a test for issue # https://bitbucket.org/logilab/astroid/issue/91, which tests # that UnboundMethod doesn't call super when doing .getattr. ast = builder.parse( """ class A(object): def test(self): pass meth = A.test """ ) node = next(ast["meth"].infer()) with self.assertRaises(AttributeInferenceError): node.getattr("__missssing__") name = node.getattr("__name__")[0] self.assertIsInstance(name, nodes.Const) self.assertEqual(name.value, "test") class BoundMethodNodeTest(unittest.TestCase): def test_is_property(self) -> None: ast = builder.parse( """ import abc def cached_property(): # Not a real decorator, but we don't care pass def reify(): # Same as cached_property pass def lazy_property(): pass def lazyproperty(): pass def lazy(): pass class A(object): @property def builtin_property(self): return 42 @abc.abstractproperty def abc_property(self): return 42 @cached_property def cached_property(self): return 42 @reify def reified(self): return 42 @lazy_property def lazy_prop(self): return 42 @lazyproperty def lazyprop(self): return 42 def not_prop(self): pass @lazy def decorated_with_lazy(self): return 42 cls = A() builtin_property = cls.builtin_property abc_property = cls.abc_property cached_p = cls.cached_property reified = cls.reified not_prop = cls.not_prop lazy_prop = cls.lazy_prop lazyprop = cls.lazyprop decorated_with_lazy = cls.decorated_with_lazy """ ) for prop in ( "builtin_property", "abc_property", "cached_p", "reified", "lazy_prop", "lazyprop", "decorated_with_lazy", ): inferred = next(ast[prop].infer()) self.assertIsInstance(inferred, nodes.Const, prop) self.assertEqual(inferred.value, 42, prop) inferred = next(ast["not_prop"].infer()) self.assertIsInstance(inferred, bases.BoundMethod) class AliasesTest(unittest.TestCase): def setUp(self) -> None: self.transformer = transforms.TransformVisitor() def parse_transform(self, code: str) -> Module: module = parse(code, apply_transforms=False) return self.transformer.visit(module) def test_aliases(self) -> None: def test_from(node: ImportFrom) -> ImportFrom: node.names = node.names + [("absolute_import", None)] return node def test_class(node: ClassDef) -> ClassDef: node.name = "Bar" return node def test_function(node: FunctionDef) -> FunctionDef: node.name = "another_test" return node def test_callfunc(node: Call) -> Optional[Call]: if node.func.name == "Foo": node.func.name = "Bar" return node return None def test_assname(node: AssignName) -> Optional[AssignName]: if node.name == "foo": return nodes.AssignName( "bar", node.lineno, node.col_offset, node.parent ) return None def test_assattr(node: AssignAttr) -> AssignAttr: if node.attrname == "a": node.attrname = "b" return node return None def test_getattr(node: Attribute) -> Attribute: if node.attrname == "a": node.attrname = "b" return node return None def test_genexpr(node: GeneratorExp) -> GeneratorExp: if node.elt.value == 1: node.elt = nodes.Const(2, node.lineno, node.col_offset, node.parent) return node return None self.transformer.register_transform(nodes.ImportFrom, test_from) self.transformer.register_transform(nodes.ClassDef, test_class) self.transformer.register_transform(nodes.FunctionDef, test_function) self.transformer.register_transform(nodes.Call, test_callfunc) self.transformer.register_transform(nodes.AssignName, test_assname) self.transformer.register_transform(nodes.AssignAttr, test_assattr) self.transformer.register_transform(nodes.Attribute, test_getattr) self.transformer.register_transform(nodes.GeneratorExp, test_genexpr) string = """ from __future__ import print_function class Foo: pass def test(a): return a foo = Foo() foo.a = test(42) foo.a (1 for _ in range(0, 42)) """ module = self.parse_transform(string) self.assertEqual(len(module.body[0].names), 2) self.assertIsInstance(module.body[0], nodes.ImportFrom) self.assertEqual(module.body[1].name, "Bar") self.assertIsInstance(module.body[1], nodes.ClassDef) self.assertEqual(module.body[2].name, "another_test") self.assertIsInstance(module.body[2], nodes.FunctionDef) self.assertEqual(module.body[3].targets[0].name, "bar") self.assertIsInstance(module.body[3].targets[0], nodes.AssignName) self.assertEqual(module.body[3].value.func.name, "Bar") self.assertIsInstance(module.body[3].value, nodes.Call) self.assertEqual(module.body[4].targets[0].attrname, "b") self.assertIsInstance(module.body[4].targets[0], nodes.AssignAttr) self.assertIsInstance(module.body[5], nodes.Expr) self.assertEqual(module.body[5].value.attrname, "b") self.assertIsInstance(module.body[5].value, nodes.Attribute) self.assertEqual(module.body[6].value.elt.value, 2) self.assertIsInstance(module.body[6].value, nodes.GeneratorExp) class Python35AsyncTest(unittest.TestCase): def test_async_await_keywords(self) -> None: async_def, async_for, async_with, await_node = builder.extract_node( """ async def func(): #@ async for i in range(10): #@ f = __(await i) async with test(): #@ pass """ ) self.assertIsInstance(async_def, nodes.AsyncFunctionDef) self.assertIsInstance(async_for, nodes.AsyncFor) self.assertIsInstance(async_with, nodes.AsyncWith) self.assertIsInstance(await_node, nodes.Await) self.assertIsInstance(await_node.value, nodes.Name) def _test_await_async_as_string(self, code: str) -> None: ast_node = parse(code) self.assertEqual(ast_node.as_string().strip(), code.strip()) def test_await_as_string(self) -> None: code = textwrap.dedent( """ async def function(): await 42 await x[0] (await x)[0] await (x + y)[0] """ ) self._test_await_async_as_string(code) def test_asyncwith_as_string(self) -> None: code = textwrap.dedent( """ async def function(): async with 42: pass """ ) self._test_await_async_as_string(code) def test_asyncfor_as_string(self) -> None: code = textwrap.dedent( """ async def function(): async for i in range(10): await 42 """ ) self._test_await_async_as_string(code) def test_decorated_async_def_as_string(self) -> None: code = textwrap.dedent( """ @decorator async def function(): async for i in range(10): await 42 """ ) self._test_await_async_as_string(code) class ContextTest(unittest.TestCase): def test_subscript_load(self) -> None: node = builder.extract_node("f[1]") self.assertIs(node.ctx, Context.Load) def test_subscript_del(self) -> None: node = builder.extract_node("del f[1]") self.assertIs(node.targets[0].ctx, Context.Del) def test_subscript_store(self) -> None: node = builder.extract_node("f[1] = 2") subscript = node.targets[0] self.assertIs(subscript.ctx, Context.Store) def test_list_load(self) -> None: node = builder.extract_node("[]") self.assertIs(node.ctx, Context.Load) def test_list_del(self) -> None: node = builder.extract_node("del []") self.assertIs(node.targets[0].ctx, Context.Del) def test_list_store(self) -> None: with self.assertRaises(AstroidSyntaxError): builder.extract_node("[0] = 2") def test_tuple_load(self) -> None: node = builder.extract_node("(1, )") self.assertIs(node.ctx, Context.Load) def test_tuple_store(self) -> None: with self.assertRaises(AstroidSyntaxError): builder.extract_node("(1, ) = 3") def test_starred_load(self) -> None: node = builder.extract_node("a = *b") starred = node.value self.assertIs(starred.ctx, Context.Load) def test_starred_store(self) -> None: node = builder.extract_node("a, *b = 1, 2") starred = node.targets[0].elts[1] self.assertIs(starred.ctx, Context.Store) def test_unknown() -> None: """Test Unknown node""" assert isinstance(next(nodes.Unknown().infer()), type(util.Uninferable)) assert isinstance(nodes.Unknown().name, str) assert isinstance(nodes.Unknown().qname(), str) @pytest.mark.skipif(not HAS_TYPED_AST, reason="requires typed_ast") def test_type_comments_with() -> None: module = builder.parse( """ with a as b: # type: int pass with a as b: # type: ignore pass """ ) node = module.body[0] ignored_node = module.body[1] assert isinstance(node.type_annotation, astroid.Name) assert ignored_node.type_annotation is None @pytest.mark.skipif(not HAS_TYPED_AST, reason="requires typed_ast") def test_type_comments_for() -> None: module = builder.parse( """ for a, b in [1, 2, 3]: # type: List[int] pass for a, b in [1, 2, 3]: # type: ignore pass """ ) node = module.body[0] ignored_node = module.body[1] assert isinstance(node.type_annotation, astroid.Subscript) assert node.type_annotation.as_string() == "List[int]" assert ignored_node.type_annotation is None @pytest.mark.skipif(not HAS_TYPED_AST, reason="requires typed_ast") def test_type_coments_assign() -> None: module = builder.parse( """ a, b = [1, 2, 3] # type: List[int] a, b = [1, 2, 3] # type: ignore """ ) node = module.body[0] ignored_node = module.body[1] assert isinstance(node.type_annotation, astroid.Subscript) assert node.type_annotation.as_string() == "List[int]" assert ignored_node.type_annotation is None @pytest.mark.skipif(not HAS_TYPED_AST, reason="requires typed_ast") def test_type_comments_invalid_expression() -> None: module = builder.parse( """ a, b = [1, 2, 3] # type: something completely invalid a, b = [1, 2, 3] # typeee: 2*+4 a, b = [1, 2, 3] # type: List[int """ ) for node in module.body: assert node.type_annotation is None @pytest.mark.skipif(not HAS_TYPED_AST, reason="requires typed_ast") def test_type_comments_invalid_function_comments() -> None: module = builder.parse( """ def func(): # type: something completely invalid pass def func1(): # typeee: 2*+4 pass def func2(): # type: List[int pass """ ) for node in module.body: assert node.type_comment_returns is None assert node.type_comment_args is None @pytest.mark.skipif(not HAS_TYPED_AST, reason="requires typed_ast") def test_type_comments_function() -> None: module = builder.parse( """ def func(): # type: (int) -> str pass def func1(): # type: (int, int, int) -> (str, str) pass def func2(): # type: (int, int, str, List[int]) -> List[int] pass """ ) expected_annotations = [ (["int"], astroid.Name, "str"), (["int", "int", "int"], astroid.Tuple, "(str, str)"), (["int", "int", "str", "List[int]"], astroid.Subscript, "List[int]"), ] for node, (expected_args, expected_returns_type, expected_returns_string) in zip( module.body, expected_annotations ): assert node.type_comment_returns is not None assert node.type_comment_args is not None for expected_arg, actual_arg in zip(expected_args, node.type_comment_args): assert actual_arg.as_string() == expected_arg assert isinstance(node.type_comment_returns, expected_returns_type) assert node.type_comment_returns.as_string() == expected_returns_string @pytest.mark.skipif(not HAS_TYPED_AST, reason="requires typed_ast") def test_type_comments_arguments() -> None: module = builder.parse( """ def func( a, # type: int ): # type: (...) -> str pass def func1( a, # type: int b, # type: int c, # type: int ): # type: (...) -> (str, str) pass def func2( a, # type: int b, # type: int c, # type: str d, # type: List[int] ): # type: (...) -> List[int] pass """ ) expected_annotations = [ ["int"], ["int", "int", "int"], ["int", "int", "str", "List[int]"], ] for node, expected_args in zip(module.body, expected_annotations): assert len(node.type_comment_args) == 1 assert isinstance(node.type_comment_args[0], astroid.Const) assert node.type_comment_args[0].value == Ellipsis assert len(node.args.type_comment_args) == len(expected_args) for expected_arg, actual_arg in zip(expected_args, node.args.type_comment_args): assert actual_arg.as_string() == expected_arg @pytest.mark.skipif( not PY38_PLUS, reason="needs to be able to parse positional only arguments" ) def test_type_comments_posonly_arguments() -> None: module = builder.parse( """ def f_arg_comment( a, # type: int b, # type: int /, c, # type: Optional[int] d, # type: Optional[int] *, e, # type: float f, # type: float ): # type: (...) -> None pass """ ) expected_annotations = [ [["int", "int"], ["Optional[int]", "Optional[int]"], ["float", "float"]] ] for node, expected_types in zip(module.body, expected_annotations): assert len(node.type_comment_args) == 1 assert isinstance(node.type_comment_args[0], astroid.Const) assert node.type_comment_args[0].value == Ellipsis type_comments = [ node.args.type_comment_posonlyargs, node.args.type_comment_args, node.args.type_comment_kwonlyargs, ] for expected_args, actual_args in zip(expected_types, type_comments): assert len(expected_args) == len(actual_args) for expected_arg, actual_arg in zip(expected_args, actual_args): assert actual_arg.as_string() == expected_arg @pytest.mark.skipif(not HAS_TYPED_AST, reason="requires typed_ast") def test_correct_function_type_comment_parent() -> None: data = """ def f(a): # type: (A) -> A pass """ parsed_data = builder.parse(data) f = parsed_data.body[0] assert f.type_comment_args[0].parent is f assert f.type_comment_returns.parent is f def test_is_generator_for_yield_assignments() -> None: node = astroid.extract_node( """ class A: def test(self): a = yield while True: print(a) yield a a = A() a.test """ ) inferred = next(node.infer()) assert isinstance(inferred, astroid.BoundMethod) assert bool(inferred.is_generator()) class AsyncGeneratorTest: def test_async_generator(self): node = astroid.extract_node( """ async def a_iter(n): for i in range(1, n + 1): yield i await asyncio.sleep(1) a_iter(2) #@ """ ) inferred = next(node.infer()) assert isinstance(inferred, bases.AsyncGenerator) assert inferred.getattr("__aiter__") assert inferred.getattr("__anext__") assert inferred.pytype() == "builtins.async_generator" assert inferred.display_type() == "AsyncGenerator" def test_async_generator_is_generator_on_older_python(self): node = astroid.extract_node( """ async def a_iter(n): for i in range(1, n + 1): yield i await asyncio.sleep(1) a_iter(2) #@ """ ) inferred = next(node.infer()) assert isinstance(inferred, bases.Generator) assert inferred.getattr("__iter__") assert inferred.getattr("__next__") assert inferred.pytype() == "builtins.generator" assert inferred.display_type() == "Generator" def test_f_string_correct_line_numbering() -> None: """Test that we generate correct line numbers for f-strings""" node = astroid.extract_node( """ def func_foo(arg_bar, arg_foo): dict_foo = {} f'{arg_bar.attr_bar}' #@ """ ) assert node.lineno == 5 assert node.last_child().lineno == 5 assert node.last_child().last_child().lineno == 5 @pytest.mark.skipif(not PY38_PLUS, reason="needs assignment expressions") def test_assignment_expression() -> None: code = """ if __(a := 1): pass if __(b := test): pass """ first, second = astroid.extract_node(code) assert isinstance(first.target, nodes.AssignName) assert first.target.name == "a" assert isinstance(first.value, nodes.Const) assert first.value.value == 1 assert first.as_string() == "a := 1" assert isinstance(second.target, nodes.AssignName) assert second.target.name == "b" assert isinstance(second.value, nodes.Name) assert second.value.name == "test" assert second.as_string() == "b := test" @pytest.mark.skipif(not PY38_PLUS, reason="needs assignment expressions") def test_assignment_expression_in_functiondef() -> None: code = """ def function(param = (assignment := "walrus")): def inner_function(inner_param = (inner_assign := "walrus")): pass pass class MyClass(attr = (assignment_two := "walrus")): pass VAR = lambda y = (assignment_three := "walrus"): print(y) def func_with_lambda( param=(named_expr_four := lambda y=(assignment_four := "walrus"): y), ): pass COMPREHENSION = [y for i in (1, 2) if (assignment_five := i ** 2)] def func(): var = lambda y = (assignment_six := 2): print(y) VAR_TWO = [ func(assignment_seven := 2) for _ in (1,) ] LAMBDA = lambda x: print(assignment_eight := x ** 2) class SomeClass: (assignment_nine := 2**2) """ module = astroid.parse(code) assert "assignment" in module.locals assert isinstance(module.locals.get("assignment")[0], nodes.AssignName) function = module.body[0] assert "inner_assign" in function.locals assert "inner_assign" not in module.locals assert isinstance(function.locals.get("inner_assign")[0], nodes.AssignName) assert "assignment_two" in module.locals assert isinstance(module.locals.get("assignment_two")[0], nodes.AssignName) assert "assignment_three" in module.locals assert isinstance(module.locals.get("assignment_three")[0], nodes.AssignName) assert "assignment_four" in module.locals assert isinstance(module.locals.get("assignment_four")[0], nodes.AssignName) assert "assignment_five" in module.locals assert isinstance(module.locals.get("assignment_five")[0], nodes.AssignName) func = module.body[5] assert "assignment_six" in func.locals assert "assignment_six" not in module.locals assert isinstance(func.locals.get("assignment_six")[0], nodes.AssignName) assert "assignment_seven" in module.locals assert isinstance(module.locals.get("assignment_seven")[0], nodes.AssignName) lambda_assign = module.body[7] assert "assignment_eight" in lambda_assign.value.locals assert "assignment_eight" not in module.locals assert isinstance( lambda_assign.value.locals.get("assignment_eight")[0], nodes.AssignName ) class_assign = module.body[8] assert "assignment_nine" in class_assign.locals assert "assignment_nine" not in module.locals assert isinstance(class_assign.locals.get("assignment_nine")[0], nodes.AssignName) def test_get_doc() -> None: node = astroid.extract_node( """ def func(): "Docstring" return 1 """ ) assert node.doc == "Docstring" node = astroid.extract_node( """ def func(): ... return 1 """ ) assert node.doc is None @test_utils.require_version(minver="3.8") def test_parse_fstring_debug_mode() -> None: node = astroid.extract_node('f"{3=}"') assert isinstance(node, nodes.JoinedStr) assert node.as_string() == "f'3={3!r}'" @pytest.mark.skipif(not HAS_TYPED_AST, reason="requires typed_ast") def test_parse_type_comments_with_proper_parent() -> None: code = """ class D: #@ @staticmethod def g( x # type: np.array ): pass """ node = astroid.extract_node(code) func = node.getattr("g")[0] type_comments = func.args.type_comment_args assert len(type_comments) == 1 type_comment = type_comments[0] assert isinstance(type_comment, astroid.Attribute) assert isinstance(type_comment.parent, astroid.Expr) assert isinstance(type_comment.parent.parent, astroid.Arguments) def test_const_itered() -> None: code = 'a = "string"' node = astroid.extract_node(code).value assert isinstance(node, astroid.Const) itered = node.itered() assert len(itered) == 6 assert [elem.value for elem in itered] == list("string") def test_is_generator_for_yield_in_while() -> None: code = """ def paused_iter(iterable): while True: # Continue to yield the same item until `next(i)` or `i.send(False)` while (yield value): pass """ node = astroid.extract_node(code) assert bool(node.is_generator()) def test_is_generator_for_yield_in_if() -> None: code = """ import asyncio def paused_iter(iterable): if (yield from asyncio.sleep(0.01)): pass return """ node = astroid.extract_node(code) assert bool(node.is_generator()) def test_is_generator_for_yield_in_aug_assign() -> None: code = """ def test(): buf = '' while True: buf += yield """ node = astroid.extract_node(code) assert bool(node.is_generator()) @pytest.mark.skipif(not PY310_PLUS, reason="pattern matching was added in PY310") class TestPatternMatching: @staticmethod def test_match_simple(): code = textwrap.dedent( """ match status: case 200: pass case 401 | 402 | 403: pass case None: pass case _: pass """ ).strip() node = builder.extract_node(code) assert node.as_string() == code assert isinstance(node, nodes.Match) assert isinstance(node.subject, nodes.Name) assert node.subject.name == "status" assert isinstance(node.cases, list) and len(node.cases) == 4 case0, case1, case2, case3 = node.cases assert list(node.get_children()) == [node.subject, *node.cases] assert isinstance(case0.pattern, nodes.MatchValue) assert ( isinstance(case0.pattern.value, astroid.Const) and case0.pattern.value.value == 200 ) assert list(case0.pattern.get_children()) == [case0.pattern.value] assert case0.guard is None assert isinstance(case0.body[0], astroid.Pass) assert list(case0.get_children()) == [case0.pattern, case0.body[0]] assert isinstance(case1.pattern, nodes.MatchOr) assert ( isinstance(case1.pattern.patterns, list) and len(case1.pattern.patterns) == 3 ) for i in range(3): match_value = case1.pattern.patterns[i] assert isinstance(match_value, nodes.MatchValue) assert isinstance(match_value.value, nodes.Const) assert match_value.value.value == (401, 402, 403)[i] assert list(case1.pattern.get_children()) == case1.pattern.patterns assert isinstance(case2.pattern, nodes.MatchSingleton) assert case2.pattern.value is None assert not list(case2.pattern.get_children()) assert isinstance(case3.pattern, nodes.MatchAs) assert case3.pattern.name is None assert case3.pattern.pattern is None assert not list(case3.pattern.get_children()) @staticmethod def test_match_sequence(): code = textwrap.dedent( """ match status: case [x, 2, _, *rest] as y if x > 2: pass """ ).strip() node = builder.extract_node(code) assert node.as_string() == code assert isinstance(node, nodes.Match) assert isinstance(node.cases, list) and len(node.cases) == 1 case = node.cases[0] assert isinstance(case.pattern, nodes.MatchAs) assert isinstance(case.pattern.name, nodes.AssignName) assert case.pattern.name.name == "y" assert list(case.pattern.get_children()) == [ case.pattern.pattern, case.pattern.name, ] assert isinstance(case.guard, nodes.Compare) assert isinstance(case.body[0], nodes.Pass) assert list(case.get_children()) == [case.pattern, case.guard, case.body[0]] pattern_seq = case.pattern.pattern assert isinstance(pattern_seq, nodes.MatchSequence) assert isinstance(pattern_seq.patterns, list) and len(pattern_seq.patterns) == 4 assert ( isinstance(pattern_seq.patterns[0], nodes.MatchAs) and isinstance(pattern_seq.patterns[0].name, nodes.AssignName) and pattern_seq.patterns[0].name.name == "x" and pattern_seq.patterns[0].pattern is None ) assert ( isinstance(pattern_seq.patterns[1], nodes.MatchValue) and isinstance(pattern_seq.patterns[1].value, nodes.Const) and pattern_seq.patterns[1].value.value == 2 ) assert ( isinstance(pattern_seq.patterns[2], nodes.MatchAs) and pattern_seq.patterns[2].name is None ) assert ( isinstance(pattern_seq.patterns[3], nodes.MatchStar) and isinstance(pattern_seq.patterns[3].name, nodes.AssignName) and pattern_seq.patterns[3].name.name == "rest" ) assert list(pattern_seq.patterns[3].get_children()) == [ pattern_seq.patterns[3].name ] assert list(pattern_seq.get_children()) == pattern_seq.patterns @staticmethod def test_match_mapping(): code = textwrap.dedent( """ match status: case {0: x, 1: _}: pass case {**rest}: pass """ ).strip() node = builder.extract_node(code) assert node.as_string() == code assert isinstance(node, nodes.Match) assert isinstance(node.cases, list) and len(node.cases) == 2 case0, case1 = node.cases assert isinstance(case0.pattern, nodes.MatchMapping) assert case0.pattern.rest is None assert isinstance(case0.pattern.keys, list) and len(case0.pattern.keys) == 2 assert ( isinstance(case0.pattern.patterns, list) and len(case0.pattern.patterns) == 2 ) for i in range(2): key = case0.pattern.keys[i] assert isinstance(key, nodes.Const) assert key.value == i pattern = case0.pattern.patterns[i] assert isinstance(pattern, nodes.MatchAs) if i == 0: assert isinstance(pattern.name, nodes.AssignName) assert pattern.name.name == "x" elif i == 1: assert pattern.name is None assert list(case0.pattern.get_children()) == [ *case0.pattern.keys, *case0.pattern.patterns, ] assert isinstance(case1.pattern, nodes.MatchMapping) assert isinstance(case1.pattern.rest, nodes.AssignName) assert case1.pattern.rest.name == "rest" assert isinstance(case1.pattern.keys, list) and len(case1.pattern.keys) == 0 assert ( isinstance(case1.pattern.patterns, list) and len(case1.pattern.patterns) == 0 ) assert list(case1.pattern.get_children()) == [case1.pattern.rest] @staticmethod def test_match_class(): code = textwrap.dedent( """ match x: case Point2D(0, a): pass case Point3D(x=0, y=1, z=b): pass """ ).strip() node = builder.extract_node(code) assert node.as_string() == code assert isinstance(node, nodes.Match) assert isinstance(node.cases, list) and len(node.cases) == 2 case0, case1 = node.cases assert isinstance(case0.pattern, nodes.MatchClass) assert isinstance(case0.pattern.cls, nodes.Name) assert case0.pattern.cls.name == "Point2D" assert ( isinstance(case0.pattern.patterns, list) and len(case0.pattern.patterns) == 2 ) match_value = case0.pattern.patterns[0] assert ( isinstance(match_value, nodes.MatchValue) and isinstance(match_value.value, nodes.Const) and match_value.value.value == 0 ) match_as = case0.pattern.patterns[1] assert ( isinstance(match_as, nodes.MatchAs) and match_as.pattern is None and isinstance(match_as.name, nodes.AssignName) and match_as.name.name == "a" ) assert list(case0.pattern.get_children()) == [ case0.pattern.cls, *case0.pattern.patterns, ] assert isinstance(case1.pattern, nodes.MatchClass) assert isinstance(case1.pattern.cls, nodes.Name) assert case1.pattern.cls.name == "Point3D" assert ( isinstance(case1.pattern.patterns, list) and len(case1.pattern.patterns) == 0 ) assert ( isinstance(case1.pattern.kwd_attrs, list) and len(case1.pattern.kwd_attrs) == 3 ) assert ( isinstance(case1.pattern.kwd_patterns, list) and len(case1.pattern.kwd_patterns) == 3 ) for i in range(2): assert case1.pattern.kwd_attrs[i] == ("x", "y")[i] kwd_pattern = case1.pattern.kwd_patterns[i] assert isinstance(kwd_pattern, nodes.MatchValue) assert isinstance(kwd_pattern.value, nodes.Const) assert kwd_pattern.value.value == i assert case1.pattern.kwd_attrs[2] == "z" kwd_pattern = case1.pattern.kwd_patterns[2] assert ( isinstance(kwd_pattern, nodes.MatchAs) and kwd_pattern.pattern is None and isinstance(kwd_pattern.name, nodes.AssignName) and kwd_pattern.name.name == "b" ) assert list(case1.pattern.get_children()) == [ case1.pattern.cls, *case1.pattern.kwd_patterns, ] if __name__ == "__main__": unittest.main()