Java源码示例:org.datavec.api.split.InputStreamInputSplit
示例1
protected Iterator<List<Writable>> getIterator(int location) {
Iterator<List<Writable>> iterator = null;
if (inputSplit instanceof InputStreamInputSplit) {
InputStream is = ((InputStreamInputSplit) inputSplit).getIs();
if (is != null) {
iterator = lineIterator(new InputStreamReader(is));
}
} else {
this.locations = inputSplit.locations();
if (locations != null && locations.length > 0) {
InputStream inputStream;
try {
inputStream = locations[location].toURL().openStream();
onLocationOpen(locations[location]);
} catch (IOException e) {
throw new RuntimeException(e);
}
iterator = lineIterator(new InputStreamReader(inputStream));
}
}
if (iterator == null)
throw new UnsupportedOperationException("Unknown input split: " + inputSplit);
return iterator;
}
示例2
private void createDataSource() throws IOException, InterruptedException {
//First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
String delimiter = ",";
RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
recordReader.initialize(new InputStreamInputSplit(dataFile));
//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 4; //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
int numClasses = 3; //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
DataSet allData = iterator.next();
allData.shuffle();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80); //Use 80% of data for training
trainingData = testAndTrain.getTrain();
testData = testAndTrain.getTest();
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData); //Apply normalization to the training data
normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
示例3
private void createDataSource() throws IOException, InterruptedException {
//First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
int numLinesToSkip = 0;
String delimiter = ",";
RecordReader recordReader = new CSVRecordReader(numLinesToSkip, delimiter);
recordReader.initialize(new InputStreamInputSplit(dataFile));
//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
int labelIndex = 11;
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, labelIndex, true);
DataSet allData = iterator.next();
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.80); //Use 80% of data for training
trainingData = testAndTrain.getTrain();
testData = testAndTrain.getTest();
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
DataNormalization normalizer = new NormalizerStandardize();
normalizer.fit(trainingData); //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData); //Apply normalization to the training data
normalizer.transform(testData); //Apply normalization to the test data. This is using statistics calculated from the *training* set
}
示例4
/**
*
* @param reader
* @param inputStream
* @param function
* @return
* @throws Exception
*/
public static TimingStatistics timeNDArrayCreation(RecordReader reader,
InputStream inputStream,
INDArrayCreationFunction function) throws Exception {
reader.initialize(new InputStreamInputSplit(inputStream));
long longNanos = System.nanoTime();
List<Writable> next = reader.next();
long endNanos = System.nanoTime();
long etlDiff = endNanos - longNanos;
long startArrCreation = System.nanoTime();
INDArray arr = function.createFromRecord(next);
long endArrCreation = System.nanoTime();
long endCreationDiff = endArrCreation - startArrCreation;
Map<Integer, Map<MemcpyDirection, Long>> currentBandwidth = PerformanceTracker.getInstance().getCurrentBandwidth();
val bw = currentBandwidth.get(0).get(MemcpyDirection.HOST_TO_DEVICE);
val deviceToHost = currentBandwidth.get(0).get(MemcpyDirection.HOST_TO_DEVICE);
return TimingStatistics.builder()
.diskReadingTimeNanos(etlDiff)
.bandwidthNanosHostToDevice(bw)
.bandwidthDeviceToHost(deviceToHost)
.ndarrayCreationTimeNanos(endCreationDiff)
.build();
}
示例5
protected Iterator<String> getIterator(int location) {
Iterator<String> iterator = null;
if (inputSplit instanceof StringSplit) {
StringSplit stringSplit = (StringSplit) inputSplit;
iterator = Collections.singletonList(stringSplit.getData()).listIterator();
} else if (inputSplit instanceof InputStreamInputSplit) {
InputStream is = ((InputStreamInputSplit) inputSplit).getIs();
if (is != null) {
iterator = IOUtils.lineIterator(new InputStreamReader(is));
}
} else {
this.locations = inputSplit.locations();
if (locations != null && locations.length > 0) {
InputStream inputStream;
try {
inputStream = locations[location].toURL().openStream();
} catch (IOException e) {
throw new RuntimeException(e);
}
iterator = IOUtils.lineIterator(new InputStreamReader(inputStream));
}
}
if (iterator == null)
throw new UnsupportedOperationException("Unknown input split: " + inputSplit);
return iterator;
}
示例6
@Test
public void testStreamReset() throws Exception {
CSVRecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new InputStreamInputSplit(new ClassPathResource("iris.dat").getInputStream()));
int count = 0;
while(rr.hasNext()){
assertNotNull(rr.next());
count++;
}
assertEquals(150, count);
assertFalse(rr.resetSupported());
try{
rr.reset();
fail("Expected exception");
} catch (Exception e){
e.printStackTrace();
}
}
示例7
@Test
public void testStreamReset() throws Exception {
CSVRecordReader rr = new CSVRecordReader(0, ',');
rr.initialize(new InputStreamInputSplit(new ClassPathResource("datavec-api/iris.dat").getInputStream()));
int count = 0;
while(rr.hasNext()){
assertNotNull(rr.next());
count++;
}
assertEquals(150, count);
assertFalse(rr.resetSupported());
try{
rr.reset();
fail("Expected exception");
} catch (Exception e){
String msg = e.getMessage();
String msg2 = e.getCause().getMessage();
assertTrue(msg, msg.contains("Error during LineRecordReader reset"));
assertTrue(msg2, msg2.contains("Reset not supported from streams"));
// e.printStackTrace();
}
}
示例8
@Test
public void testLineReaderWithInputStreamInputSplit() throws Exception {
File tmpdir = testDir.newFolder();
File tmp1 = new File(tmpdir, "tmp1.txt.gz");
OutputStream os = new GZIPOutputStream(new FileOutputStream(tmp1, false));
IOUtils.writeLines(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9"), null, os);
os.flush();
os.close();
InputSplit split = new InputStreamInputSplit(new GZIPInputStream(new FileInputStream(tmp1)));
RecordReader reader = new LineRecordReader();
reader.initialize(split);
int count = 0;
while (reader.hasNext()) {
assertEquals(1, reader.next().size());
count++;
}
assertEquals(9, count);
}
示例9
/**
* Convert a traditional sc.binaryFiles
* in to something usable for machine learning
* @param binaryFiles the binary files to convert
* @param reader the reader to use
* @return the labeled points based on the given rdd
*/
public static JavaRDD<LabeledPoint> fromBinary(JavaPairRDD<String, PortableDataStream> binaryFiles,
final RecordReader reader) {
JavaRDD<Collection<Writable>> records =
binaryFiles.map(new Function<Tuple2<String, PortableDataStream>, Collection<Writable>>() {
@Override
public Collection<Writable> call(
Tuple2<String, PortableDataStream> stringPortableDataStreamTuple2)
throws Exception {
reader.initialize(new InputStreamInputSplit(stringPortableDataStreamTuple2._2().open(),
stringPortableDataStreamTuple2._1()));
return reader.next();
}
});
JavaRDD<LabeledPoint> ret = records.map(new Function<Collection<Writable>, LabeledPoint>() {
@Override
public LabeledPoint call(Collection<Writable> writables) throws Exception {
return pointOf(writables);
}
});
return ret;
}
示例10
@Override
public ArrowWritableRecordBatch convert(Buffer input, ConverterArgs parameters, Map<String, Object> contextData) {
ArrowRecordReader arrowRecordReader = new ArrowRecordReader();
arrowRecordReader.initialize(new InputStreamInputSplit(new ByteArrayInputStream(input.getBytes())));
arrowRecordReader.next();
return arrowRecordReader.getCurrentBatch();
}
示例11
@Override
public boolean hasNext() {
if(inputSplit instanceof InputStreamInputSplit) {
return finishedInputStreamSplit;
}
if (iter != null) {
return iter.hasNext();
} else if (record != null) {
return !hitImage;
}
throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist");
}
示例12
@Test
public void testLineReaderWithInputStreamInputSplit() throws Exception {
String tempDir = System.getProperty("java.io.tmpdir");
File tmpdir = new File(tempDir, "tmpdir");
tmpdir.mkdir();
File tmp1 = new File(tmpdir, "tmp1.txt.gz");
OutputStream os = new GZIPOutputStream(new FileOutputStream(tmp1, false));
IOUtils.writeLines(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8", "9"), null, os);
os.flush();
os.close();
InputSplit split = new InputStreamInputSplit(new GZIPInputStream(new FileInputStream(tmp1)));
RecordReader reader = new LineRecordReader();
reader.initialize(split);
int count = 0;
while (reader.hasNext()) {
assertEquals(1, reader.next().size());
count++;
}
assertEquals(9, count);
try {
FileUtils.deleteDirectory(tmpdir);
} catch (Exception e) {
e.printStackTrace();
}
}
示例13
@Override
public boolean hasNext() {
if(inputSplit instanceof InputStreamInputSplit) {
return finishedInputStreamSplit;
}
if (iter != null) {
return iter.hasNext();
} else if (record != null) {
return !hitImage;
}
throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist");
}
示例14
@Override
public void initialize(InputSplit split) throws IOException, InterruptedException {
super.initialize(split);
if(!(inputSplit instanceof StringSplit || inputSplit instanceof InputStreamInputSplit)){
final ArrayList<URI> uris = new ArrayList<>();
final Iterator<URI> uriIterator = inputSplit.locationsIterator();
while(uriIterator.hasNext()) uris.add(uriIterator.next());
this.locations = uris.toArray(new URI[0]);
}
this.iter = getIterator(0);
this.initialized = true;
}
示例15
@Test
public void testReadingFromStream() throws Exception {
for(boolean b : new boolean[]{false, true}) {
int batchSize = 1;
int labelIndex = 4;
int numClasses = 3;
InputStream dataFile = Resources.asStream("iris.txt");
RecordReader recordReader = new CSVRecordReader(0, ',');
recordReader.initialize(new InputStreamInputSplit(dataFile));
assertTrue(recordReader.hasNext());
assertFalse(recordReader.resetSupported());
DataSetIterator iterator;
if(b){
iterator = new RecordReaderDataSetIterator.Builder(recordReader, batchSize)
.classification(labelIndex, numClasses)
.build();
} else {
iterator = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numClasses);
}
assertFalse(iterator.resetSupported());
int count = 0;
while (iterator.hasNext()) {
assertNotNull(iterator.next());
count++;
}
assertEquals(150, count);
try {
iterator.reset();
fail("Expected exception");
} catch (Exception e) {
//expected
}
}
}