Skip to content

Model

airt.client.Model (ProgressStatus)

A class for querying the model training, evaluation, and prediction status.

The Model class is instantiated automatically when the DataSource.train method is called on a datasource. Currently, it is the only way to instantiate the Model class.

The model is trained to predict a specific event in the future and we assume the input data to have:

  • a column identifying a client (client_column). E.g: person, car, business, etc.,
  • a column specifying a type of event to predict (target_column). E.g: buy, checkout, etc.,
  • a timestamp column (timestamp_column) specifying the time of an occurred event.

Along with the above mandatory columns, the input data can have additional columns of any type (int, category, float, datetime type, etc.,). These additional columns will be used in the model training for making more accurate predictions.

Finally, we need to know how much ahead we wish to make predictions. This lead time varies widely for each use case and can be in minutes for a webshop or even several weeks for a banking product such as a loan.

As always, the model training and prediction is an asynchronous process and can take a few hours to finish depending on the size of your dataset. The progress for the same can be checked by calling the is_ready method on the Model instance. Alternatively, you can call the progress_bar method to monitor the status interactively.

__init__(self, uuid, datasource=None, client_column=None, target_column=None, target=None, predict_after=None, timestamp_column=None, total_steps=None, completed_steps=None, region=None, cloud_provider=None, error=None, disabled=None, created=None, user=None) special

Constructs a new Model instance

Warning

Do not construct this object directly by calling the constructor, please use DataSource.train method instead.

Parameters:

Name Type Description Default
uuid str

Model uuid.

required
datasource Optional[str]

DataSource uuid.

None
client_column Optional[str]

The column name that uniquely identifies the users/clients.

None
target_column Optional[str]

Target column name that indicates the type of the event.

None
target Optional[str]

Target event name to train and make predictions. You can pass the target event as a string or as a regular expression for predicting more than one event. For example, passing *checkout will train a model to predict any checkout event.

None
predict_after Optional[str]

Time delta in hours of the expected target event.

None
timestamp_column Optional[str]

The timestamp column indicating the time of an event. If not passed, then the default value None will be used.

None
total_steps Optional[int]

No of steps needed to complete the model training.

None
completed_steps Optional[int]

No of steps completed so far in the model training.

None
region Optional[str]

AWS bucket region.

None
cloud_provider Optional[str]

The name of the cloud storage provider where the model is stored

None
error Optional[str]

Error message while processing the model.

None
disabled Optional[bool]

Flag to indicate the status of the model.

None
created Optional[pandas._libs.tslibs.timestamps.Timestamp]

Model creation date.

None
user Optional[str]

The uuid of the user who created the model.

None
Source code in airt/client.py
def __init__(
    self,
    uuid: str,
    datasource: Optional[str] = None,
    client_column: Optional[str] = None,
    target_column: Optional[str] = None,
    target: Optional[str] = None,
    predict_after: Optional[str] = None,
    timestamp_column: Optional[str] = None,
    total_steps: Optional[int] = None,
    completed_steps: Optional[int] = None,
    region: Optional[str] = None,
    cloud_provider: Optional[str] = None,
    error: Optional[str] = None,
    disabled: Optional[bool] = None,
    created: Optional[pd.Timestamp] = None,
    user: Optional[str] = None,
):
    """Constructs a new `Model` instance

    Warning:
        Do not construct this object directly by calling the constructor, please use
        `DataSource.train` method instead.

    Args:
        uuid: Model uuid.
        datasource: DataSource uuid.
        client_column: The column name that uniquely identifies the users/clients.
        target_column: Target column name that indicates the type of the event.
        target: Target event name to train and make predictions. You can pass the target event as a string or as a
            regular expression for predicting more than one event. For example, passing ***checkout** will
            train a model to predict any checkout event.
        predict_after: Time delta in hours of the expected target event.
        timestamp_column: The timestamp column indicating the time of an event. If not passed,
            then the default value **None** will be used.
        total_steps: No of steps needed to complete the model training.
        completed_steps: No of steps completed so far in the model training.
        region: AWS bucket region.
        cloud_provider: The name of the cloud storage provider where the model is stored
        error: Error message while processing the model.
        disabled: Flag to indicate the status of the model.
        created: Model creation date.
        user: The uuid of the user who created the model.
    """
    self.uuid = uuid
    self.datasource = datasource
    self.client_column = client_column
    self.target_column = target_column
    self.target = target
    self.predict_after = predict_after
    self.timestamp_column = timestamp_column
    self.total_steps = total_steps
    self.completed_steps = completed_steps
    self.region = region
    self.cloud_provider = cloud_provider
    self.error = error
    self.disabled = disabled
    self.created = created
    self.user = user
    ProgressStatus.__init__(self, relative_url=f"/model/{self.uuid}")

as_df(mx) staticmethod

Return the details of Model instances as a pandas dataframe.

Parameters:

Name Type Description Default
mx List[Model]

List of Model instances.

required

Returns:

Type Description
DataFrame

Details of all the models in a dataframe.

Exceptions:

Type Description
ConnectionError

If the server address is invalid or not reachable.

An example get the details of available models:

mx = Model.ls()
Model.as_df(mx)
Source code in airt/client.py
@staticmethod
def as_df(mx: List["Model"]) -> pd.DataFrame:
    """Return the details of Model instances as a pandas dataframe.

    Args:
        mx: List of Model instances.

    Returns:
        Details of all the models in a dataframe.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.

    An example get the details of available models:

    ```python
    mx = Model.ls()
    Model.as_df(mx)
    ```
    """
    model_lists = get_attributes_from_instances(mx, Model.BASIC_MODEL_COLS)  # type: ignore

    df = generate_df(model_lists, Model.BASIC_MODEL_COLS)

    df = df.rename(columns=Model.COLS_TO_RENAME)

    return add_ready_column(df)

delete(self)

Delete a model from the server.

Returns:

Type Description
DataFrame

A pandas DataFrame encapsulating the details of the deleted model.

Exceptions:

Type Description
ConnectionError

If the server address is invalid or not reachable.

An example to delete a model from server:

model.delete()
Source code in airt/client.py
@patch
def delete(self: Model) -> pd.DataFrame:
    """Delete a model from the server.

    Returns:
        A pandas DataFrame encapsulating the details of the deleted model.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.

    An example to delete a model from server:

    ```python
    model.delete()
    ```
    """

    response = Client._delete_data(relative_url=f"/model/{self.uuid}")

    df = pd.DataFrame(response, index=[0])[Model.BASIC_MODEL_COLS]

    df = df.rename(columns=Model.COLS_TO_RENAME)

    return add_ready_column(df)

details(self)

Return the details of a model.

Returns:

Type Description
DataFrame

A pandas DataFrame encapsulating the details of the model.

Exceptions:

Type Description
ConnectionError

If the server address is invalid or not reachable.

An example to get the details of a model:

model.details()
Source code in airt/client.py
@patch
def details(self: Model) -> pd.DataFrame:
    """Return the details of a model.

    Returns:
        A pandas DataFrame encapsulating the details of the model.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.

    An example to get the details of a model:

    ```python
    model.details()
    ```
    """

    response = Client._get_data(relative_url=f"/model/{self.uuid}")

    df = pd.DataFrame(response, index=[0])[Model.ALL_MODEL_COLS]

    df = df.rename(columns=Model.COLS_TO_RENAME)

    return add_ready_column(df)

evaluate(self)

Return the evaluation metrics of the trained model.

Currently, this function returns the accuracy, precision, and recall of the model. More performance metrics will be added in the future.

Returns:

Type Description
DataFrame

The performance metrics of the trained model as a pandas series.

Exceptions:

Type Description
ConnectionError

If the server address is invalid or not reachable.

Source code in airt/client.py
@patch
def evaluate(self: Model) -> pd.DataFrame:
    """Return the evaluation metrics of the trained model.

    Currently, this function returns the accuracy, precision, and recall of the model. More
    performance metrics will be added in the future.

    Returns:
        The performance metrics of the trained model as a pandas series.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.
    """
    model_evaluate = Client._get_data(relative_url=f"/model/{self.uuid}/evaluate")
    return pd.DataFrame(dict(model_evaluate), index=[0]).T.rename(columns={0: "eval"})

ls(offset=0, limit=100, disabled=False, completed=False) staticmethod

Return the list of Model instances available in the server.

Parameters:

Name Type Description Default
offset int

The number of models to offset at the beginning. If None, then the default value 0 will be used.

0
limit int

The maximum number of models to return from the server. If None, then the default value 100 will be used.

100
disabled bool

If set to True, then only the deleted models will be returned. Else, the default value False will be used to return only the list of active models.

False
completed bool

If set to True, then only the models that are successfully processed in server will be returned. Else, the default value False will be used to return all the models.

False

Returns:

Type Description
List[Model]

A list of Model instances available in the server.

Exceptions:

Type Description
ConnectionError

If the server address is invalid or not reachable.

An example to list the available models:

Model.ls()
Source code in airt/client.py
@staticmethod
def ls(
    offset: int = 0,
    limit: int = 100,
    disabled: bool = False,
    completed: bool = False,
) -> List["Model"]:
    """Return the list of Model instances available in the server.

    Args:
        offset: The number of models to offset at the beginning. If None, then the default value **0** will be used.
        limit: The maximum number of models to return from the server. If None,
            then the default value **100** will be used.
        disabled: If set to **True**, then only the deleted models will be returned. Else, the default value
            **False** will be used to return only the list of active models.
        completed: If set to **True**, then only the models that are successfully processed in server will be returned.
            Else, the default value **False** will be used to return all the models.

    Returns:
        A list of Model instances available in the server.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.

    An example to list the available models:

    ```python
    Model.ls()
    ```
    """
    lists = Client._get_data(
        relative_url=f"/model/?disabled={disabled}&completed={completed}&offset={offset}&limit={limit}"
    )

    mx = [
        Model(
            uuid=model["uuid"],
            datasource=model["datasource"],
            client_column=model["client_column"],
            target_column=model["target_column"],
            target=model["target"],
            predict_after=model["predict_after"],
            timestamp_column=model["timestamp_column"],
            total_steps=model["total_steps"],
            completed_steps=model["completed_steps"],
            region=model["region"],
            cloud_provider=model["cloud_provider"],
            error=model["error"],
            disabled=model["disabled"],
            created=model["created"],
            user=model["user"],
        )
        for model in lists
    ]

    return mx

predict(self, data_uuid=0)

Run predictions against the trained model.

The progress for the same can be checked by calling the is_ready method on the Model instance. Alternatively, you can call the progress_bar method to monitor the status interactively.

Parameters:

Name Type Description Default
data_uuid Optional[int]

The datasource uuid to run the predictions. If not set, then the datasource used for training the model will be used for prediction aswell.

0

Returns:

Type Description
Prediction

An instance of the Prediction class.

Exceptions:

Type Description
ConnectionError

If the server address is invalid or not reachable.

An example to run the prediction:

model.predict()
Source code in airt/client.py
@patch
def predict(self: Model, data_uuid: Optional[int] = 0) -> Prediction:
    """Run predictions against the trained model.

    The progress for the same can be checked by calling the `is_ready` method on the `Model` instance.
    Alternatively, you can call the `progress_bar` method to monitor the status interactively.

    Args:
        data_uuid: The datasource uuid to run the predictions. If not set, then the datasource used for training
            the model will be used for prediction aswell.

    Returns:
        An instance of the `Prediction` class.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.

    An example to run the prediction:

    ```python
    model.predict()
    ```
    """

    req_json = dict(data_uuid=data_uuid) if data_uuid else None

    response = Client._post_data(
        relative_url=f"/model/{self.uuid}/predict", json=req_json
    )

    return Prediction(uuid=response["uuid"], datasource=response["datasource"])
Back to top