Applied 236 return type annotations to functions with no return values across 100+ files (core modules, tests, scripts, simulations). Added Phase 4 to python_style_refactor track for remaining 597 items (untyped params, vars, and functions with return values). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
122 lines
2.4 KiB
Python
122 lines
2.4 KiB
Python
import pytest
|
|
import textwrap
|
|
from scripts.ai_style_formatter import format_code
|
|
|
|
def test_basic_indentation() -> None:
|
|
source = textwrap.dedent("""\
|
|
def hello():
|
|
print("world")
|
|
if True:
|
|
print("nested")
|
|
""")
|
|
expected = (
|
|
"def hello():\n"
|
|
" print(\"world\")\n"
|
|
" if True:\n"
|
|
" print(\"nested\")\n"
|
|
)
|
|
assert format_code(source) == expected
|
|
|
|
def test_top_level_blank_lines() -> None:
|
|
source = textwrap.dedent("""\
|
|
def a():
|
|
pass
|
|
|
|
|
|
def b():
|
|
pass
|
|
""")
|
|
expected = (
|
|
"def a():\n"
|
|
" pass\n"
|
|
"\n"
|
|
"def b():\n"
|
|
" pass\n"
|
|
)
|
|
assert format_code(source) == expected
|
|
|
|
def test_inner_blank_lines() -> None:
|
|
source = textwrap.dedent("""\
|
|
def a():
|
|
print("start")
|
|
|
|
print("end")
|
|
""")
|
|
expected = (
|
|
"def a():\n"
|
|
" print(\"start\")\n"
|
|
" print(\"end\")\n"
|
|
)
|
|
assert format_code(source) == expected
|
|
|
|
def test_multiline_string_safety() -> None:
|
|
source = textwrap.dedent("""\
|
|
def a():
|
|
'''
|
|
This is a multiline
|
|
string that should
|
|
not be reformatted
|
|
inside.
|
|
'''
|
|
pass
|
|
""")
|
|
# Note: the indentation of the ''' itself becomes 1 space.
|
|
# The content inside remains exactly as in source.
|
|
# textwrap.dedent will remove the common leading whitespace from the source.
|
|
# The source's ''' is at 4 spaces. Content is at 4 spaces.
|
|
# After dedent:
|
|
# def a():
|
|
# '''
|
|
# This is a...
|
|
result = format_code(source)
|
|
assert " This is a multiline" in result
|
|
assert result.startswith("def a():\n '''")
|
|
|
|
def test_continuation_indentation() -> None:
|
|
source = textwrap.dedent("""\
|
|
def long_func(
|
|
a,
|
|
b
|
|
):
|
|
return (
|
|
a +
|
|
b
|
|
)
|
|
""")
|
|
expected = (
|
|
"def long_func(\n"
|
|
" a,\n"
|
|
" b\n"
|
|
"):\n"
|
|
" return (\n"
|
|
" a +\n"
|
|
" b\n"
|
|
" )\n"
|
|
)
|
|
assert format_code(source) == expected
|
|
|
|
def test_multiple_top_level_definitions() -> None:
|
|
source = textwrap.dedent("""\
|
|
class MyClass:
|
|
def __init__(self):
|
|
self.x = 1
|
|
|
|
def method(self):
|
|
pass
|
|
|
|
|
|
def top_level():
|
|
pass
|
|
""")
|
|
expected = (
|
|
"class MyClass:\n"
|
|
" def __init__(self):\n"
|
|
" self.x = 1\n"
|
|
" def method(self):\n"
|
|
" pass\n"
|
|
"\n"
|
|
"def top_level():\n"
|
|
" pass\n"
|
|
)
|
|
assert format_code(source) == expected
|