Java源码示例:org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer
示例1
@Test
public void testEarlyStoppingEveryNEpoch() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Sgd(0.01)).weightInit(WeightInit.XAVIER).list()
.layer(0, new OutputLayer.Builder().nIn(4).nOut(3)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(5))
.scoreCalculator(new DataSetLossCalculator(irisIter, true))
.evaluateEveryNEpochs(2).modelSaver(saver).build();
IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf, net, irisIter);
EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
System.out.println(result);
assertEquals(5, result.getTotalEpochs());
assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
}
示例2
@Test
public void testBadTuning() {
//Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Sgd(5.0)) //Intentionally huge LR
.weightInit(WeightInit.XAVIER).list()
.layer(0, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(5000))
.iterationTerminationConditions(
new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES),
new MaxScoreIterationTerminationCondition(10)) //Initial score is ~2.5
.scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver)
.build();
IEarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, irisIter);
EarlyStoppingResult result = trainer.fit();
assertTrue(result.getTotalEpochs() < 5);
assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition,
result.getTerminationReason());
String expDetails = new MaxScoreIterationTerminationCondition(10).toString();
assertEquals(expDetails, result.getTerminationDetails());
assertEquals(0, result.getBestModelEpoch());
assertNotNull(result.getBestModel());
}
示例3
@Test
public void testNoImprovementNEpochsTermination() {
//Idea: terminate training if score (test set loss) does not improve for 5 consecutive epochs
//Simulate this by setting LR = 0.0
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Sgd(0.0)).weightInit(WeightInit.XAVIER).list()
.layer(0, new OutputLayer.Builder().nIn(4).nOut(3)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(100),
new ScoreImprovementEpochTerminationCondition(5))
.iterationTerminationConditions(
new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES),
new MaxScoreIterationTerminationCondition(50)) //Initial score is ~8
.scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver)
.build();
IEarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, irisIter);
EarlyStoppingResult result = trainer.fit();
//Expect no score change due to 0 LR -> terminate after 6 total epochs
assertEquals(6, result.getTotalEpochs());
assertEquals(0, result.getBestModelEpoch());
assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
String expDetails = new ScoreImprovementEpochTerminationCondition(5).toString();
assertEquals(expDetails, result.getTerminationDetails());
}
示例4
@Test
public void testEarlyStoppingGetBestModel() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER).list()
.layer(0, new OutputLayer.Builder().nIn(4).nOut(3)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
MultipleEpochsIterator mIter = new MultipleEpochsIterator(10, irisIter);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(5))
.iterationTerminationConditions(
new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
.scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver)
.build();
IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf, net, mIter);
EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
System.out.println(result);
MultiLayerNetwork mln = result.getBestModel();
assertEquals(net.getnLayers(), mln.getnLayers());
assertEquals(net.conf().getOptimizationAlgo(), mln.conf().getOptimizationAlgo());
BaseLayer bl = (BaseLayer) net.conf().getLayer();
assertEquals(bl.getActivationFn().toString(), ((BaseLayer) mln.conf().getLayer()).getActivationFn().toString());
assertEquals(bl.getIUpdater(), ((BaseLayer) mln.conf().getLayer()).getIUpdater());
}
示例5
@Test
public void testListeners() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER).list()
.layer(0, new OutputLayer.Builder().nIn(4).nOut(3)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(5))
.iterationTerminationConditions(
new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
.scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver)
.build();
LoggingEarlyStoppingListener listener = new LoggingEarlyStoppingListener();
IEarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, irisIter, listener);
trainer.fit();
assertEquals(1, listener.onStartCallCount);
assertEquals(5, listener.onEpochCallCount);
assertEquals(1, listener.onCompletionCallCount);
}
示例6
@Test
public void testClassificationScoreFunctionSimple() throws Exception {
for(Evaluation.Metric metric : Evaluation.Metric.values()) {
log.info("Metric: " + metric);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new DenseLayer.Builder().nIn(784).nOut(32).build())
.layer(new OutputLayer.Builder().nIn(32).nOut(10).activation(Activation.SOFTMAX).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);
List<DataSet> l = new ArrayList<>();
for( int i=0; i<10; i++ ){
DataSet ds = iter.next();
l.add(ds);
}
iter = new ExistingDataSetIterator(l);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(5))
.iterationTerminationConditions(
new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
.scoreCalculator(new ClassificationScoreCalculator(metric, iter)).modelSaver(saver)
.build();
EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
assertNotNull(result.getBestModel());
}
}
示例7
@Test
public void testEarlyStoppingListeners() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER).list()
.layer(0, new OutputLayer.Builder().nIn(4).nOut(3)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
TestListener tl = new TestListener();
net.setListeners(tl);
DataSetIterator irisIter = new IrisDataSetIterator(50, 150);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(5))
.iterationTerminationConditions(
new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
.scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver)
.build();
IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf, net, irisIter);
EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
assertEquals(5, tl.countEpochStart);
assertEquals(5, tl.countEpochEnd);
assertEquals(5 * 150/50, tl.iterCount);
assertEquals(4, tl.maxEpochStart);
assertEquals(4, tl.maxEpochEnd);
}
示例8
public void train() throws IOException {
MnistDataSetIterator mnistTrain = new MnistDataSetIterator(MINI_BATCH_SIZE, true, 12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(SEED)
.learningRate(LEARNING_RATE)
.weightInit(WeightInit.XAVIER)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(Updater.NESTEROVS)
.list()
.layer(0, new ConvolutionLayer.Builder(5, 5)
.nIn(CHANNELS)
.stride(1, 1)
.nOut(20)
.activation(Activation.IDENTITY)
.build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(2, new ConvolutionLayer.Builder(5, 5)
.nIn(20)
.stride(1, 1)
.nOut(50)
.activation(Activation.IDENTITY)
.build())
.layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build())
.layer(4, new DenseLayer.Builder().activation(Activation.RELU)
.nIn(800)
.nOut(128).build())
.layer(5, new DenseLayer.Builder().activation(Activation.RELU)
.nIn(128)
.nOut(64).build())
.layer(6, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(OUTPUT)
.activation(Activation.SOFTMAX)
.build())
.setInputType(InputType.convolutionalFlat(28, 28, 1))
.backprop(true).pretrain(false).build();
EarlyStoppingConfiguration earlyStoppingConfiguration = new EarlyStoppingConfiguration.Builder()
.epochTerminationConditions(new MaxEpochsTerminationCondition(MAX_EPOCHS))
.scoreCalculator(new AccuracyCalculator(new MnistDataSetIterator(MINI_BATCH_SIZE, false, 12345)))
.evaluateEveryNEpochs(1)
.modelSaver(new LocalFileModelSaver(OUT_DIR))
.build();
EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(earlyStoppingConfiguration, conf, mnistTrain);
EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
log.info("Termination reason: " + result.getTerminationReason());
log.info("Termination details: " + result.getTerminationDetails());
log.info("Total epochs: " + result.getTotalEpochs());
log.info("Best epoch number: " + result.getBestModelEpoch());
log.info("Score at best epoch: " + result.getBestModelScore());
}
示例9
public static void main(String... args) throws java.io.IOException {
// create the data iterators for emnist
DataSetIterator emnistTrain = new EmnistDataSetIterator(emnistSet, batchSize, true);
DataSetIterator emnistTest = new EmnistDataSetIterator(emnistSet, batchSize, false);
int outputNum = EmnistDataSetIterator.numLabels(emnistSet);
// network configuration (not yet initialized)
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngSeed)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Adam())
.l2(1e-4)
.list()
.layer(new DenseLayer.Builder()
.nIn(numRows * numColumns) // Number of input datapoints.
.nOut(1000) // Number of output datapoints.
.activation(Activation.RELU) // Activation function.
.weightInit(WeightInit.XAVIER) // Weight initialization.
.build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(1000)
.nOut(outputNum)
.activation(Activation.SOFTMAX)
.weightInit(WeightInit.XAVIER)
.build())
.pretrain(false).backprop(true)
.build();
// create the MLN
MultiLayerNetwork network = new MultiLayerNetwork(conf);
network.init();
// pass a training listener that reports score every N iterations
network.addListeners(new ScoreIterationListener(reportingInterval));
// here we set up an early stopping trainer
// early stopping is useful when your trainer runs for
// a long time or you need to programmatically stop training
EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
.epochTerminationConditions(new MaxEpochsTerminationCondition(5))
.iterationTerminationConditions(new MaxTimeIterationTerminationCondition(20, TimeUnit.MINUTES))
.scoreCalculator(new DataSetLossCalculator(emnistTest, true))
.evaluateEveryNEpochs(1)
.modelSaver(new LocalFileModelSaver(System.getProperty("user.dir")))
.build();
// training
EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, network, emnistTrain);
EarlyStoppingResult result = trainer.fit();
// print out early stopping results
System.out.println("Termination reason: " + result.getTerminationReason());
System.out.println("Termination details: " + result.getTerminationDetails());
System.out.println("Total epochs: " + result.getTotalEpochs());
System.out.println("Best epoch number: " + result.getBestModelEpoch());
System.out.println("Score at best epoch: " + result.getBestModelScore());
// evaluate basic performance
Evaluation eval = network.evaluate(emnistTest);
System.out.println(eval.accuracy());
System.out.println(eval.precision());
System.out.println(eval.recall());
// evaluate ROC and calculate the Area Under Curve
ROCMultiClass roc = network.evaluateROCMultiClass(emnistTest);
System.out.println(roc.calculateAverageAUC());
// calculate AUC for a single class
int classIndex = 0;
System.out.println(roc.calculateAUC(classIndex));
// optionally, you can print all stats from the evaluations
System.out.println(eval.stats());
System.out.println(roc.stats());
}
示例10
@Test
public void testTimeTermination() {
//test termination after max time
Nd4j.getRandom().setSeed(12345);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Sgd(1e-6)).weightInit(WeightInit.XAVIER).list()
.layer(0, new OutputLayer.Builder().nIn(4).nOut(3)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(10000))
.iterationTerminationConditions(
new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS),
new MaxScoreIterationTerminationCondition(50)) //Initial score is ~8
.scoreCalculator(new DataSetLossCalculator(irisIter, true))
.modelSaver(saver).build();
IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf, net, irisIter);
long startTime = System.currentTimeMillis();
EarlyStoppingResult result = trainer.fit();
long endTime = System.currentTimeMillis();
int durationSeconds = (int) (endTime - startTime) / 1000;
assertTrue(durationSeconds >= 3);
assertTrue(durationSeconds <= 12);
assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition,
result.getTerminationReason());
String expDetails = new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS).toString();
assertEquals(expDetails, result.getTerminationDetails());
}
示例11
@Test
public void testRegressionScoreFunctionSimple() throws Exception {
for(Metric metric : new Metric[]{Metric.MSE,
Metric.MAE}) {
log.info("Metric: " + metric);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new DenseLayer.Builder().nIn(784).nOut(32).build())
.layer(new OutputLayer.Builder().nIn(32).nOut(784).activation(Activation.SIGMOID).lossFunction(LossFunctions.LossFunction.MSE).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);
List<DataSet> l = new ArrayList<>();
for( int i=0; i<10; i++ ){
DataSet ds = iter.next();
l.add(new DataSet(ds.getFeatures(), ds.getFeatures()));
}
iter = new ExistingDataSetIterator(l);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(5))
.iterationTerminationConditions(
new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
.scoreCalculator(new RegressionScoreCalculator(metric, iter)).modelSaver(saver)
.build();
EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
assertNotNull(result.getBestModel());
assertTrue(result.getBestModelScore() > 0.0);
}
}
示例12
@Test
public void testAEScoreFunctionSimple() throws Exception {
for(Metric metric : new Metric[]{Metric.MSE,
Metric.MAE}) {
log.info("Metric: " + metric);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new AutoEncoder.Builder().nIn(784).nOut(32).build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);
List<DataSet> l = new ArrayList<>();
for( int i=0; i<10; i++ ){
DataSet ds = iter.next();
l.add(new DataSet(ds.getFeatures(), ds.getFeatures()));
}
iter = new ExistingDataSetIterator(l);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(5))
.iterationTerminationConditions(
new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
.scoreCalculator(new AutoencoderScoreCalculator(metric, iter)).modelSaver(saver)
.build();
EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
EarlyStoppingResult<MultiLayerNetwork> result = trainer.pretrain();
assertNotNull(result.getBestModel());
assertTrue(result.getBestModelScore() > 0.0);
}
}
示例13
@Test
public void testVAEScoreFunctionSimple() throws Exception {
for(Metric metric : new Metric[]{Metric.MSE,
Metric.MAE}) {
log.info("Metric: " + metric);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new VariationalAutoencoder.Builder()
.nIn(784).nOut(32)
.encoderLayerSizes(64)
.decoderLayerSizes(64)
.build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);
List<DataSet> l = new ArrayList<>();
for( int i=0; i<10; i++ ){
DataSet ds = iter.next();
l.add(new DataSet(ds.getFeatures(), ds.getFeatures()));
}
iter = new ExistingDataSetIterator(l);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(5))
.iterationTerminationConditions(
new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
.scoreCalculator(new VAEReconErrorScoreCalculator(metric, iter)).modelSaver(saver)
.build();
EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
EarlyStoppingResult<MultiLayerNetwork> result = trainer.pretrain();
assertNotNull(result.getBestModel());
assertTrue(result.getBestModelScore() > 0.0);
}
}
示例14
@Test
public void testVAEScoreFunctionReconstructionProbSimple() throws Exception {
for(boolean logProb : new boolean[]{false, true}) {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list()
.layer(new VariationalAutoencoder.Builder()
.nIn(784).nOut(32)
.encoderLayerSizes(64)
.decoderLayerSizes(64)
.reconstructionDistribution(new BernoulliReconstructionDistribution(Activation.SIGMOID))
.build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);
List<DataSet> l = new ArrayList<>();
for (int i = 0; i < 10; i++) {
DataSet ds = iter.next();
l.add(new DataSet(ds.getFeatures(), ds.getFeatures()));
}
iter = new ExistingDataSetIterator(l);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(5))
.iterationTerminationConditions(
new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
.scoreCalculator(new VAEReconProbScoreCalculator(iter, 20, logProb)).modelSaver(saver)
.build();
EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
EarlyStoppingResult<MultiLayerNetwork> result = trainer.pretrain();
assertNotNull(result.getBestModel());
assertTrue(result.getBestModelScore() > 0.0);
}
}
示例15
@Test
public void testEarlyStoppingMaximizeScore() throws Exception {
Nd4j.getRandom().setSeed(12345);
int outputs = 2;
DataSet ds = new DataSet(
Nd4j.rand(new int[]{3, 10, 50}),
TestUtils.randomOneHotTimeSeries(3, outputs, 50, 12345));
DataSetIterator train = new ExistingDataSetIterator(
Arrays.asList(ds, ds, ds, ds, ds, ds, ds, ds, ds, ds));
DataSetIterator test = new SingletonDataSetIterator(ds);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(123)
.weightInit(WeightInit.XAVIER)
.updater(new Adam(0.1))
.activation(Activation.ELU)
.l2(1e-5)
.gradientNormalization(GradientNormalization
.ClipElementWiseAbsoluteValue)
.gradientNormalizationThreshold(1.0)
.list()
.layer(0, new LSTM.Builder()
.nIn(10)
.nOut(10)
.activation(Activation.TANH)
.gateActivationFunction(Activation.SIGMOID)
.dropOut(0.5)
.build())
.layer(1, new RnnOutputLayer.Builder()
.nIn(10)
.nOut(outputs)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT)
.build())
.build();
File f = testDir.newFolder();
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new LocalFileModelSaver(f.getAbsolutePath());
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(
new MaxEpochsTerminationCondition(10),
new ScoreImprovementEpochTerminationCondition(1))
.iterationTerminationConditions(
new MaxTimeIterationTerminationCondition(10, TimeUnit.MINUTES))
.scoreCalculator(new ClassificationScoreCalculator(Evaluation.Metric.F1, test))
.modelSaver(saver)
.saveLastModel(true)
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
EarlyStoppingTrainer t = new EarlyStoppingTrainer(esConf, net, train);
EarlyStoppingResult<MultiLayerNetwork> result = t.fit();
Map<Integer,Double> map = result.getScoreVsEpoch();
for( int i=1; i<map.size(); i++ ){
if(i == map.size() - 1){
assertTrue(map.get(i) <+ map.get(i-1));
} else {
assertTrue(map.get(i) > map.get(i-1));
}
}
}
示例16
/**
* The main method.
* @param args Not used.
*/
public static void main(String[] args) {
try {
int seed = 43;
double learningRate = 1e-2;
int nEpochs = 50;
int batchSize = 500;
// Setup training data.
System.out.println("Please wait, reading MNIST training data.");
String dir = System.getProperty("user.dir");
MNISTReader trainingReader = MNIST.loadMNIST(dir, true);
MNISTReader validationReader = MNIST.loadMNIST(dir, false);
DataSet trainingSet = trainingReader.getData();
DataSet validationSet = validationReader.getData();
DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize);
DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationReader.getNumRows());
System.out.println("Training set size: " + trainingReader.getNumImages());
System.out.println("Validation set size: " + validationReader.getNumImages());
System.out.println(trainingSet.get(0).getFeatures().size(1));
System.out.println(validationSet.get(0).getFeatures().size(1));
int numInputs = trainingReader.getNumCols()*trainingReader.getNumRows();
int numOutputs = 10;
int numHiddenNodes = 200;
// Create neural network.
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(1)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(learningRate)
.updater(Updater.NESTEROVS).momentum(0.9)
.regularization(true).dropOut(0.50)
.list(2)
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
.weightInit(WeightInit.XAVIER)
.activation("relu")
.build())
.layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
.weightInit(WeightInit.XAVIER)
.activation("softmax")
.nIn(numHiddenNodes).nOut(numOutputs).build())
.pretrain(false).backprop(true).build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(1));
// Define when we want to stop training.
EarlyStoppingModelSaver saver = new InMemoryModelSaver();
EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
//.epochTerminationConditions(new MaxEpochsTerminationCondition(10))
.epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5))
.evaluateEveryNEpochs(1)
.scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score
.modelSaver(saver)
.build();
EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);
// Train and display result.
EarlyStoppingResult result = trainer.fit();
System.out.println("Termination reason: " + result.getTerminationReason());
System.out.println("Termination details: " + result.getTerminationDetails());
System.out.println("Total epochs: " + result.getTotalEpochs());
System.out.println("Best epoch number: " + result.getBestModelEpoch());
System.out.println("Score at best epoch: " + result.getBestModelScore());
model = saver.getBestModel();
// Evaluate
Evaluation eval = new Evaluation(numOutputs);
validationSetIterator.reset();
for (int i = 0; i < validationSet.numExamples(); i++) {
DataSet t = validationSet.get(i);
INDArray features = t.getFeatureMatrix();
INDArray labels = t.getLabels();
INDArray predicted = model.output(features, false);
eval.eval(labels, predicted);
}
//Print the evaluation statistics
System.out.println(eval.stats());
} catch(Exception ex) {
ex.printStackTrace();
}
}
示例17
/**
* The main method.
* @param args Not used.
*/
public static void main(String[] args) {
try {
int seed = 43;
double learningRate = 1e-2;
int nEpochs = 50;
int batchSize = 500;
// Setup training data.
System.out.println("Please wait, reading MNIST training data.");
String dir = System.getProperty("user.dir");
MNISTReader trainingReader = MNIST.loadMNIST(dir, true);
MNISTReader validationReader = MNIST.loadMNIST(dir, false);
DataSet trainingSet = trainingReader.getData();
DataSet validationSet = validationReader.getData();
DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize);
DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationReader.getNumRows());
System.out.println("Training set size: " + trainingReader.getNumImages());
System.out.println("Validation set size: " + validationReader.getNumImages());
System.out.println(trainingSet.get(0).getFeatures().size(1));
System.out.println(validationSet.get(0).getFeatures().size(1));
int numInputs = trainingReader.getNumCols()*trainingReader.getNumRows();
int numOutputs = 10;
int numHiddenNodes = 100;
// Create neural network.
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(1)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(learningRate)
.updater(Updater.NESTEROVS).momentum(0.9)
.list(2)
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
.weightInit(WeightInit.XAVIER)
.activation("relu")
.build())
.layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
.weightInit(WeightInit.XAVIER)
.activation("softmax")
.nIn(numHiddenNodes).nOut(numOutputs).build())
.pretrain(false).backprop(true).build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(1));
// Define when we want to stop training.
EarlyStoppingModelSaver saver = new InMemoryModelSaver();
EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
//.epochTerminationConditions(new MaxEpochsTerminationCondition(10))
.epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5))
.evaluateEveryNEpochs(1)
.scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score
.modelSaver(saver)
.build();
EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);
// Train and display result.
EarlyStoppingResult result = trainer.fit();
System.out.println("Termination reason: " + result.getTerminationReason());
System.out.println("Termination details: " + result.getTerminationDetails());
System.out.println("Total epochs: " + result.getTotalEpochs());
System.out.println("Best epoch number: " + result.getBestModelEpoch());
System.out.println("Score at best epoch: " + result.getBestModelScore());
model = saver.getBestModel();
// Evaluate
Evaluation eval = new Evaluation(numOutputs);
validationSetIterator.reset();
for (int i = 0; i < validationSet.numExamples(); i++) {
DataSet t = validationSet.get(i);
INDArray features = t.getFeatureMatrix();
INDArray labels = t.getLabels();
INDArray predicted = model.output(features, false);
eval.eval(labels, predicted);
}
//Print the evaluation statistics
System.out.println(eval.stats());
} catch(Exception ex) {
ex.printStackTrace();
}
}
示例18
/**
* The main method.
* @param args Not used.
*/
public static void main(String[] args) {
try {
int seed = 43;
double learningRate = 0.1;
int splitTrainNum = (int) (150 * .75);
int numInputs = 4;
int numOutputs = 3;
int numHiddenNodes = 50;
// Setup training data.
final InputStream istream = LearnIrisBackprop.class.getResourceAsStream("/iris.csv");
if( istream==null ) {
System.out.println("Cannot access data set, make sure the resources are available.");
System.exit(1);
}
final NormalizeDataSet ds = NormalizeDataSet.load(istream);
final CategoryMap species = ds.encodeOneOfN(4); // species is column 4
istream.close();
DataSet next = ds.extractSupervised(0, 4, 4, 3);
next.shuffle();
// Training and validation data split
SplitTestAndTrain testAndTrain = next.splitTestAndTrain(splitTrainNum, new Random(seed));
DataSet trainSet = testAndTrain.getTrain();
DataSet validationSet = testAndTrain.getTest();
DataSetIterator trainSetIterator = new ListDataSetIterator(trainSet.asList(), trainSet.numExamples());
DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationSet.numExamples());
// Create neural network.
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(1)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(learningRate)
.updater(Updater.NESTEROVS).momentum(0.9)
.list(2)
.layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
.weightInit(WeightInit.XAVIER)
.activation("relu")
.build())
.layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
.weightInit(WeightInit.XAVIER)
.activation("softmax")
.nIn(numHiddenNodes).nOut(numOutputs).build())
.pretrain(false).backprop(true).build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(1));
// Define when we want to stop training.
EarlyStoppingModelSaver saver = new InMemoryModelSaver();
EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
.epochTerminationConditions(new MaxEpochsTerminationCondition(500)) //Max of 50 epochs
.epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(25))
.evaluateEveryNEpochs(1)
.scoreCalculator(new DataSetLossCalculator(validationSetIterator, true)) //Calculate test set score
.modelSaver(saver)
.build();
EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);
// Train and display result.
EarlyStoppingResult result = trainer.fit();
System.out.println("Termination reason: " + result.getTerminationReason());
System.out.println("Termination details: " + result.getTerminationDetails());
System.out.println("Total epochs: " + result.getTotalEpochs());
System.out.println("Best epoch number: " + result.getBestModelEpoch());
System.out.println("Score at best epoch: " + result.getBestModelScore());
model = saver.getBestModel();
// Evaluate
Evaluation eval = new Evaluation(numOutputs);
validationSetIterator.reset();
for (int i = 0; i < validationSet.numExamples(); i++) {
DataSet t = validationSet.get(i);
INDArray features = t.getFeatureMatrix();
INDArray labels = t.getLabels();
INDArray predicted = model.output(features, false);
System.out.println(features + ":Prediction("+findSpecies(labels,species)
+"):Actual("+findSpecies(predicted,species)+")" + predicted );
eval.eval(labels, predicted);
}
//Print the evaluation statistics
System.out.println(eval.stats());
} catch(Exception ex) {
ex.printStackTrace();
}
}