Source code for brainstorm.training.trainer

#!/usr/bin/env python
# coding=utf-8
from __future__ import division, print_function, unicode_literals

import sys
import traceback
from collections import OrderedDict

from brainstorm.describable import Describable
from brainstorm.scorers import (aggregate_losses_and_scores,
                                gather_losses_and_scores)
from brainstorm.training.utils import run_network


[docs]class Trainer(Describable): """ Trainer objects organize the process of training a network. They can employ different training methods (``Steppers``) and call ``Hooks``. """ __undescribed__ = { 'current_epoch_nr': 0, 'current_update_nr': 0, 'logs': {}, 'results': {}, 'failed_hooks': {} } __default_values__ = {'verbose': True}
[docs] def __init__(self, stepper, verbose=True): """Create a new Trainer. Args: stepper (brainstorm.training.steppers.TrainingStepper): verbose (bool): """ self.stepper = stepper self.verbose = verbose self.hooks = OrderedDict() self.train_scorers = [] self.current_epoch_nr = 0 self.current_update_nr = 0 self.logs = {} self.results = {}
[docs] def add_hook(self, hook): """Add a hook to this trainer. Hooks add a variety of functionality to the trainer and can be called after every specified number of parameter updates or epochs. See documentation for ::class::`Hook` for more details. Note: During training, hooks will be called in the same order that they were added. This should be kept in mind when using a hook which relies on another hook having been called. Args: hook (brainstorm.hooks.Hook): Any ::class::`Hook` object that should be called by this trainer. Raises: ValueError: If a hook with the same name has already been added. """ if hook.__name__ in self.hooks: raise ValueError("Hook '{}' already exists.".format(hook.__name__)) self.hooks[hook.__name__] = hook hook.priority = max([h.priority for h in self.hooks.values()]) + 1
[docs] def train(self, net, training_data_iter, **named_data_iters): """ Train a network using a data iterator and further named data iterators. """ if self.verbose: print('\n\n', 10 * '- ', "Before Training", 10 * ' -') assert set(training_data_iter.data_shapes.keys()) == set( net.buffer.Input.outputs.keys()), \ "The data names provided by the training data iterator {} do not "\ "map to the network input names {}".format( training_data_iter.data_shapes.keys(), net.buffer.Input.outputs.keys()) self.stepper.start(net) named_data_iters['training_data_iter'] = training_data_iter self._start_hooks(net, named_data_iters) if self._emit_hooks(net, 'update') or self._emit_hooks(net, 'epoch'): return should_stop = False while not should_stop: self.current_epoch_nr += 1 sys.stdout.flush() train_scores = {s.__name__: [] for s in self.train_scorers} train_scores.update({n: [] for n in net.get_loss_values()}) if self.verbose: print('\n\n', 12 * '- ', "Epoch", self.current_epoch_nr, 12 * ' -') iterator = training_data_iter(handler=net.handler) for _ in run_network(net, iterator): self.current_update_nr += 1 self.stepper.run() gather_losses_and_scores(net, self.train_scorers, train_scores) net.apply_weight_modifiers() if self._emit_hooks(net, 'update'): should_stop = True break self._add_log('rolling_training', aggregate_losses_and_scores(train_scores, net, self.train_scorers)) should_stop |= self._emit_hooks(net, 'epoch')
def evaluate(self, net, **named_data_iters): self._start_hooks(net, named_data_iters) self._emit_hooks(net, 'epoch', logs=self.results) self._emit_hooks(net, 'update', logs=self.results) return self.results def __init_from_description__(self, description): # Recover the hooks in order of priority and set their names. def get_priority(x): return getattr(x[1], 'priority', 0) ordered_mon = sorted(self.hooks.items(), key=get_priority) self.hooks = OrderedDict() for name, mon in ordered_mon: self.hooks[name] = mon mon.__name__ = name def _start_hooks(self, net, named_data_iters): """Call the ::attr::`start()` methods for all the hooks.""" for name, hook in self.hooks.items(): try: if hasattr(hook, 'start'): hook.start(net, self.stepper, self.verbose, named_data_iters) except Exception: print('An error occurred while starting the "{}" hook:' .format(name), file=sys.stderr) raise def _emit_hooks(self, net, timescale, logs=None): """Call the hooks which should be called at this timescale.""" should_stop = False count = self.current_epoch_nr if timescale == 'epoch' else \ self.current_update_nr for name, hook in self.hooks.items(): if hook.timescale != timescale or count % hook.interval != 0: continue hook_log, stop = self._call_hook(hook, net) should_stop |= stop self._add_log(name, hook_log, hook.verbose, logs=logs) return should_stop def _call_hook(self, hook, net): """Call a hook and check if raises a stopping signal.""" try: return hook(epoch_nr=self.current_epoch_nr, update_nr=self.current_update_nr, net=net, stepper=self.stepper, logs=self.logs), False except StopIteration as err: return getattr(err, 'value', None), True except Exception as e: print('An error occurred while calling the "{}" hook:' .format(hook.__name__), file=sys.stderr) print(traceback.format_exc()) raise e def _add_log(self, name, val, verbose=None, logs=None, indent=0): """Accumulate the logs (possibly a nested dictionary) recursively.""" if val is None: return verbose = self.verbose if verbose is None else verbose logs = self.logs if logs is None else logs if isinstance(val, dict): if verbose: print(" " * indent + name) logs[name] = dict() if name not in logs else logs[name] for k, v in val.items(): self._add_log(k, v, verbose, logs[name], indent + 2) else: if verbose: print(" " * indent + ("{0:%d}: {1}" % (40 - indent)) .format(name, val)) logs[name] = [] if name not in logs else logs[name] logs[name].append(val)