Train and Generate Jobs

Methods for submitting jobs to Gretel workers

from gretel_client import Gretel

gretel = Gretel(project_name="sdk-docs", api_key="prompt")

With the Gretel object instance ready to go, you can use its submit_* methods to submit model training and data generation jobs. Behind the scenes, Gretel will spin up workers with the necessary compute resources, set up the model with your desired configuration, and perform the submitted task.

Submit a model training job

The submit_train method submits a model training job based on the given model configuration. The data source for the training job is passed in using the data_source argument and may be a file path or pandas DataFrame:

trained = gretel.submit_train("tabular-actgan", data_source="data.csv")

We trained an ACTGAN model by setting base_config="tabular-actgan". You can replace this base config with the path to a custom config file, or you can select any of the config names listed here (excluding the .yml extension). The returned trained object is a dataclass that contains the training job results such as the Gretel model object, synthetic data quality report, training logs, and the final model configuration.

Dynamically modify model configs using keyword arguments

The base configuration can be modified using keyword arguments with the following rules:

  • Nested model settings can be passed as keyword arguments in the submit_train method, where the keyword is the name of the config subsection and the value is a dictionary with the desired subsection's parameter settings. For example, this is how you update settings in ACTGAN's params and privacy_filters subsections, where epochs, discriminator_dim, similarity, and outliers are nested settings:

trained = gretel.submit_train(
    base_config="tabular-actgan",
    data_source="tabular-dataset.csv",
    params={"epochs": 800, "discriminator_dim": [1024, 1024, 1024]},
    privacy_filters={"similarity": "high", "outliers": None},
)
  • Non-nested model settings can be passed directly as keyword arguments in the submit_train method. For example, this is how you update Gretel GPT's pretrained_model and column_name, which are not nested within a subsection:

trained = gretel.submit_train(
    base_config="natural-language",
    data_source="text-dataset.csv",
    pretrained_model="gretelai/gpt-auto",
    column_name="name_of_text_column",
    params={"batch_size": 16, "steps": 500},
    generate={"num_records": 100, "temperature": 0.8}
)

Submit a synthetic data generation job

Once you have models in your Gretel Project, you can use any of them to generate synthetic data using the submit_generate method:

generated = gretel.submit_generate(trained.model_id, num_records=100)

Above we use the model_id attribute of a completed training job, but you are free to use the model_id of any model within the current project. If the model has additional generate settings (e.g., temperature when generating text), you can pass them as keyword arguments to the submit_generate method. The returned generated object is a dataclass that contains results from the generation job, including the generated synthetic data.

A model's model_id can be extracted from the model's URL in the Console: {base_console_url}/{project_id}/models/{model_id}.

Conditional data generation

In the previous example, we unconditionally generated num_records records. To conditionally generate synthetic data, use the seed_data argument:

import pandas as pd

generated = gretel.submit_generate(
    model_id=trained.model_id,
    seed_data=pd.DataFrame({"field": ["seed"] * 50})
)

The above code will conditionally generate 50 examples where the given field's class is "seed".

⏳ To wait or not to wait?

If you do not want to wait for a job to complete, you can set wait=False when calling submit_train or submit_generate. In this case, the method will return immediately after the job starts:

trained = gretel.submit_train(
    "tabular-actgan",
    data_source="data.csv",
    wait=False
)

Some things to know if you use this option:

  • You can still monitor the job progress in the Gretel Console.

  • You can check the job status using the job_status attribute of the returned object: print(trained.job_status).

  • You can continue waiting for the job to complete by calling the wait_for_completion method of the returned object: trained.wait_for_completion().

  • If you are not waiting when the job completes, you must call the refresh method of the returned object to fetch the job results: trained.refresh().

Submit a transforms job

Our transforms product allows you to remove PII from data, and you can submit these transform jobs from the high level SDK. The default behavior is to use a model to classify the data and fake entities from that.

transform = gretel.submit_transform(
    config="transform/default",
    data_source="data.csv",
)

Fetching previous job results

You can fetch results from previous training and generation jobs using the fetch_*_job_results methods:

trained = gretel.fetch_train_job_results(model_id)

generated = gretel.fetch_generate_job_results(model_id, record_id)

The record_id of a generation job can be found in the Records & Downloads section of the associated model's Console page. In the Records table, the records are listed as {model name}@{record_id}.

For fetching transform results, you can do the following, and also access the transformed object as a DataFrame

transform_result = gretel.fetch_transform_results(model_id)

print(transform_result.transformed_df)

Submit an Evaluate job

Evaluate job analyzes the quality of synthetic data and generates the Data Quality Report.

The submit_evaluate method submits an evaluate model training job based on the given evaluate model configuration. The data source for the job is passed in using the data_source argument, the original data source is passed with ref_data, and these data sources may be file path or pandas DataFrame:

evaluate_result = gretel.submit_evaluate("evaluate/default", data_source="data.csv", ref_data="train.csv")

The test (holdout) data source for MIA is passed with an optional test_data argument, it may be a file path or pandas DataFrame:

evaluate_result = gretel.submit_evaluate("evaluate/default", data_source="data.csv", ref_data="train.csv", test_data="test.csv")

Last updated