Java源码示例:org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer

示例1
@Test
public void testEarlyStoppingEveryNEpoch() {
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .updater(new Sgd(0.01)).weightInit(WeightInit.XAVIER).list()
                    .layer(0, new OutputLayer.Builder().nIn(4).nOut(3)
                            .activation(Activation.SOFTMAX)
                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));

    DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                    new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                                    .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                                    .scoreCalculator(new DataSetLossCalculator(irisIter, true))
                                    .evaluateEveryNEpochs(2).modelSaver(saver).build();

    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf, net, irisIter);

    EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
    System.out.println(result);

    assertEquals(5, result.getTotalEpochs());
    assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
}
 
示例2
@Test
public void testBadTuning() {
    //Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition

    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .updater(new Sgd(5.0)) //Intentionally huge LR
                    .weightInit(WeightInit.XAVIER).list()
                    .layer(0, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX)
                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));

    DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                    new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                                    .epochTerminationConditions(new MaxEpochsTerminationCondition(5000))
                                    .iterationTerminationConditions(
                                                    new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES),
                                                    new MaxScoreIterationTerminationCondition(10)) //Initial score is ~2.5
                                    .scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver)
                                    .build();

    IEarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, irisIter);
    EarlyStoppingResult result = trainer.fit();

    assertTrue(result.getTotalEpochs() < 5);
    assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition,
                    result.getTerminationReason());
    String expDetails = new MaxScoreIterationTerminationCondition(10).toString();
    assertEquals(expDetails, result.getTerminationDetails());

    assertEquals(0, result.getBestModelEpoch());
    assertNotNull(result.getBestModel());
}
 
示例3
@Test
public void testNoImprovementNEpochsTermination() {
    //Idea: terminate training if score (test set loss) does not improve for 5 consecutive epochs
    //Simulate this by setting LR = 0.0

    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .updater(new Sgd(0.0)).weightInit(WeightInit.XAVIER).list()
                    .layer(0, new OutputLayer.Builder().nIn(4).nOut(3)
                            .activation(Activation.SOFTMAX)
                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));

    DataSetIterator irisIter = new IrisDataSetIterator(150, 150);

    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                    new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                                    .epochTerminationConditions(new MaxEpochsTerminationCondition(100),
                                                    new ScoreImprovementEpochTerminationCondition(5))
                                    .iterationTerminationConditions(
                                                    new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES),
                                                    new MaxScoreIterationTerminationCondition(50)) //Initial score is ~8
                                    .scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver)
                                    .build();

    IEarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, irisIter);
    EarlyStoppingResult result = trainer.fit();

    //Expect no score change due to 0 LR -> terminate after 6 total epochs
    assertEquals(6, result.getTotalEpochs());
    assertEquals(0, result.getBestModelEpoch());
    assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
    String expDetails = new ScoreImprovementEpochTerminationCondition(5).toString();
    assertEquals(expDetails, result.getTerminationDetails());
}
 
示例4
@Test
public void testEarlyStoppingGetBestModel() {
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER).list()
                    .layer(0, new OutputLayer.Builder().nIn(4).nOut(3)
                            .activation(Activation.SOFTMAX)
                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));

    DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
    MultipleEpochsIterator mIter = new MultipleEpochsIterator(10, irisIter);

    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                    new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                                    .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                                    .iterationTerminationConditions(
                                                    new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
                                    .scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver)
                                    .build();

    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf, net, mIter);

    EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
    System.out.println(result);

    MultiLayerNetwork mln = result.getBestModel();

    assertEquals(net.getnLayers(), mln.getnLayers());
    assertEquals(net.conf().getOptimizationAlgo(), mln.conf().getOptimizationAlgo());
    BaseLayer bl = (BaseLayer) net.conf().getLayer();
    assertEquals(bl.getActivationFn().toString(), ((BaseLayer) mln.conf().getLayer()).getActivationFn().toString());
    assertEquals(bl.getIUpdater(), ((BaseLayer) mln.conf().getLayer()).getIUpdater());
}
 
示例5
@Test
public void testListeners() {
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER).list()
                    .layer(0, new OutputLayer.Builder().nIn(4).nOut(3)
                            .activation(Activation.SOFTMAX)
                                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));

    DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                    new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                                    .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                                    .iterationTerminationConditions(
                                                    new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
                                    .scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver)
                                    .build();

    LoggingEarlyStoppingListener listener = new LoggingEarlyStoppingListener();

    IEarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, irisIter, listener);

    trainer.fit();

    assertEquals(1, listener.onStartCallCount);
    assertEquals(5, listener.onEpochCallCount);
    assertEquals(1, listener.onCompletionCallCount);
}
 
示例6
@Test
public void testClassificationScoreFunctionSimple() throws Exception {

    for(Evaluation.Metric metric : Evaluation.Metric.values()) {
        log.info("Metric: " + metric);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .list()
                .layer(new DenseLayer.Builder().nIn(784).nOut(32).build())
                .layer(new OutputLayer.Builder().nIn(32).nOut(10).activation(Activation.SOFTMAX).build())
                .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);

        List<DataSet> l = new ArrayList<>();
        for( int i=0; i<10; i++ ){
            DataSet ds = iter.next();
            l.add(ds);
        }

        iter = new ExistingDataSetIterator(l);

        EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
        EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                        .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                        .iterationTerminationConditions(
                                new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
                        .scoreCalculator(new ClassificationScoreCalculator(metric, iter)).modelSaver(saver)
                        .build();

        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
        EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();

        assertNotNull(result.getBestModel());
    }
}
 
示例7
@Test
public void testEarlyStoppingListeners() {
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER).list()
            .layer(0, new OutputLayer.Builder().nIn(4).nOut(3)
                    .activation(Activation.SOFTMAX)
                    .lossFunction(LossFunctions.LossFunction.MCXENT).build())
            .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);

    TestListener tl = new TestListener();
    net.setListeners(tl);

    DataSetIterator irisIter = new IrisDataSetIterator(50, 150);
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
            new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                    .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                    .iterationTerminationConditions(
                            new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
                    .scoreCalculator(new DataSetLossCalculator(irisIter, true)).modelSaver(saver)
                    .build();

    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf, net, irisIter);

    EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();

    assertEquals(5, tl.countEpochStart);
    assertEquals(5, tl.countEpochEnd);
    assertEquals(5 * 150/50, tl.iterCount);

    assertEquals(4, tl.maxEpochStart);
    assertEquals(4, tl.maxEpochEnd);
}
 
示例8
public void train() throws IOException {

        MnistDataSetIterator mnistTrain = new MnistDataSetIterator(MINI_BATCH_SIZE, true, 12345);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(SEED)
                .learningRate(LEARNING_RATE)
                .weightInit(WeightInit.XAVIER)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .updater(Updater.NESTEROVS)
                .list()
                .layer(0, new ConvolutionLayer.Builder(5, 5)
                        .nIn(CHANNELS)
                        .stride(1, 1)
                        .nOut(20)
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(2, 2)
                        .build())
                .layer(2, new ConvolutionLayer.Builder(5, 5)
                        .nIn(20)
                        .stride(1, 1)
                        .nOut(50)
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(2, 2)
                        .build())
                .layer(4, new DenseLayer.Builder().activation(Activation.RELU)
                        .nIn(800)
                        .nOut(128).build())
                .layer(5, new DenseLayer.Builder().activation(Activation.RELU)
                        .nIn(128)
                        .nOut(64).build())
                .layer(6, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nOut(OUTPUT)
                        .activation(Activation.SOFTMAX)
                        .build())
                .setInputType(InputType.convolutionalFlat(28, 28, 1))
                .backprop(true).pretrain(false).build();

        EarlyStoppingConfiguration earlyStoppingConfiguration = new EarlyStoppingConfiguration.Builder()
                .epochTerminationConditions(new MaxEpochsTerminationCondition(MAX_EPOCHS))
                .scoreCalculator(new AccuracyCalculator(new MnistDataSetIterator(MINI_BATCH_SIZE, false, 12345)))
                .evaluateEveryNEpochs(1)
                .modelSaver(new LocalFileModelSaver(OUT_DIR))
                .build();

        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(earlyStoppingConfiguration, conf, mnistTrain);

        EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();

        log.info("Termination reason: " + result.getTerminationReason());
        log.info("Termination details: " + result.getTerminationDetails());
        log.info("Total epochs: " + result.getTotalEpochs());
        log.info("Best epoch number: " + result.getBestModelEpoch());
        log.info("Score at best epoch: " + result.getBestModelScore());
    }
 
示例9
public static void main(String... args) throws java.io.IOException {
    // create the data iterators for emnist
    DataSetIterator emnistTrain = new EmnistDataSetIterator(emnistSet, batchSize, true);
    DataSetIterator emnistTest = new EmnistDataSetIterator(emnistSet, batchSize, false);

    int outputNum = EmnistDataSetIterator.numLabels(emnistSet);

    // network configuration (not yet initialized)
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(rngSeed)
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .updater(new Adam())
            .l2(1e-4)
            .list()
            .layer(new DenseLayer.Builder()
                    .nIn(numRows * numColumns) // Number of input datapoints.
                    .nOut(1000) // Number of output datapoints.
                    .activation(Activation.RELU) // Activation function.
                    .weightInit(WeightInit.XAVIER) // Weight initialization.
                    .build())
            .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                    .nIn(1000)
                    .nOut(outputNum)
                    .activation(Activation.SOFTMAX)
                    .weightInit(WeightInit.XAVIER)
                    .build())
            .pretrain(false).backprop(true)
            .build();

    // create the MLN
    MultiLayerNetwork network = new MultiLayerNetwork(conf);
    network.init();

    // pass a training listener that reports score every N iterations
    network.addListeners(new ScoreIterationListener(reportingInterval));

    // here we set up an early stopping trainer
    // early stopping is useful when your trainer runs for
    // a long time or you need to programmatically stop training
    EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
            .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
            .iterationTerminationConditions(new MaxTimeIterationTerminationCondition(20, TimeUnit.MINUTES))
            .scoreCalculator(new DataSetLossCalculator(emnistTest, true))
            .evaluateEveryNEpochs(1)
            .modelSaver(new LocalFileModelSaver(System.getProperty("user.dir")))
            .build();

    // training
    EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, network, emnistTrain);
    EarlyStoppingResult result = trainer.fit();

    // print out early stopping results
    System.out.println("Termination reason: " + result.getTerminationReason());
    System.out.println("Termination details: " + result.getTerminationDetails());
    System.out.println("Total epochs: " + result.getTotalEpochs());
    System.out.println("Best epoch number: " + result.getBestModelEpoch());
    System.out.println("Score at best epoch: " + result.getBestModelScore());

    // evaluate basic performance
    Evaluation eval = network.evaluate(emnistTest);
    System.out.println(eval.accuracy());
    System.out.println(eval.precision());
    System.out.println(eval.recall());

    // evaluate ROC and calculate the Area Under Curve
    ROCMultiClass roc = network.evaluateROCMultiClass(emnistTest);
    System.out.println(roc.calculateAverageAUC());

    // calculate AUC for a single class
    int classIndex = 0;
    System.out.println(roc.calculateAUC(classIndex));

    // optionally, you can print all stats from the evaluations
    System.out.println(eval.stats());
    System.out.println(roc.stats());
}
 
示例10
@Test
public void testTimeTermination() {
    //test termination after max time

    Nd4j.getRandom().setSeed(12345);
    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345)
                    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                    .updater(new Sgd(1e-6)).weightInit(WeightInit.XAVIER).list()
                    .layer(0, new OutputLayer.Builder().nIn(4).nOut(3)
                                .activation(Activation.SOFTMAX)
                                .lossFunction(LossFunctions.LossFunction.MCXENT).build())
                    .build();
    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.setListeners(new ScoreIterationListener(1));

    DataSetIterator irisIter = new IrisDataSetIterator(150, 150);

    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                    new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                                    .epochTerminationConditions(new MaxEpochsTerminationCondition(10000))
                                    .iterationTerminationConditions(
                                                    new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS),
                                                    new MaxScoreIterationTerminationCondition(50)) //Initial score is ~8
                                    .scoreCalculator(new DataSetLossCalculator(irisIter, true))
                                    .modelSaver(saver).build();

    IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf, net, irisIter);
    long startTime = System.currentTimeMillis();
    EarlyStoppingResult result = trainer.fit();
    long endTime = System.currentTimeMillis();
    int durationSeconds = (int) (endTime - startTime) / 1000;

    assertTrue(durationSeconds >= 3);
    assertTrue(durationSeconds <= 12);

    assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition,
                    result.getTerminationReason());
    String expDetails = new MaxTimeIterationTerminationCondition(3, TimeUnit.SECONDS).toString();
    assertEquals(expDetails, result.getTerminationDetails());
}
 
示例11
@Test
public void testRegressionScoreFunctionSimple() throws Exception {

    for(Metric metric : new Metric[]{Metric.MSE,
            Metric.MAE}) {
        log.info("Metric: " + metric);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .list()
                .layer(new DenseLayer.Builder().nIn(784).nOut(32).build())
                .layer(new OutputLayer.Builder().nIn(32).nOut(784).activation(Activation.SIGMOID).lossFunction(LossFunctions.LossFunction.MSE).build())
                .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);

        List<DataSet> l = new ArrayList<>();
        for( int i=0; i<10; i++ ){
            DataSet ds = iter.next();
            l.add(new DataSet(ds.getFeatures(), ds.getFeatures()));
        }

        iter = new ExistingDataSetIterator(l);

        EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
        EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                        .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                        .iterationTerminationConditions(
                                new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
                        .scoreCalculator(new RegressionScoreCalculator(metric, iter)).modelSaver(saver)
                        .build();

        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
        EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();

        assertNotNull(result.getBestModel());
        assertTrue(result.getBestModelScore() > 0.0);
    }
}
 
示例12
@Test
public void testAEScoreFunctionSimple() throws Exception {

    for(Metric metric : new Metric[]{Metric.MSE,
            Metric.MAE}) {
        log.info("Metric: " + metric);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .list()
                .layer(new AutoEncoder.Builder().nIn(784).nOut(32).build())

                .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);

        List<DataSet> l = new ArrayList<>();
        for( int i=0; i<10; i++ ){
            DataSet ds = iter.next();
            l.add(new DataSet(ds.getFeatures(), ds.getFeatures()));
        }

        iter = new ExistingDataSetIterator(l);

        EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
        EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                        .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                        .iterationTerminationConditions(
                                new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
                        .scoreCalculator(new AutoencoderScoreCalculator(metric, iter)).modelSaver(saver)
                        .build();

        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
        EarlyStoppingResult<MultiLayerNetwork> result = trainer.pretrain();

        assertNotNull(result.getBestModel());
        assertTrue(result.getBestModelScore() > 0.0);
    }
}
 
示例13
@Test
public void testVAEScoreFunctionSimple() throws Exception {

    for(Metric metric : new Metric[]{Metric.MSE,
            Metric.MAE}) {
        log.info("Metric: " + metric);

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .list()
                .layer(new VariationalAutoencoder.Builder()
                        .nIn(784).nOut(32)
                        .encoderLayerSizes(64)
                        .decoderLayerSizes(64)
                        .build())

                .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);

        List<DataSet> l = new ArrayList<>();
        for( int i=0; i<10; i++ ){
            DataSet ds = iter.next();
            l.add(new DataSet(ds.getFeatures(), ds.getFeatures()));
        }

        iter = new ExistingDataSetIterator(l);

        EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
        EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                        .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                        .iterationTerminationConditions(
                                new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
                        .scoreCalculator(new VAEReconErrorScoreCalculator(metric, iter)).modelSaver(saver)
                        .build();

        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
        EarlyStoppingResult<MultiLayerNetwork> result = trainer.pretrain();

        assertNotNull(result.getBestModel());
        assertTrue(result.getBestModelScore() > 0.0);
    }
}
 
示例14
@Test
public void testVAEScoreFunctionReconstructionProbSimple() throws Exception {

    for(boolean logProb : new boolean[]{false, true}) {

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .list()
                .layer(new VariationalAutoencoder.Builder()
                        .nIn(784).nOut(32)
                        .encoderLayerSizes(64)
                        .decoderLayerSizes(64)
                        .reconstructionDistribution(new BernoulliReconstructionDistribution(Activation.SIGMOID))
                        .build())

                .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        DataSetIterator iter = new MnistDataSetIterator(32, false, 12345);

        List<DataSet> l = new ArrayList<>();
        for (int i = 0; i < 10; i++) {
            DataSet ds = iter.next();
            l.add(new DataSet(ds.getFeatures(), ds.getFeatures()));
        }

        iter = new ExistingDataSetIterator(l);

        EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
        EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
                new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                        .epochTerminationConditions(new MaxEpochsTerminationCondition(5))
                        .iterationTerminationConditions(
                                new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES))
                        .scoreCalculator(new VAEReconProbScoreCalculator(iter, 20, logProb)).modelSaver(saver)
                        .build();

        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, net, iter);
        EarlyStoppingResult<MultiLayerNetwork> result = trainer.pretrain();

        assertNotNull(result.getBestModel());
        assertTrue(result.getBestModelScore() > 0.0);
    }
}
 
示例15
@Test
public void testEarlyStoppingMaximizeScore() throws Exception {
    Nd4j.getRandom().setSeed(12345);

    int outputs = 2;

    DataSet ds = new DataSet(
            Nd4j.rand(new int[]{3, 10, 50}),
            TestUtils.randomOneHotTimeSeries(3, outputs, 50, 12345));
    DataSetIterator train = new ExistingDataSetIterator(
            Arrays.asList(ds, ds, ds, ds, ds, ds, ds, ds, ds, ds));
    DataSetIterator test = new SingletonDataSetIterator(ds);


    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(123)
            .weightInit(WeightInit.XAVIER)
            .updater(new Adam(0.1))
            .activation(Activation.ELU)
            .l2(1e-5)
            .gradientNormalization(GradientNormalization
                    .ClipElementWiseAbsoluteValue)
            .gradientNormalizationThreshold(1.0)
            .list()
            .layer(0, new LSTM.Builder()
                    .nIn(10)
                    .nOut(10)
                    .activation(Activation.TANH)
                    .gateActivationFunction(Activation.SIGMOID)
                    .dropOut(0.5)
                    .build())
            .layer(1, new RnnOutputLayer.Builder()
                    .nIn(10)
                    .nOut(outputs)
                    .activation(Activation.SOFTMAX)
                    .lossFunction(LossFunctions.LossFunction.MCXENT)
                    .build())
            .build();

    File f = testDir.newFolder();
    EarlyStoppingModelSaver<MultiLayerNetwork> saver = new LocalFileModelSaver(f.getAbsolutePath());
    EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
            new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
                    .epochTerminationConditions(
                            new MaxEpochsTerminationCondition(10),
                            new ScoreImprovementEpochTerminationCondition(1))
                    .iterationTerminationConditions(
                            new MaxTimeIterationTerminationCondition(10, TimeUnit.MINUTES))
                    .scoreCalculator(new ClassificationScoreCalculator(Evaluation.Metric.F1, test))
                    .modelSaver(saver)
                    .saveLastModel(true)
                    .build();

    MultiLayerNetwork net = new MultiLayerNetwork(conf);
    net.init();

    EarlyStoppingTrainer t = new EarlyStoppingTrainer(esConf, net, train);
    EarlyStoppingResult<MultiLayerNetwork> result = t.fit();

    Map<Integer,Double> map = result.getScoreVsEpoch();
    for( int i=1; i<map.size(); i++ ){
        if(i == map.size() - 1){
            assertTrue(map.get(i) <+ map.get(i-1));
        } else {
            assertTrue(map.get(i) > map.get(i-1));
        }
    }
}
 
示例16
/**
 * The main method.
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 1e-2;
        int nEpochs = 50;
        int batchSize = 500;

        // Setup training data.
        System.out.println("Please wait, reading MNIST training data.");
        String dir = System.getProperty("user.dir");
        MNISTReader trainingReader = MNIST.loadMNIST(dir, true);
        MNISTReader validationReader = MNIST.loadMNIST(dir, false);

        DataSet trainingSet = trainingReader.getData();
        DataSet validationSet = validationReader.getData();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize);
        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationReader.getNumRows());

        System.out.println("Training set size: " + trainingReader.getNumImages());
        System.out.println("Validation set size: " + validationReader.getNumImages());

        System.out.println(trainingSet.get(0).getFeatures().size(1));
        System.out.println(validationSet.get(0).getFeatures().size(1));

        int numInputs = trainingReader.getNumCols()*trainingReader.getNumRows();
        int numOutputs = 10;
        int numHiddenNodes = 200;

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .regularization(true).dropOut(0.50)
                .list(2)
                .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                        .weightInit(WeightInit.XAVIER)
                        .activation("relu")
                        .build())
                .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .weightInit(WeightInit.XAVIER)
                        .activation("softmax")
                        .nIn(numHiddenNodes).nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();


        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                //.epochTerminationConditions(new MaxEpochsTerminationCondition(10))
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5))
                .evaluateEveryNEpochs(1)
                .scoreCalculator(new DataSetLossCalculator(validationSetIterator, true))     //Calculate test set score
                .modelSaver(saver)
                .build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch(Exception ex) {
        ex.printStackTrace();
    }

}
 
示例17
/**
 * The main method.
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 1e-2;
        int nEpochs = 50;
        int batchSize = 500;

        // Setup training data.
        System.out.println("Please wait, reading MNIST training data.");
        String dir = System.getProperty("user.dir");
        MNISTReader trainingReader = MNIST.loadMNIST(dir, true);
        MNISTReader validationReader = MNIST.loadMNIST(dir, false);

        DataSet trainingSet = trainingReader.getData();
        DataSet validationSet = validationReader.getData();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainingSet.asList(), batchSize);
        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationReader.getNumRows());

        System.out.println("Training set size: " + trainingReader.getNumImages());
        System.out.println("Validation set size: " + validationReader.getNumImages());

        System.out.println(trainingSet.get(0).getFeatures().size(1));
        System.out.println(validationSet.get(0).getFeatures().size(1));

        int numInputs = trainingReader.getNumCols()*trainingReader.getNumRows();
        int numOutputs = 10;
        int numHiddenNodes = 100;

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .list(2)
                .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                        .weightInit(WeightInit.XAVIER)
                        .activation("relu")
                        .build())
                .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .weightInit(WeightInit.XAVIER)
                        .activation("softmax")
                        .nIn(numHiddenNodes).nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();


        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                //.epochTerminationConditions(new MaxEpochsTerminationCondition(10))
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(5))
                .evaluateEveryNEpochs(1)
                .scoreCalculator(new DataSetLossCalculator(validationSetIterator, true))     //Calculate test set score
                .modelSaver(saver)
                .build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch(Exception ex) {
        ex.printStackTrace();
    }

}
 
示例18
/**
 * The main method.
 * @param args Not used.
 */
public static void main(String[] args) {
    try {
        int seed = 43;
        double learningRate = 0.1;
        int splitTrainNum = (int) (150 * .75);

        int numInputs = 4;
        int numOutputs = 3;
        int numHiddenNodes = 50;

        // Setup training data.
        final InputStream istream = LearnIrisBackprop.class.getResourceAsStream("/iris.csv");
        if( istream==null ) {
            System.out.println("Cannot access data set, make sure the resources are available.");
            System.exit(1);
        }
        final NormalizeDataSet ds = NormalizeDataSet.load(istream);
        final CategoryMap species = ds.encodeOneOfN(4); // species is column 4
        istream.close();

        DataSet next = ds.extractSupervised(0, 4, 4, 3);
        next.shuffle();

        // Training and validation data split
        SplitTestAndTrain testAndTrain = next.splitTestAndTrain(splitTrainNum, new Random(seed));
        DataSet trainSet = testAndTrain.getTrain();
        DataSet validationSet = testAndTrain.getTest();

        DataSetIterator trainSetIterator = new ListDataSetIterator(trainSet.asList(), trainSet.numExamples());

        DataSetIterator validationSetIterator = new ListDataSetIterator(validationSet.asList(), validationSet.numExamples());

        // Create neural network.
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .iterations(1)
                .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                .learningRate(learningRate)
                .updater(Updater.NESTEROVS).momentum(0.9)
                .list(2)
                .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                        .weightInit(WeightInit.XAVIER)
                        .activation("relu")
                        .build())
                .layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                        .weightInit(WeightInit.XAVIER)
                        .activation("softmax")
                        .nIn(numHiddenNodes).nOut(numOutputs).build())
                .pretrain(false).backprop(true).build();


        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        model.setListeners(new ScoreIterationListener(1));

        // Define when we want to stop training.
        EarlyStoppingModelSaver saver = new InMemoryModelSaver();
        EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder()
                .epochTerminationConditions(new MaxEpochsTerminationCondition(500)) //Max of 50 epochs
                .epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(25))
                .evaluateEveryNEpochs(1)
                .scoreCalculator(new DataSetLossCalculator(validationSetIterator, true))     //Calculate test set score
                .modelSaver(saver)
                .build();
        EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConf, conf, trainSetIterator);

        // Train and display result.
        EarlyStoppingResult result = trainer.fit();
        System.out.println("Termination reason: " + result.getTerminationReason());
        System.out.println("Termination details: " + result.getTerminationDetails());
        System.out.println("Total epochs: " + result.getTotalEpochs());
        System.out.println("Best epoch number: " + result.getBestModelEpoch());
        System.out.println("Score at best epoch: " + result.getBestModelScore());

        model = saver.getBestModel();

        // Evaluate
        Evaluation eval = new Evaluation(numOutputs);
        validationSetIterator.reset();

        for (int i = 0; i < validationSet.numExamples(); i++) {
            DataSet t = validationSet.get(i);
            INDArray features = t.getFeatureMatrix();
            INDArray labels = t.getLabels();
            INDArray predicted = model.output(features, false);
            System.out.println(features + ":Prediction("+findSpecies(labels,species)
                    +"):Actual("+findSpecies(predicted,species)+")" + predicted );
            eval.eval(labels, predicted);
        }

        //Print the evaluation statistics
        System.out.println(eval.stats());
    } catch(Exception ex) {
        ex.printStackTrace();
    }
}