Source code for icontract._types

"""Define data structures shared among the modules."""

import enum
import inspect
import reprlib
from typing import (
    Callable,
    Optional,
    Union,
    Set,
    List,
    Any,
    Type,
    cast,
)  # pylint: disable=unused-import

import icontract._globals

from icontract._globals import ExceptionT


class Contract:
    """Represent a contract to be enforced as a precondition, postcondition or as an invariant."""

    def __init__(
        self,
        condition: Callable[..., Any],
        description: Optional[str] = None,
        a_repr: reprlib.Repr = icontract._globals.aRepr,
        error: Optional[
            Union[Callable[..., ExceptionT], Type[ExceptionT], BaseException]
        ] = None,
        location: Optional[str] = None,
    ) -> None:
        """
        Initialize.

        :param condition: condition predicate
        :param description: textual description of the contract
        :param a_repr: representation instance that defines how the values are represented
        :param error:
            if given as a callable, ``error`` is expected to accept a subset of function arguments
            (*e.g.*, also including ``result`` for preconditions, only ``self`` for invariants *etc.*) and return
            an exception. The ``error`` is called on contract violation and the resulting exception is raised.

            Otherwise, it is expected to denote an Exception class which is instantiated with the violation message
            and raised on contract violation.
        :param location: indicate where the contract was defined (*e.g.*, path and line number)
        """
        self.condition = condition

        signature = inspect.signature(condition)

        # All argument names of the condition
        self.condition_args = list(signature.parameters.keys())  # type: List[str]
        self.condition_arg_set = set(self.condition_args)  # type: Set[str]

        # Names of the mandatory arguments of the condition
        self.mandatory_args = [
            name
            for name, param in signature.parameters.items()
            if param.default == inspect.Parameter.empty
        ]

        self.description = description
        self._a_repr = a_repr

        self.error = error
        self.error_args = None  # type: Optional[List[str]]
        self.error_arg_set = None  # type: Optional[Set[str]]
        if error is not None and (inspect.isfunction(error) or inspect.ismethod(error)):
            error_as_callable = cast(Callable[..., ExceptionT], error)
            self.error_args = list(
                inspect.signature(error_as_callable).parameters.keys()
            )
            self.error_arg_set = set(self.error_args)

        self.location = location


class Snapshot:
    """Define a snapshot of an argument *prior* to the function invocation that is later supplied to a postcondition."""

    def __init__(
        self,
        capture: Callable[..., Any],
        name: Optional[str] = None,
        location: Optional[str] = None,
    ) -> None:
        """
        Initialize.

        :param capture:
            function to capture the snapshot accepting a single argument (from a set of arguments
            of the original function)

        :param name: name of the captured variable in OLD that is passed to postconditions
        :param location: indicate where the snapshot was defined (*e.g.*, path and line number)

        """
        self.capture = capture

        args = list(inspect.signature(capture).parameters.keys())  # type: List[str]

        if name is None:
            if len(args) == 0:
                raise ValueError(
                    "You must name a snapshot if no argument was given in the capture function."
                )
            elif len(args) > 1:
                raise ValueError(
                    "You must name a snapshot if multiple arguments were given in the capture function."
                )
            else:
                assert len(args) == 1
                name = args[0]

        assert (
            name is not None
        ), "Expected ``name`` to be set in the preceding execution paths."
        self.name = name

        self.args = args
        self.arg_set = set(args)

        self.location = location


[docs]class InvariantCheckEvent(enum.Flag): """Define when an invariant should be checked.""" #: Evaluate the invariant before and after all calls to a method. CALL = enum.auto() #: Evaluate the invariant before and after all the calls to ``__setattr__``. SETATTR = enum.auto() #: Always evaluate the invariant, *i.e., both on calls and on attributes set. ALL = CALL | SETATTR
class Invariant(Contract): """Represent a contract which is checked on all or some of the class operations.""" # NOTE (mristin): # The class ``Invariant`` inherits from ``Contract`` so that we can maintain # the backwards compatibility with the integrators after introducing # the ``check_on`` feature. def __init__( self, check_on: InvariantCheckEvent, condition: Callable[..., Any], description: Optional[str] = None, a_repr: reprlib.Repr = icontract._globals.aRepr, error: Optional[ Union[Callable[..., ExceptionT], Type[ExceptionT], BaseException] ] = None, location: Optional[str] = None, ) -> None: """Initialize with the given values.""" assert not hasattr(self, "check_on") self.check_on = check_on super().__init__( condition=condition, description=description, a_repr=a_repr, error=error, location=location, )