Model
airt.client.Model (ProgressStatus)
A class for querying the model training, evaluation, and prediction status.
The Model
class is instantiated automatically when the DataSource.train
method is called on a datasource. Currently,
it is the only way to instantiate the Model
class.
The model is trained to predict a specific event in the future and we assume the input data to have:
- a column identifying a client (client_column). E.g: person, car, business, etc.,
- a column specifying a type of event to predict (target_column). E.g: buy, checkout, etc.,
- a timestamp column (timestamp_column) specifying the time of an occurred event.
Along with the above mandatory columns, the input data can have additional columns of any type (int, category, float, datetime type, etc.,). These additional columns will be used in the model training for making more accurate predictions.
Finally, we need to know how much ahead we wish to make predictions. This lead time varies widely for each use case and can be in minutes for a webshop or even several weeks for a banking product such as a loan.
As always, the model training and prediction is an asynchronous process and can take a few hours to finish depending
on the size of your dataset. The progress for the same can be checked by calling the is_ready
method on the Model
instance. Alternatively, you can call the progress_bar
method to monitor the status interactively.
__init__(self, id, datasource_id=None, client_column=None, target_column=None, target=None, predict_after=None, timestamp_column=None, total_steps=None, completed_steps=None, error=None, disabled=None, created=None, user_id=None)
special
Constructs a new Model
instance
Warning
Do not construct this object directly by calling the constructor, please use
DataSource.train
method instead.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
id |
int |
Model id. |
required |
datasource_id |
Optional[int] |
DataSource id. |
None |
client_column |
Optional[str] |
The column name that uniquely identifies the users/clients. |
None |
target_column |
Optional[str] |
Target column name that indicates the type of the event. |
None |
target |
Optional[str] |
Target event name to train and make predictions. You can pass the target event as a string or as a regular expression for predicting more than one event. For example, passing *checkout will train a model to predict any checkout event. |
None |
predict_after |
Optional[str] |
Time delta in hours of the expected target event. |
None |
timestamp_column |
Optional[str] |
The timestamp column indicating the time of an event. If not passed, then the default value None will be used. |
None |
total_steps |
Optional[int] |
No of steps needed to complete the model training. |
None |
completed_steps |
Optional[int] |
No of steps completed so far in the model training. |
None |
error |
Optional[str] |
Error message while processing the model. |
None |
disabled |
Optional[bool] |
Flag to indicate the status of the model. |
None |
created |
Optional[pandas._libs.tslibs.timestamps.Timestamp] |
Model creation date. |
None |
user_id |
Optional[int] |
The id of the user who created the model. |
None |
Source code in airt/client.py
def __init__(
self,
id: int,
datasource_id: Optional[int] = None,
client_column: Optional[str] = None,
target_column: Optional[str] = None,
target: Optional[str] = None,
predict_after: Optional[str] = None,
timestamp_column: Optional[str] = None,
total_steps: Optional[int] = None,
completed_steps: Optional[int] = None,
error: Optional[str] = None,
disabled: Optional[bool] = None,
created: Optional[pd.Timestamp] = None,
user_id: Optional[int] = None,
):
"""Constructs a new `Model` instance
Warning:
Do not construct this object directly by calling the constructor, please use
`DataSource.train` method instead.
Args:
id: Model id.
datasource_id: DataSource id.
client_column: The column name that uniquely identifies the users/clients.
target_column: Target column name that indicates the type of the event.
target: Target event name to train and make predictions. You can pass the target event as a string or as a
regular expression for predicting more than one event. For example, passing ***checkout** will
train a model to predict any checkout event.
predict_after: Time delta in hours of the expected target event.
timestamp_column: The timestamp column indicating the time of an event. If not passed,
then the default value **None** will be used.
total_steps: No of steps needed to complete the model training.
completed_steps: No of steps completed so far in the model training.
error: Error message while processing the model.
disabled: Flag to indicate the status of the model.
created: Model creation date.
user_id: The id of the user who created the model.
"""
self.id = id
self.datasource_id = datasource_id
self.client_column = client_column
self.target_column = target_column
self.target = target
self.predict_after = predict_after
self.timestamp_column = timestamp_column
self.total_steps = total_steps
self.completed_steps = completed_steps
self.error = error
self.disabled = disabled
self.created = created
self.user_id = user_id
ProgressStatus.__init__(self, relative_url=f"/model/{self.id}")
as_df(mx)
staticmethod
Return the details of Model instances as a pandas dataframe.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mx |
List[Model] |
List of Model instances. |
required |
Returns:
Type | Description |
---|---|
DataFrame |
Details of all the models in a dataframe. |
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the server address is invalid or not reachable. |
An example get the details of available models:
mx = Model.ls()
Model.as_df(mx)
Source code in airt/client.py
@staticmethod
def as_df(mx: List["Model"]) -> pd.DataFrame:
"""Return the details of Model instances as a pandas dataframe.
Args:
mx: List of Model instances.
Returns:
Details of all the models in a dataframe.
Raises:
ConnectionError: If the server address is invalid or not reachable.
An example get the details of available models:
```python
mx = Model.ls()
Model.as_df(mx)
```
"""
model_lists = get_attributes_from_instances(mx, Model.BASIC_MODEL_COLS) # type: ignore
df = generate_df(model_lists, Model.BASIC_MODEL_COLS)
return add_ready_column(df)
delete(self)
Delete a model from the server.
Returns:
Type | Description |
---|---|
DataFrame |
A pandas DataFrame encapsulating the details of the deleted model. |
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the server address is invalid or not reachable. |
An example to delete a model from server:
model.delete()
Source code in airt/client.py
@patch
def delete(self: Model) -> pd.DataFrame:
"""Delete a model from the server.
Returns:
A pandas DataFrame encapsulating the details of the deleted model.
Raises:
ConnectionError: If the server address is invalid or not reachable.
An example to delete a model from server:
```python
model.delete()
```
"""
models = Client._delete_data(relative_url=f"/model/{self.id}")
models_df = pd.DataFrame(models, index=[0])[Model.BASIC_MODEL_COLS]
return add_ready_column(models_df)
details(self)
Return the details of a model.
Returns:
Type | Description |
---|---|
DataFrame |
A pandas DataFrame encapsulating the details of the model. |
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the server address is invalid or not reachable. |
An example to get the details of a model:
model.details()
Source code in airt/client.py
@patch
def details(self: Model) -> pd.DataFrame:
"""Return the details of a model.
Returns:
A pandas DataFrame encapsulating the details of the model.
Raises:
ConnectionError: If the server address is invalid or not reachable.
An example to get the details of a model:
```python
model.details()
```
"""
details = Client._get_data(relative_url=f"/model/{self.id}")
details_df = pd.DataFrame(details, index=[0])[Model.ALL_MODEL_COLS]
return add_ready_column(details_df)
evaluate(self)
Return the evaluation metrics of the trained model.
Currently, this function returns the accuracy, precision, and recall of the model. More performance metrics will be added in the future.
Returns:
Type | Description |
---|---|
DataFrame |
The performance metrics of the trained model as a pandas series. |
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the server address is invalid or not reachable. |
Source code in airt/client.py
@patch
def evaluate(self: Model) -> pd.DataFrame:
"""Return the evaluation metrics of the trained model.
Currently, this function returns the accuracy, precision, and recall of the model. More
performance metrics will be added in the future.
Returns:
The performance metrics of the trained model as a pandas series.
Raises:
ConnectionError: If the server address is invalid or not reachable.
"""
model_evaluate = Client._get_data(relative_url=f"/model/{self.id}/evaluate")
return pd.DataFrame(dict(model_evaluate), index=[0]).T.rename(columns={0: "eval"})
is_ready(self)
inherited
Check if the method's progress is complete.
Returns:
Type | Description |
---|---|
bool |
True if the progress if completed, else False. |
Source code in airt/client.py
def is_ready(self) -> bool:
"""Check if the method's progress is complete.
Returns:
**True** if the progress if completed, else **False**.
"""
response = Client._get_data(relative_url=self.relative_url)
return response["completed_steps"] == response["total_steps"]
ls(offset=0, limit=100, disabled=False, completed=False)
staticmethod
Return the list of Model instances available in the server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
offset |
int |
The number of models to offset at the beginning. If None, then the default value 0 will be used. |
0 |
limit |
int |
The maximum number of models to return from the server. If None, then the default value 100 will be used. |
100 |
disabled |
bool |
If set to True, then only the deleted models will be returned. Else, the default value False will be used to return only the list of active models. |
False |
completed |
bool |
If set to True, then only the models that are successfully processed in server will be returned. Else, the default value False will be used to return all the models. |
False |
Returns:
Type | Description |
---|---|
List[Model] |
A list of Model instances available in the server. |
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the server address is invalid or not reachable. |
An example to list the available models:
Model.ls()
Source code in airt/client.py
@staticmethod
def ls(
offset: int = 0,
limit: int = 100,
disabled: bool = False,
completed: bool = False,
) -> List["Model"]:
"""Return the list of Model instances available in the server.
Args:
offset: The number of models to offset at the beginning. If None, then the default value **0** will be used.
limit: The maximum number of models to return from the server. If None,
then the default value **100** will be used.
disabled: If set to **True**, then only the deleted models will be returned. Else, the default value
**False** will be used to return only the list of active models.
completed: If set to **True**, then only the models that are successfully processed in server will be returned.
Else, the default value **False** will be used to return all the models.
Returns:
A list of Model instances available in the server.
Raises:
ConnectionError: If the server address is invalid or not reachable.
An example to list the available models:
```python
Model.ls()
```
"""
lists = Client._get_data(
relative_url=f"/model/?disabled={disabled}&completed={completed}&offset={offset}&limit={limit}"
)
mx = [
Model(
id=model["id"],
datasource_id=model["datasource_id"],
client_column=model["client_column"],
target_column=model["target_column"],
target=model["target"],
predict_after=model["predict_after"],
timestamp_column=model["timestamp_column"],
total_steps=model["total_steps"],
completed_steps=model["completed_steps"],
error=model["error"],
disabled=model["disabled"],
created=model["created"],
user_id=model["user_id"],
)
for model in lists
]
return mx
predict(self, data_id=0)
Run predictions against the trained model.
The progress for the same can be checked by calling the is_ready
method on the Model
instance.
Alternatively, you can call the progress_bar
method to monitor the status interactively.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
data_id |
Optional[int] |
The datasource id to run the predictions. If not set, then the datasource used for training the model will be used for prediction aswell. |
0 |
Returns:
Type | Description |
---|---|
Prediction |
An instance of the |
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the server address is invalid or not reachable. |
An example to run the prediction:
model.predict()
Source code in airt/client.py
@patch
def predict(self: Model, data_id: Optional[int] = 0) -> Prediction:
"""Run predictions against the trained model.
The progress for the same can be checked by calling the `is_ready` method on the `Model` instance.
Alternatively, you can call the `progress_bar` method to monitor the status interactively.
Args:
data_id: The datasource id to run the predictions. If not set, then the datasource used for training
the model will be used for prediction aswell.
Returns:
An instance of the `Prediction` class.
Raises:
ConnectionError: If the server address is invalid or not reachable.
An example to run the prediction:
```python
model.predict()
```
"""
request_body = dict(data_id=data_id) if data_id else None
response = Client._post_data(
relative_url=f"/model/{self.id}/predict", data=request_body
)
return Prediction(id=response["id"], datasource_id=response["datasource_id"])
progress_bar(self)
inherited
Blocks the execution and displays a progress bar showing the remote action progress.
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the server address is invalid or not reachable. |
TimeoutError |
in case of connection timeout. |
Source code in airt/client.py
def progress_bar(self):
"""Blocks the execution and displays a progress bar showing the remote action progress.
Raises:
ConnectionError: If the server address is invalid or not reachable.
TimeoutError: in case of connection timeout.
"""
total_steps = Client._get_data(relative_url=self.relative_url)["total_steps"]
with tqdm(total=total_steps) as pbar:
started_at = datetime.now()
while True:
if (0 < self.timeout) and (datetime.now() - started_at) > timedelta(
seconds=self.timeout
):
raise TimeoutError()
response = Client._get_data(relative_url=self.relative_url)
completed_steps = response["completed_steps"]
pbar.update(completed_steps)
if completed_steps == total_steps:
break
sleep(self.sleep_for)
wait(self)
inherited
Blocks execution while waiting for the remote action to complete.
Exceptions:
Type | Description |
---|---|
ConnectionError |
If the server address is invalid or not reachable. |
TimeoutError |
in case of timeout. |
Source code in airt/client.py
@patch
def wait(self: ProgressStatus):
"""Blocks execution while waiting for the remote action to complete.
Raises:
ConnectionError: If the server address is invalid or not reachable.
TimeoutError: in case of timeout.
"""
started_at = datetime.now()
while True:
if (0 < self.timeout) and (datetime.now() - started_at) > timedelta(
seconds=self.timeout
):
raise TimeoutError()
if self.is_ready():
return
sleep(self.sleep_for)