Skip to content

GreatAI reference#

from great_ai import *

Core#

GreatAI #

Bases: Generic[T, V]

Wrapper for a prediction function providing the implementation of SE4ML best practices.

Provides caching (with argument freezing), a TracingContext during execution, the scaffolding of HTTP endpoints using FastAPI and a dashboard using Dash.

IMPORTANT: when a request is served from cache, no new trace is created. Thus, the same trace can be returned multiple times. If this is undesirable turn off caching using configure(prediction_cache_size=0).

Supports wrapping async and synchronous functions while also maintaining correct typing.

Attributes:

Name Type Description
app

FastAPI instance wrapping the scaffolded endpoints and the Dash app.

version

SemVer derived from the app's version and the model names and versions registered through use_model.

Source code in great_ai/deploy/great_ai.py
class GreatAI(Generic[T, V]):
    """Wrapper for a prediction function providing the implementation of SE4ML best practices.

    Provides caching (with argument freezing), a TracingContext during execution, the
    scaffolding of HTTP endpoints using FastAPI and a dashboard using Dash.

    IMPORTANT: when a request is served from cache, no new trace is created. Thus, the
    same trace can be returned multiple times. If this is undesirable turn off caching
    using `configure(prediction_cache_size=0)`.

    Supports wrapping async and synchronous functions while also maintaining correct
    typing.

    Attributes:
        app: FastAPI instance wrapping the scaffolded endpoints and the Dash app.
        version: SemVer derived from the app's version and the model names and versions
            registered through use_model.
    """

    __name__: str  # help for MyPy
    __doc__: str  # help for MyPy

    def __init__(
        self,
        func: Callable[..., Union[V, Awaitable[V]]],
    ):
        """Do not call this function directly, use GreatAI.create instead."""

        func = automatically_decorate_parameters(func)
        get_function_metadata_store(func).is_finalised = True

        self._cached_func = self._get_cached_traced_function(func)
        self._wrapped_func = wraps(func)(freeze_arguments(self._cached_func))

        wraps(func)(self)
        self.__doc__ = (
            f"GreatAI wrapper for interacting with the `{self.__name__}` "
            + f"function.\n\n{dedent(self.__doc__ or '')}"
        )

        self.version = str(get_context().version)
        flat_model_versions = ".".join(f"{k}-v{v}" for k, v in model_versions)
        if flat_model_versions:
            self.version += f"+{flat_model_versions}"

        self.app = FastAPI(
            title=snake_case_to_text(self.__name__),
            version=self.version,
            description=self.__doc__
            + f"\n\nFind out more in the [dashboard]({DASHBOARD_PATH}).",
            docs_url=None,
            redoc_url=None,
        )

        self._bootstrap_rest_api()

    @overload
    @staticmethod
    def create(  # type: ignore
        # Overloaded function signatures 1 and 2 overlap with incompatible return types
        # https://github.com/python/mypy/issues/12759
        func: Callable[..., Awaitable[V]],
    ) -> "GreatAI[Awaitable[Trace[V]], V]":
        ...

    @overload
    @staticmethod
    def create(
        func: Callable[..., V],
    ) -> "GreatAI[Trace[V], V]":
        ...

    @staticmethod
    def create(
        func: Union[Callable[..., Awaitable[V]], Callable[..., V]],
    ) -> Union["GreatAI[Awaitable[Trace[V]], V]", "GreatAI[Trace[V], V]"]:
        """Decorate a function by wrapping it in a GreatAI instance.

        The function can be typed, synchronous or async. If it has
        unwrapped parameters (parameters not affected by a
        [@parameter][great_ai.parameter] or [@use_model][great_ai.use_model] decorator),
        those will be automatically wrapped.

        The return value is replaced by a Trace (or Awaitable[Trace]),
        while the original return value is available under the `.output`
        property.

        For configuration options, see [great_ai.configure][].

        Examples:
            >>> @GreatAI.create
            ... def my_function(a):
            ...     return a + 2
            >>> my_function(3).output
            5

            >>> @GreatAI.create
            ... def my_function(a: int) -> int:
            ...     return a + 2
            >>> my_function(3)
            Trace[int]...

            >>> my_function('3').output
            Traceback (most recent call last):
                ...
            typeguard.TypeCheckError: str is not an instance of int

        Args:
            func: The prediction function that needs to be decorated.

        Returns:
            A GreatAI instance wrapping `func`.
        """

        return GreatAI[Trace[V], V](
            func,
        )

    def __call__(self, *args: Any, **kwargs: Any) -> T:
        return self._wrapped_func(*args, **kwargs)

    @overload
    def process_batch(
        self,
        batch: Sequence[Tuple],
        *,
        concurrency: Optional[int] = None,
        unpack_arguments: Literal[True],
        do_not_persist_traces: bool = ...,
    ) -> List[Trace[V]]:
        ...

    @overload
    def process_batch(
        self,
        batch: Sequence,
        *,
        concurrency: Optional[int] = None,
        unpack_arguments: Literal[False] = ...,
        do_not_persist_traces: bool = ...,
    ) -> List[Trace[V]]:
        ...

    def process_batch(
        self,
        batch: Sequence,
        *,
        concurrency: Optional[int] = None,
        unpack_arguments: bool = False,
        do_not_persist_traces: bool = False,
    ) -> List[Trace[V]]:
        """Map the wrapped function over a list of input_values (`batch`).

        A wrapper over [parallel_map][great_ai.utilities.parallel_map.parallel_map.parallel_map]
        providing type-safety and a progressbar through tqdm.

        Args:
            batch: A list of arguments for the original (wrapped) function. If the
                function expects multiple arguments, provide a list of tuples and set
                `unpack_arguments=True`.
            concurrency: Number of processes to start. Don't set it too much higher than
                the number of available CPU cores.
            unpack_arguments: Expect a list of tuples and unpack the tuples before
                giving them to the wrapped function.
            do_not_persist_traces: Don't save the traces in the database. Useful for
                evaluations run part of the CI.
        """

        wrapped_function = self._wrapped_func

        def inner(value: Any) -> T:
            return (
                wrapped_function(*value, do_not_persist_traces=do_not_persist_traces)
                if unpack_arguments
                else wrapped_function(
                    value, do_not_persist_traces=do_not_persist_traces
                )
            )

        async def inner_async(value: Any) -> T:
            return await cast(
                Awaitable,
                (
                    wrapped_function(
                        *value, do_not_persist_traces=do_not_persist_traces
                    )
                    if unpack_arguments
                    else wrapped_function(
                        value, do_not_persist_traces=do_not_persist_traces
                    )
                ),
            )

        return list(
            tqdm(
                parallel_map(
                    inner_async
                    if get_function_metadata_store(self).is_asynchronous
                    else inner,
                    batch,
                    concurrency=concurrency,
                ),
                total=len(batch),
            )
        )

    @staticmethod
    def _get_cached_traced_function(
        func: Callable[..., Union[V, Awaitable[V]]]
    ) -> Callable[..., T]:
        @lru_cache(maxsize=get_context().prediction_cache_size)
        def func_in_tracing_context_sync(
            *args: Any,
            do_not_persist_traces: bool = False,
            **kwargs: Any,
        ) -> T:
            with TracingContext[V](
                func.__name__, do_not_persist_traces=do_not_persist_traces
            ) as t:
                result = func(*args, **kwargs)
                return cast(T, t.finalise(output=result))

        @alru_cache(maxsize=get_context().prediction_cache_size)
        async def func_in_tracing_context_async(
            *args: Any,
            do_not_persist_traces: bool = False,
            **kwargs: Any,
        ) -> T:
            with TracingContext[V](
                func.__name__, do_not_persist_traces=do_not_persist_traces
            ) as t:
                result = await cast(Callable[..., Awaitable], func)(*args, **kwargs)
                return cast(T, t.finalise(output=result))

        return cast(
            Callable[..., T],
            (
                func_in_tracing_context_async
                if get_function_metadata_store(func).is_asynchronous
                else func_in_tracing_context_sync
            ),
        )

    def _bootstrap_rest_api(
        self,
    ) -> None:
        route_config = get_context().route_config

        if route_config.prediction_endpoint_enabled:
            bootstrap_prediction_endpoint(self.app, self._wrapped_func)

        if route_config.docs_endpoints_enabled:
            bootstrap_docs_endpoints(self.app)

        if route_config.dashboard_enabled:
            bootstrap_dashboard(
                self.app,
                function_name=self.__name__,
                documentation=self.__doc__,
            )

        if route_config.trace_endpoints_enabled:
            bootstrap_trace_endpoints(self.app)

        if route_config.feedback_endpoints_enabled:
            bootstrap_feedback_endpoints(self.app)

        if route_config.meta_endpoints_enabled:
            bootstrap_meta_endpoints(
                self.app,
                self._cached_func,
                ApiMetadata(
                    name=self.__name__,
                    version=self.version,
                    documentation=self.__doc__,
                    configuration=get_context().to_flat_dict(),
                ),
            )

create(func) staticmethod #

Decorate a function by wrapping it in a GreatAI instance.

The function can be typed, synchronous or async. If it has unwrapped parameters (parameters not affected by a @parameter or @use_model decorator), those will be automatically wrapped.

The return value is replaced by a Trace (or Awaitable[Trace]), while the original return value is available under the .output property.

For configuration options, see great_ai.configure.

Examples:

>>> @GreatAI.create
... def my_function(a):
...     return a + 2
>>> my_function(3).output
5
>>> @GreatAI.create
... def my_function(a: int) -> int:
...     return a + 2
>>> my_function(3)
Trace[int]...
>>> my_function('3').output
Traceback (most recent call last):
    ...
typeguard.TypeCheckError: str is not an instance of int

Parameters:

Name Type Description Default
func Union[Callable[..., Awaitable[V]], Callable[..., V]]

The prediction function that needs to be decorated.

required

Returns:

Type Description
Union[GreatAI[Awaitable[Trace[V]], V], GreatAI[Trace[V], V]]

A GreatAI instance wrapping func.

Source code in great_ai/deploy/great_ai.py
@staticmethod
def create(
    func: Union[Callable[..., Awaitable[V]], Callable[..., V]],
) -> Union["GreatAI[Awaitable[Trace[V]], V]", "GreatAI[Trace[V], V]"]:
    """Decorate a function by wrapping it in a GreatAI instance.

    The function can be typed, synchronous or async. If it has
    unwrapped parameters (parameters not affected by a
    [@parameter][great_ai.parameter] or [@use_model][great_ai.use_model] decorator),
    those will be automatically wrapped.

    The return value is replaced by a Trace (or Awaitable[Trace]),
    while the original return value is available under the `.output`
    property.

    For configuration options, see [great_ai.configure][].

    Examples:
        >>> @GreatAI.create
        ... def my_function(a):
        ...     return a + 2
        >>> my_function(3).output
        5

        >>> @GreatAI.create
        ... def my_function(a: int) -> int:
        ...     return a + 2
        >>> my_function(3)
        Trace[int]...

        >>> my_function('3').output
        Traceback (most recent call last):
            ...
        typeguard.TypeCheckError: str is not an instance of int

    Args:
        func: The prediction function that needs to be decorated.

    Returns:
        A GreatAI instance wrapping `func`.
    """

    return GreatAI[Trace[V], V](
        func,
    )

process_batch(batch, *, concurrency=None, unpack_arguments=False, do_not_persist_traces=False) #

Map the wrapped function over a list of input_values (batch).

A wrapper over parallel_map providing type-safety and a progressbar through tqdm.

Parameters:

Name Type Description Default
batch Sequence

A list of arguments for the original (wrapped) function. If the function expects multiple arguments, provide a list of tuples and set unpack_arguments=True.

required
concurrency Optional[int]

Number of processes to start. Don't set it too much higher than the number of available CPU cores.

None
unpack_arguments bool

Expect a list of tuples and unpack the tuples before giving them to the wrapped function.

False
do_not_persist_traces bool

Don't save the traces in the database. Useful for evaluations run part of the CI.

False
Source code in great_ai/deploy/great_ai.py
def process_batch(
    self,
    batch: Sequence,
    *,
    concurrency: Optional[int] = None,
    unpack_arguments: bool = False,
    do_not_persist_traces: bool = False,
) -> List[Trace[V]]:
    """Map the wrapped function over a list of input_values (`batch`).

    A wrapper over [parallel_map][great_ai.utilities.parallel_map.parallel_map.parallel_map]
    providing type-safety and a progressbar through tqdm.

    Args:
        batch: A list of arguments for the original (wrapped) function. If the
            function expects multiple arguments, provide a list of tuples and set
            `unpack_arguments=True`.
        concurrency: Number of processes to start. Don't set it too much higher than
            the number of available CPU cores.
        unpack_arguments: Expect a list of tuples and unpack the tuples before
            giving them to the wrapped function.
        do_not_persist_traces: Don't save the traces in the database. Useful for
            evaluations run part of the CI.
    """

    wrapped_function = self._wrapped_func

    def inner(value: Any) -> T:
        return (
            wrapped_function(*value, do_not_persist_traces=do_not_persist_traces)
            if unpack_arguments
            else wrapped_function(
                value, do_not_persist_traces=do_not_persist_traces
            )
        )

    async def inner_async(value: Any) -> T:
        return await cast(
            Awaitable,
            (
                wrapped_function(
                    *value, do_not_persist_traces=do_not_persist_traces
                )
                if unpack_arguments
                else wrapped_function(
                    value, do_not_persist_traces=do_not_persist_traces
                )
            ),
        )

    return list(
        tqdm(
            parallel_map(
                inner_async
                if get_function_metadata_store(self).is_asynchronous
                else inner,
                batch,
                concurrency=concurrency,
            ),
            total=len(batch),
        )
    )

configure(*, version='0.0.1', log_level=DEBUG, seed=42, tracing_database_factory=None, large_file_implementation=None, should_log_exception_stack=None, prediction_cache_size=512, disable_se4ml_banner=False, dashboard_table_size=50, route_config=RouteConfig()) #

Set the global configuration used by the great-ai library.

You must call configure before calling (or decorating with) any other great-ai function.

If tracing_database_factory or large_file_implementation is not specified, their default value is determined based on which TracingDatabase and LargeFile has been configured (e.g.: LargeFileS3.configure_credentials_from_file('s3.ini')), or whether there is any file named s3.ini or mongo.ini in the working directory.

Examples:

>>> configure(prediction_cache_size=0)

Parameters:

Name Type Description Default
version Union[int, str]

The version of your application (using SemVer is recommended).

'0.0.1'
log_level int

Set the default logging level of logging.

DEBUG
seed int

Set seed of random (and numpy if installed) for reproducibility.

42
tracing_database_factory Optional[Type[TracingDatabaseDriver]]

Specify a different TracingDatabaseDriver than the one already configured.

None
large_file_implementation Optional[Type[LargeFileBase]]

Specify a different LargeFile than the one already configured.

None
should_log_exception_stack Optional[bool]

Log the traces of unhandled exceptions.

None
prediction_cache_size int

Size of the LRU cache applied over the prediction functions.

512
disable_se4ml_banner bool

Turn off the warning about the importance of SE4ML best- practices.

False
dashboard_table_size int

Number of rows to display in the dashboard's table.

50
route_config RouteConfig

Enable or disable specific HTTP API endpoints.

RouteConfig()
Source code in great_ai/context.py
def configure(
    *,
    version: Union[int, str] = "0.0.1",
    log_level: int = DEBUG,
    seed: int = 42,
    tracing_database_factory: Optional[Type[TracingDatabaseDriver]] = None,
    large_file_implementation: Optional[Type[LargeFileBase]] = None,
    should_log_exception_stack: Optional[bool] = None,
    prediction_cache_size: int = 512,
    disable_se4ml_banner: bool = False,
    dashboard_table_size: int = 50,
    route_config: RouteConfig = RouteConfig(),
) -> None:
    """Set the global configuration used by the great-ai library.

    You must call `configure` before calling (or decorating with) any other great-ai
    function.

    If `tracing_database_factory` or `large_file_implementation` is not specified, their
    default value is determined based on which TracingDatabase and LargeFile has been
    configured (e.g.: LargeFileS3.configure_credentials_from_file('s3.ini')), or whether
    there is any file named s3.ini or mongo.ini in the working directory.

    Examples:
        >>> configure(prediction_cache_size=0)

    Arguments:
        version: The version of your application (using SemVer is recommended).
        log_level: Set the default logging level of `logging`.
        seed: Set seed of `random` (and `numpy` if installed) for reproducibility.
        tracing_database_factory: Specify a different TracingDatabaseDriver than the one
            already configured.
        large_file_implementation: Specify a different LargeFile than the one already
            configured.
        should_log_exception_stack: Log the traces of unhandled exceptions.
        prediction_cache_size: Size of the LRU cache applied over the prediction
            functions.
        disable_se4ml_banner: Turn off the warning about the importance of SE4ML best-
            practices.
        dashboard_table_size: Number of rows to display in the dashboard's table.
        route_config: Enable or disable specific HTTP API endpoints.
    """

    global _context
    logger = get_logger("great_ai", level=log_level)

    if _context is not None:
        logger.error(
            "Configuration has been already initialised, overwriting.\n"
            + "Make sure to call `configure()` before importing your application code."
        )

    is_production = _is_in_production_mode(logger=logger)

    _set_seed(seed)

    tracing_database_factory = _initialize_tracing_database(
        tracing_database_factory, logger=logger
    )
    tracing_database = tracing_database_factory()

    if not tracing_database.is_production_ready:
        message = f"""The selected tracing database ({
            tracing_database_factory.__name__
        }) is not recommended for production"""

        if is_production:
            logger.error(message)
        else:
            logger.warning(message)

    _context = Context(
        version=version,
        tracing_database=tracing_database,
        large_file_implementation=_initialize_large_file(
            large_file_implementation, logger=logger
        ),
        is_production=is_production,
        logger=logger,
        should_log_exception_stack=not is_production
        if should_log_exception_stack is None
        else should_log_exception_stack,
        prediction_cache_size=prediction_cache_size,
        dashboard_table_size=dashboard_table_size,
        route_config=route_config,
    )

    logger.info(f"GreatAI (v{__version__}): configured ✅")
    for k, v in get_context().to_flat_dict().items():
        logger.info(f"{LIST_ITEM_PREFIX}{k}: {v}")

    if not is_production and not disable_se4ml_banner:
        logger.warning(
            "You still need to check whether you follow all best practices before "
            "trusting your deployment."
        )
        logger.warning(f"> Find out more at {SE4ML_WEBSITE}")

save_model(model, key, *, keep_last_n=None) #

Save (and optionally serialise) a model in order to use by use_model.

The model can be a Path or string representing a path in which case the local file/folder is read and saved using the current LargeFile implementation. In case model is an object, it is serialised using dill before uploading it.

Examples:

>>> from great_ai import use_model
>>> save_model(3, 'my_number')
'my_number:...'
>>> @use_model('my_number')
... def my_function(a, model):
...     return a + model
>>> my_function(4)
7

Parameters:

Name Type Description Default
model Union[Path, str, object]

The object or path to be uploaded.

required
key str

The model's name.

required
keep_last_n Optional[int]

If specified, remove old models and only keep the latest n. Directly passed to LargeFile.

None

Returns: The key and version of the saved model separated by a colon. Example: "key:version"

Source code in great_ai/models/save_model.py
def save_model(
    model: Union[Path, str, object], key: str, *, keep_last_n: Optional[int] = None
) -> str:
    """Save (and optionally serialise) a model in order to use by `use_model`.

    The `model` can be a Path or string representing a path in which case the
    local file/folder is read and saved using the current LargeFile implementation.
    In case `model` is an object, it is serialised using `dill` before uploading it.

    Examples:
            >>> from great_ai import use_model
            >>> save_model(3, 'my_number')
            'my_number:...'

            >>> @use_model('my_number')
            ... def my_function(a, model):
            ...     return a + model
            >>> my_function(4)
            7

    Args:
        model: The object or path to be uploaded.
        key: The model's name.
        keep_last_n: If specified, remove old models and only keep the latest n. Directly passed to LargeFile.
    Returns:
        The key and version of the saved model separated by a colon. Example: "key:version"
    """
    file = get_context().large_file_implementation(
        name=key, mode="wb", keep_last_n=keep_last_n
    )

    if isinstance(model, Path) or isinstance(model, str):
        file.push(model)
    else:
        with file as f:
            dump(model, f)

    get_context().logger.info(f"Model {key} uploaded with version {file.version}")

    return f"{key}:{file.version}"

use_model(key, *, version='latest', model_kwarg_name='model') #

Inject a model into a function.

Load a model specified by key and version using the currently active LargeFile implementation. If it's a single object, it is deserialised using dill. If it's a directory of files, a pathlib.Path instance is given.

By default, the function's model parameter is replaced by the loaded model. This can be customised by changing model_kwarg_name. Multiple models can be loaded by decorating the same function with use_model multiple times.

Examples:

>>> from great_ai import save_model
>>> save_model(3, 'my_number')
'my_number:...'
>>> @use_model('my_number')
... def my_function(a, model):
...     return a + model
>>> my_function(4)
7

Parameters:

Name Type Description Default
key str

The model's name as stored by the LargeFile implementation.

required
version Union[int, Literal['latest']]

The model's version as stored by the LargeFile implementation.

'latest'
model_kwarg_name str

the parameter to use for injecting the loaded model

'model'

Returns: A decorator for model injection.

Source code in great_ai/models/use_model.py
def use_model(
    key: str,
    *,
    version: Union[int, Literal["latest"]] = "latest",
    model_kwarg_name: str = "model",
) -> Callable[[F], F]:
    """Inject a model into a function.

    Load a model specified by `key` and `version` using the currently active `LargeFile`
    implementation. If it's a single object, it is deserialised using `dill`. If it's a
    directory of files, a `pathlib.Path` instance is given.

    By default, the function's `model` parameter is replaced by the loaded model. This
    can be customised by changing `model_kwarg_name`. Multiple models can be loaded by
    decorating the same function with `use_model` multiple times.

    Examples:
            >>> from great_ai import save_model
            >>> save_model(3, 'my_number')
            'my_number:...'
            >>> @use_model('my_number')
            ... def my_function(a, model):
            ...     return a + model
            >>> my_function(4)
            7

    Args:
        key: The model's name as stored by the LargeFile implementation.
        version: The model's version as stored by the LargeFile implementation.
        model_kwarg_name: the parameter to use for injecting the loaded model
    Returns:
        A decorator for model injection.
    """

    assert (
        isinstance(version, int) or version == "latest"
    ), "Only integers or the string literal `latest` is allowed as a version"

    model, actual_version = _load_model(
        key=key,
        version=None if version == "latest" else version,
    )

    def decorator(func: F) -> F:
        assert_function_is_not_finalised(func)

        store = get_function_metadata_store(func)
        store.model_parameter_names.append(model_kwarg_name)

        @wraps(func)
        def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
            tracing_context = TracingContext.get_current_tracing_context()
            if tracing_context:
                tracing_context.log_model(Model(key=key, version=actual_version))
            return func(*args, **kwargs, **{model_kwarg_name: model})

        return cast(F, wrapper)

    return decorator

parameter(parameter_name, *, validate=lambda : True, disable_logging=False) #

Control the validation and logging of function parameters.

Basically, a parameter decorator. Unfortunately, Python does not have that concept, thus, it's a method decorator that expects the name of the to-be-decorated parameter.

Examples:

>>> @parameter('a')
... def my_function(a: int):
...     return a + 2
>>> my_function(4)
6
>>> my_function('3')
Traceback (most recent call last):
    ...
typeguard.TypeCheckError: str is not an instance of int
>>> @parameter('positive_a', validate=lambda v: v > 0)
... def my_function(positive_a: int):
...     return a + 2
>>> my_function(-1)
Traceback (most recent call last):
    ...
great_ai.errors.argument_validation_error.ArgumentValidationError: ...

Parameters:

Name Type Description Default
parameter_name str

Name of parameter to consider.

required
validate Callable[[Any], bool]

Optional validate to run against the concrete argument. ArgumentValidationError is thrown when the return value is False.

lambda : True
disable_logging bool

Do not save the value in any active TracingContext.

False

Returns: A decorator for argument validation.

Source code in great_ai/parameters/parameter.py
def parameter(
    parameter_name: str,
    *,
    validate: Callable[[Any], bool] = lambda _: True,
    disable_logging: bool = False,
) -> Callable[[F], F]:
    """Control the validation and logging of function parameters.

    Basically, a parameter decorator. Unfortunately, Python does not have that concept,
    thus, it's a method decorator that expects the name of the to-be-decorated
    parameter.

    Examples:
        >>> @parameter('a')
        ... def my_function(a: int):
        ...     return a + 2
        >>> my_function(4)
        6
        >>> my_function('3')
        Traceback (most recent call last):
            ...
        typeguard.TypeCheckError: str is not an instance of int

        >>> @parameter('positive_a', validate=lambda v: v > 0)
        ... def my_function(positive_a: int):
        ...     return a + 2
        >>> my_function(-1)
        Traceback (most recent call last):
            ...
        great_ai.errors.argument_validation_error.ArgumentValidationError: ...

    Args:
        parameter_name: Name of parameter to consider.
        validate: Optional validate to run against the concrete argument.
            ArgumentValidationError is thrown when the return value is False.
        disable_logging: Do not save the value in any active TracingContext.
    Returns:
        A decorator for argument validation.
    """

    def decorator(func: F) -> F:
        get_function_metadata_store(func).input_parameter_names.append(parameter_name)
        assert_function_is_not_finalised(func)

        actual_name = f"arg:{parameter_name}"

        @wraps(func)
        def wrapper(*args: Any, **kwargs: Dict[str, Any]) -> Any:
            arguments = get_arguments(func, args, kwargs)
            argument = arguments.get(parameter_name)

            expected_type = func.__annotations__.get(parameter_name)

            if expected_type is not None:
                check_type(argument, expected_type)

            if not validate(argument):
                raise ArgumentValidationError(
                    f"""Argument {parameter_name} in {
                        func.__name__
                    } did not pass validation"""
                )

            context = TracingContext.get_current_tracing_context()
            if context and not disable_logging:
                context.log_value(name=f"{actual_name}:value", value=argument)
                if isinstance(argument, str):
                    context.log_value(name=f"{actual_name}:length", value=len(argument))

            return func(*args, **kwargs)

        return cast(F, wrapper)

    return decorator

log_metric(argument_name, value, disable_logging=False) #

Log a key (argument_name)-value pair that is persisted inside the trace.

The name of the function from where this is called is also stored.

Parameters:

Name Type Description Default
argument_name str

The key for storing the value.

required
value Any

Value to log. Must be JSON-serialisable.

required
disable_logging bool

If True, only persist in trace but don't show in console

False
Source code in great_ai/parameters/log_metric.py
def log_metric(argument_name: str, value: Any, disable_logging: bool = False) -> None:
    """Log a key (argument_name)-value pair that is persisted inside the trace.

    The name of the function from where this is called is also stored.

    Args:
        argument_name: The key for storing the value.
        value: Value to log. Must be JSON-serialisable.
        disable_logging: If True, only persist in trace but don't show in console
    """

    tracing_context = TracingContext.get_current_tracing_context()
    try:
        caller = inspect.stack()[1].function
        actual_name = f"metric:{caller}:{argument_name}"
    except:
        # inspect might not work in notebooks
        actual_name = f"metric:{argument_name}"

    if tracing_context:
        tracing_context.log_value(name=actual_name, value=value)

    if not disable_logging:
        get_context().logger.info(f"{actual_name}={value}")

Remote calls#

call_remote_great_ai(base_uri, data, retry_count=4, timeout_in_seconds=300, model_class=None) #

Communicate with a GreatAI object through an HTTP request.

Wrapper over call_remote_great_ai_async making it synchronous. For more info, see call_remote_great_ai_async.

Source code in great_ai/remote/call_remote_great_ai.py
def call_remote_great_ai(
    base_uri: str,
    data: Mapping[str, Any],
    retry_count: int = 4,
    timeout_in_seconds: Optional[int] = 300,
    model_class: Optional[Type[T]] = None,
) -> Trace[T]:
    """Communicate with a GreatAI object through an HTTP request.

    Wrapper over `call_remote_great_ai_async` making it synchronous. For more info, see
    `call_remote_great_ai_async`.
    """
    try:
        asyncio.get_running_loop()
        raise Exception(
            f"Already running in an event loop, you have to call `{call_remote_great_ai_async.__name__}`"
        )
    except RuntimeError:
        pass

    future = call_remote_great_ai_async(
        base_uri=base_uri,
        data=data,
        retry_count=retry_count,
        timeout_in_seconds=timeout_in_seconds,
        model_class=model_class,
    )

    return asyncio.run(future)

call_remote_great_ai_async(base_uri, data, retry_count=4, timeout_in_seconds=300, model_class=None) async #

Communicate with a GreatAI object through an HTTP request.

Send a POST request using httpx to implement a remote call. Error-handling and retries are provided by httpx.

The return value is inflated into a Trace. If model_class is specified, the original output is deserialised.

Parameters:

Name Type Description Default
base_uri str

Address of the remote instance, example: 'http://localhost:6060'

required
data Mapping[str, Any]

The input sent as a json to the remote instance.

required
retry_count int

Retry on any HTTP communication failure.

4
timeout_in_seconds Optional[int]

Overall permissible max length of the request. None means no timeout.

300
model_class Optional[Type[T]]

A subtype of BaseModel to be used for deserialising the .output of the trace.

None
Source code in great_ai/remote/call_remote_great_ai_async.py
async def call_remote_great_ai_async(
    base_uri: str,
    data: Mapping[str, Any],
    retry_count: int = 4,
    timeout_in_seconds: Optional[int] = 300,
    model_class: Optional[Type[T]] = None,
) -> Trace[T]:
    """Communicate with a GreatAI object through an HTTP request.

    Send a POST request using [httpx](https://www.python-httpx.org/) to implement a
    remote call. Error-handling and retries are provided by httpx.

    The return value is inflated into a Trace. If `model_class` is specified, the
    original output is deserialised.

    Args:
        base_uri: Address of the remote instance, example: 'http://localhost:6060'
        data: The input sent as a json to the remote instance.
        retry_count: Retry on any HTTP communication failure.
        timeout_in_seconds: Overall permissible max length of the request. `None` means
            no timeout.
        model_class: A subtype of BaseModel to be used for deserialising the `.output`
            of the trace.
    """

    if base_uri.endswith("/"):
        base_uri = base_uri[:-1]

    if not base_uri.endswith("/predict"):
        base_uri = f"{base_uri}/predict"

    transport = httpx.AsyncHTTPTransport(retries=retry_count)

    try:
        async with httpx.AsyncClient(
            transport=transport, timeout=timeout_in_seconds
        ) as client:
            response = await client.post(base_uri, json=data)
            try:
                response.raise_for_status()
            except Exception:
                raise RemoteCallError(
                    f"Unexpected status code, reason: {response.text}"
                )
    except Exception as e:
        raise RemoteCallError from e

    try:
        trace = response.json()
    except Exception:
        raise RemoteCallError(
            f"JSON parsing failed {response.text}",
        )
    try:
        if model_class is not None:
            trace["output"] = model_class.parse_obj(trace["output"])
        return Trace.parse_obj(trace)
    except Exception:
        raise RemoteCallError("Could not parse Trace")

Ground-truth#

add_ground_truth(inputs, expected_outputs, *, tags=[], train_split_ratio=1, test_split_ratio=0, validation_split_ratio=0) #

Add training data (with optional train-test splitting).

Add and tag data-points, wrap them into traces. The inputs are available via the .input property, while expected_outputs under both the .output and .feedback properties.

All generated traces are tagged with ground_truth by default. Additional tags can be also provided. Using the split_ratio arguments, tags can be given randomly with a user-defined distribution. Only the ratio of the splits matter, they don't have to add up to 1.

Examples:

>>> add_ground_truth(
...    [1, 2, 3],
...    ['odd', 'even', 'odd'],
...    tags='my_tag',
...    train_split_ratio=1,
...    test_split_ratio=1,
...    validation_split_ratio=0.5,
... )
>>> add_ground_truth(
...    [1, 2],
...    ['odd', 'even', 'odd'],
...    tags='my_tag',
...    train_split_ratio=1,
...    test_split_ratio=1,
...    validation_split_ratio=0.5,
... )
Traceback (most recent call last):
    ...
AssertionError: The length of the inputs and expected_outputs must be equal

Parameters:

Name Type Description Default
inputs Iterable[Any]

The inputs. (X in scikit-learn)

required
expected_outputs Iterable[T]

The ground-truth values corresponding to the inputs. (y in scikit-learn)

required
tags Union[List[str], str]

A single tag or a list of tags to append to each generated trace's tags.

[]
train_split_ratio float

The probability-weight of giving each trace the train tag.

1
test_split_ratio float

The probability-weight of giving each trace the test tag.

0
validation_split_ratio float

The probability-weight of giving each trace the validation tag.

0
Source code in great_ai/tracing/add_ground_truth.py
def add_ground_truth(
    inputs: Iterable[Any],
    expected_outputs: Iterable[T],
    *,
    tags: Union[List[str], str] = [],
    train_split_ratio: float = 1,
    test_split_ratio: float = 0,
    validation_split_ratio: float = 0,
) -> None:
    """Add training data (with optional train-test splitting).

    Add and tag data-points, wrap them into traces. The `inputs` are available via the
    `.input` property, while `expected_outputs` under both the `.output` and `.feedback`
    properties.

    All generated traces are tagged with `ground_truth` by default. Additional tags can
    be also provided. Using the `split_ratio` arguments, tags can be given randomly with
    a user-defined distribution. Only the ratio of the splits matter, they don't have to
    add up to 1.

    Examples:
        >>> add_ground_truth(
        ...    [1, 2, 3],
        ...    ['odd', 'even', 'odd'],
        ...    tags='my_tag',
        ...    train_split_ratio=1,
        ...    test_split_ratio=1,
        ...    validation_split_ratio=0.5,
        ... )

        >>> add_ground_truth(
        ...    [1, 2],
        ...    ['odd', 'even', 'odd'],
        ...    tags='my_tag',
        ...    train_split_ratio=1,
        ...    test_split_ratio=1,
        ...    validation_split_ratio=0.5,
        ... )
        Traceback (most recent call last):
            ...
        AssertionError: The length of the inputs and expected_outputs must be equal

    Args:
        inputs: The inputs. (X in scikit-learn)
        expected_outputs: The ground-truth values corresponding to the inputs. (y in
            scikit-learn)
        tags: A single tag or a list of tags to append to each generated trace's tags.
        train_split_ratio: The probability-weight of giving each trace the `train` tag.
        test_split_ratio: The probability-weight of giving each trace the `test` tag.
        validation_split_ratio: The probability-weight of giving each trace the
            `validation` tag.
    """

    inputs = list(inputs)
    expected_outputs = list(expected_outputs)
    assert len(inputs) == len(
        expected_outputs
    ), "The length of the inputs and expected_outputs must be equal"

    tags = tags if isinstance(tags, list) else [tags]

    sum_ratio = train_split_ratio + test_split_ratio + validation_split_ratio
    assert sum_ratio > 0, "The sum of the split ratios must be a positive number"

    train_split_ratio /= sum_ratio
    test_split_ratio /= sum_ratio
    validation_split_ratio /= sum_ratio

    values = list(zip(inputs, expected_outputs))
    shuffle(values)

    split_tags = (
        [TRAIN_SPLIT_TAG_NAME] * ceil(train_split_ratio * len(inputs))
        + [TEST_SPLIT_TAG_NAME] * ceil(test_split_ratio * len(inputs))
        + [VALIDATION_SPLIT_TAG_NAME] * ceil(validation_split_ratio * len(inputs))
    )
    shuffle(split_tags)

    created = datetime.utcnow().isoformat()
    traces = [
        cast(
            Trace[T],
            Trace(  # avoid ValueError: "Trace" object has no field "__orig_class__"
                trace_id=str(uuid4()),
                created=created,
                original_execution_time_ms=0,
                logged_values=X if isinstance(X, dict) else {"input": X},
                models=[],
                output=y,
                feedback=y,
                exception=None,
                tags=[GROUND_TRUTH_TAG_NAME, split_tag, *tags],
            ),
        )
        for ((X, y), split_tag) in zip(values, split_tags)
    ]

    get_context().tracing_database.save_batch(traces)

query_ground_truth(conjunctive_tags=[], *, since=None, until=None, return_max_count=None) #

Return training samples.

Combines, filters, and returns data-points that have been either added by add_ground_truth or were the result of a prediction after which their trace got feedback through the RESP API-s /traces/{trace_id}/feedback endpoint (end-to-end feedback).

Filtering can be used to only return points matching all given tags (or the single given tag) and by time of creation.

Examples:

>>> query_ground_truth()
[...]

Parameters:

Name Type Description Default
conjunctive_tags Union[List[str], str]

Single tag or a list of tags which the returned traces have to match. The relationship between the tags is conjunctive (AND).

[]
since Optional[datetime]

Only return traces created after the given timestamp. None means no filtering.

None
until Optional[datetime]

Only return traces created before the given timestamp. None means no filtering.

None
return_max_count Optional[int]

Return at-most this many traces. (take, limit)

None
Source code in great_ai/tracing/query_ground_truth.py
def query_ground_truth(
    conjunctive_tags: Union[List[str], str] = [],
    *,
    since: Optional[datetime] = None,
    until: Optional[datetime] = None,
    return_max_count: Optional[int] = None,
) -> List[Trace]:
    """Return training samples.

    Combines, filters, and returns data-points that have been either added by
    `add_ground_truth` or were the result of a prediction after which their trace got
    feedback through the RESP API-s `/traces/{trace_id}/feedback` endpoint
    (end-to-end feedback).

    Filtering can be used to only return points matching all given tags (or the single
    given tag) and by time of creation.

    Examples:
        >>> query_ground_truth()
        [...]

    Args:
        conjunctive_tags: Single tag or a list of tags which the returned traces have to
            match. The relationship between the tags is conjunctive (AND).
        since: Only return traces created after the given timestamp. `None` means no
            filtering.
        until: Only return traces created before the given timestamp. `None` means no
            filtering.
        return_max_count: Return at-most this many traces. (take, limit)
    """

    tags = (
        conjunctive_tags if isinstance(conjunctive_tags, list) else [conjunctive_tags]
    )
    db = get_context().tracing_database

    items, length = db.query(
        conjunctive_tags=tags,
        since=since,
        until=until,
        take=return_max_count,
        has_feedback=True,
    )
    return items

delete_ground_truth(conjunctive_tags=[], *, since=None, until=None) #

Delete traces matching the given criteria.

Takes the same arguments as query_ground_truth but instead of returning them, it simply deletes them.

You can rely on the efficiency of the delete's implementation.

Examples:

>>> delete_ground_truth(['train', 'test', 'validation'])

Parameters:

Name Type Description Default
conjunctive_tags Union[List[str], str]

Single tag or a list of tags which the deleted traces have to match. The relationship between the tags is conjunctive (AND).

[]
since Optional[datetime]

Only delete traces created after the given timestamp. None means no filtering.

None
until Optional[datetime]

Only delete traces created before the given timestamp. None means no filtering.

None
Source code in great_ai/tracing/delete_ground_truth.py
def delete_ground_truth(
    conjunctive_tags: Union[List[str], str] = [],
    *,
    since: Optional[datetime] = None,
    until: Optional[datetime] = None,
) -> None:
    """Delete traces matching the given criteria.

    Takes the same arguments as `query_ground_truth` but instead of returning them,
    it simply deletes them.

    You can rely on the efficiency of the delete's implementation.

    Examples:
        >>> delete_ground_truth(['train', 'test', 'validation'])

    Args:
        conjunctive_tags: Single tag or a list of tags which the deleted traces have to
            match. The relationship between the tags is conjunctive (AND).
        since: Only delete traces created after the given timestamp. `None` means no
            filtering.
        until: Only delete traces created before the given timestamp. `None` means no
            filtering.
    """

    tags = (
        conjunctive_tags if isinstance(conjunctive_tags, list) else [conjunctive_tags]
    )
    db = get_context().tracing_database

    items, length = db.query(
        conjunctive_tags=tags, until=until, since=since, has_feedback=True
    )

    db.delete_batch([i.trace_id for i in items])

Tracing databases#

TracingDatabaseDriver #

Bases: ABC

Interface expected from a database to be used for storing traces.

Source code in great_ai/persistence/tracing_database_driver.py
class TracingDatabaseDriver(ABC):
    """Interface expected from a database to be used for storing traces."""

    is_production_ready: bool
    initialized: bool = False

    @classmethod
    def configure_credentials_from_file(
        cls,
        secrets: Union[Path, str, ConfigFile],
    ) -> None:
        if not isinstance(secrets, ConfigFile):
            secrets = ConfigFile(secrets)
        cls.configure_credentials(**{k.lower(): v for k, v in secrets.items()})

    @classmethod
    def configure_credentials(
        cls,
    ) -> None:
        cls.initialized = True

    @abstractmethod
    def save(self, document: Trace) -> str:
        pass

    @abstractmethod
    def save_batch(
        self,
        documents: List[Trace],
    ) -> List[str]:
        pass

    @abstractmethod
    def get(self, id: str) -> Optional[Trace]:
        pass

    @abstractmethod
    def query(
        self,
        *,
        skip: int = 0,
        take: Optional[int] = None,
        conjunctive_filters: Sequence[Filter] = [],
        conjunctive_tags: Sequence[str] = [],
        until: Optional[datetime] = None,
        since: Optional[datetime] = None,
        has_feedback: Optional[bool] = None,
        sort_by: Sequence[SortBy] = [],
    ) -> Tuple[List[Trace], int]:
        pass

    @abstractmethod
    def update(self, id: str, new_version: Trace) -> None:
        pass

    @abstractmethod
    def delete(self, id: str) -> None:
        pass

    @abstractmethod
    def delete_batch(
        self,
        ids: List[str],
    ) -> None:
        pass

MongoDbDriver #

Bases: TracingDatabaseDriver

TracingDatabaseDriver implementation using MongoDB as a backend.

A production-ready database driver suitable for efficiently handling semi-structured data.

Checkout MongoDB Atlas for a hosted MongoDB solution.

Source code in great_ai/persistence/mongodb_driver.py
class MongoDbDriver(TracingDatabaseDriver):
    """TracingDatabaseDriver implementation using MongoDB as a backend.

    A production-ready database driver suitable for efficiently handling semi-structured
    data.

    Checkout [MongoDB Atlas](https://www.mongodb.com/cloud/atlas/register) for a hosted
    MongoDB solution.
    """

    is_production_ready = True

    mongo_connection_string: str
    mongo_database: str

    def __init__(self) -> None:
        super().__init__()
        if self.mongo_connection_string is None or self.mongo_database is None:
            raise ValueError(
                "Please configure the MongoDB access options by calling "
                "MongoDbDriver.configure_credentials"
            )

        with MongoClient[Any](self.mongo_connection_string) as client:
            client[self.mongo_database].traces.create_index(
                [("tags", ASCENDING), ("created", DESCENDING)], background=True
            )

    @classmethod
    def configure_credentials(  # type: ignore
        cls,
        *,
        mongo_connection_string: str,
        mongo_database: str,
        **_: Any,
    ) -> None:
        """Configure the connection details for MongoDB.

        Args:
            mongo_connection_string: For example:
                'mongodb://my_user:my_pass@localhost:27017'
            mongo_database: Name of the database to use. If doesn't exist, it is
                created and initialised.
        """
        cls.mongo_connection_string = mongo_connection_string
        cls.mongo_database = mongo_database
        super().configure_credentials()

    def save(self, trace: Trace) -> str:
        serialized = trace.to_flat_dict()
        serialized["_id"] = trace.trace_id

        with MongoClient[Any](self.mongo_connection_string) as client:
            return client[self.mongo_database].traces.insert_one(serialized).inserted_id

    def save_batch(self, documents: List[Trace]) -> List[str]:
        serialized = [d.to_flat_dict() for d in documents]
        for s in serialized:
            s["_id"] = s["trace_id"]

        with MongoClient[Any](self.mongo_connection_string) as client:
            return (
                client[self.mongo_database]
                .traces.insert_many(serialized, ordered=False)
                .inserted_ids
            )

    def get(self, id: str) -> Optional[Trace]:
        with MongoClient[Any](self.mongo_connection_string) as client:
            value = client[self.mongo_database].traces.find_one(id)

        if value:
            value = Trace.parse_obj(value)

        return value

    def _get_operator(self, filter: Filter) -> str:
        if filter.operator == "contains" and not isinstance(filter.value, str):
            return operator_mapping["="]
        return operator_mapping[filter.operator]

    def query(
        self,
        *,
        skip: int = 0,
        take: Optional[int] = None,
        conjunctive_filters: Sequence[Filter] = [],
        conjunctive_tags: Sequence[str] = [],
        since: Optional[datetime] = None,
        until: Optional[datetime] = None,
        has_feedback: Optional[bool] = None,
        sort_by: Sequence[SortBy] = [],
    ) -> Tuple[List[Trace], int]:
        query: Dict[str, Any] = {
            "filter": {},
        }

        and_query: List[Dict[str, Any]] = []
        and_query.extend({"tags": tag} for tag in conjunctive_tags)
        and_query.extend(
            {f.property: {self._get_operator(f): f.value}} for f in conjunctive_filters
        )
        if not and_query:
            and_query.append({})

        if since:
            and_query.append({"created": {"$gte": since}})

        if until:
            and_query.append({"created": {"$lte": until}})

        if has_feedback is not None:
            and_query.append(
                {"feedback": {"$ne": None}} if has_feedback else {"feedback": None}
            )
        query["filter"]["$and"] = and_query

        with MongoClient[Any](self.mongo_connection_string) as client:
            count = client[self.mongo_database].traces.count_documents(**query)

            if skip:
                query["skip"] = skip

            if take:
                query["limit"] = take

            query["sort"] = [
                (col.column_id, 1 if col.direction == "asc" else -1) for col in sort_by
            ]

            with client[self.mongo_database].traces.find(**query) as cursor:
                documents = [Trace[Any].parse_obj(t) for t in cursor]
        return documents, count

    def update(self, id: str, new_version: Trace) -> None:
        serialized = new_version.to_flat_dict()
        serialized["_id"] = new_version.trace_id

        with MongoClient[Any](self.mongo_connection_string) as client:
            client[self.mongo_database].traces.update_one({"_id": id}, serialized)

    def delete(self, id: str) -> None:
        with MongoClient[Any](self.mongo_connection_string) as client:
            client[self.mongo_database].traces.delete_one({"_id": id})

    def delete_batch(self, ids: List[str]) -> None:
        with MongoClient[Any](self.mongo_connection_string) as client:
            for c in chunk(
                ids, chunk_size=10000
            ):  # avoid: 'delete' command document too large
                delete_filter = {"_id": {"$in": c}}
                client[self.mongo_database].traces.delete_many(delete_filter)

configure_credentials(*, mongo_connection_string, mongo_database, **_) classmethod #

Configure the connection details for MongoDB.

Parameters:

Name Type Description Default
mongo_connection_string str

For example: 'mongodb://my_user:my_pass@localhost:27017'

required
mongo_database str

Name of the database to use. If doesn't exist, it is created and initialised.

required
Source code in great_ai/persistence/mongodb_driver.py
@classmethod
def configure_credentials(  # type: ignore
    cls,
    *,
    mongo_connection_string: str,
    mongo_database: str,
    **_: Any,
) -> None:
    """Configure the connection details for MongoDB.

    Args:
        mongo_connection_string: For example:
            'mongodb://my_user:my_pass@localhost:27017'
        mongo_database: Name of the database to use. If doesn't exist, it is
            created and initialised.
    """
    cls.mongo_connection_string = mongo_connection_string
    cls.mongo_database = mongo_database
    super().configure_credentials()

ParallelTinyDbDriver #

Bases: TracingDatabaseDriver

TracingDatabaseDriver with TinyDB as a backend.

Saves the database as a JSON into a single file. Highly inefficient on inserting, not advised for production use.

A multiprocessing lock protects the database file to avoid parallelisation issues.

Source code in great_ai/persistence/parallel_tinydb_driver.py
class ParallelTinyDbDriver(TracingDatabaseDriver):
    """TracingDatabaseDriver with TinyDB as a backend.

    Saves the database as a JSON into a single file. Highly inefficient on inserting,
    not advised for production use.

    A multiprocessing lock protects the database file to avoid parallelisation issues.
    """

    is_production_ready = False
    path_to_db = Path(DEFAULT_TRACING_DB_FILENAME)

    def save(self, trace: Trace) -> str:
        return self._safe_execute(lambda db: db.insert(trace.dict()))

    def save_batch(self, documents: List[Trace]) -> List[str]:
        traces = [d.dict() for d in documents]
        return self._safe_execute(lambda db: db.insert_multiple(traces))

    def get(self, id: str) -> Optional[Trace]:
        value = self._safe_execute(lambda db: db.get(lambda d: d["trace_id"] == id))
        if value:
            value = Trace.parse_obj(value)
        return value

    def query(
        self,
        *,
        skip: int = 0,
        take: Optional[int] = None,
        conjunctive_filters: Sequence[Filter] = [],
        conjunctive_tags: Sequence[str] = [],
        since: Optional[datetime] = None,
        until: Optional[datetime] = None,
        has_feedback: Optional[bool] = None,
        sort_by: Sequence[SortBy] = [],
    ) -> Tuple[List[Trace], int]:
        def does_match(d: Dict[str, Any]) -> bool:
            return (
                not set(conjunctive_tags) - set(d["tags"])
                and (since is None or datetime.fromisoformat(d["created"]) >= since)
                and (until is None or datetime.fromisoformat(d["created"]) <= until)
                and (
                    has_feedback is None or has_feedback == (d["feedback"] is not None)
                )
            )

        documents = self._safe_execute(lambda db: db.search(does_match))
        if not documents:
            return [], 0

        df = pd.DataFrame([Trace.parse_obj(d).to_flat_dict() for d in documents])

        for f in conjunctive_filters:
            operator = f.operator.lower()
            if operator in operator_mapping:
                df = df.loc[
                    getattr(df[f.property], operator_mapping[f.operator])(f.value)
                ]
            elif operator == "contains":
                df = df.loc[
                    df[f.property].str.contains(
                        str(int(f.value)) if isinstance(f.value, float) else f.value,
                        case=False,
                    )
                ]

        if sort_by:
            df.sort_values(
                [col.column_id for col in sort_by],
                ascending=[col.direction == "asc" for col in sort_by],
                inplace=True,
            )

        count = len(df)
        result = df.iloc[skip:] if take is None else df.iloc[skip : skip + take]
        return [Trace.parse_obj(trace) for _, trace in result.iterrows()], count

    def update(self, id: str, new_version: Trace) -> None:
        self._safe_execute(
            lambda db: db.update(new_version.dict(), lambda d: d["trace_id"] == id)
        )

    def delete(self, id: str) -> None:
        self._safe_execute(lambda db: db.remove(lambda d: d["trace_id"] == id))

    def delete_batch(self, ids: List[str]) -> None:
        with lock:
            with TinyDB(self.path_to_db) as db:
                for id in ids:
                    db.remove(lambda d: d["trace_id"] == id)

    def _safe_execute(self, func: Callable[[TinyDB], Any]) -> Any:
        with lock:
            with TinyDB(self.path_to_db) as db:
                return func(db)

Last update: July 11, 2022