Code Generation¶
Rejig provides tools for generating boilerplate code: dunder methods, properties, test stubs, and class conversions.
Dunder Method Generation¶
Generate common special methods for classes.
Generate __init__¶
from rejig import Rejig
rj = Rejig("src/")
cls = rj.find_class("Person")
# Generate __init__ from class attributes
cls.generate_init()
# Before:
# class Person:
# name: str
# age: int
# email: str | None
# After:
# class Person:
# name: str
# age: int
# email: str | None
#
# def __init__(self, name: str, age: int, email: str | None = None) -> None:
# self.name = name
# self.age = age
# self.email = email
Generate __repr__¶
cls.generate_repr()
# Generates:
# def __repr__(self) -> str:
# return f"Person(name={self.name!r}, age={self.age!r}, email={self.email!r})"
Generate __eq__¶
cls.generate_eq()
# Generates:
# def __eq__(self, other: object) -> bool:
# if not isinstance(other, Person):
# return NotImplemented
# return (self.name, self.age, self.email) == (other.name, other.age, other.email)
Generate __hash__¶
cls.generate_hash()
# Generates:
# def __hash__(self) -> int:
# return hash((self.name, self.age, self.email))
Generate All Common Dunders¶
Overwriting Existing Dunders¶
By default, generation skips a dunder that already exists. Pass
overwrite=True to replace it:
cls.generate_init(overwrite=True)
cls.generate_repr(overwrite=True)
cls.generate_eq(overwrite=True)
cls.generate_hash(overwrite=True)
cls.generate_all_dunders(overwrite=True)
Property Generation¶
Convert Attribute to Property¶
cls = rj.find_class("User")
# Convert an attribute to a property with getter/setter
cls.convert_attribute_to_property("email")
# Before:
# class User:
# email: str
# After:
# class User:
# _email: str
#
# @property
# def email(self) -> str:
# return self._email
#
# @email.setter
# def email(self, value: str) -> None:
# self._email = value
Add Property with Validation¶
add_property takes the getter body as a string, an optional setter body, and
an optional return type:
cls.add_property(
"age",
getter="self._age",
setter="""
if value < 0:
raise ValueError("Age cannot be negative")
self._age = value
""",
return_type="int",
)
Read-Only Property¶
cls.add_property(
"full_name",
getter="f'{self.first_name} {self.last_name}'",
return_type="str",
) # No setter argument means the property is read-only
Class Conversions¶
Convert to Dataclass¶
cls = rj.find_class("Config")
cls.convert_to_dataclass()
# Before:
# class Config:
# host: str
# port: int
# debug: bool
#
# def __init__(self, host: str, port: int, debug: bool = False):
# self.host = host
# self.port = port
# self.debug = debug
# After:
# from dataclasses import dataclass
#
# @dataclass
# class Config:
# host: str
# port: int
# debug: bool = False
Dataclass Options¶
cls.convert_to_dataclass(
frozen=True, # Immutable (frozen=True)
slots=True, # Use __slots__ (Python 3.10+)
)
Convert to NamedTuple¶
cls.convert_to_named_tuple()
# Before:
# class Point:
# x: int
# y: int
# After:
# from typing import NamedTuple
#
# class Point(NamedTuple):
# x: int
# y: int
Convert to TypedDict¶
cls.convert_to_typed_dict()
# Before:
# class UserData:
# name: str
# age: int
# email: str | None
# After:
# from typing import TypedDict
#
# class UserData(TypedDict):
# name: str
# age: int
# email: str | None
Convert from Dataclass¶
Protocol and ABC Extraction¶
Extract Protocol¶
cls = rj.find_class("DatabaseConnection")
# Extract a Protocol from the class's public interface
protocol = cls.extract_protocol("DatabaseProtocol")
# Creates:
# from typing import Protocol
#
# class DatabaseProtocol(Protocol):
# def connect(self) -> None: ...
# def execute(self, query: str) -> list[dict]: ...
# def close(self) -> None: ...
Extract Abstract Base Class¶
result = cls.extract_abstract_base("AbstractDatabase")
# Creates:
# from abc import ABC, abstractmethod
#
# class AbstractDatabase(ABC):
# @abstractmethod
# def connect(self) -> None: ...
#
# @abstractmethod
# def execute(self, query: str) -> list[dict]: ...
#
# @abstractmethod
# def close(self) -> None: ...
Inheritance Management¶
Add Base Class¶
cls = rj.find_class("MyModel")
cls.add_base_class("BaseModel")
# Before:
# class MyModel:
# After:
# class MyModel(BaseModel):
Add Mixin¶
cls.add_mixin("TimestampMixin")
# Before:
# class MyModel(BaseModel):
# After:
# class MyModel(TimestampMixin, BaseModel):
Remove Base Class¶
Test Generation¶
Test stubs are generated through the target API. Each method returns a
Result; the generated code is also available in result.data.
Generate a Test Stub for a Function¶
func = rj.find_function("process_data")
# Writes tests/test_<module>.py with a stub for the function
result = func.generate_test_stub()
print(result.data)
# def test_process_data():
# # Arrange
# items = []
#
# # Act
# result = process_data(items)
#
# # Assert
# assert result is not None # TODO: add specific assertions
Generate a Test File for a Class¶
cls = rj.find_class("UserService")
# Write a complete pytest test file with stubs for all public methods
result = cls.generate_test_file("tests/test_user_service.py")
# Or write into the default tests/ directory mirroring the source layout
cls.generate_test_stub()
Generate a Test for a Single Method¶
method = rj.find_class("UserService").find_method("create_user")
# Simple stub
method.generate_test()
# Parameterized test from explicit cases
method.generate_test(
test_cases=[
{"input": {"data": "valid"}, "expected": True},
{"input": {"data": ""}, "expected": False},
]
)
Generate a Test Class Scaffold¶
# Create a standalone test class (no source class required)
rj.generate_test_class("MyClass", include_setup=True)
Low-Level Generator¶
For finer control, use TestGenerator from rejig.generation directly. It
works on FunctionSignature objects (extracted via SignatureExtractor)
rather than targets:
from rejig.generation import TestGenerator
generator = TestGenerator()
# generator.generate_function_test_stub(signature)
# generator.generate_class_test_file(class_name, method_signatures)
# generator.generate_parameterized_test(signature, test_cases)
unittest to pytest Conversion¶
UnittestToPytestConverter is a LibCST transformer that rewrites
self.assertX(...) calls into plain assert statements. Apply it with
Rejig.transform_file:
from rejig.generation import UnittestToPytestConverter
rj.transform_file(
rj.root / "tests/test_utils.py",
UnittestToPytestConverter(),
)
# Before: self.assertEqual(sum(self.data), 6)
# After: assert sum(self.data) == 6
Batch Conversion¶
# Convert every test file in the working set
for file in rj.find_files("**/test_*.py"):
rj.transform_file(file.path, UnittestToPytestConverter())
Assertion Conversions¶
The converter handles these assertion patterns:
| unittest | pytest |
|---|---|
self.assertEqual(a, b) |
assert a == b |
self.assertNotEqual(a, b) |
assert a != b |
self.assertTrue(x) |
assert x |
self.assertFalse(x) |
assert not x |
self.assertIs(a, b) |
assert a is b |
self.assertIsNone(x) |
assert x is None |
self.assertIn(a, b) |
assert a in b |
self.assertRaises(E) |
pytest.raises(E) |
self.assertAlmostEqual(a, b) |
assert a == pytest.approx(b) |
Doctest Extraction¶
Extract Doctests to pytest¶
Convert the doctest examples in a function's docstring into pytest tests:
func = rj.find_function("add")
result = func.generate_tests_from_doctest()
print(result.data)
# def test_add_doctest_1():
# """Test from doctest example."""
# result = add(2, 3)
# assert result == 5
#
# def test_add_doctest_2():
# """Test from doctest example."""
# result = add(-1, 1)
# assert result == 0
# Or write to a specific file
func.generate_tests_from_doctest("tests/test_doctests.py")
Batch Generation¶
Generate for Multiple Classes¶
Generation methods skip elements that already exist (unless overwrite=True),
so they are safe to call across every class:
# Generate __repr__ for all classes that lack one
for cls in rj.find_classes():
cls.generate_repr()
# Generate test stubs for all public functions
for func in rj.find_functions():
if not func.name.startswith("_"):
func.generate_test_stub()
Add Missing Dunders¶
# generate_* is a no-op when the dunder already exists
for cls in rj.find_classes():
cls.generate_repr()
cls.generate_eq()
cls.generate_hash()
# Or all at once
cls.generate_all_dunders()