1"""
  2Type annotation analyzer for Python codebases.
  3
  4Analyzes Python files to determine the percentage of functions/methods
  5that have complete type annotations (parameters and return types).
  6"""
  7
  8from __future__ import annotations
  9
 10import ast
 11import os
 12import re
 13from dataclasses import dataclass, field
 14from fnmatch import fnmatch
 15from pathlib import Path
 16
 17
 18@dataclass
 19class FunctionInfo:
 20    """Information about a function/method for type checking."""
 21
 22    name: str
 23    file: str
 24    line: int
 25    is_method: bool = False
 26    has_return_type: bool = False
 27    total_params: int = 0
 28    typed_params: int = 0
 29    is_property: bool = False
 30
 31    @property
 32    def is_fully_typed(self) -> bool:
 33        """Check if function has all type annotations."""
 34        return self.has_return_type and (self.typed_params == self.total_params)
 35
 36
 37@dataclass
 38class FileStats:
 39    """Statistics for a single Python file."""
 40
 41    path: str
 42    functions: list[FunctionInfo] = field(default_factory=list)
 43    ignore_comments: int = 0
 44    cast_calls: int = 0
 45    assert_statements: int = 0
 46
 47    @property
 48    def total_functions(self) -> int:
 49        return len(self.functions)
 50
 51    @property
 52    def fully_typed_functions(self) -> int:
 53        return sum(1 for f in self.functions if f.is_fully_typed)
 54
 55    @property
 56    def missing_functions(self) -> int:
 57        return self.total_functions - self.fully_typed_functions
 58
 59
 60class TypeAnnotationAnalyzer(ast.NodeVisitor):
 61    """AST visitor to analyze type annotations in Python code."""
 62
 63    def __init__(self, file_path: str) -> None:
 64        self.file_path = file_path
 65        self.functions: list[FunctionInfo] = []
 66        self.class_stack: list[str] = []
 67
 68    def visit_ClassDef(self, node: ast.ClassDef) -> None:
 69        """Track when we enter/exit a class."""
 70        self.class_stack.append(node.name)
 71        self.generic_visit(node)
 72        self.class_stack.pop()
 73
 74    def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
 75        """Analyze function definitions."""
 76        self._analyze_function(node)
 77        self.generic_visit(node)
 78
 79    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
 80        """Analyze async function definitions."""
 81        self._analyze_function(node)
 82        self.generic_visit(node)
 83
 84    def _analyze_function(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> None:
 85        """Analyze a function/method for type annotations."""
 86        # Skip __init__ return type check (it's always None implicitly)
 87        is_init = node.name == "__init__"
 88
 89        # Check if it's a method (inside a class)
 90        is_method = bool(self.class_stack)
 91
 92        # Check decorators
 93        is_property = any(
 94            (isinstance(d, ast.Name) and d.id == "property")
 95            or (isinstance(d, ast.Attribute) and d.attr == "property")
 96            for d in node.decorator_list
 97        )
 98
 99        # Create function info
100        func_info = FunctionInfo(
101            name=node.name,
102            file=self.file_path,
103            line=node.lineno,
104            is_method=is_method,
105            is_property=is_property,
106        )
107
108        # Check return type (not needed for __init__)
109        if not is_init:
110            func_info.has_return_type = node.returns is not None
111        else:
112            func_info.has_return_type = True
113
114        def handle_param(arg: ast.arg) -> None:
115            if is_method and arg.arg in {"self", "cls"}:
116                return
117
118            func_info.total_params += 1
119            if arg.annotation is not None:
120                func_info.typed_params += 1
121
122        # Analyze parameters
123        for arg in node.args.posonlyargs:
124            handle_param(arg)
125
126        for arg in node.args.args:
127            handle_param(arg)
128
129        for arg in node.args.kwonlyargs:
130            handle_param(arg)
131
132        # Check *args and **kwargs
133        if node.args.vararg:
134            func_info.total_params += 1
135            if node.args.vararg.annotation is not None:
136                func_info.typed_params += 1
137
138        if node.args.kwarg:
139            func_info.total_params += 1
140            if node.args.kwarg.annotation is not None:
141                func_info.typed_params += 1
142
143        self.functions.append(func_info)
144
145
146def count_ignore_comments(content: str) -> int:
147    """Count type: ignore comments in the file."""
148    count = 0
149    pattern = r"#\s*type:\s*ignore"
150
151    for line in content.split("\n"):
152        if re.search(pattern, line, re.IGNORECASE):
153            count += 1
154
155    return count
156
157
158def count_cast_calls(content: str) -> int:
159    """Count cast() function calls in the file."""
160    # Match both 'cast(' and 'typing.cast('
161    patterns = [
162        r"\bcast\s*\(",
163        r"\btyping\.cast\s*\(",
164    ]
165
166    count = 0
167    for line in content.split("\n"):
168        for pattern in patterns:
169            count += len(re.findall(pattern, line))
170
171    return count
172
173
174class AssertCounter(ast.NodeVisitor):
175    """AST visitor to count assert statements."""
176
177    def __init__(self) -> None:
178        self.count = 0
179
180    def visit_Assert(self, node: ast.Assert) -> None:
181        self.count += 1
182        self.generic_visit(node)
183
184
185def count_assert_statements(tree: ast.AST) -> int:
186    """Count assert statements in the AST."""
187    counter = AssertCounter()
188    counter.visit(tree)
189    return counter.count
190
191
192def analyze_file(file_path: Path) -> FileStats | None:
193    """Analyze a single Python file for type annotations."""
194    try:
195        with open(file_path, encoding="utf-8") as f:
196            content = f.read()
197
198        tree = ast.parse(content, filename=str(file_path))
199        analyzer = TypeAnnotationAnalyzer(str(file_path))
200        analyzer.visit(tree)
201
202        ignore_count = count_ignore_comments(content)
203        cast_count = count_cast_calls(content)
204        assert_count = count_assert_statements(tree)
205
206        stats = FileStats(
207            path=str(file_path),
208            functions=analyzer.functions,
209            ignore_comments=ignore_count,
210            cast_calls=cast_count,
211            assert_statements=assert_count,
212        )
213        return stats
214
215    except (SyntaxError, UnicodeDecodeError):
216        return None
217
218
219def find_python_files(
220    directory: Path, exclude_patterns: list[str] | None = None
221) -> list[Path]:
222    """Find all Python files in a directory, excluding certain patterns."""
223    default_patterns = [
224        "__pycache__",
225        ".git",
226        ".venv",
227        "venv",
228        "env",
229        ".tox",
230        "build",
231        "dist",
232        "*.egg-info",
233        ".mypy_cache",
234        ".pytest_cache",
235        "node_modules",
236        # Exclude test files from annotation metrics
237        "test_*.py",
238        "*_test.py",
239        "tests",
240        "test",
241    ]
242
243    patterns = list(default_patterns)
244    if exclude_patterns:
245        patterns.extend(exclude_patterns)
246
247    def should_exclude(path: Path) -> bool:
248        try:
249            relative = path.relative_to(directory).as_posix()
250        except ValueError:
251            relative = path.as_posix()
252
253        candidates = {relative, path.as_posix(), path.name}
254        for pattern in patterns:
255            if any(fnmatch(candidate, pattern) for candidate in candidates):
256                return True
257        return False
258
259    python_files = []
260
261    for root, dirs, files in os.walk(directory):
262        # Filter out excluded directories
263        root_path = Path(root)
264        dirs[:] = [d for d in dirs if not should_exclude(root_path / d)]
265
266        for file in files:
267            if file.endswith(".py"):
268                file_path = root_path / file
269                # Check if file path matches excluded patterns
270                if not should_exclude(file_path):
271                    python_files.append(file_path)
272
273    return python_files
274
275
276@dataclass
277class AnnotationResult:
278    """Result of annotation analysis."""
279
280    total_functions: int
281    fully_typed_functions: int
282    missing_count: int
283    total_ignores: int
284    total_casts: int
285    total_asserts: int
286    file_stats: list[FileStats]
287
288    @property
289    def coverage_percentage(self) -> float:
290        if self.total_functions == 0:
291            return 100.0
292        return (self.fully_typed_functions / self.total_functions) * 100
293
294
295def check_annotations(
296    path: str, exclude_patterns: list[str] | None = None
297) -> AnnotationResult:
298    """Check type annotations in the given path."""
299    target = Path(path)
300
301    if target.is_file():
302        if not target.suffix == ".py":
303            return AnnotationResult(
304                total_functions=0,
305                fully_typed_functions=0,
306                missing_count=0,
307                total_ignores=0,
308                total_casts=0,
309                total_asserts=0,
310                file_stats=[],
311            )
312        python_files = [target]
313    elif target.is_dir():
314        python_files = find_python_files(target, exclude_patterns)
315    else:
316        return AnnotationResult(
317            total_functions=0,
318            fully_typed_functions=0,
319            missing_count=0,
320            total_ignores=0,
321            total_casts=0,
322            total_asserts=0,
323            file_stats=[],
324        )
325
326    all_stats = []
327    for file_path in python_files:
328        stats = analyze_file(file_path)
329        if stats:
330            all_stats.append(stats)
331
332    total_functions = sum(s.total_functions for s in all_stats)
333    fully_typed_functions = sum(s.fully_typed_functions for s in all_stats)
334    total_ignores = sum(s.ignore_comments for s in all_stats)
335    total_casts = sum(s.cast_calls for s in all_stats)
336    total_asserts = sum(s.assert_statements for s in all_stats)
337
338    return AnnotationResult(
339        total_functions=total_functions,
340        fully_typed_functions=fully_typed_functions,
341        missing_count=total_functions - fully_typed_functions,
342        total_ignores=total_ignores,
343        total_casts=total_casts,
344        total_asserts=total_asserts,
345        file_stats=all_stats,
346    )