mirror of https://github.com/python/cpython.git
389 lines
13 KiB
Python
389 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
import pkgutil
|
|
import sys
|
|
import token
|
|
import tokenize
|
|
from io import StringIO
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from itertools import chain
|
|
from tokenize import TokenInfo
|
|
|
|
TYPE_CHECKING = False
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import Any, Iterable, Iterator, Mapping
|
|
|
|
|
|
def make_default_module_completer() -> ModuleCompleter:
|
|
# Inside pyrepl, __package__ is set to None by default
|
|
return ModuleCompleter(namespace={'__package__': None})
|
|
|
|
|
|
class ModuleCompleter:
|
|
"""A completer for Python import statements.
|
|
|
|
Examples:
|
|
- import <tab>
|
|
- import foo<tab>
|
|
- import foo.<tab>
|
|
- import foo as bar, baz<tab>
|
|
|
|
- from <tab>
|
|
- from foo<tab>
|
|
- from foo import <tab>
|
|
- from foo import bar<tab>
|
|
- from foo import (bar as baz, qux<tab>
|
|
"""
|
|
|
|
def __init__(self, namespace: Mapping[str, Any] | None = None) -> None:
|
|
self.namespace = namespace or {}
|
|
self._global_cache: list[pkgutil.ModuleInfo] = []
|
|
self._curr_sys_path: list[str] = sys.path[:]
|
|
|
|
def get_completions(self, line: str) -> list[str] | None:
|
|
"""Return the next possible import completions for 'line'."""
|
|
result = ImportParser(line).parse()
|
|
if not result:
|
|
return None
|
|
try:
|
|
return self.complete(*result)
|
|
except Exception:
|
|
# Some unexpected error occurred, make it look like
|
|
# no completions are available
|
|
return []
|
|
|
|
def complete(self, from_name: str | None, name: str | None) -> list[str]:
|
|
if from_name is None:
|
|
# import x.y.z<tab>
|
|
assert name is not None
|
|
path, prefix = self.get_path_and_prefix(name)
|
|
modules = self.find_modules(path, prefix)
|
|
return [self.format_completion(path, module) for module in modules]
|
|
|
|
if name is None:
|
|
# from x.y.z<tab>
|
|
path, prefix = self.get_path_and_prefix(from_name)
|
|
modules = self.find_modules(path, prefix)
|
|
return [self.format_completion(path, module) for module in modules]
|
|
|
|
# from x.y import z<tab>
|
|
return self.find_modules(from_name, name)
|
|
|
|
def find_modules(self, path: str, prefix: str) -> list[str]:
|
|
"""Find all modules under 'path' that start with 'prefix'."""
|
|
modules = self._find_modules(path, prefix)
|
|
# Filter out invalid module names
|
|
# (for example those containing dashes that cannot be imported with 'import')
|
|
return [mod for mod in modules if mod.isidentifier()]
|
|
|
|
def _find_modules(self, path: str, prefix: str) -> list[str]:
|
|
if not path:
|
|
# Top-level import (e.g. `import foo<tab>`` or `from foo<tab>`)`
|
|
builtin_modules = [name for name in sys.builtin_module_names
|
|
if self.is_suggestion_match(name, prefix)]
|
|
third_party_modules = [module.name for module in self.global_cache
|
|
if self.is_suggestion_match(module.name, prefix)]
|
|
return sorted(builtin_modules + third_party_modules)
|
|
|
|
if path.startswith('.'):
|
|
# Convert relative path to absolute path
|
|
package = self.namespace.get('__package__', '')
|
|
path = self.resolve_relative_name(path, package) # type: ignore[assignment]
|
|
if path is None:
|
|
return []
|
|
|
|
modules: Iterable[pkgutil.ModuleInfo] = self.global_cache
|
|
for segment in path.split('.'):
|
|
modules = [mod_info for mod_info in modules
|
|
if mod_info.ispkg and mod_info.name == segment]
|
|
modules = self.iter_submodules(modules)
|
|
return [module.name for module in modules
|
|
if self.is_suggestion_match(module.name, prefix)]
|
|
|
|
def is_suggestion_match(self, module_name: str, prefix: str) -> bool:
|
|
if prefix:
|
|
return module_name.startswith(prefix)
|
|
# For consistency with attribute completion, which
|
|
# does not suggest private attributes unless requested.
|
|
return not module_name.startswith("_")
|
|
|
|
def iter_submodules(self, parent_modules: list[pkgutil.ModuleInfo]) -> Iterator[pkgutil.ModuleInfo]:
|
|
"""Iterate over all submodules of the given parent modules."""
|
|
specs = [info.module_finder.find_spec(info.name, None)
|
|
for info in parent_modules if info.ispkg]
|
|
search_locations = set(chain.from_iterable(
|
|
getattr(spec, 'submodule_search_locations', [])
|
|
for spec in specs if spec
|
|
))
|
|
return pkgutil.iter_modules(search_locations)
|
|
|
|
def get_path_and_prefix(self, dotted_name: str) -> tuple[str, str]:
|
|
"""
|
|
Split a dotted name into an import path and a
|
|
final prefix that is to be completed.
|
|
|
|
Examples:
|
|
'foo.bar' -> 'foo', 'bar'
|
|
'foo.' -> 'foo', ''
|
|
'.foo' -> '.', 'foo'
|
|
"""
|
|
if '.' not in dotted_name:
|
|
return '', dotted_name
|
|
if dotted_name.startswith('.'):
|
|
stripped = dotted_name.lstrip('.')
|
|
dots = '.' * (len(dotted_name) - len(stripped))
|
|
if '.' not in stripped:
|
|
return dots, stripped
|
|
path, prefix = stripped.rsplit('.', 1)
|
|
return dots + path, prefix
|
|
path, prefix = dotted_name.rsplit('.', 1)
|
|
return path, prefix
|
|
|
|
def format_completion(self, path: str, module: str) -> str:
|
|
if path == '' or path.endswith('.'):
|
|
return f'{path}{module}'
|
|
return f'{path}.{module}'
|
|
|
|
def resolve_relative_name(self, name: str, package: str) -> str | None:
|
|
"""Resolve a relative module name to an absolute name.
|
|
|
|
Example: resolve_relative_name('.foo', 'bar') -> 'bar.foo'
|
|
"""
|
|
# taken from importlib._bootstrap
|
|
level = 0
|
|
for character in name:
|
|
if character != '.':
|
|
break
|
|
level += 1
|
|
bits = package.rsplit('.', level - 1)
|
|
if len(bits) < level:
|
|
return None
|
|
base = bits[0]
|
|
name = name[level:]
|
|
return f'{base}.{name}' if name else base
|
|
|
|
@property
|
|
def global_cache(self) -> list[pkgutil.ModuleInfo]:
|
|
"""Global module cache"""
|
|
if not self._global_cache or self._curr_sys_path != sys.path:
|
|
self._curr_sys_path = sys.path[:]
|
|
# print('getting packages')
|
|
self._global_cache = list(pkgutil.iter_modules())
|
|
return self._global_cache
|
|
|
|
|
|
class ImportParser:
|
|
"""
|
|
Parses incomplete import statements that are
|
|
suitable for autocomplete suggestions.
|
|
|
|
Examples:
|
|
- import foo -> Result(from_name=None, name='foo')
|
|
- import foo. -> Result(from_name=None, name='foo.')
|
|
- from foo -> Result(from_name='foo', name=None)
|
|
- from foo import bar -> Result(from_name='foo', name='bar')
|
|
- from .foo import ( -> Result(from_name='.foo', name='')
|
|
|
|
Note that the parser works in reverse order, starting from the
|
|
last token in the input string. This makes the parser more robust
|
|
when parsing multiple statements.
|
|
"""
|
|
_ignored_tokens = {
|
|
token.INDENT, token.DEDENT, token.COMMENT,
|
|
token.NL, token.NEWLINE, token.ENDMARKER
|
|
}
|
|
_keywords = {'import', 'from', 'as'}
|
|
|
|
def __init__(self, code: str) -> None:
|
|
self.code = code
|
|
tokens = []
|
|
try:
|
|
for t in tokenize.generate_tokens(StringIO(code).readline):
|
|
if t.type not in self._ignored_tokens:
|
|
tokens.append(t)
|
|
except tokenize.TokenError as e:
|
|
if 'unexpected EOF' not in str(e):
|
|
# unexpected EOF is fine, since we're parsing an
|
|
# incomplete statement, but other errors are not
|
|
# because we may not have all the tokens so it's
|
|
# safer to bail out
|
|
tokens = []
|
|
except SyntaxError:
|
|
tokens = []
|
|
self.tokens = TokenQueue(tokens[::-1])
|
|
|
|
def parse(self) -> tuple[str | None, str | None] | None:
|
|
if not (res := self._parse()):
|
|
return None
|
|
return res.from_name, res.name
|
|
|
|
def _parse(self) -> Result | None:
|
|
with self.tokens.save_state():
|
|
return self.parse_from_import()
|
|
with self.tokens.save_state():
|
|
return self.parse_import()
|
|
|
|
def parse_import(self) -> Result:
|
|
if self.code.rstrip().endswith('import') and self.code.endswith(' '):
|
|
return Result(name='')
|
|
if self.tokens.peek_string(','):
|
|
name = ''
|
|
else:
|
|
if self.code.endswith(' '):
|
|
raise ParseError('parse_import')
|
|
name = self.parse_dotted_name()
|
|
if name.startswith('.'):
|
|
raise ParseError('parse_import')
|
|
while self.tokens.peek_string(','):
|
|
self.tokens.pop()
|
|
self.parse_dotted_as_name()
|
|
if self.tokens.peek_string('import'):
|
|
return Result(name=name)
|
|
raise ParseError('parse_import')
|
|
|
|
def parse_from_import(self) -> Result:
|
|
stripped = self.code.rstrip()
|
|
if stripped.endswith('import') and self.code.endswith(' '):
|
|
return Result(from_name=self.parse_empty_from_import(), name='')
|
|
if stripped.endswith('from') and self.code.endswith(' '):
|
|
return Result(from_name='')
|
|
if self.tokens.peek_string('(') or self.tokens.peek_string(','):
|
|
return Result(from_name=self.parse_empty_from_import(), name='')
|
|
if self.code.endswith(' '):
|
|
raise ParseError('parse_from_import')
|
|
name = self.parse_dotted_name()
|
|
if '.' in name:
|
|
self.tokens.pop_string('from')
|
|
return Result(from_name=name)
|
|
if self.tokens.peek_string('from'):
|
|
return Result(from_name=name)
|
|
from_name = self.parse_empty_from_import()
|
|
return Result(from_name=from_name, name=name)
|
|
|
|
def parse_empty_from_import(self) -> str:
|
|
if self.tokens.peek_string(','):
|
|
self.tokens.pop()
|
|
self.parse_as_names()
|
|
if self.tokens.peek_string('('):
|
|
self.tokens.pop()
|
|
self.tokens.pop_string('import')
|
|
return self.parse_from()
|
|
|
|
def parse_from(self) -> str:
|
|
from_name = self.parse_dotted_name()
|
|
self.tokens.pop_string('from')
|
|
return from_name
|
|
|
|
def parse_dotted_as_name(self) -> str:
|
|
self.tokens.pop_name()
|
|
if self.tokens.peek_string('as'):
|
|
self.tokens.pop()
|
|
with self.tokens.save_state():
|
|
return self.parse_dotted_name()
|
|
|
|
def parse_dotted_name(self) -> str:
|
|
name = []
|
|
if self.tokens.peek_string('.'):
|
|
name.append('.')
|
|
self.tokens.pop()
|
|
if (self.tokens.peek_name()
|
|
and (tok := self.tokens.peek())
|
|
and tok.string not in self._keywords):
|
|
name.append(self.tokens.pop_name())
|
|
if not name:
|
|
raise ParseError('parse_dotted_name')
|
|
while self.tokens.peek_string('.'):
|
|
name.append('.')
|
|
self.tokens.pop()
|
|
if (self.tokens.peek_name()
|
|
and (tok := self.tokens.peek())
|
|
and tok.string not in self._keywords):
|
|
name.append(self.tokens.pop_name())
|
|
else:
|
|
break
|
|
|
|
while self.tokens.peek_string('.'):
|
|
name.append('.')
|
|
self.tokens.pop()
|
|
return ''.join(name[::-1])
|
|
|
|
def parse_as_names(self) -> None:
|
|
self.parse_as_name()
|
|
while self.tokens.peek_string(','):
|
|
self.tokens.pop()
|
|
self.parse_as_name()
|
|
|
|
def parse_as_name(self) -> None:
|
|
self.tokens.pop_name()
|
|
if self.tokens.peek_string('as'):
|
|
self.tokens.pop()
|
|
self.tokens.pop_name()
|
|
|
|
|
|
class ParseError(Exception):
|
|
pass
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Result:
|
|
from_name: str | None = None
|
|
name: str | None = None
|
|
|
|
|
|
class TokenQueue:
|
|
"""Provides helper functions for working with a sequence of tokens."""
|
|
|
|
def __init__(self, tokens: list[TokenInfo]) -> None:
|
|
self.tokens: list[TokenInfo] = tokens
|
|
self.index: int = 0
|
|
self.stack: list[int] = []
|
|
|
|
@contextmanager
|
|
def save_state(self) -> Any:
|
|
try:
|
|
self.stack.append(self.index)
|
|
yield
|
|
except ParseError:
|
|
self.index = self.stack.pop()
|
|
else:
|
|
self.stack.pop()
|
|
|
|
def __bool__(self) -> bool:
|
|
return self.index < len(self.tokens)
|
|
|
|
def peek(self) -> TokenInfo | None:
|
|
if not self:
|
|
return None
|
|
return self.tokens[self.index]
|
|
|
|
def peek_name(self) -> bool:
|
|
if not (tok := self.peek()):
|
|
return False
|
|
return tok.type == token.NAME
|
|
|
|
def pop_name(self) -> str:
|
|
tok = self.pop()
|
|
if tok.type != token.NAME:
|
|
raise ParseError('pop_name')
|
|
return tok.string
|
|
|
|
def peek_string(self, string: str) -> bool:
|
|
if not (tok := self.peek()):
|
|
return False
|
|
return tok.string == string
|
|
|
|
def pop_string(self, string: str) -> str:
|
|
tok = self.pop()
|
|
if tok.string != string:
|
|
raise ParseError('pop_string')
|
|
return tok.string
|
|
|
|
def pop(self) -> TokenInfo:
|
|
if not self:
|
|
raise ParseError('pop')
|
|
tok = self.tokens[self.index]
|
|
self.index += 1
|
|
return tok
|