Gretel LSTM
Deep learning model that supports tabular, time-series, and natural language text data.
The Gretel LSTM model API is a generative data model that works with any language or character set, and is open-sourced as part of the gretel-synthetics
library.
Model creation
The configuration below contains additional options for training a Gretel LSTM model, with the default options displayed.
These params are the same params available in Gretel’s Open Source synthetic package. When one of these parameters is null, the default value from the Open Source package will be used. This helps ensure a similar experience when switching between open source and Gretel Cloud.
data_source
(str, required) -__tmp__
or point to a valid and accessible file in CSV, JSON, or JSONL format.batch_size
(int, optional, defaults to64
) - Number of samples per gradient update. Using larger batch sizes can help make more efficient use of CPU/GPU parallelization, at the cost of memory.vocab_size
(int, optional, defaults to20000
) - The maximum vocabulary size for the tokenizer created by the unsupervised SentencePiece model. Set to0
to use character-based tokenization.reset_states
(bool, optional, defaults toFalse
) - Reset RNN model states between each generation run. This guarantees more consistent dataset creation over time, at the expense of model accuracy.learning_rate
(float, optional, defaults to0.01
) - The higher the learning rate, the more that each update during training matters.rnn_units
(int, optional, defaults to256
) - Positive integer, dimensionality of the output space for LSTM layers.dropout_rate
(float, optional, defaults to0.2
) - Float between 0 and 1. Fraction of the units to drop for the linear transformation of the inputs. Using a dropout can help to prevent overfitting by ignoring randomly selected neurons during training. 0.2 (20%) is often used as a good compromise between retaining model accuracy and preventing overfitting.field_cluster_size
(int, optional, defaults to20
) - The maximum number of fields (columns) to train per model batch.early_stopping
(bool, optional, defaults to True) - deduce when the model is no longer improving and terminate training.gen_temp
(int, optional, defaults to1.0
) - Controls the randomness of predictions by scaling the logits before applyingsoftmax
. Low temperatures result in more predictable text. Higher temperatures result in more surprising text. Experiment to find the best setting.predict_batch_size
(int, optional, defaults to64
) - How many words to generate in parallel. Higher values may result in increased throughput. The default of64
should provide reasonable performance for most users.validation_split
(bool, optional, defaults toFalse
) - Use a fraction of the training data as validation data. Use of a validation set is recommended as it helps prevent over-fitting and model memorization. When enabled, 20% of data will be used for validation.field_delimiter
(string, optional, defaults tonull
) - Because the LSTM trains on entire rows as text data, a unique character that does not exist in the dataset must be used as a field delimiter. By default, Gretel will attempt to automatically determine a field delimiter to use from the list below. If Gretel cannot find a field delimiter, the model training job will fail with an appropriate error. The following delimiters are attempted to be used by our model - you can specify a custom delimiter otherwise:[",", ":", "|", ";", "\t", "+", "-", "^", "@", "#", "*"]
Differential privacy
Differential privacy for the LSTM model is now deprecated. We suggest using the Tabular DP model.
Fallback
Occasionally, powerful deep learning models such as Gretel LSTM are unable to complete training or generating enough valid records within a reasonable timeframe. Instead of failing with an error, Gretel LSTM will by default fallback to generating records using a simpler and much faster statistical model based on Amplify. This behavior increases the chances of success, at the cost of a moderately reduced output quality in the event of a fallback. You can use the optional fallback
(dict) param to control this behavior.
Smart seeding is not currently compatible with the fallback model. If either of these settings is enabled, the fallback behavior will be automatically turned off.
allow
(bool, optional, defaults toTrue)
- IfTrue
, fallback to a simpler model if needed to complete Gretel LSTM training and valid record generation.max_fallback_training_time_seconds
(int, optional, defaults to300
) - Maximum duration to spend training the fallback model. Most datasets take less than 2 minutes, although a higher limit can increase the chances of successful training and record generation.max_fallback_generation_time_seconds
(int, optional, defaults to300
) - Maximum duration to allocate to generating records with the fallback model if the Gretel LSTM training is not likely to complete within the maximum runtime for model creation (controlled by user limits and themax_runtime_seconds
parameter). In the event Gretel LSTM training does not completemax_fallback_generation_time_seconds
before the maximum runtime, a fallback model (not Gretel LSTM) will be created and used for record generation.
Smart seeding
When using conditional data generation (smart seeding), you must provide the field names that you wish to use as seeds for generating records at model creation time. This is done by specifying a seed
task at model training time.
Example configuration to enable Smart seeding:
Data generation
generate.num_records
(int, optional, defaults to 5000) - The number of text outputs to generate.generate.max_invalid
(int, optional) - This is the number of records that can fail the Data Validation process before generation is stopped. This setting helps govern a long running data generation process where the model is not producing optimal data. The default value will be five times the number of records being generated.data_source
(str, optional) - Provide a series of seed columns in CSV format to use for conditional data generation, matching the columns provided in the Smart seeding model definition. This will override thenum_records
parameters, generating one record for each prompt in thedata_source
param. Must point to a valid and accessible file URL in CSV format.
If your training data is less than 5000 records, we recommend setting the generate.num_records
value to null. If this value is null, then the number of records generated will be the lesser of 5000 or the total number of records in the training data.
Automated validators
The Gretel LSTM model provides automatic data semantic validation when generating synthetic data that can be configured using the validators
tag. When a model is trained, the following validator models are built, automatically, on a per-field basis:
Character Set: For categorical and other string values, the underlying character sets are learned. For example, if the values in a field are all hexadecimal, then the generated data will only contain [0-9a-f] values. This validator is case sensitive.
String Length: For categorical and other string values, the maximum and minimum string lengths of a field’s values are learned. When generating data, the generated values will be between the minimum and maximum lengths.
Numerical Ranges: For numerical fields, the minimum and maximum values are learned. During generation, numerical values will be between these learned ranges. The following numerical types will be learned and enforced: Float, Base2, Base8, Base10, Base16.
Field Data Types: For fields that are entirely integers, strings, and floats, the generated data will be of the data type assigned to the field. If a field has mixed data types, then the field may be one of any of the data types, but the above value-based validators will still be enforced.
The validators above are composed automatically. This to ensure that individual values in the synthetic data are within the basic semantic constraints of the training data.
Configurable validators
In addition to the built-in validators, Gretel offers advanced validators that can be managed in the configuration:
in_set_count
(int, optional, defaults to10
): This validator accumulates all of the unique values in a field. If the cardinality of the field’s values is less than or equal to the setting, then the validator will enforce generated values being in the set of training values. If the cardinality of the field’s value is greater than the setting, the validator will have no affect. For example, if there is a field called US-State, and it has a cardinality of 50 and in_set_count is set to 50, during generation each value for this field must be one of the original values. If in_set_count was only set to 40, then the generated values will not be enforced.pattern_count
(int, optional, defaults to10
): This validator builds a pattern mask for each value in a field. Alphanumeric characters are masked, while retaining other special characters. For example,867-5309
will be masked toddd-dddd
, andf32-sk-39d
would mask toadd-aa-dda
where a represents any A-Za-z character. Much like the previous validator, if the cardinality of learned patterns is less than or equal to the settings, patterns will be enforced during data generation. If the unique pattern count is above the settings, enforcement will be ignored.use_numeric_iqr
(bool, optional, defaults toTrue
): IQR-based validation for all numeric fields. When enabled, it calculates the IQR for values in the field and uses that range to validate generated values. This validator is useful when the training data may contain undesirable outliers that skew the min and max values in a field. Numeric outliers in the synthetic data can impact both quality and privacy. Outlier values can be exploited by Membership Inference and other adversarial attacks.open_close_chars
(string or list of strings, optional, defaults tonull
) This validator may be used when values contain specific open and closing characters around other values. For example, if there is a field named "Age" and a value of143 (Months)
the inclusion of the(
and)
characters around the "Months" string should be enforced. This validator will check for and enforce multiple nested open/close characters. It can check for any 2-tuple combination of open/close characters as well, so a more advanced usage might be to enforce a synthetic value such asFoo [Bar(baz), Fiz(bunch)]
. If you wish to utilize this, you have two options:The value for this setting can be set to
default
which will automatically look for and enforce the following open/close pairs:""
,()
,[]
,{}
.A list of strings, where each string must be exactly a length of 2. In this mode you can define custom open/close characters. Here is an example of using custom open/close chars:
open_close_chars: ["()", "[]", "{}", "$$"]
allow_empty_values
(bool, optional, defaults toTrue
): During data generation, determine if field values are allowed to be empty. If the training data used to create the model contained some empty fields, it is likely the model will generate empty field values as well. When this option isTrue
(the default) empty values will be allowed for any field. However, if you want to enforce that there should not be any empty values in the generated data, you may set this option toFalse
. Setting this toFalse
will enforce no empty values for every field.NOTE: If a field contained empty values for every record (i.e. a field that was 100% missing or empty for the dataset), then this setting will be ignored and empty values for that field will be allowed.
Data checks
The success rate for training Gretel LSTM Models is often dependent on the quality of the input data. We run pre-flight checks on your data source to identify any warnings that might lead to model failure. These checks run automatically before training has begun and display warnings, if any. They will not interrupt the training job, but they can be configured to do so, or skip the check entirely.
Some of the issues we see most frequently are:
Complex entities: Presence of UUIDs and hash values such as MD5, SHA256, etc.
Text data: When present with other fields, text fields increase the complexity and can trigger validation errors. Examples are user agent strings, raw log data, etc. We recommend using the Gretel GPT model for text data.
Large amount of missing data: When columns have a lot of missing data, the model can end up creating too few or too many columns because it has to generate a sequence of contiguous delimiter characters.
High floating point precision: Floats with long precision (i.e. 2.34473832) are difficult to learn and should be reduced if possible to 2-4 digits.
Sparse data: When several of the columns contain binary values, this creates challenges. This is usually because of converting data to a “one hot encoded” format. Ideally, these columns should be reverted back to a dense representation.
Very large or complex datasets: Try using our high-dimensionality configuration blueprint for large datasets. The Gretel Console automatically select the configuration file, and we also provide a Trainer notebook that improves performance.
Whitespaces: Cell values with surrounding white spaces present challenges for our language model. We recommend trimming leading and trailing whitespaces.
Data checks run quickly, within seconds, before model training begins. They can help detect potential problems with the input data, thereby saving you time and making credit consumption more efficient.
If your configuration doesn’t contain a data-checks collection, it will run automatically. Here’s a sample configuration which interrupts the training job if warnings are found.
data_checks.strategy
(str, optional, defaults tolog
): strategy for running data checks before the training. Supported values:log
(default) - runs checks and logs any warnings in the user log (these can be viewed in the Gretel console, CLI or downloaded as a model artifact).skip
- checks are not run at all.interrupt
- runs checks and if there are any warnings, it stops the training. All warnings are logged in the user log. Enable this option if you're using an unfamiliar dataset and are unsure how the model needs to be configured. The data check will help you better clean and prepare your dataset, and increase chances of success.
Here’s an example of a data check which returned a series of warnings and interrupted training:
Model information
The underlying model used is a Long Short-Term Memory (LSTM) recurrent neural network. This model is initialized from random weights and trained on a dataset as an autoregressive language model, using cross-entropy loss.
Minimum requirements
If running this system in local mode (on-premises), the following instance types are recommended.
CPU: Minimum 4 cores, 32GB RAM.
GPU (Required). Minimum Nvidia T4 or similar CUDA compliant GPU with 16GB+ RAM is required to run basic language models.
Limitations and Biases
This model is trained entirely on the examples provided in the training dataset and will therefore capture and likely repeat any biases that exist in the training set. We recommend having a human review the data set used to train models before using in production.
Last updated