Prediction
airt.client.Prediction (ProgressStatus)
A class to manage and download the predictions.
The Prediction
class is automatically instantiated by calling the Model.predict
method of a Model
instance.
Currently, it is the only way to instantiate this class.
__init__(self, id, datasource_id=None, model_id=None, created=None, total_steps=None, completed_steps=None, error=None, disabled=None)
special
Constructs a new Prediction
instance
Warning
Do not construct this object directly by calling the constructor, instead please use
Model.predict
method of the Model instance.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
id |
int |
Prediction ID in the server. |
required |
datasource_id |
Optional[int] |
DataSource ID in the server. |
None |
model_id |
Optional[int] |
Model ID in the server. |
None |
created |
Optional[pandas._libs.tslibs.timestamps.Timestamp] |
Prediction creation date. |
None |
total_steps |
Optional[int] |
No of steps needed to complete the model prediction. |
None |
completed_steps |
Optional[int] |
No of steps completed so far in the model prediction. |
None |
error |
Optional[str] |
Error message while making the prediction. |
None |
disabled |
Optional[bool] |
Flag to indicate the status of the prediction. |
None |
Source code in airt/client.py
def __init__(
self,
id: int,
datasource_id: Optional[int] = None,
model_id: Optional[int] = None,
created: Optional[pd.Timestamp] = None,
total_steps: Optional[int] = None,
completed_steps: Optional[int] = None,
error: Optional[str] = None,
disabled: Optional[bool] = None,
):
"""Constructs a new `Prediction` instance
Warning:
Do not construct this object directly by calling the constructor, instead please use
`Model.predict` method of the Model instance.
Args:
id: Prediction ID in the server.
datasource_id: DataSource ID in the server.
model_id: Model ID in the server.
created: Prediction creation date.
total_steps: No of steps needed to complete the model prediction.
completed_steps: No of steps completed so far in the model prediction.
error: Error message while making the prediction.
disabled: Flag to indicate the status of the prediction.
"""
self.id = id
self.datasource_id = datasource_id
self.model_id = model_id
self.created = created
self.total_steps = total_steps
self.completed_steps = completed_steps
self.error = error
self.disabled = disabled
ProgressStatus.__init__(self, relative_url=f"/prediction/{self.id}")
as_df(predx)
staticmethod
Return the details of prediction instances as a pandas dataframe.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
predx |
List[Prediction] |
List of prediction instances. |
required |
Returns:
Type | Description |
---|---|
DataFrame |
Details of all the prediction in a dataframe. |
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the server address is invalid or not reachable. |
An example get the details of available prediction:
predictionx = Prediction.ls()
Prediction.as_df(predictionx)
Source code in airt/client.py
@staticmethod
def as_df(predx: List["Prediction"]) -> pd.DataFrame:
"""Return the details of prediction instances as a pandas dataframe.
Args:
predx: List of prediction instances.
Returns:
Details of all the prediction in a dataframe.
Raises:
ConnectionError: If the server address is invalid or not reachable.
An example get the details of available prediction:
```python
predictionx = Prediction.ls()
Prediction.as_df(predictionx)
```
"""
lists = get_attributes_from_instances(predx, Prediction.BASIC_PRED_COLS) # type: ignore
df = generate_df(lists, Prediction.BASIC_PRED_COLS)
return add_ready_column(df)
delete(self)
Delete a prediction from the server.
Returns:
Type | Description |
---|---|
DataFrame |
A pandas DataFrame encapsulating the details of the deleted prediction. |
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the server address is invalid or not reachable. |
An example to delete a prediction:
prediction.delete()
Source code in airt/client.py
@patch
def delete(self: Prediction) -> pd.DataFrame:
"""Delete a prediction from the server.
Returns:
A pandas DataFrame encapsulating the details of the deleted prediction.
Raises:
ConnectionError: If the server address is invalid or not reachable.
An example to delete a prediction:
```python
prediction.delete()
```
"""
predictions = Client._delete_data(relative_url=f"/prediction/{self.id}")
predictions_df = pd.DataFrame(predictions, index=[0])[Prediction.BASIC_PRED_COLS]
return add_ready_column(predictions_df)
details(self)
Return the details of a prediction.
Returns:
Type | Description |
---|---|
DataFrame |
A pandas DataFrame encapsulating the details of the prediction. |
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the server address is invalid or not reachable. |
An example to get details of a prediction:
prediction.details()
Source code in airt/client.py
@patch
def details(self: Prediction) -> pd.DataFrame:
"""Return the details of a prediction.
Returns:
A pandas DataFrame encapsulating the details of the prediction.
Raises:
ConnectionError: If the server address is invalid or not reachable.
An example to get details of a prediction:
```python
prediction.details()
```
"""
details = Client._get_data(relative_url=f"/prediction/{self.id}")
details_df = pd.DataFrame(details, index=[0])[Prediction.ALL_PRED_COLS]
return add_ready_column(details_df)
is_ready(self)
inherited
Check if the method's progress is complete.
Returns:
Type | Description |
---|---|
bool |
True if the progress if completed, else False. |
Source code in airt/client.py
def is_ready(self) -> bool:
"""Check if the method's progress is complete.
Returns:
**True** if the progress if completed, else **False**.
"""
response = Client._get_data(relative_url=self.relative_url)
return response["completed_steps"] == response["total_steps"]
ls(offset=0, limit=100, disabled=False, completed=False)
staticmethod
Return the list of Prediction instances available in the server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
offset |
int |
The number of predictions to offset at the beginning. If None, then the default value 0 will be used. |
0 |
limit |
int |
The maximum number of predictions to return from the server. If None, then the default value 100 will be used. |
100 |
disabled |
bool |
If set to True, then only the deleted predictions will be returned. Else, the default value False will be used to return only the list of active predictions. |
False |
completed |
bool |
If set to True, then only the predictions that are successfully processed in server will be returned. Else, the default value False will be used to return all the predictions. |
False |
Returns:
Type | Description |
---|---|
List[Prediction] |
A list of Prediction instances available in the server. |
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the server address is invalid or not reachable. |
An example to list the available predictions:
Prediction.ls()
Source code in airt/client.py
@staticmethod
def ls(
offset: int = 0,
limit: int = 100,
disabled: bool = False,
completed: bool = False,
) -> List["Prediction"]:
"""Return the list of Prediction instances available in the server.
Args:
offset: The number of predictions to offset at the beginning. If None, then the default value **0** will be used.
limit: The maximum number of predictions to return from the server. If None,
then the default value **100** will be used.
disabled: If set to **True**, then only the deleted predictions will be returned. Else, the default value
**False** will be used to return only the list of active predictions.
completed: If set to **True**, then only the predictions that are successfully processed in server will be returned.
Else, the default value **False** will be used to return all the predictions.
Returns:
A list of Prediction instances available in the server.
Raises:
ConnectionError: If the server address is invalid or not reachable.
An example to list the available predictions:
```python
Prediction.ls()
```
"""
lists = Client._get_data(
relative_url=f"/prediction/?disabled={disabled}&completed={completed}&offset={offset}&limit={limit}"
)
predx = [
Prediction(
id=pred["id"],
model_id=pred["model_id"],
datasource_id=pred["datasource_id"],
created=pred["created"],
total_steps=pred["total_steps"],
completed_steps=pred["completed_steps"],
error=pred["error"],
disabled=pred["disabled"],
)
for pred in lists
]
return predx
progress_bar(self)
inherited
Blocks the execution and displays a progress bar showing the remote action progress.
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the server address is invalid or not reachable. |
TimeoutError |
in case of connection timeout. |
Source code in airt/client.py
def progress_bar(self):
"""Blocks the execution and displays a progress bar showing the remote action progress.
Raises:
ConnectionError: If the server address is invalid or not reachable.
TimeoutError: in case of connection timeout.
"""
total_steps = Client._get_data(relative_url=self.relative_url)["total_steps"]
with tqdm(total=total_steps) as pbar:
started_at = datetime.now()
while True:
if (0 < self.timeout) and (datetime.now() - started_at) > timedelta(
seconds=self.timeout
):
raise TimeoutError()
response = Client._get_data(relative_url=self.relative_url)
completed_steps = response["completed_steps"]
pbar.update(completed_steps)
if completed_steps == total_steps:
break
sleep(self.sleep_for)
to_clickhouse(self, *, host, database, table, protocol, port=0, username=None, password=None)
Push the prediction results to a clickhouse database.
If the database requires authentication, pass the username/password as parameters or store it in the CLICKHOUSE_USERNAME and CLICKHOUSE_PASSWORD environment variables.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
host |
str |
Remote database host name. |
required |
database |
str |
Database name. |
required |
table |
str |
Table name. |
required |
protocol |
str |
Protocol to use (native/http). |
required |
port |
int |
Host port number. If not passed, then the default value 0 will be used. |
0 |
username |
Optional[str] |
Database username. If not passed, then the value set in the environment variable CLICKHOUSE_USERNAME will be used else the default value "root" will be used. |
None |
password |
Optional[str] |
Database password. If not passed, then the value set in the environment variable CLICKHOUSE_PASSWORD will be used else the default value "" will be used. |
None |
Returns:
Type | Description |
---|---|
ProgressStatus |
An instance of |
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the server address is invalid or not reachable. |
An example to push the prediction resluts to a clickhouse database:
status = prediction.to_clickhouse(
host="host_name",
database="database_name",
table="table_name",
protocol="native",
)
status.progress_bar()
Source code in airt/client.py
@patch
def to_clickhouse(
self: Prediction,
*,
host: str,
database: str,
table: str,
protocol: str,
port: int = 0,
username: Optional[str] = None,
password: Optional[str] = None,
) -> ProgressStatus:
"""Push the prediction results to a clickhouse database.
If the database requires authentication, pass the username/password as parameters or store it in
the **CLICKHOUSE_USERNAME** and **CLICKHOUSE_PASSWORD** environment variables.
Args:
host: Remote database host name.
database: Database name.
table: Table name.
protocol: Protocol to use (native/http).
port: Host port number. If not passed, then the default value **0** will be used.
username: Database username. If not passed, then the value set in the environment variable
**CLICKHOUSE_USERNAME** will be used else the default value "root" will be used.
password: Database password. If not passed, then the value set in the environment variable
**CLICKHOUSE_PASSWORD** will be used else the default value "" will be used.
Returns:
An instance of `ProgressStatus` class.
Raises:
ConnectionError: If the server address is invalid or not reachable.
An example to push the prediction resluts to a clickhouse database:
```python
status = prediction.to_clickhouse(
host="host_name",
database="database_name",
table="table_name",
protocol="native",
)
status.progress_bar()
```
"""
username = (
username
if username is not None
else os.environ.get("CLICKHOUSE_USERNAME", "root")
)
password = (
password if password is not None else os.environ.get("CLICKHOUSE_PASSWORD", "")
)
_body = dict(
host=host,
database=database,
table=table,
protocol=protocol,
port=port,
username=username,
password=password,
)
response = Client._post_data(
relative_url=f"/prediction/{self.id}/to_clickhouse", data=_body
)
return ProgressStatus(relative_url=f"/prediction/push/{response['id']}")
to_local(self, path, show_progress=True)
Download the prediction results to a local directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path |
Union[str, pathlib.Path] |
Local directory path. |
required |
show_progress |
Optional[bool] |
Flag to set the progressbar visibility. If not passed, then the default value True will be used. |
True |
Exceptions:
Type | Description |
---|---|
FileNotFoundError |
If the path is invalid. |
HTTPError |
If the presigned AWS s3 uri to download the prediction results are invalid or not reachable. |
An example to download the predictions to local:
prediction.to_local(
path=Path('path-to-local-directory')
)
Source code in airt/client.py
@patch
def to_local(
self: Prediction,
path: Union[str, Path],
show_progress: Optional[bool] = True,
) -> None:
"""Download the prediction results to a local directory.
Args:
path: Local directory path.
show_progress: Flag to set the progressbar visibility. If not passed, then the default value **True** will be used.
Raises:
FileNotFoundError: If the **path** is invalid.
HTTPError: If the presigned AWS s3 uri to download the prediction results are invalid or not reachable.
An example to download the predictions to local:
```python
prediction.to_local(
path=Path('path-to-local-directory')
)
```
"""
response = Client._get_data(relative_url=f"/prediction/{self.id}/to_local")
# Initiate progress bar
t = tqdm(total=len(response), disable=not show_progress)
for file_name, url in response.items():
Prediction._download_prediction_file_to_local(file_name, url, Path(path))
t.update()
t.close()
to_mysql(self, *, host, database, table, port=3306, username=None, password=None)
Push the prediction results to a mysql database.
If the database requires authentication, pass the username/password as parameters or store it in the AIRT_CLIENT_DB_USERNAME and AIRT_CLIENT_DB_PASSWORD environment variables.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
host |
str |
Database host name. |
required |
database |
str |
Database name. |
required |
table |
str |
Table name. |
required |
port |
int |
Host port number. If not passed, then the default value 3306 will be used. |
3306 |
username |
Optional[str] |
Database username. If not passed, then the value set in the environment variable AIRT_CLIENT_DB_USERNAME will be used else the default value "root" will be used. |
None |
password |
Optional[str] |
Database password. If not passed, then the value set in the environment variable AIRT_CLIENT_DB_PASSWORD will be used else the default value "" will be used. |
None |
Returns:
Type | Description |
---|---|
ProgressStatus |
An instance of |
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the server address is invalid or not reachable. |
An example to push the prediction resluts to a mysql database:
status = prediction.to_mysql(
host="host_name",
database="database_name",
table="table_name"
)
status.progress_bar()
Source code in airt/client.py
@patch
def to_mysql(
self: Prediction,
*,
host: str,
database: str,
table: str,
port: int = 3306,
username: Optional[str] = None,
password: Optional[str] = None,
) -> ProgressStatus:
"""Push the prediction results to a mysql database.
If the database requires authentication, pass the username/password as parameters or store it in
the **AIRT_CLIENT_DB_USERNAME** and **AIRT_CLIENT_DB_PASSWORD** environment variables.
Args:
host: Database host name.
database: Database name.
table: Table name.
port: Host port number. If not passed, then the default value **3306** will be used.
username: Database username. If not passed, then the value set in the environment variable
**AIRT_CLIENT_DB_USERNAME** will be used else the default value "root" will be used.
password: Database password. If not passed, then the value set in the environment variable
**AIRT_CLIENT_DB_PASSWORD** will be used else the default value "" will be used.
Returns:
An instance of `ProgressStatus` class.
Raises:
ConnectionError: If the server address is invalid or not reachable.
An example to push the prediction resluts to a mysql database:
```python
status = prediction.to_mysql(
host="host_name",
database="database_name",
table="table_name"
)
status.progress_bar()
```
"""
username = (
username if username is not None else os.environ.get(CLIENT_DB_USERNAME, "root")
)
password = (
password if password is not None else os.environ.get(CLIENT_DB_PASSWORD, "")
)
_body = dict(
host=host,
port=port,
username=username,
password=password,
database=database,
table=table,
)
response = Client._post_data(
relative_url=f"/prediction/{self.id}/to_mysql", data=_body
)
return ProgressStatus(relative_url=f"/prediction/push/{response['id']}")
to_pandas(self)
Return the prediction results as a pandas DataFrame
Returns:
Type | Description |
---|---|
DataFrame |
A pandas DataFrame encapsulating the results of the prediction. |
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the server address is invalid or not reachable. |
An example to convert the prediction results into a pandas DataFrame:
predictions.to_pandas()
Source code in airt/client.py
@patch
def to_pandas(self: Prediction) -> pd.DataFrame:
"""Return the prediction results as a pandas DataFrame
Returns:
A pandas DataFrame encapsulating the results of the prediction.
Raises:
ConnectionError: If the server address is invalid or not reachable.
An example to convert the prediction results into a pandas DataFrame:
```python
predictions.to_pandas()
```
"""
response = Client._get_data(relative_url=f"/prediction/{self.id}/pandas")
keys = list(response.keys())
keys.remove("Score")
index_name = keys[0]
return (
pd.DataFrame(response)
.set_index(index_name)
.sort_values("Score", ascending=False)
)
to_s3(self, uri, access_key=None, secret_key=None)
Push the prediction results to the target AWS S3 bucket.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
uri |
str |
Target S3 bucket uri. |
required |
access_key |
Optional[str] |
Access key for the target S3 bucket. If None (default value), then the value from AWS_ACCESS_KEY_ID environment variable is used. |
None |
secret_key |
Optional[str] |
Secret key for the target S3 bucket. If None (default value), then the value from AWS_SECRET_ACCESS_KEY environment variable is used. |
None |
Returns:
Type | Description |
---|---|
ProgressStatus |
An instance of |
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the server address is invalid or not reachable. |
An example to push the prediction resluts to an AWS s3 bucket:
status = prediction.to_s3(
uri="s3://target-bucket/"
)
status.progress_bar()
Source code in airt/client.py
@patch
def to_s3(
self: Prediction,
uri: str,
access_key: Optional[str] = None,
secret_key: Optional[str] = None,
) -> ProgressStatus:
"""Push the prediction results to the target AWS S3 bucket.
Args:
uri: Target S3 bucket uri.
access_key: Access key for the target S3 bucket. If **None** (default value), then the value
from **AWS_ACCESS_KEY_ID** environment variable is used.
secret_key: Secret key for the target S3 bucket. If **None** (default value), then the value
from **AWS_SECRET_ACCESS_KEY** environment variable is used.
Returns:
An instance of `ProgressStatus` class.
Raises:
ConnectionError: If the server address is invalid or not reachable.
An example to push the prediction resluts to an AWS s3 bucket:
```python
status = prediction.to_s3(
uri="s3://target-bucket/"
)
status.progress_bar()
```
"""
access_key = (
access_key if access_key is not None else os.environ["AWS_ACCESS_KEY_ID"]
)
secret_key = (
secret_key if secret_key is not None else os.environ["AWS_SECRET_ACCESS_KEY"]
)
response = Client._post_data(
relative_url=f"/prediction/{self.id}/to_s3",
data=dict(uri=uri, access_key=access_key, secret_key=secret_key),
)
return ProgressStatus(relative_url=f"/prediction/push/{response['id']}")
wait(self)
inherited
Blocks execution while waiting for the remote action to complete.
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the server address is invalid or not reachable. |
TimeoutError |
in case of timeout. |
Source code in airt/client.py
@patch
def wait(self: ProgressStatus):
"""Blocks execution while waiting for the remote action to complete.
Raises:
ConnectionError: If the server address is invalid or not reachable.
TimeoutError: in case of timeout.
"""
started_at = datetime.now()
while True:
if (0 < self.timeout) and (datetime.now() - started_at) > timedelta(
seconds=self.timeout
):
raise TimeoutError()
if self.is_ready():
return
sleep(self.sleep_for)