Skip to content

๐Ÿ“Š TensorFlow 2.X implementation of Conditional Tabular Generative Adversarial Network.

License

Notifications You must be signed in to change notification settings

ljk423/ctgan-tf

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

62 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

TensorFlow CTGAN

TensorFlow 2.1 implementation of Conditional Tabular GAN.

PyPI Shield Build Status Coverage Status

Tensorflow 2.1 implementation of a Conditional Tabular Generative Adversarial Network. CTGAN is a GAN-based data synthesizer that can "generate synthetic tabular data with high fidelity".

This model was originally designed by the Data to AI Lab at MIT team, and it was published in their NeurIPS paper Modeling Tabular data using Conditional GAN.

For more information regarding this work, and to access the original PyTorch implementation provided by the authors, please refer to their GitHub repository and their documentation:

Install

Requirements

As of this moment, CTGAN has been solely tested tested on Python 3.7, and TensorFlow 2.2.

  • tensorflow (<2.3,>=2.1.0)
  • tensorflow-probability (<0.11.0,>=0.9.0)
  • scikit-learn (<0.23,>=0.21)
  • numpy (<2,>=1.17.4)
  • pandas (<1.0.2,>=1.0)
  • tqdm (<4.44,>=4.43)

Install

You can either install ctgan-tf through the PyPI package:

pip3 install ctgan-tf

Or by cloning this repository and copying the ctgan folder to your project folder, or simply run:

make install

Data Format

CTGAN expects the input data to be a table given as either a numpy.ndarray or a pandas.DataFrame object with two types of columns:

  • Continuous Columns: Columns that contain numerical values and which can take any value.
  • Discrete columns: Columns that only contain a finite number of possible values, whether these are string values or not.

Quickstart

Before being able to use CTGAN you will need to prepare your data as specified above.

For this example, we will be loading some data using the ctgan.load_demo function.

from ctgan.utils import load_demo

data, discrete_columns = load_demo()

Even though the provided example already contains a list of discrete values, aside from the data itself, you will need to create a list with the names of the discrete variables:

discrete_columns = [
    'workclass',
    'education',
    'marital-status',
    'occupation',
    'relationship',
    'race',
    'sex',
    'native-country',
    'income'
]

Once you have the data ready, you need to import and create an instance of the CTGANSynthesizer class and fit it passing your data and the list of discrete columns.

from ctgan.synthesizer import CTGANSynthesizer

ctgan = CTGANSynthesizer()
ctgan.train(data, discrete_columns)

Once the process has finished, all you need to do is call the sample method of your CTGANSynthesizer instance indicating the number of rows that you want to generate.

samples = ctgan.sample(1000)

The output will be a table with the exact same format as the input and filled with the synthetic data generated by the model.

For a more in-depth guide and API specification, check our documentation here.

About

๐Ÿ“Š TensorFlow 2.X implementation of Conditional Tabular Generative Adversarial Network.

Resources

License

Stars

Watchers

Forks

Packages

No packages published