astra.contrib.classifier.model

Module Contents

Functions

train(network_factory, training_spectra, training_labels, validation_spectra, validation_labels, test_spectra, test_labels, class_names=None, learning_rate=0.0001, weight_decay=1e-05, n_epochs=200, batch_size=100, task=None) Train a neural network to classify stellar sources.
predict_classes(network, spectra) Predict an object class given a trained network and some spectra.
astra.contrib.classifier.model.device
astra.contrib.classifier.model.CUDA_AVAILABLE
astra.contrib.classifier.model.train(network_factory, training_spectra, training_labels, validation_spectra, validation_labels, test_spectra, test_labels, class_names=None, learning_rate=0.0001, weight_decay=1e-05, n_epochs=200, batch_size=100, task=None)

Train a neural network to classify stellar sources.

Parameters:
  • network_factory – The neural network class to use. This should be an object from astra.contrib.classifier.networks.
  • training_spectra – An array of shape (N, P) training set fluxes to use, where N is the number of sources and P is the number of pixels.
  • training_labels – An array of training set labels to use. This should be shape (N, L), where L is the number of labels.
  • validation_spectra – An array of validation set fluxes to use.
  • validation_labels – An array of validation set labels to use.
  • test_spectra – An array of test set fluxes to use.
  • test_labels – An array of test set labels to use.
  • class_names – (optional) A tuple of class names for different objects.
  • learning_rate – (optional) The learning rate to use during training (default: 1e-4).
  • weight_decay – (optional) The weight decay to use during training (default: 1e-5).
  • n_epochs – (optional) The number of epochs to use during training (default: 200).
  • batch_size – (optional) The number of sources to use per batch.
  • task – (optional) If supplied, then progress messages will be sent back via this task.
astra.contrib.classifier.model.predict_classes(network, spectra)

Predict an object class given a trained network and some spectra.

astra.contrib.classifier.model.model_path = cnn_nir.model