diff --git a/nmigen/hdl/ast.py b/nmigen/hdl/ast.py index 0ebde34..8fddad4 100644 --- a/nmigen/hdl/ast.py +++ b/nmigen/hdl/ast.py @@ -2,14 +2,13 @@ from abc import ABCMeta, abstractmethod import builtins import traceback from collections import OrderedDict -from collections.abc import Iterable, MutableMapping, MutableSet, MutableSequence +from typing import ClassVar, Union, Tuple, Optional, Iterable, MutableMapping, MutableSet, MutableSequence from .. import tracer from ..tools import * - __all__ = [ - "Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Repl", + "Valish", "Shape", "Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Repl", "Array", "ArrayProxy", "Sample", "Past", "Stable", "Rose", "Fell", "Signal", "ClockSignal", "ResetSignal", @@ -21,15 +20,23 @@ __all__ = [ class DUID: """Deterministic Unique IDentifier""" - __next_uid = 0 - def __init__(self): - self.duid = DUID.__next_uid + __next_uid: ClassVar[int] = 0 + def __init__(self) -> None: + self.duid: int = DUID.__next_uid DUID.__next_uid += 1 +Valish = Union['Value', bool, int] + + +Shape = Tuple[int, bool] + + class Value(metaclass=ABCMeta): + src_loc: Optional[tracer.SrcLoc] + @staticmethod - def wrap(obj): + def wrap(obj: Valish) -> 'Value': """Ensures that the passed object is an nMigen value. Booleans and integers are automatically wrapped into ``Const``.""" if isinstance(obj, Value): @@ -39,76 +46,76 @@ class Value(metaclass=ABCMeta): else: raise TypeError("Object '{!r}' is not an nMigen value".format(obj)) - def __init__(self, src_loc_at=0): + def __init__(self, src_loc_at: int = 0) -> None: super().__init__() self.src_loc = tracer.get_src_loc(1 + src_loc_at) - def __bool__(self): + def __bool__(self) -> bool: raise TypeError("Attempted to convert nMigen value to boolean") - def __invert__(self): + def __invert__(self) -> 'Value': return Operator("~", [self]) - def __neg__(self): + def __neg__(self) -> 'Value': return Operator("-", [self]) - def __add__(self, other): + def __add__(self, other: Valish) -> 'Value': return Operator("+", [self, other]) - def __radd__(self, other): + def __radd__(self, other: Valish) -> 'Value': return Operator("+", [other, self]) - def __sub__(self, other): + def __sub__(self, other: Valish) -> 'Value': return Operator("-", [self, other]) - def __rsub__(self, other): + def __rsub__(self, other: Valish) -> 'Value': return Operator("-", [other, self]) - def __mul__(self, other): + def __mul__(self, other: Valish) -> 'Value': return Operator("*", [self, other]) - def __rmul__(self, other): + def __rmul__(self, other: Valish) -> 'Value': return Operator("*", [other, self]) - def __mod__(self, other): + def __mod__(self, other: Valish) -> 'Value': return Operator("%", [self, other]) - def __rmod__(self, other): + def __rmod__(self, other: Valish) -> 'Value': return Operator("%", [other, self]) - def __div__(self, other): + def __div__(self, other: Valish) -> 'Value': return Operator("/", [self, other]) - def __rdiv__(self, other): + def __rdiv__(self, other: Valish) -> 'Value': return Operator("/", [other, self]) - def __lshift__(self, other): + def __lshift__(self, other: Valish) -> 'Value': return Operator("<<", [self, other]) - def __rlshift__(self, other): + def __rlshift__(self, other: Valish) -> 'Value': return Operator("<<", [other, self]) - def __rshift__(self, other): + def __rshift__(self, other: Valish) -> 'Value': return Operator(">>", [self, other]) - def __rrshift__(self, other): + def __rrshift__(self, other: Valish) -> 'Value': return Operator(">>", [other, self]) - def __and__(self, other): + def __and__(self, other: Valish) -> 'Value': return Operator("&", [self, other]) - def __rand__(self, other): + def __rand__(self, other: Valish) -> 'Value': return Operator("&", [other, self]) - def __xor__(self, other): + def __xor__(self, other: Valish) -> 'Value': return Operator("^", [self, other]) - def __rxor__(self, other): + def __rxor__(self, other: Valish) -> 'Value': return Operator("^", [other, self]) - def __or__(self, other): + def __or__(self, other: Valish) -> 'Value': return Operator("|", [self, other]) - def __ror__(self, other): + def __ror__(self, other: Valish) -> 'Value': return Operator("|", [other, self]) - def __eq__(self, other): + def __eq__(self, other: Valish) -> 'Value': # type: ignore return Operator("==", [self, other]) - def __ne__(self, other): + def __ne__(self, other: Valish) -> 'Value': # type: ignore return Operator("!=", [self, other]) - def __lt__(self, other): + def __lt__(self, other: Valish) -> 'Value': return Operator("<", [self, other]) - def __le__(self, other): + def __le__(self, other: Valish) -> 'Value': return Operator("<=", [self, other]) - def __gt__(self, other): + def __gt__(self, other: Valish) -> 'Value': return Operator(">", [self, other]) - def __ge__(self, other): + def __ge__(self, other: Valish) -> 'Value': return Operator(">=", [self, other]) - def __len__(self): + def __len__(self) -> int: return self.shape()[0] - def __getitem__(self, key): + def __getitem__(self, key: Union[int, slice]) -> 'Value': n = len(self) if isinstance(key, int): if key not in range(-n, n): @@ -124,7 +131,7 @@ class Value(metaclass=ABCMeta): else: raise TypeError("Cannot index value with {}".format(repr(key))) - def bool(self): + def bool(self) -> 'Value': """Conversion to boolean. Returns @@ -134,7 +141,7 @@ class Value(metaclass=ABCMeta): """ return Operator("b", [self]) - def implies(premise, conclusion): + def implies(self, conclusion: Valish) -> 'Value': """Implication. Returns @@ -142,9 +149,9 @@ class Value(metaclass=ABCMeta): Value, out ``0`` if ``premise`` is true and ``conclusion`` is not, ``1`` otherwise. """ - return ~premise | conclusion + return ~self | conclusion - def part(self, offset, width): + def part(self, offset: Valish, width: int) -> 'Part': """Indexed part-select. Selects a constant width but variable offset part of a ``Value``. @@ -163,7 +170,7 @@ class Value(metaclass=ABCMeta): """ return Part(self, offset, width) - def eq(self, value): + def eq(self, value: Valish) -> 'Assign': """Assignment. Parameters @@ -179,7 +186,7 @@ class Value(metaclass=ABCMeta): return Assign(self, value) @abstractmethod - def shape(self): + def shape(self) -> Shape: """Bit length and signedness of a value. Returns @@ -197,14 +204,14 @@ class Value(metaclass=ABCMeta): """ pass # :nocov: - def _lhs_signals(self): + def _lhs_signals(self) -> 'ValueSet': raise TypeError("Value {!r} cannot be used in assignments".format(self)) @abstractmethod - def _rhs_signals(self): + def _rhs_signals(self) -> 'ValueSet': pass # :nocov: - def _as_const(self): + def _as_const(self) -> int: raise TypeError("Value {!r} cannot be evaluated as constant".format(self)) @@ -228,7 +235,7 @@ class Const(Value): src_loc = None @staticmethod - def normalize(value, shape): + def normalize(value: int, shape: Shape) -> int: nbits, signed = shape mask = (1 << nbits) - 1 value &= mask @@ -236,7 +243,7 @@ class Const(Value): value |= ~mask return value - def __init__(self, value, shape=None): + def __init__(self, value: int, shape: Optional[Union[Shape, int]] = None) -> None: self.value = int(value) if shape is None: shape = bits_for(self.value), self.value < 0 @@ -247,16 +254,16 @@ class Const(Value): raise TypeError("Width must be a non-negative integer, not '{!r}'", self.nbits) self.value = self.normalize(self.value, shape) - def shape(self): + def shape(self) -> Shape: return self.nbits, self.signed - def _rhs_signals(self): + def _rhs_signals(self) -> 'ValueSet': return ValueSet() - def _as_const(self): + def _as_const(self) -> int: return self.value - def __repr__(self): + def __repr__(self) -> str: return "(const {}'{}d{})".format(self.nbits, "s" if self.signed else "", self.value) @@ -264,7 +271,7 @@ C = Const # shorthand class AnyValue(Value, DUID): - def __init__(self, shape): + def __init__(self, shape: Union[Shape, int]) -> None: super().__init__(src_loc_at=0) if isinstance(shape, int): shape = shape, False @@ -272,31 +279,31 @@ class AnyValue(Value, DUID): if not isinstance(self.nbits, int) or self.nbits < 0: raise TypeError("Width must be a non-negative integer, not '{!r}'", self.nbits) - def shape(self): + def shape(self) -> Shape: return self.nbits, self.signed - def _rhs_signals(self): + def _rhs_signals(self) -> 'ValueSet': return ValueSet() class AnyConst(AnyValue): - def __repr__(self): + def __repr__(self) -> str: return "(anyconst {}'{})".format(self.nbits, "s" if self.signed else "") class AnySeq(AnyValue): - def __repr__(self): + def __repr__(self) -> str: return "(anyseq {}'{})".format(self.nbits, "s" if self.signed else "") class Operator(Value): - def __init__(self, op, operands, src_loc_at=0): + def __init__(self, op: str, operands: Iterable[Valish], src_loc_at: int = 0) -> None: super().__init__(src_loc_at=1 + src_loc_at) self.op = op self.operands = [Value.wrap(o) for o in operands] @staticmethod - def _bitwise_binary_shape(a_shape, b_shape): + def _bitwise_binary_shape(a_shape: Shape, b_shape: Shape) -> Shape: a_bits, a_sign = a_shape b_bits, b_sign = b_shape if not a_sign and not b_sign: @@ -312,7 +319,7 @@ class Operator(Value): # first signed, second operand unsigned (add sign bit) return max(a_bits, b_bits + 1), True - def shape(self): + def shape(self) -> Shape: op_shapes = list(map(lambda x: x.shape(), self.operands)) if len(op_shapes) == 1: (a_bits, a_sign), = op_shapes @@ -357,8 +364,8 @@ class Operator(Value): raise NotImplementedError("Operator {}/{} not implemented" .format(self.op, len(op_shapes))) # :nocov: - def _rhs_signals(self): - return union(op._rhs_signals() for op in self.operands) + def _rhs_signals(self) -> 'ValueSet': + return ValueSet().union(op._rhs_signals() for op in self.operands) def __repr__(self): return "({} {})".format(self.op, " ".join(map(repr, self.operands))) @@ -384,7 +391,7 @@ def Mux(sel, val1, val0): class Slice(Value): - def __init__(self, value, start, end): + def __init__(self, value, start, end) -> None: if not isinstance(start, int): raise TypeError("Slice start must be an integer, not '{!r}'".format(start)) if not isinstance(end, int): @@ -421,7 +428,7 @@ class Slice(Value): class Part(Value): - def __init__(self, value, offset, width): + def __init__(self, value, offset, width) -> None: if not isinstance(width, int) or width < 0: raise TypeError("Part width must be a non-negative integer, not '{!r}'".format(width)) @@ -467,7 +474,7 @@ class Cat(Value): Value, inout Resulting ``Value`` obtained by concatentation. """ - def __init__(self, *args): + def __init__(self, *args) -> None: super().__init__() self.parts = [Value.wrap(v) for v in flatten(args)] @@ -899,7 +906,7 @@ class Statement: class Assign(Statement): - def __init__(self, lhs, rhs): + def __init__(self, lhs, rhs) -> None: self.lhs = Value.wrap(lhs) self.rhs = Value.wrap(rhs) @@ -1070,7 +1077,7 @@ class _MappedKeyDict(MutableMapping, _MappedKeyCollection): class _MappedKeySet(MutableSet, _MappedKeyCollection): - def __init__(self, elements=()): + def __init__(self, elements=()) -> None: self._storage = OrderedDict() for elem in elements: self.add(elem) diff --git a/nmigen/tools.py b/nmigen/tools.py index 25b6893..8eb4021 100644 --- a/nmigen/tools.py +++ b/nmigen/tools.py @@ -1,14 +1,17 @@ import contextlib import functools import warnings -from collections.abc import Iterable from contextlib import contextmanager +from typing import Iterable, Iterator, Generator, Optional, Callable, Union, MutableSet, AbstractSet, TypeVar, Any __all__ = ["flatten", "union", "log2_int", "bits_for", "deprecated"] -def flatten(i): +T = TypeVar('T') + + +def flatten(i: Iterable[Any]) -> Generator[Any, None, None]: for e in i: if isinstance(e, Iterable): yield from flatten(e) @@ -16,17 +19,15 @@ def flatten(i): yield e -def union(i, start=None): +TSet = TypeVar('TSet', bound=MutableSet[Any]) +def union(i: Iterable[TSet], start: TSet) -> Optional[TSet]: r = start for e in i: - if r is None: - r = e - else: - r |= e + r |= e return r -def log2_int(n, need_pow2=True): +def log2_int(n: int, need_pow2: bool = True) -> int: if n == 0: return 0 r = (n - 1).bit_length() @@ -35,7 +36,7 @@ def log2_int(n, need_pow2=True): return r -def bits_for(n, require_sign_bit=False): +def bits_for(n: int, require_sign_bit: bool = False) -> int: if n > 0: r = log2_int(n + 1, False) else: @@ -46,35 +47,38 @@ def bits_for(n, require_sign_bit=False): return r -def deprecated(message, stacklevel=2): - def decorator(f): +def deprecated(message: Union[str, Warning], stacklevel: int = 2) -> \ + Callable[[Callable[..., T]], Callable[..., T]]: + def decorator(f: Callable[..., T]) -> Callable[..., T]: @functools.wraps(f) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> T: warnings.warn(message, DeprecationWarning, stacklevel=stacklevel) return f(*args, **kwargs) return wrapper return decorator -def _ignore_deprecated(f=None): +# emily TODO +def _ignore_deprecated(f: Any = None) -> Any: if f is None: @contextlib.contextmanager - def context_like(): + def context_like() -> Generator[None, None, None]: with warnings.catch_warnings(): warnings.filterwarnings(action="ignore", category=DeprecationWarning) yield return context_like() else: @functools.wraps(f) - def decorator_like(*args, **kwargs): + def decorator_like(*args: Any, **kwargs: Any) -> None: with warnings.catch_warnings(): warnings.filterwarnings(action="ignore", category=DeprecationWarning) f(*args, **kwargs) return decorator_like -def extend(cls): - def decorator(f): +# emily TODO +def extend(cls: Any) -> Any: + def decorator(f: Any) -> None: if isinstance(f, property): name = f.fget.__name__ else: diff --git a/nmigen/tracer.py b/nmigen/tracer.py index 225ee1a..c4f2953 100644 --- a/nmigen/tracer.py +++ b/nmigen/tracer.py @@ -1,20 +1,28 @@ import traceback import inspect from opcode import opname +from typing import Optional, Tuple, NewType -__all__ = ["NameNotFound", "get_var_name", "get_src_loc"] +__all__ = ["NameNotFound", "get_var_name", "SrcLoc", "get_src_loc"] class NameNotFound(Exception): pass -def get_var_name(depth=2): +def get_var_name(depth: int = 2) -> Optional[str]: frame = inspect.currentframe() + + if frame is None: + return None + for _ in range(depth): frame = frame.f_back + if frame is None: + raise NameNotFound + code = frame.f_code call_index = frame.f_lasti call_opc = opname[code.co_code[call_index]] @@ -40,11 +48,13 @@ def get_var_name(depth=2): raise NameNotFound -def get_src_loc(src_loc_at=0): +SrcLoc = NewType('SrcLoc', Tuple[str, int]) + +def get_src_loc(src_loc_at: int = 0) -> SrcLoc: # n-th frame: get_src_loc() # n-1th frame: caller of get_src_loc() (usually constructor) # n-2th frame: caller of caller (usually user code) # Python returns the stack frames reversed, so it is enough to set limit and grab the very # first one in the array. tb = traceback.extract_stack(limit=3 + src_loc_at) - return (tb[0].filename, tb[0].lineno) + return SrcLoc((tb[0].filename, tb[0].lineno))