Skip to content

DataSource

airt.client.DataSource

A class for encapsulating the data from sources like CSV files, databases, or AWS S3 buckets.

The DataSource class is automatically instantiated by calling either the csv, db or the s3 static methods of a DataSource class. Currently, it is the only way to instantiate this class.

Currently, we support reading and pushing the data from:

  • a local CSV file,

  • a MySql database, and

  • an AWS S3 bucket in the Parquet file format.

We plan to add other databases and storage mediums in the future.

For reading and pushing from a local CSV file, a relative or an absolute path to the target CSV file is required.

And for establishing the connection to the MySql database, parameters like host, port, database name, table name needs are required.

And for establishing a connection to the S3 bucket, URI to the target parquet file is required.

In case if access to the database requires authentication, the required username and password for the database are automatically read from either the environment variables AIRT_CLIENT_DB_USERNAME and AIRT_CLIENT_DB_PASSWORD or from the username and password parameters that are passed to the db methods of the DataSource class.

All the function calls to the library are asynchronous and they return immediately. To manage completion, all methods inside the returned object will return a status object and a method to display an interactive progress bar that can be called to check the progress.

Below are code examples for accessing the above methods:

An example to check for the status flag of s3 connection:

data_source_s3 = DataSource.s3(
    uri="s3://bucket/events.parquet"
)
status = data_source_s3.pull()
status.is_ready()

An example to display an interactive progress bar of s3 connection:

data_source_s3 = DataSource.s3(
    uri="s3://bucket/events.parquet"
)
data_source_s3.pull().progress_bar()

dtypes: DataFrame property readonly

Return the data type of each column for a data source.

Returns:

Type Description
DataFrame

A pandas dataframe that contains the column names and its data types.

Exceptions:

Type Description
ValueError

If the input parameters to the API are invalid.

ConnectionError

If the server address is invalid or not reachable.

An example to check the dtypes of the connected datasource:

data_source_s3 = DataSource.s3(
    uri="s3://bucket/events.parquet"
)
data_source_s3.pull().progress_bar()
data_source_s3.dtypes

__init__(self, data_id, data_type=None, csv_pull_params={}) special

Constructs a new DataSource instance.

Warning

Do not construct this object directly by calling the constructor, please use DataSource.db or DataSource.s3 function instead.

Parameters:

Name Type Description Default
data_id int

ID of the data source in the airt service.

required
data_type Optional[str]

The type of the data source in the airt service.

None
csv_pull_params dict

Additional parameters for processing the csv file in airt service. This includes deduplicate_data, index_column, sort_by, blocksize, kwargs_json.

{}
Source code in airt/client.py
def __init__(
    self, data_id: int, data_type: Optional[str] = None, csv_pull_params: dict = {}
):
    """Constructs a new DataSource instance.

    Warning:
        Do not construct this object directly by calling the constructor, please use `DataSource.db` or `DataSource.s3` function instead.

    Args:
        data_id: ID of the data source in the airt service.
        data_type: The type of the data source in the airt service.
        csv_pull_params: Additional parameters for processing the csv file in airt service. This includes
            deduplicate_data, index_column, sort_by, blocksize, kwargs_json.
    """
    self.id = data_id
    self.data_type = data_type
    self.csv_pull_params = csv_pull_params

csv(file_path, index_column, sort_by, tag_name=None, deduplicate_data=False, blocksize='256MB', kwargs_json=None) staticmethod

Create and return an object that encapsulates the data from a local CSV file.

Parameters:

Name Type Description Default
file_path Union[str, pathlib.Path]

The relative or absolute path of the local CSV file as a string or a Path object.

required
index_column str

The Column name to use as the row labels.

required
sort_by str

The Column name to sort the data.

required
tag_name Optional[str]

The tag name for the datasource.

None
deduplicate_data bool

A boolean flag to handle the duplicate rows in the CSV file. If set to True (default value False), the duplicate rows will be removed from the CSV file before uploading to airt server.

False
blocksize str

The number of bytes by which to cut up larger files. If None, the default value 256MB will be used for each file.

'256MB'
kwargs_json Optional[str]

Any additional parameters for Pandas read_csv method as a JSON string.

None

Returns:

Type Description
DataSource

An instance of the DataSource class. For more information on the methods that are available in the returned object, please check the documentation of the DataSource class.

Exceptions:

Type Description
ValueError

If the csv upload to s3 is unsuccessful.

ConnectionError

If the server address is invalid or not reachable.

An example function call to the DataSource.csv:

    data_source_csv = DataSource.csv(
        file_path=csv_file_path,
        index_column=index_column_name,
        sort_by=sort_by_column_name
    )
Source code in airt/client.py
@staticmethod
def csv(
    file_path: Union[str, Path],
    index_column: str,
    sort_by: str,
    tag_name: Optional[str] = None,
    deduplicate_data: bool = False,
    blocksize: str = "256MB",
    kwargs_json: Optional[str] = None,
) -> "DataSource":
    """Create and return an object that encapsulates the data from a local CSV file.

    Args:
        file_path: The relative or absolute path of the local CSV file as a string or a Path object.
        index_column: The Column name to use as the row labels.
        sort_by: The Column name to sort the data.
        tag_name: The tag name for the datasource.
        deduplicate_data: A boolean flag to handle the duplicate rows in the CSV file. If set to **True** (default value **False**),
            the duplicate rows will be removed from the CSV file before uploading to airt server.
        blocksize: The number of bytes by which to cut up larger files. If None, the default value **256MB** will be used for each file.
        kwargs_json: Any additional parameters for Pandas read_csv method as a JSON string.

    Returns:
        An instance of the `DataSource` class. For more information on the methods that are available in
        the returned object, please check the documentation of the `DataSource` class.

    Raises:
        ValueError: If the csv upload to s3 is unsuccessful.
        ConnectionError: If the server address is invalid or not reachable.

    An example function call to the DataSource.csv:

    ```python
        data_source_csv = DataSource.csv(
            file_path=csv_file_path,
            index_column=index_column_name,
            sort_by=sort_by_column_name
        )
    ```
    """
    # Step 1: get presigned URL
    response = Client.post_data(relative_url=f"/data/csv", data=dict(tag=tag_name))

    # Step 2: download the csv to the s3 bucket
    files = {"file": open(Path(file_path), "rb")}

    s3_response = requests.post(
        response["presigned"]["url"],
        files=files,
        data=response["presigned"]["fields"],
    )

    if not s3_response.status_code == 204:
        raise ValueError(s3_response.text)

    csv_pull_params = dict(
        deduplicate_data=deduplicate_data,
        index_column=index_column,
        sort_by=sort_by,
        blocksize=blocksize,
        kwargs_json=kwargs_json,
    )
    return DataSource(
        data_id=response["id"],
        data_type=response["type"],
        csv_pull_params=csv_pull_params,
    )

db(*, host, database, table, port=3306, engine='mysql', username=None, password=None, tag=None) staticmethod

Create and return an object that encapsulates the data from a database.

A static method that creates and returns an object that encapsulates the data from a database. In case if access to the database requires authentication, the username and password will be read either from the arguments or in the airt environment variables.

The objects created by calling this method won't establish the connection yet.

Parameters:

Name Type Description Default
host str

The name of the remote database host machine.

required
database str

The logical name of the database to establish the connection.

required
table str

The name of the table in the database.

required
port int

The port for the database server. If the value is not passed then the default port number will be used (e.g. for MySQL, 3306 will be used).

3306
engine str

The name of the database engine. If the value is not passed then the default database engine for MySQL will be used.

'mysql'
username Optional[str]

A valid database user name. If not set (default value "root"), it will try to use the value from environment variable AIRT_CLIENT_DB_USERNAME.

None
password Optional[str]

The password for the specified user. If not set (default value ""), it will try to use the value from environment variable AIRT_CLIENT_DB_PASSWORD.

None
tag Optional[str]

The tag name for the data source. If "None" (default value), then the tag latest will be assigned to the data source.

None

Returns:

Type Description
DataSource

An instance of the DataSource class. For more information on the methods that are available in the returned object, please check the documentation of the DataSource class.

Exceptions:

Type Description
ValueError

If the requred parameters are empty or None.

ValueError

If the requred parameters to the API are invalid.

ConnectionError

If the server address is invalid or not reachable.

An example function call to the DataSource.db:

data_source = DataSource.db(
    host="db.staging.airt.ai",
    database="test",
    table="events"
)
Source code in airt/client.py
@staticmethod
def db(
    *,
    host: str,
    database: str,
    table: str,
    port: int = 3306,
    engine: str = "mysql",
    username: Optional[str] = None,
    password: Optional[str] = None,
    tag: Optional[str] = None,
) -> "DataSource":
    """Create and return an object that encapsulates the data from a database.

    A static method that creates and returns an object that encapsulates the data
    from a database. In case if access to the database requires authentication,
    the username and password will be read either from the arguments or in the airt
    environment variables.

    The objects created by calling this method won't establish the connection yet.

    Args:
        host: The name of the remote database host machine.
        database: The logical name of the database to establish the connection.
        table: The name of the table in the database.
        port: The port for the database server. If the value is not passed then the
            default port number will be used (e.g. for MySQL, 3306 will be used).
        engine: The name of the database engine. If the value is not passed
            then the default database engine for MySQL will be used.
        username: A valid database user name. If not set (default value "root"),
            it will try to use the value from environment variable AIRT_CLIENT_DB_USERNAME.
        password: The password for the specified user. If not set (default value ""),
            it will try to use the value from environment variable AIRT_CLIENT_DB_PASSWORD.
        tag: The tag name for the data source. If "None" (default value), then the tag **latest**
            will be assigned to the data source.

    Returns:
        An instance of the `DataSource` class. For more information on the methods that
        are available in the returned object, please check the documentation of the
        `DataSource` class.

    Raises:
        ValueError: If the requred parameters are empty or None.
        ValueError: If the requred parameters to the API are invalid.
        ConnectionError: If the server address is invalid or not reachable.

    An example function call to the DataSource.db:

    ```python
    data_source = DataSource.db(
        host="db.staging.airt.ai",
        database="test",
        table="events"
    )
    ```
    """
    username = (
        username
        if username is not None
        else os.environ.get("AIRT_CLIENT_DB_USERNAME", "root")
    )

    password = (
        password
        if password is not None
        else os.environ.get("AIRT_CLIENT_DB_PASSWORD", "")
    )

    _body = dict(
        host=host,
        port=port,
        username=username,
        password=password,
        database=database,
        table=table,
        database_server=engine,
        tag=tag,
    )

    response = Client.post_data(relative_url=f"/data/db", data=_body)

    return DataSource(data_id=response["id"], data_type=response["type"])

delete(data_id) staticmethod

Delete a datasource from airt service

Parameters:

Name Type Description Default
data_id int

The id of the data in the airt service.

required

Returns:

Type Description
DataFrame

A pandas DataFrame encapsulating the details of the deleted data id

Exceptions:

Type Description
ConnectionError

If the server address is invalid or not reachable.

An example to delete a data source (id=1) from airt service:

DataSource.delete(data_id=1)
Source code in airt/client.py
@staticmethod
def delete(data_id: int) -> pd.DataFrame:
    """Delete a datasource from airt service

    Args:
        data_id: The id of the data in the airt service.

    Returns:
        A pandas DataFrame encapsulating the details of the deleted data id

    Raises:
        ConnectionError: If the server address is invalid or not reachable.

    An example to delete a data source (id=1) from airt service:

    ```python
    DataSource.delete(data_id=1)
    ```
    """

    response = Client.delete_data(relative_url=f"/data/{data_id}")

    response["tags"] = get_tag_str(response["tags"])

    columns = DataSource._get_columns()

    datasource_df = pd.DataFrame(response, index=[0])[columns]

    return add_ready_column(datasource_df)

details(data_id) staticmethod

Return details of a data source

Parameters:

Name Type Description Default
data_id int

The id of the data in the airt service.

required

Returns:

Type Description
DataFrame

A pandas dataframe with the details of the data source.

Exceptions:

Type Description
ConnectionError

If the server address is invalid or not reachable.

An example to get details of a data source (id=1) from airt service:

DataSource.details(data_id=1)
Source code in airt/client.py
@staticmethod
def details(data_id: int) -> pd.DataFrame:
    """Return details of a data source

    Args:
        data_id: The id of the data in the airt service.

    Returns:
        A pandas dataframe with the details of the data source.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.

    An example to get details of a data source (id=1) from airt service:

    ```python
    DataSource.details(data_id=1)
    ```
    """

    details = Client.get_data(relative_url=f"/data/{data_id}")

    details["tags"] = get_tag_str(details["tags"])

    additional_cols = ["user_id", "error", "disabled"]

    columns = DataSource._get_columns() + additional_cols

    details_df = pd.DataFrame(details, index=[0])[columns]

    return add_ready_column(details_df)

head(self)

A function to display the first few records of the data source.

After successfully pulling the data into the server, the head function can be used to display the first few records of the downloaded data.

Returns:

Type Description
DataFrame

A pandas dataframe that displays the first few records of the connected data source.

Exceptions:

Type Description
ValueError

If the input parameters to the API are invalid.

ConnectionError

If the server address is invalid or not reachable.

An example to show the first few records of the data source:

Client.get_token()
data_source_s3 = DataSource.s3(
    uri="s3://test-airt-service/ecommerce_behavior"
)
data_source_s3.pull().progress_bar()

data_source_s3.head()
Source code in airt/client.py
@patch
def head(self: DataSource) -> pd.DataFrame:
    """A function to display the first few records of the data source.

    After successfully pulling the data into the server, the head function can be used
    to display the first few records of the downloaded data.

    Returns:
        A pandas dataframe that displays the first few records of the connected data source.

    Raises:
        ValueError: If the input parameters to the API are invalid.
        ConnectionError: If the server address is invalid or not reachable.

    An example to show the first few records of the data source:

    ```python

    Client.get_token()
    data_source_s3 = DataSource.s3(
        uri="s3://test-airt-service/ecommerce_behavior"
    )
    data_source_s3.pull().progress_bar()

    data_source_s3.head()
    ```
    """
    response = Client.get_data(relative_url=f"/data/{int(self.id)}/head")
    return pd.DataFrame(response)

ls(offset=0, limit=100, disabled=False, completed=False) staticmethod

Display the list of available datasources.

Parameters:

Name Type Description Default
offset int

The number of rows to offset at the beginning of the datasource list from the server.If None, then the default value 0 will be used.

0
limit int

The maximum number of rows to return from the server. If None, then the default value 100 will be used.

100
disabled bool

If set to True, then only the deleted datasources will be displayed. Else, the default value False will be used to display only the list of active datasources.

False
completed bool

If set to True, then only the datasources that are successfully downloaded to airt server will be displayed. Else, the default value False will be used to display all the datasources including the ones that are created but not yet pulled into the airt server.

False

Returns:

Type Description
DataFrame

A pandas dataframe with the list of available datasources.

Exceptions:

Type Description
ConnectionError

If the server address is invalid or not reachable.

An example to list the available datasources:

DataSource.ls()
Source code in airt/client.py
@staticmethod
def ls(
    offset: int = 0,
    limit: int = 100,
    disabled: bool = False,
    completed: bool = False,
) -> pd.DataFrame:
    """Display the list of available datasources.

    Args:
        offset: The number of rows to offset at the beginning of the datasource
            list from the server.If **None**, then the default value **0** will be used.
        limit: The maximum number of rows to return from the server. If **None**,
            then the default value **100** will be used.
        disabled: If set to **True**, then only the deleted datasources will be displayed.
            Else, the default value **False** will be used to display only the list
            of active datasources.
        completed: If set to **True**, then only the datasources that are successfully downloaded
            to airt server will be displayed. Else, the default value **False** will be used to
            display all the datasources including the ones that are created but not yet pulled
            into the airt server.

    Returns:
        A pandas dataframe with the list of available datasources.

    Raises:
        ConnectionError: If the server address is invalid or not reachable.

    An example to list the available datasources:

    ```python
    DataSource.ls()
    ```
    """
    lists = Client.get_data(
        relative_url=f"/data/?disabled={disabled}&completed={completed}&offset={offset}&limit={limit}"
    )

    for _list in lists:
        _list["tags"] = get_tag_str(_list["tags"])

    columns = DataSource._get_columns()

    lists_df = generate_df(lists, columns)

    return add_ready_column(lists_df)

pull(self)

A function to establish the connection with the data source.

The pull method establishes the connection with the specified DataSource and pulls the data into the server for further processing. The call to this method is asynchronous and the progress of the connection can be checked using the progress bar or the status flag. Please refer to DataSource class documentation for more information.

Returns:

Type Description
ProgressStatus

An instance of ProgressStatus class. ProgressStatus is a base class for querying status of a remote operation. For more information please refer to ProgressStatus class documentation.

Exceptions:

Type Description
ValueError

If the input parameters to the API are invalid.

ConnectionError

If the server address is invalid or not reachable.

Below example shows establishing a connection with the s3 bucket:

Client.get_token()
data_source_s3 = DataSource.s3(
    uri="s3://test-airt-service/ecommerce_behavior"
)

data_source_s3.pull().progress_bar()
Source code in airt/client.py
@patch
def pull(self: DataSource) -> ProgressStatus:
    """A function to establish the connection with the data source.

    The pull method establishes the connection with the specified `DataSource` and pulls the
    data into the server for further processing. The call to this method is asynchronous and the progress of the connection can be checked
    using the progress bar or the status flag. Please refer to `DataSource` class documentation for more information.

    Returns:
        An instance of `ProgressStatus` class. `ProgressStatus` is a base class for querying status of a remote operation. For more information
        please refer to `ProgressStatus` class documentation.

    Raises:
        ValueError: If the input parameters to the API are invalid.
        ConnectionError: If the server address is invalid or not reachable.

    Below example shows establishing a connection with the s3 bucket:

    ```python

    Client.get_token()
    data_source_s3 = DataSource.s3(
        uri="s3://test-airt-service/ecommerce_behavior"
    )

    data_source_s3.pull().progress_bar()
    ```
    """
    if self.data_type == "csv":
        Client.post_data(
            relative_url=f"/data/{int(self.id)}/csv/pull", data=self.csv_pull_params
        )
    else:
        Client.get_data(relative_url=f"/data/{int(self.id)}/pull")

    return ProgressStatus(relative_url=f"/data/{int(self.id)}")

push(self, predictions)

A function to push the prediction results into the target data source.

For more information on the supported datasources, please refer to the documentation on DataSource class.

Parameters:

Name Type Description Default
predictions Prediction

An instance of the Prediction class.

required

Exceptions:

Type Description
ValueError

If the input parameters to the API are invalid.

ConnectionError

If the server address is invalid or not reachable.

The below example illustrates pushing the prediction results to a database:

from datetime import timedelta

Client.get_token()

data_source_s3 = DataSource.s3(
    uri="s3://test-airt-service/ecommerce_behavior"
)
data_source_s3.pull().progress_bar()

model = data_source_s3.train(
    client_column="user_id",
    target_column="event_type",
    target="*purchase",
    predict_after=timedelta(hours=3),
)

predictions = model.predict()

data_source_pred = DataSource.s3(
    uri="s3://target-bucket"
)

progress = data_source_pred.push(predictions)
progress.progress_bar()
Source code in airt/client.py
@patch
def push(self: DataSource, predictions: Prediction):
    """A function to push the prediction results into the target data source.

    For more information on the supported datasources, please refer to the documentation on `DataSource` class.

    Args:
        predictions: An instance of the `Prediction` class.

    Raises:
        ValueError: If the input parameters to the API are invalid.
        ConnectionError: If the server address is invalid or not reachable.

    The below example illustrates pushing the prediction results to a database:

    ```python
    from datetime import timedelta

    Client.get_token()

    data_source_s3 = DataSource.s3(
        uri="s3://test-airt-service/ecommerce_behavior"
    )
    data_source_s3.pull().progress_bar()

    model = data_source_s3.train(
        client_column="user_id",
        target_column="event_type",
        target="*purchase",
        predict_after=timedelta(hours=3),
    )

    predictions = model.predict()

    data_source_pred = DataSource.s3(
        uri="s3://target-bucket"
    )

    progress = data_source_pred.push(predictions)
    progress.progress_bar()
    ```
    """
    return predictions.push(data_source=self)

s3(*, uri, access_key=None, secret_key=None, tag=None) staticmethod

Create and return an object that encapsulates the data from a AWS S3 bucket.

Parameters:

Name Type Description Default
uri str

The AWS S3 bucket location of the Parquet files as a string.

required
access_key Optional[str]

The access key for the S3 bucket. If None (default value), then the value of environment variable AWS_ACCESS_KEY_ID is used.

None
secret_key Optional[str]

The secret key for the S3 bucket. If None (default value), then the value of environment variable AWS_SECRET_ACCESS_KEY is used.

None
tag Optional[str]

The tag name for the data source. If None (default value), then the tag latest will be assigned to the data source.

None

Returns:

Type Description
DataSource

An instance of the DataSource class. For more information on the methods that are available in the returned object, please check the documentation of the DataSource class.

Exceptions:

Type Description
ValueError

If the parameters client and URI are empty or None.

ValueError

If the input parameters to the API are invalid.

ConnectionError

If the server address is invalid or not reachable.

An example function call to the DataSource.s3:

    data_source_s3 = DataSource.s3(
        uri="s3://bucket/events.parquet"
    )
Source code in airt/client.py
@staticmethod
def s3(
    *,
    uri: str,
    access_key: Optional[str] = None,
    secret_key: Optional[str] = None,
    tag: Optional[str] = None,
) -> "DataSource":
    """Create and return an object that encapsulates the data from a AWS S3 bucket.

    Args:
        uri: The AWS S3 bucket location of the Parquet files as a string.
        access_key: The access key for the S3 bucket. If **None** (default value), then the value
            of environment variable AWS_ACCESS_KEY_ID is used.
        secret_key: The secret key for the S3 bucket. If **None** (default value), then the value
            of environment variable AWS_SECRET_ACCESS_KEY is used.
        tag: The tag name for the data source. If **None** (default value), then the tag **latest**
            will be assigned to the data source.

    Returns:
        An instance of the `DataSource` class. For more information on the methods that are available in
        the returned object, please check the documentation of the `DataSource` class.

    Raises:
        ValueError: If the parameters client and URI are empty or None.
        ValueError: If the input parameters to the API are invalid.
        ConnectionError: If the server address is invalid or not reachable.

    An example function call to the DataSource.s3:

    ```python
        data_source_s3 = DataSource.s3(
            uri="s3://bucket/events.parquet"
        )
    ```
    """
    access_key = (
        access_key if access_key is not None else os.environ["AWS_ACCESS_KEY_ID"]
    )
    secret_key = (
        secret_key
        if secret_key is not None
        else os.environ["AWS_SECRET_ACCESS_KEY"]
    )

    response = Client.post_data(
        relative_url="/data/s3",
        data=dict(uri=uri, access_key=access_key, secret_key=secret_key, tag=tag),
    )

    return DataSource(data_id=response["id"], data_type=response["type"])

tag(data_id, name) staticmethod

Tag an existing datasource in airt service.

Parameters:

Name Type Description Default
data_id int

The id of the datasource in airt service for tagging.

required
name str

The tag name for the datasource.

required

Returns:

Type Description
DataFrame

A pandas dataframe with the details of the newly tagged datasource.

Exceptions:

Type Description
ValueError

Incase of invalid data_id.

ConnectionError

If the server address is invalid or not reachable.

An example to tag an existing datasource details of a data source (id=1) from airt service:

DataSource.tag(data_id=1, name="v1.0")
Source code in airt/client.py
@staticmethod
def tag(data_id: int, name: str) -> pd.DataFrame:
    """Tag an existing datasource in airt service.

    Args:
        data_id: The id of the datasource in airt service for tagging.
        name: The tag name for the datasource.

    Returns:
        A pandas dataframe with the details of the newly tagged datasource.

    Raises:
        ValueError: Incase of invalid data_id.
        ConnectionError: If the server address is invalid or not reachable.

    An example to tag an existing datasource details of a data source (id=1) from airt service:

    ```python
    DataSource.tag(data_id=1, name="v1.0")
    ```
    """
    response = Client.post_data(
        relative_url=f"/data/{data_id}/tag", data=dict(name=name)
    )

    response["tags"] = get_tag_str(response["tags"])

    columns = DataSource._get_columns()

    tag_df = pd.DataFrame(response, index=[0])[columns]

    return add_ready_column(tag_df)

train(self, *, client_column, timestamp_column=None, target_column, target, predict_after)

A method to train the ML model on the connected DataSource.

This method trains the model for predicting which clients are most likely to have a specified event in the future. The call to this method is asynchronous and the progress of the connection can be checked using the progress bar method or the status flag attribute available in the DataSource class. For more information on the model, please check the documentation of Model class.

Parameters:

Name Type Description Default
client_column str

The name of the column that uniquely identifies the users/clients as string.

required
timestamp_column Optional[str]

Name of the timestamp_column specifying the time of an occurred event as a string. If the value is not passed then the None will be

None
target_column str

Name of the target column that captures the type of event as string. This will be used for training the model as well as for making predictions for our target event.

required
target str

Name of the target event for which the model needs to be trained to make predictions. You can pass regular expressions as well to this parameter for making predictions for more than one event. For example, the passing "*checkout will train a model to predict which users will do any kind of a checkout event.

required
predict_after timedelta

Time delta in hours of the expected target event mentioned as timedelta.

required

Returns:

Type Description
Model

An instance of the Model class.

Exceptions:

Type Description
ValueError

If any of the required parameters are empty or None.

ValueError

If the input parameters to the API are invalid.

ConnectionError

If the server address is invalid or not reachable.

Below is an example for training a model to predict which users will perform a purchase event ("*purchase") 3 hours before they acctually do it:

from datetime import timedelta

model = data_source_s3.train(
    client_column="user_id",
    target_column="event_type",
    target="*purchase",
    predict_after=timedelta(hours=3)
)

model.progress_bar()
Source code in airt/client.py
@patch
def train(
    self: DataSource,
    *,
    client_column: str,
    timestamp_column: Optional[str] = None,
    target_column: str,
    target: str,
    predict_after: timedelta,
) -> Model:
    """A method to train the ML model on the connected `DataSource`.

    This method trains the model for predicting which clients are most likely to have a specified
    event in the future. The call to this method is asynchronous and the progress of the connection
    can be checked using the progress bar method or the status flag attribute available in the `DataSource` class.
    For more information on the model, please check the documentation of `Model` class.

    Args:
        client_column: The name of the column that uniquely identifies the users/clients as string.
        timestamp_column: Name of the timestamp_column specifying the time of an
            occurred event as a string. If the value is not passed then the None will be
        target_column: Name of the target column that captures the type of event as string. This will
            be used for training the model as well as for making predictions for our target event.
        target: Name of the target event for which the model needs to be trained to make predictions.
            You can pass regular expressions as well to this parameter for making predictions for more than one event.
            For example, the passing "*checkout will train a model to predict which users will do any kind of a
            checkout event.
        predict_after: Time delta in hours of the expected target event mentioned as timedelta.

    Returns:
        An instance of the `Model` class.

    Raises:
        ValueError: If any of the required parameters are empty or None.
        ValueError: If the input parameters to the API are invalid.
        ConnectionError: If the server address is invalid or not reachable.

    Below is an example for training a model to predict which users will perform a purchase event ("*purchase") 3 hours before they acctually do it:

    ```python
    from datetime import timedelta

    model = data_source_s3.train(
        client_column="user_id",
        target_column="event_type",
        target="*purchase",
        predict_after=timedelta(hours=3)
    )

    model.progress_bar()
    ```
    """
    response = Client.post_data(
        relative_url=f"/model/train",
        data=dict(
            data_id=int(self.id),
            client_column=client_column,
            target_column=target_column,
            target=target,
            predict_after=int(predict_after.total_seconds()),
        ),
    )

    return Model(model_id=response["id"])
Back to top