Java源码示例:org.datavec.api.split.InputStreamInputSplit

示例1
protected Iterator<List<Writable>> getIterator(int location) {
	Iterator<List<Writable>> iterator = null;
	
	if (inputSplit instanceof InputStreamInputSplit) {
        InputStream is = ((InputStreamInputSplit) inputSplit).getIs();
        if (is != null) {
            iterator = lineIterator(new InputStreamReader(is));
        }
    } else {
     this.locations = inputSplit.locations();
     if (locations != null && locations.length > 0) {
         InputStream inputStream;
         try {
             inputStream = locations[location].toURL().openStream();
             onLocationOpen(locations[location]);
         } catch (IOException e) {
             throw new RuntimeException(e);
         }
         iterator = lineIterator(new InputStreamReader(inputStream));
     }
    }
    if (iterator == null)
        throw new UnsupportedOperationException("Unknown input split: " + inputSplit);
    return iterator;
}
 
示例2
private void createDataSource() throws IOException, InterruptedException {
    //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
    int numLinesToSkip = 0;
    String delimiter = ",";
    RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    recordReader.initialize(new InputStreamInputSplit(dataFile));

    //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
    int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
    int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2

    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
    DataSet allData = iterator.next();
    allData.shuffle();

    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80);  //Use 80% of data for training

    trainingData = testAndTrain.getTrain();
    testData = testAndTrain.getTest();

    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
    normalizer.transform(trainingData);     //Apply normalization to the training data
    normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
 
示例3
private void createDataSource() throws IOException, InterruptedException {
    //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
    int numLinesToSkip = 0;
    String delimiter = ",";
    RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
    recordReader.initialize(new InputStreamInputSplit(dataFile));

    //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
    int labelIndex = 11;

    DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true);
    DataSet allData = iterator.next();

    SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80);  //Use 80% of data for training

    trainingData = testAndTrain.getTrain();
    testData = testAndTrain.getTest();

    //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
    DataNormalization normalizer = new NormalizerStandardize();
    normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
    normalizer.transform(trainingData);     //Apply normalization to the training data
    normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
 
示例4
/**
 *
 * @param reader
 * @param inputStream
 * @param function
 * @return
 * @throws Exception
 */
public static TimingStatistics timeNDArrayCreation(RecordReader reader,
                                                   InputStream inputStream,
                                                   INDArrayCreationFunction function) throws Exception {


    reader.initialize(new InputStreamInputSplit(inputStream));
    long longNanos = System.nanoTime();
    List<Writable> next = reader.next();
    long endNanos = System.nanoTime();
    long etlDiff = endNanos - longNanos;
    long startArrCreation = System.nanoTime();
    INDArray arr = function.createFromRecord(next);
    long endArrCreation = System.nanoTime();
    long endCreationDiff = endArrCreation - startArrCreation;
    Map<Integer, Map<MemcpyDirection, Long>> currentBandwidth = PerformanceTracker.getInstance().getCurrentBandwidth();
    val bw = currentBandwidth.get(0).get(MemcpyDirection.HOST_TO_DEVICE);
    val deviceToHost = currentBandwidth.get(0).get(MemcpyDirection.HOST_TO_DEVICE);

    return TimingStatistics.builder()
            .diskReadingTimeNanos(etlDiff)
            .bandwidthNanosHostToDevice(bw)
            .bandwidthDeviceToHost(deviceToHost)
            .ndarrayCreationTimeNanos(endCreationDiff)
            .build();
}
 
示例5
protected Iterator<String> getIterator(int location) {
    Iterator<String> iterator = null;
    if (inputSplit instanceof StringSplit) {
        StringSplit stringSplit = (StringSplit) inputSplit;
        iterator = Collections.singletonList(stringSplit.getData()).listIterator();
    } else if (inputSplit instanceof InputStreamInputSplit) {
        InputStream is = ((InputStreamInputSplit) inputSplit).getIs();
        if (is != null) {
            iterator = IOUtils.lineIterator(new InputStreamReader(is));
        }
    } else {
        this.locations = inputSplit.locations();
        if (locations != null && locations.length > 0) {
            InputStream inputStream;
            try {
                inputStream = locations[location].toURL().openStream();
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
            iterator = IOUtils.lineIterator(new InputStreamReader(inputStream));
        }
    }
    if (iterator == null)
        throw new UnsupportedOperationException("Unknown input split: " + inputSplit);
    return iterator;
}
 
示例6
@Test
public void testStreamReset() throws Exception {
    CSVRecordReader rr = new CSVRecordReader(0, ',');
    rr.initialize(new InputStreamInputSplit(new ClassPathResource("iris.dat").getInputStream()));

    int count = 0;
    while(rr.hasNext()){
        assertNotNull(rr.next());
        count++;
    }
    assertEquals(150, count);

    assertFalse(rr.resetSupported());

    try{
        rr.reset();
        fail("Expected exception");
    } catch (Exception e){
        e.printStackTrace();
    }
}
 
示例7
@Test
    public void testStreamReset() throws Exception {
        CSVRecordReader rr = new CSVRecordReader(0, ',');
        rr.initialize(new InputStreamInputSplit(new ClassPathResource("datavec-api/iris.dat").getInputStream()));

        int count = 0;
        while(rr.hasNext()){
            assertNotNull(rr.next());
            count++;
        }
        assertEquals(150, count);

        assertFalse(rr.resetSupported());

        try{
            rr.reset();
            fail("Expected exception");
        } catch (Exception e){
            String msg = e.getMessage();
            String msg2 = e.getCause().getMessage();
            assertTrue(msg, msg.contains("Error during LineRecordReader reset"));
            assertTrue(msg2, msg2.contains("Reset not supported from streams"));
//            e.printStackTrace();
        }
    }
 
示例8
@Test
public void testLineReaderWithInputStreamInputSplit() throws Exception {
    File tmpdir = testDir.newFolder();

    File tmp1 = new File(tmpdir, "tmp1.txt.gz");

    OutputStream os = new GZIPOutputStream(new FileOutputStream(tmp1, false));
    IOUtils.writeLines(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9"), null, os);
    os.flush();
    os.close();

    InputSplit split = new InputStreamInputSplit(new GZIPInputStream(new FileInputStream(tmp1)));

    RecordReader reader = new LineRecordReader();
    reader.initialize(split);

    int count = 0;
    while (reader.hasNext()) {
        assertEquals(1, reader.next().size());
        count++;
    }

    assertEquals(9, count);
}
 
示例9
/**
 * Convert a traditional sc.binaryFiles
 * in to something usable for machine learning
 * @param binaryFiles the binary files to convert
 * @param reader the reader to use
 * @return the labeled points based on the given rdd
 */
public static JavaRDD<LabeledPoint> fromBinary(JavaPairRDD<String, PortableDataStream> binaryFiles,
                final RecordReader reader) {
    JavaRDD<Collection<Writable>> records =
                    binaryFiles.map(new Function<Tuple2<String, PortableDataStream>, Collection<Writable>>() {
                        @Override
                        public Collection<Writable> call(
                                        Tuple2<String, PortableDataStream> stringPortableDataStreamTuple2)
                                        throws Exception {
                            reader.initialize(new InputStreamInputSplit(stringPortableDataStreamTuple2._2().open(),
                                            stringPortableDataStreamTuple2._1()));
                            return reader.next();
                        }
                    });

    JavaRDD<LabeledPoint> ret = records.map(new Function<Collection<Writable>, LabeledPoint>() {
        @Override
        public LabeledPoint call(Collection<Writable> writables) throws Exception {
            return pointOf(writables);
        }
    });
    return ret;
}
 
示例10
@Override
public ArrowWritableRecordBatch convert(Buffer input, ConverterArgs parameters, Map<String, Object> contextData) {
    ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
    arrowRecordReader.initialize(new InputStreamInputSplit(new ByteArrayInputStream(input.getBytes())));
    arrowRecordReader.next();
    return arrowRecordReader.getCurrentBatch();
}
 
示例11
@Override
public boolean hasNext() {
    if(inputSplit instanceof InputStreamInputSplit) {
        return finishedInputStreamSplit;
    }

    if (iter != null) {
        return iter.hasNext();
    } else if (record != null) {
        return !hitImage;
    }
    throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist");
}
 
示例12
@Test
public void testLineReaderWithInputStreamInputSplit() throws Exception {
    String tempDir = System.getProperty("java.io.tmpdir");
    File tmpdir = new File(tempDir, "tmpdir");
    tmpdir.mkdir();

    File tmp1 = new File(tmpdir, "tmp1.txt.gz");

    OutputStream os = new GZIPOutputStream(new FileOutputStream(tmp1, false));
    IOUtils.writeLines(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9"), null, os);
    os.flush();
    os.close();

    InputSplit split = new InputStreamInputSplit(new GZIPInputStream(new FileInputStream(tmp1)));

    RecordReader reader = new LineRecordReader();
    reader.initialize(split);

    int count = 0;
    while (reader.hasNext()) {
        assertEquals(1, reader.next().size());
        count++;
    }

    assertEquals(9, count);

    try {
        FileUtils.deleteDirectory(tmpdir);
    } catch (Exception e) {
        e.printStackTrace();
    }
}
 
示例13
@Override
public boolean hasNext() {
    if(inputSplit instanceof InputStreamInputSplit) {
        return finishedInputStreamSplit;
    }

    if (iter != null) {
        return iter.hasNext();
    } else if (record != null) {
        return !hitImage;
    }
    throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist");
}
 
示例14
@Override
public void initialize(InputSplit split) throws IOException, InterruptedException {
    super.initialize(split);
    if(!(inputSplit instanceof StringSplit || inputSplit instanceof InputStreamInputSplit)){
        final ArrayList<URI> uris = new ArrayList<>();
        final Iterator<URI> uriIterator = inputSplit.locationsIterator();
        while(uriIterator.hasNext()) uris.add(uriIterator.next());

        this.locations = uris.toArray(new URI[0]);
    }
    this.iter = getIterator(0);
    this.initialized = true;
}
 
示例15
@Test
public void testReadingFromStream() throws Exception {

    for(boolean b : new boolean[]{false, true}) {
        int batchSize = 1;
        int labelIndex = 4;
        int numClasses = 3;
        InputStream dataFile = Resources.asStream("iris.txt");
        RecordReader recordReader = new CSVRecordReader(0, ',');
        recordReader.initialize(new InputStreamInputSplit(dataFile));

        assertTrue(recordReader.hasNext());
        assertFalse(recordReader.resetSupported());

        DataSetIterator iterator;
        if(b){
            iterator = new RecordReaderDataSetIterator.Builder(recordReader, batchSize)
                    .classification(labelIndex, numClasses)
                    .build();
        } else {
            iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
        }
        assertFalse(iterator.resetSupported());

        int count = 0;
        while (iterator.hasNext()) {
            assertNotNull(iterator.next());
            count++;
        }

        assertEquals(150, count);

        try {
            iterator.reset();
            fail("Expected exception");
        } catch (Exception e) {
            //expected
        }
    }
}