Skip to content

Prediction

airt.client.Prediction (ProgressStatus)

A class to run predictions on the data source.

The Prediction class is automatically instantiated by calling the DataSource.train method of a DataSource instance. Currently, it is the only way to instantiate this class. The returned object will have utility methods like converting the prediction results into a pandas dataframe and pushing the prediction results into one of the supported data sources etc.,

For more information on the supported data sources, please refer to the documentation on DataSource class

__init__(self, prediction_id, datasource_id=None) special

Constructs a new Prediction instance

Warning

Do not construct this object directly by calling the constructor, instead please use Model.predict method of the Model instance.

Parameters:

Name Type Description Default
prediction_id int

ID of the prediction in the airt service

required
datasource_id Optional[int]

The data ID used for running the prediction

None
Source code in airt/client.py
def __init__(self, prediction_id: int, datasource_id: Optional[int] = None):
    """Constructs a new `Prediction` instance

    Warning:
        Do not construct this object directly by calling the constructor, instead please use
        `Model.predict` method of the Model instance.

    Args:
        prediction_id: ID of the prediction in the airt service
        datasource_id: The data ID used for running the prediction
    """
    self.prediction_id = prediction_id
    self.datasource_id = datasource_id
    ProgressStatus.__init__(self, relative_url=f"/prediction/{self.prediction_id}")

delete(id) staticmethod

Delete a prediction from airt service

Parameters:

Name Type Description Default
id int

The prediction id in airt service.

required

Returns:

Type Description
DataFrame

A pandas DataFrame encapsulating the details of the deleted prediction id.

Exceptions:

Type Description
ConnectionError

If the server address is invalid or not reachable.

An example to delete a prediction (id=1) from airt service:

Prediction.delete(id=1)
Source code in airt/client.py
@staticmethod
def delete(id: int) -> pd.DataFrame:
    """Delete a prediction from airt service

    Args:
        id: The prediction id in airt service.

    Returns:
        A pandas DataFrame encapsulating the details of the deleted prediction id.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.

    An example to delete a prediction (id=1) from airt service:

    ```python
    Prediction.delete(id=1)
    ```
    """
    predictions = Client.delete_data(relative_url=f"/prediction/{id}")

    columns = Prediction._get_columns()

    predictions_df = pd.DataFrame(predictions, index=[0])[columns]

    return add_ready_column(predictions_df)

details(id) staticmethod

Return details of a prediction

Parameters:

Name Type Description Default
id int

The id of the prediction in the airt service.

required

Returns:

Type Description
DataFrame

A pandas dataframe with the details of the prediction.

Exceptions:

Type Description
ConnectionError

If the server address is invalid or not reachable.

An example to get details of a prediction (id=1) from airt service:

Prediction.details(id=1)
Source code in airt/client.py
@staticmethod
def details(id: int) -> pd.DataFrame:
    """Return details of a prediction

    Args:
        id: The id of the prediction in the airt service.

    Returns:
        A pandas dataframe with the details of the prediction.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.

    An example to get details of a prediction (id=1) from airt service:

    ```python
    Prediction.details(id=1)
    ```
    """
    details = Client.get_data(relative_url=f"/prediction/{id}")

    additional_cols = ["model_id", "datasource_id", "error"]

    columns = Prediction._get_columns() + additional_cols

    details_df = pd.DataFrame(details, index=[0])[columns]

    return add_ready_column(details_df)

is_ready(self) inherited

A function to check if the method's progress is completed.

Returns:

Type Description
bool

True if the progress if completed, else False.

Source code in airt/client.py
def is_ready(self) -> bool:
    """A function to check if the method's progress is completed.

    Returns:
        True if the progress if completed, else False.
    """
    response = Client.get_data(relative_url=self.relative_url)
    return response["completed_steps"] == response["total_steps"]

ls(offset=0, limit=100, disabled=False, completed=False) staticmethod

Display the list of available predictions in airt service.

Parameters:

Name Type Description Default
offset int

The number of rows to offset at the beginning of the predictions list from the server.If None, then the default value 0 will be used.

0
limit int

The maximum number of rows to return from the server. If None, then the default value 100 will be used.

100
disabled bool

If set to True, then only the deleted predictions will be displayed. Else, the default value False will be used to display only the list of active predictions.

False
completed bool

If set to True, then only the predictions that are successfully completed in airt server will be displayed. Else, the default value False will be used to display all the predictions including the ones that are yet to finish the prediction.

False

Returns:

Type Description
DataFrame

A pandas dataframe with the list of available predictions.

Exceptions:

Type Description
ConnectionError

If the server address is invalid or not reachable.

An example to list the available predictions:

Prediction.ls()
Source code in airt/client.py
@staticmethod
def ls(
    offset: int = 0,
    limit: int = 100,
    disabled: bool = False,
    completed: bool = False,
) -> pd.DataFrame:
    """Display the list of available predictions in airt service.

    Args:
        offset: The number of rows to offset at the beginning of the predictions
            list from the server.If **None**, then the default value **0** will be used.
        limit: The maximum number of rows to return from the server. If **None**,
            then the default value **100** will be used.
        disabled: If set to **True**, then only the deleted predictions will be displayed.
            Else, the default value **False** will be used to display only the list
            of active predictions.
        completed: If set to **True**, then only the predictions that are successfully completed
            in airt server will be displayed. Else, the default value **False** will be used
            to display all the predictions including the ones that are yet to finish the prediction.

    Returns:
        A pandas dataframe with the list of available predictions.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.

    An example to list the available predictions:

    ```python
    Prediction.ls()
    ```
    """
    predictions = Client.get_data(
        relative_url=f"/prediction/?disabled={disabled}&completed={completed}&offset={offset}&limit={limit}"
    )

    columns = Prediction._get_columns()

    predictions_df = generate_df(predictions, columns)

    return add_ready_column(predictions_df)

progress_bar(self, sleep_for=5, timeout=0) inherited

Blocks execution while waiting for remote action to be completed and displays a progress bar indicating the completion status.

Parameters:

Name Type Description Default
sleep_for Union[int, float]

The time interval in seconds between successive API calls to ping the server for fetching the completed steps.

5
timeout int

The maximum time allowed in seconds for the asynchronous call to complete the process. If the timeout exceeds and the process is yet to complete, then the progress_bar will be terminated.

0

Exceptions:

Type Description
TimeoutError

in case of timeout

Source code in airt/client.py
def progress_bar(self, sleep_for: Union[int, float] = 5, timeout: int = 0):
    """Blocks execution while waiting for remote action to be completed and displays a progress bar indicating the completion status.

    Args:
        sleep_for: The time interval in seconds between successive API calls to ping the server for fetching the completed steps.
        timeout: The maximum time allowed in seconds for the asynchronous call to complete the process. If the timeout
            exceeds and the process is yet to complete, then the progress_bar will be terminated.

    Raises:
        TimeoutError: in case of timeout

    """
    total_steps = Client.get_data(relative_url=self.relative_url)["total_steps"]
    with tqdm(total=total_steps) as pbar:
        started_at = datetime.now()
        while True:
            if (0 < timeout) and (datetime.now() - started_at) > timedelta(
                seconds=timeout
            ):
                raise TimeoutError()

            response = Client.get_data(relative_url=self.relative_url)
            completed_steps = response["completed_steps"]

            pbar.update(completed_steps)

            if completed_steps == total_steps:
                break

            sleep(sleep_for)

push(self, data_source)

A function to push the prediction results into the target data source.

For more information on the supported data sources, please refer to the documentation on DataSource class

Parameters:

Name Type Description Default
data_source airt.components.datasource.DataSource

An instance of the DataSource class that encapsulates the data.

required

Returns:

Type Description
ProgressStatus

An instance of ProgressStatus class. ProgressStatus is a base class for querying status of a remote operation. For more information please refer to ProgressStatus class documentation.

Exceptions:

Type Description
ValueError

If the input parameters to the API are invalid.

ConnectionError

If the server address is invalid or not reachable.

The below example illustrates pushing the prediction results to a database:

from datetime import timedelta

Client.get_token()

data_source_s3 = DataSource.s3(
    uri="s3://test-airt-service/ecommerce_behavior"
)
data_source_s3.pull().progress_bar()

model = data_source_s3.train(
    client_column="user_id",
    target_column="event_type",
    target="*purchase",
    predict_after=timedelta(hours=3),
)

predictions = model.predict()

data_source_pred = DataSource.s3(
    uri="s3://target-bucket"
)

progress = predictions.push(data_source_pred)
progress.progress_bar()
Source code in airt/client.py
@patch
def push(self: Prediction, data_source: "airt.components.datasource.DataSource") -> ProgressStatus:  # type: ignore
    """A function to push the prediction results into the target data source.

    For more information on the supported data sources, please refer to the documentation on `DataSource` class

    Args:
        data_source: An instance of the `DataSource` class that encapsulates the data.

    Returns:
        An instance of `ProgressStatus` class. `ProgressStatus` is a base class for querying status of a remote operation. For more information
        please refer to `ProgressStatus` class documentation.

    Raises:
        ValueError: If the input parameters to the API are invalid.
        ConnectionError: If the server address is invalid or not reachable.


    The below example illustrates pushing the prediction results to a database:

    ```python
    from datetime import timedelta

    Client.get_token()

    data_source_s3 = DataSource.s3(
        uri="s3://test-airt-service/ecommerce_behavior"
    )
    data_source_s3.pull().progress_bar()

    model = data_source_s3.train(
        client_column="user_id",
        target_column="event_type",
        target="*purchase",
        predict_after=timedelta(hours=3),
    )

    predictions = model.predict()

    data_source_pred = DataSource.s3(
        uri="s3://target-bucket"
    )

    progress = predictions.push(data_source_pred)
    progress.progress_bar()
    ```
    """
    response = Client.post_data(
        relative_url=f"/prediction/{self.prediction_id}/push",
        data=dict(data_id=data_source.id),
    )
    return ProgressStatus(relative_url=f"/prediction/push/{int(response['id'])}")

to_pandas(self)

A function to convert the predicted results into a Pandas DataFrame object.

Returns:

Type Description
DataFrame

A Pandas DataFrame that contains the prediction results from the model.

Exceptions:

Type Description
ValueError

If the input parameters to the API are invalid.

ConnectionError

If the server address is invalid or not reachable.

The below example illustrates the usage of to_pandas function:

from datetime import timedelta

Client.get_token()
data_source_s3 = DataSource.s3(
    uri="s3://test-airt-service/ecommerce_behavior"
)
data_source_s3.pull().progress_bar()
model = data_source_s3.train(
    client_column="user_id",
    target_column="event_type",
    target="*purchase",
    predict_after=timedelta(hours=3),
)

predictions = model.predict()
predictions.to_pandas()
Source code in airt/client.py
@patch
def to_pandas(self: Prediction) -> pd.DataFrame:
    """A function to convert the predicted results into a Pandas DataFrame object.

    Returns:
        A Pandas DataFrame that contains the prediction results from the model.

    Raises:
        ValueError: If the input parameters to the API are invalid.
        ConnectionError: If the server address is invalid or not reachable.

    The below example illustrates the usage of to_pandas function:

    ```python
    from datetime import timedelta

    Client.get_token()
    data_source_s3 = DataSource.s3(
        uri="s3://test-airt-service/ecommerce_behavior"
    )
    data_source_s3.pull().progress_bar()
    model = data_source_s3.train(
        client_column="user_id",
        target_column="event_type",
        target="*purchase",
        predict_after=timedelta(hours=3),
    )

    predictions = model.predict()
    predictions.to_pandas()
    ```
    """
    response = Client.get_data(relative_url=f"/prediction/{self.prediction_id}/pandas")
    keys = list(response.keys())
    keys.remove("Score")
    index_name = keys[0]
    return (
        pd.DataFrame(response)
        .set_index(index_name)
        .sort_values("Score", ascending=False)
    )

wait(self, sleep_for=1, timeout=0) inherited

Blocks execution while waiting for remote action to be completed.

Parameters:

Name Type Description Default
sleep_for Union[int, float]

The time interval in seconds between successive API calls to ping the server for fetching the completed steps.

1
timeout int

The maximum time allowed in seconds for the asynchronous call to complete the process. If the timeout exceeds and the process is yet to complete, then the progress_bar will be terminated.

0

Exceptions:

Type Description
TimeoutError

in case of timeout

Source code in airt/client.py
@patch
def wait(self: ProgressStatus, sleep_for: Union[int, float] = 1, timeout: int = 0):
    """Blocks execution while waiting for remote action to be completed.

    Args:
        sleep_for: The time interval in seconds between successive API calls to ping the server for fetching the completed steps.
        timeout: The maximum time allowed in seconds for the asynchronous call to complete the process. If the timeout
            exceeds and the process is yet to complete, then the progress_bar will be terminated.

    Raises:
        TimeoutError: in case of timeout
    """
    started_at = datetime.now()
    while True:
        if (0 < timeout) and (datetime.now() - started_at) > timedelta(seconds=timeout):
            raise TimeoutError()

        if self.is_ready():
            return

        sleep(sleep_for)
Back to top