Skip to content
forked from pytorch/ignite

High-level library to help with training neural networks in PyTorch

License

Notifications You must be signed in to change notification settings

cajanond/ignite

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

35 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Ignite

https://travis-ci.org/pytorch/ignite.svg?branch=master

Ignite is a high-level library to help with training neural networks in PyTorch.

Documentation

API documentation, examples and tutorials coming soon.

Installation

From Source:

python setup.py install

Getting Started

The Trainer

The main component of Ignite is the Trainer, an abstraction over your training loop. Getting started with the trainer is easy, the constructor only requires two things:

  • training_data: A collection of training batches allowing repeated iteration (e.g., list or DataLoader)
  • training_update_function: A function which is passed a batch and passes data through and updates your model

Optionally, you can also provide validation_data and validation_update_function for evaluating on your validation set.

Given a model, criterion and optimizer your training_update_function will be something like:

optimzer = ...
model = ...
criterion = ...
def training_update_function(batch):
    model.train()
    optimizer.zero_grad()
    x, y = Variable(batch[0]), Variable(batch[1])
    prediction = model(x)
    loss = criterion(prediction, y)
    loss.backward()
    optimizer.step()
    return loss.data[0]

You can then construct your Trainer and train for num_epochs as follows:

from ignite.trainer import Trainer

trainer = Trainer(train_dataloader, training_update_function)
trainer.run(max_epochs=5)

Training & Validation History

The return values of your training and validation update functions are stored in the Trainer in the members training_history and validation_history. These can be accessed via event handlers (see below) and used for updating metrics, logging etc. Importantly, the return type of your update functions need not just be the loss, but can be any type (list, typle, dict, tensors etc.).

Events & Event Handlers

The Trainer emits events during the training loop, which the user can attach event handlers to. The events that are emitted are defined in ignite.trainer.TrainingEvents, which at present are:

  • EPOCH_STARTED
  • EPOCH_COMPLETED
  • TRAINING_EPOCH_STARTED
  • TRAINING_EPOCH_COMPLETED
  • VALIDATION_STARTING
  • VALIDATION_COMPLETED
  • TRAINING_STARTED
  • TRAINING_COMPLETED
  • TRAINING_ITERATION_STARTED
  • TRAINING_ITERATION_COMPLETED
  • VALIDATION_ITERATION_STARTED
  • VALIDATION_ITERATION_COMPLETED
  • EXCEPTION_RAISED

Users can attach multiple handlers to each of these events, which allows them to control aspects of training such as early stopping, or reducing the learning rate as well as things such as logging or updating external dashboards like Visdom or TensorBoard.

Event handlers are any callable where the first argument is an instance of the Trainer. Users can also pass any orther arguments or keywword arguments to their event handlers. For example, if we want to terminate training after 100 iterations if the learning rate hasn't decreased in the last 10 iterations, we could define the following event handler and attach it to the TRAINING_ITERATION_COMPLETED event.

from ignite.trainer import TrainingEvents

def early_stopping_handler(trainer, min_iterations, lookback=1):
    if trainer.current_iterations >= min_iterations:
        last_loss = trainer.training_history[-1]
        if not any(x < last_loss for x in trainer.training_history[-lookback:-1]):
            trainer.terminate()

min_iterations = 100
trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED,
                          early_stopping_handler,
                          min_iterations,
                          lookback=5)

Examples

Coming soon

Logging

Ignite uses python's standard library logging module, which means you can integrate the Ignite logs directly into your application logs. To do this, simply attach a log handler to the ignite logger:

import logging
logger = logging.getLogger('ignite')
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.INFO)

How does this compare to Torchnet?

Ignite, in spirit is very similar to torchnet (and was inspired by torchnet).

The main differences with torchnet is the level of abstraction for the user. Ignite's higher level of abstraction assumes less about the type of network (or networks) that you are training, and we require the user to define the closure to be run in the training and validation loop. In contrast to this, torchnet creates this closure internally based on the network and optimizer you pass to it. This higher level of abstraction allows for a great deal more of flexibility, such as co-training multiple models (i.e. GANs) and computing/tracking multiple losses and metrics in your training loop.

Ignite also allows for multiple handlers to be attached to events, and a finer granularity of events in the loop.

That being said, there are some things from torchnet we really like and would like to port over, such as the integration with Visdom (and possibly add integration with TensorBoard).

As always, PRs are welcome :)

Contributing

We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion. If you plan to contribute new features, utility functions or extensions, please first open an issue and discuss the feature with us.

About

High-level library to help with training neural networks in PyTorch

Resources

License

Code of conduct

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 99.9%
  • Shell 0.1%