Train and Generate Jobs

Methods for submitting jobs to Gretel workers

from gretel_client import Gretel

gretel = Gretel(project_name="sdk-docs", api_key="prompt")

With the Gretel object instance ready to go, you can use its submit_* methods to submit model training and data generation jobs. Behind the scenes, Gretel will spin up workers with the necessary compute resources, set up the model with your desired configuration, and perform the submitted task.

Submit a model training job

The submit_train method submits a model training job based on the given model configuration. The data source for the training job is passed in using the data_source argument and may be a file path or pandas DataFrame:

trained = gretel.submit_train("tabular-actgan", data_source="data.csv")

We trained an ACTGAN model by setting base_config="tabular-actgan". You can replace this base config with the path to a custom config file, or you can select any of the config names listed here (excluding the .yml extension). The returned trained object is a dataclass that contains the training job results such as the Gretel model object, synthetic data quality report, training logs, and the final model configuration.

Dynamically modify model configs using keyword arguments

The base configuration can be modified using keyword arguments with the following rules:

  • Nested model settings can be passed as keyword arguments in the submit_train method, where the keyword is the name of the config subsection and the value is a dictionary with the desired subsection's parameter settings. For example, this is how you update settings in ACTGAN's params and privacy_filters subsections, where epochs, discriminator_dim, similarity, and outliers are nested settings:

trained = gretel.submit_train(
    base_config="tabular-actgan",
    data_source="tabular-dataset.csv",
    params={"epochs": 800, "discriminator_dim": [1024, 1024, 1024]},
    privacy_filters={"similarity": "high", "outliers": None},
)
  • Non-nested model settings can be passed directly as keyword arguments in the submit_train method. For example, this is how you update Gretel GPT's pretrained_model and column_name, which are not nested within a subsection:

trained = gretel.submit_train(
    base_config="natural-language",
    data_source="text-dataset.csv",
    pretrained_model="gretelai/mpt-7b",
    column_name="name_of_text_column",
    params={"batch_size": 16, "steps": 500},
    generate={"num_records": 100, "temperature": 0.8}
)

Submit a synthetic data generation job

Once you have models in your Gretel Project, you can use any of them to generate synthetic data using the submit_generate method:

generated = gretel.submit_generate(trained.model_id, num_records=100)

Above we use the model_id attribute of a completed training job, but you are free to use the model_id of any model within the current project. If the model has additional generate settings (e.g., temperature when generating text), you can pass them as keyword arguments to the submit_generate method. The returned generated object is a dataclass that contains results from the generation job, including the generated synthetic data.

A model's model_id can be extracted from the model's URL in the Console: {base_console_url}/{project_id}/models/{model_id}.

Conditional data generation

In the previous example, we unconditionally generated num_records records. To conditionally generate synthetic data, use the seed_data argument:

import pandas as pd

generated = gretel.submit_generate(
    model_id=trained.model_id, 
    seed_data=pd.DataFrame({"field": ["seed"] * 50})
)

The above code will conditionally generate 50 examples where the given field's class is "seed".

⏳ To wait or not to wait?

If you do not want to wait for a job to complete, you can set wait=False when calling submit_train or submit_generate. In this case, the method will return immediately after the job starts:

trained = gretel.submit_train(
    "tabular-actgan", 
    data_source="data.csv",
    wait=False
)

Some things to know if you use this option:

  • You can still monitor the job progress in the Gretel Console.

  • You can check the job status using the job_status attribute of the returned object: print(trained.job_status).

  • You can continue waiting for the job to complete by calling the wait_for_completion method of the returned object: trained.wait_for_completion().

  • If you are not waiting when the job completes, you must call the refresh method of the returned object to fetch the job results: trained.refresh().

Fetching previous job results

You can fetch results from previous training and generation jobs using the fetch_*_job_results methods:

trained = gretel.fetch_train_job_results(model_id)

generated = gretel.fetch_generate_job_results(model_id, record_id)

The record_id of a generation job can be found in the Records & Downloads section of the associated model's Console page. In the Records table, the records are listed as {model name}@{record_id}.

Last updated