Java源码示例:org.nd4j.linalg.api.iter.NdIndexIterator
示例1
void testSvd(int M, int N, char matrixOrder) {
INDArray A = Nd4j.rand(M, N, matrixOrder);
INDArray Aorig = A.dup();
INDArray U = Nd4j.create(M, M, matrixOrder);
INDArray S = Nd4j.create(N, matrixOrder);
INDArray VT = Nd4j.create(N, N, matrixOrder);
Nd4j.getBlasWrapper().lapack().gesvd(A, S, U, VT);
INDArray SS = Nd4j.create(M, N);
for (int i = 0; i < Math.min(M, N); i++) {
SS.put(i, i, S.getDouble(i));
}
INDArray AA = U.mmul(SS).mmul(VT);
NdIndexIterator iter = new NdIndexIterator(AA.shape());
while(iter.hasNext()){
int[] pos = iter.next();
assertEquals("SVD did not factorize properly", AA.getDouble(pos), Aorig.getDouble(pos), 1e-5);
}
}
示例2
@Override
public INDArray sample(INDArray ret) {
if (random.getStatePointer() != null) {
if (means != null) {
return Nd4j.getExecutioner().exec(new GaussianDistribution(
ret, means, standardDeviation), random);
} else {
return Nd4j.getExecutioner().exec(new GaussianDistribution(
ret, mean, standardDeviation), random);
}
} else {
Iterator<long[]> idxIter = new NdIndexIterator(ret.shape()); //For consistent values irrespective of c vs. fortran ordering
long len = ret.length();
if (means != null) {
for (int i = 0; i < len; i++) {
long[] idx = idxIter.next();
ret.putScalar(idx, standardDeviation * random.nextGaussian() + means.getDouble(idx));
}
} else {
for (int i = 0; i < len; i++) {
ret.putScalar(idxIter.next(), standardDeviation * random.nextGaussian() + mean);
}
}
return ret;
}
}
示例3
@Override
public String apply(INDArray actual) {
//TODO switch to binary relative error ops
if(!Arrays.equals(expected.shape(), actual.shape())){
throw new IllegalStateException("Shapes differ! " + Arrays.toString(expected.shape()) + " vs " + Arrays.toString(actual.shape()));
}
NdIndexIterator iter = new NdIndexIterator(expected.shape());
while(iter.hasNext()){
long[] next = iter.next();
double d1 = expected.getDouble(next);
double d2 = actual.getDouble(next);
if(d1 == 0.0 && d2 == 0){
continue;
}
if(Math.abs(d1-d2) < minAbsoluteError){
continue;
}
double re = Math.abs(d1-d2) / (Math.abs(d1) + Math.abs(d2));
if(re > maxRelativeError){
return "Failed on relative error at position " + Arrays.toString(next) + ": relativeError=" + re + ", maxRE=" + maxRelativeError + ", absError=" +
Math.abs(d1-d2) + ", minAbsError=" + minAbsoluteError + " - values (" + d1 + "," + d2 + ")";
}
}
return null;
}
示例4
@Override
public INDArray sample(INDArray ret) {
if (random.getStatePointer() != null) {
if (means != null) {
return Nd4j.getExecutioner().exec(new GaussianDistribution(
ret, means, standardDeviation), random);
} else {
return Nd4j.getExecutioner().exec(new GaussianDistribution(
ret, mean, standardDeviation), random);
}
} else {
Iterator<long[]> idxIter = new NdIndexIterator(ret.shape()); //For consistent values irrespective of c vs. fortran ordering
long len = ret.length();
if (means != null) {
for (int i = 0; i < len; i++) {
long[] idx = idxIter.next();
ret.putScalar(idx, standardDeviation * random.nextGaussian() + means.getDouble(idx));
}
} else {
for (int i = 0; i < len; i++) {
ret.putScalar(idxIter.next(), standardDeviation * random.nextGaussian() + mean);
}
}
return ret;
}
}
示例5
public static void main(String[] args) {
INDArray nd = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, new int[]{2, 6});
System.out.println("打印原有数组");
System.out.println(nd);
/*
获取指定索引的值
*/
System.out.println("获取数组下标为0, 3的值");
double value = nd.getDouble(0, 3);
System.out.println(value);
/*
修改指定索引的值
*/
System.out.println("修改数组下标为0, 3的值");
//scalar 标量
nd.putScalar(0, 3, 100);
System.out.println(nd);
/*
使用索引迭代器遍历ndarray,使用c order
*/
System.out.println("使用索引迭代器遍历ndarray");
NdIndexIterator iter = new NdIndexIterator(2, 6);
while (iter.hasNext()) {
long[] nextIndex = iter.next();
double nextVal = nd.getDouble(nextIndex);
System.out.println(nextVal);
}
}
示例6
private static boolean allZeros(INDArray array) {
NdIndexIterator iter = new NdIndexIterator(array.shape());
while (iter.hasNext()) {
double nextVal = array.getDouble(iter.next());
if (nextVal != 0) {
return false;
}
}
return true;
}
示例7
void testEv(int N, char matrixOrder) {
INDArray A = Nd4j.rand(N, N, matrixOrder);
for (int r = 1; r < N; r++) {
for (int c = 0; c < r; c++) {
double v = A.getDouble(r, c);
A.putScalar(c, r, v);
}
}
INDArray Aorig = A.dup();
INDArray V = Nd4j.create(N);
Nd4j.getBlasWrapper().lapack().syev('V', 'U', A, V);
INDArray VV = Nd4j.create(N, N);
for (int i = 0; i < N; i++) {
VV.put(i, i, V.getDouble(i));
}
INDArray L = Aorig.mmul(A);
INDArray R = A.mmul(VV);
NdIndexIterator iter = new NdIndexIterator(L.shape());
while(iter.hasNext()){
int[] pos = iter.next();
assertEquals("SVD did not factorize properly", L.getDouble(pos), R.getDouble(pos), 1e-5);
}
}
示例8
@Test
public void testIterate() {
val shapeIter = new NdIndexIterator(2, 2);
val possibleSolutions = new long[][] {{0, 0}, {0, 1}, {1, 0}, {1, 1},};
for (int i = 0; i < 4; i++) {
assertArrayEquals(possibleSolutions[i], shapeIter.next());
}
}
示例9
@Override
public INDArray sample(INDArray target) {
Iterator<long[]> idxIter = new NdIndexIterator(target.shape()); //For consistent values irrespective of c vs. fortran ordering
long len = target.length();
for (long i = 0; i < len; i++) {
target.putScalar(idxIter.next(), sample());
}
return target;
}
示例10
@Override
public INDArray sample(INDArray ret) {
if (random.getStatePointer() != null) {
return Nd4j.getExecutioner().exec(new org.nd4j.linalg.api.ops.random.impl.UniformDistribution(
ret, lower, upper), random);
} else {
val idxIter = new NdIndexIterator(ret.shape()); //For consistent values irrespective of c vs. fortran ordering
long len = ret.length();
for (int i = 0; i < len; i++) {
ret.putScalar(idxIter.next(), sample());
}
return ret;
}
}
示例11
@Test
public void testEvaluationCalibration3d() {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
List<INDArray> rowsP = new ArrayList<>();
List<INDArray> rowsL = new ArrayList<>();
NdIndexIterator iter = new NdIndexIterator(2, 10);
while (iter.hasNext()) {
long[] idx = iter.next();
INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])};
rowsP.add(prediction.get(idxs));
rowsL.add(label.get(idxs));
}
INDArray p2d = Nd4j.vstack(rowsP);
INDArray l2d = Nd4j.vstack(rowsL);
EvaluationCalibration e3d = new EvaluationCalibration();
EvaluationCalibration e2d = new EvaluationCalibration();
e3d.eval(label, prediction);
e2d.eval(l2d, p2d);
System.out.println(e2d.stats());
assertEquals(e2d, e3d);
assertEquals(e2d.stats(), e3d.stats());
}
示例12
@Test
public void testEvaluationCalibration3dMasking() {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
List<INDArray> rowsP = new ArrayList<>();
List<INDArray> rowsL = new ArrayList<>();
//Check "DL4J-style" 2d per timestep masking [minibatch, seqLength] mask shape
INDArray mask2d = Nd4j.randomBernoulli(0.5, 2, 10);
NdIndexIterator iter = new NdIndexIterator(2, 10);
while (iter.hasNext()) {
long[] idx = iter.next();
if(mask2d.getDouble(idx[0], idx[1]) != 0.0) {
INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])};
rowsP.add(prediction.get(idxs));
rowsL.add(label.get(idxs));
}
}
INDArray p2d = Nd4j.vstack(rowsP);
INDArray l2d = Nd4j.vstack(rowsL);
EvaluationCalibration e3d_m2d = new EvaluationCalibration();
EvaluationCalibration e2d_m2d = new EvaluationCalibration();
e3d_m2d.eval(label, prediction, mask2d);
e2d_m2d.eval(l2d, p2d);
assertEquals(e3d_m2d, e2d_m2d);
}
示例13
@Test
public void testRegressionEval3d() {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
List<INDArray> rowsP = new ArrayList<>();
List<INDArray> rowsL = new ArrayList<>();
NdIndexIterator iter = new NdIndexIterator(2, 10);
while (iter.hasNext()) {
long[] idx = iter.next();
INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])};
rowsP.add(prediction.get(idxs));
rowsL.add(label.get(idxs));
}
INDArray p2d = Nd4j.vstack(rowsP);
INDArray l2d = Nd4j.vstack(rowsL);
RegressionEvaluation e3d = new RegressionEvaluation();
RegressionEvaluation e2d = new RegressionEvaluation();
e3d.eval(label, prediction);
e2d.eval(l2d, p2d);
for (Metric m : Metric.values()) {
double d1 = e3d.scoreForMetric(m);
double d2 = e2d.scoreForMetric(m);
assertEquals(m.toString(), d2, d1, 1e-6);
}
}
示例14
@Test
public void testRegressionEval4d() {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
List<INDArray> rowsP = new ArrayList<>();
List<INDArray> rowsL = new ArrayList<>();
NdIndexIterator iter = new NdIndexIterator(2, 10, 10);
while (iter.hasNext()) {
long[] idx = iter.next();
INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1]), NDArrayIndex.point(idx[2])};
rowsP.add(prediction.get(idxs));
rowsL.add(label.get(idxs));
}
INDArray p2d = Nd4j.vstack(rowsP);
INDArray l2d = Nd4j.vstack(rowsL);
RegressionEvaluation e4d = new RegressionEvaluation();
RegressionEvaluation e2d = new RegressionEvaluation();
e4d.eval(label, prediction);
e2d.eval(l2d, p2d);
for (Metric m : Metric.values()) {
double d1 = e4d.scoreForMetric(m);
double d2 = e2d.scoreForMetric(m);
assertEquals(m.toString(), d2, d1, 1e-5);
}
}
示例15
@Test
public void testEvaluationBinary3d() {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
List<INDArray> rowsP = new ArrayList<>();
List<INDArray> rowsL = new ArrayList<>();
NdIndexIterator iter = new NdIndexIterator(2, 10);
while (iter.hasNext()) {
long[] idx = iter.next();
INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])};
rowsP.add(prediction.get(idxs));
rowsL.add(label.get(idxs));
}
INDArray p2d = Nd4j.vstack(rowsP);
INDArray l2d = Nd4j.vstack(rowsL);
EvaluationBinary e3d = new EvaluationBinary();
EvaluationBinary e2d = new EvaluationBinary();
e3d.eval(label, prediction);
e2d.eval(l2d, p2d);
for (EvaluationBinary.Metric m : EvaluationBinary.Metric.values()) {
for( int i=0; i<5; i++ ) {
double d1 = e3d.scoreForMetric(m, i);
double d2 = e2d.scoreForMetric(m, i);
assertEquals(m.toString(), d2, d1, 1e-6);
}
}
}
示例16
@Test
public void testEvaluationBinary4d() {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
List<INDArray> rowsP = new ArrayList<>();
List<INDArray> rowsL = new ArrayList<>();
NdIndexIterator iter = new NdIndexIterator(2, 10, 10);
while (iter.hasNext()) {
long[] idx = iter.next();
INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1]), NDArrayIndex.point(idx[2])};
rowsP.add(prediction.get(idxs));
rowsL.add(label.get(idxs));
}
INDArray p2d = Nd4j.vstack(rowsP);
INDArray l2d = Nd4j.vstack(rowsL);
EvaluationBinary e4d = new EvaluationBinary();
EvaluationBinary e2d = new EvaluationBinary();
e4d.eval(label, prediction);
e2d.eval(l2d, p2d);
for (EvaluationBinary.Metric m : EvaluationBinary.Metric.values()) {
for( int i=0; i<3; i++ ) {
double d1 = e4d.scoreForMetric(m, i);
double d2 = e2d.scoreForMetric(m, i);
assertEquals(m.toString(), d2, d1, 1e-6);
}
}
}
示例17
@Test
public void testROCBinary3d() {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
List<INDArray> rowsP = new ArrayList<>();
List<INDArray> rowsL = new ArrayList<>();
NdIndexIterator iter = new NdIndexIterator(2, 10);
while (iter.hasNext()) {
long[] idx = iter.next();
INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])};
rowsP.add(prediction.get(idxs));
rowsL.add(label.get(idxs));
}
INDArray p2d = Nd4j.vstack(rowsP);
INDArray l2d = Nd4j.vstack(rowsL);
ROCBinary e3d = new ROCBinary();
ROCBinary e2d = new ROCBinary();
e3d.eval(label, prediction);
e2d.eval(l2d, p2d);
for (ROCBinary.Metric m : ROCBinary.Metric.values()) {
for( int i=0; i<5; i++ ) {
double d1 = e3d.scoreForMetric(m, i);
double d2 = e2d.scoreForMetric(m, i);
assertEquals(m.toString(), d2, d1, 1e-6);
}
}
}
示例18
@Test
public void testROCBinary4d() {
INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
List<INDArray> rowsP = new ArrayList<>();
List<INDArray> rowsL = new ArrayList<>();
NdIndexIterator iter = new NdIndexIterator(2, 10, 10);
while (iter.hasNext()) {
long[] idx = iter.next();
INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1]), NDArrayIndex.point(idx[2])};
rowsP.add(prediction.get(idxs));
rowsL.add(label.get(idxs));
}
INDArray p2d = Nd4j.vstack(rowsP);
INDArray l2d = Nd4j.vstack(rowsL);
ROCBinary e4d = new ROCBinary();
ROCBinary e2d = new ROCBinary();
e4d.eval(label, prediction);
e2d.eval(l2d, p2d);
for (ROCBinary.Metric m : ROCBinary.Metric.values()) {
for( int i=0; i<3; i++ ) {
double d1 = e4d.scoreForMetric(m, i);
double d2 = e2d.scoreForMetric(m, i);
assertEquals(m.toString(), d2, d1, 1e-6);
}
}
}
示例19
@Test
public void testIterate() {
val shapeIter = new NdIndexIterator(2, 2);
val possibleSolutions = new long[][] {{0, 0}, {0, 1}, {1, 0}, {1, 1},};
for (int i = 0; i < 4; i++) {
assertArrayEquals(possibleSolutions[i], shapeIter.next());
}
}
示例20
@Override
public INDArray sample(INDArray target) {
Iterator<long[]> idxIter = new NdIndexIterator(target.shape()); //For consistent values irrespective of c vs. fortran ordering
long len = target.length();
for (long i = 0; i < len; i++) {
target.putScalar(idxIter.next(), sample());
}
return target;
}
示例21
@Override
public INDArray sample(INDArray ret) {
if (random.getStatePointer() != null) {
return Nd4j.getExecutioner().exec(new org.nd4j.linalg.api.ops.random.impl.UniformDistribution(
ret, lower, upper), random);
} else {
val idxIter = new NdIndexIterator(ret.shape()); //For consistent values irrespective of c vs. fortran ordering
long len = ret.length();
for (int i = 0; i < len; i++) {
ret.putScalar(idxIter.next(), sample());
}
return ret;
}
}
示例22
@Override
public INDArray put(INDArray indices, INDArray element) {
if(indices.rank() > 2) {
throw new ND4JIllegalArgumentException("Indices must be a vector or matrix.");
}
if(indices.rows() == rank()) {
NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape());
for(int i = 0; i < indices.columns(); i++) {
int[] specifiedIndex = indices.getColumn(i).dup().data().asInt();
putScalar(specifiedIndex,element.getDouble(ndIndexIterator.next()));
}
}
else {
List<INDArray> arrList = new ArrayList<>();
if(indices.isMatrix() || indices.isColumnVector()) {
for(int i = 0; i < indices.rows(); i++) {
INDArray row = indices.getRow(i);
for(int j = 0; j < row.length(); j++) {
INDArray slice = slice(row.getInt(j));
Nd4j.getExecutioner().execAndReturn(new Assign(new INDArray[]{slice,element},new INDArray[]{slice}));
arrList.add(slice(row.getInt(j)));
}
}
}
else if(indices.isRowVector()) {
for(int i = 0; i < indices.length(); i++) {
arrList.add(slice(indices.getInt(i)));
}
}
}
return this;
}
示例23
public static float[] asFloat(INDArray arr) {
long len = arr.length();
if (len > Integer.MAX_VALUE)
throw new ND4JArraySizeException();
float[] f = new float[(int) len];
NdIndexIterator iterator = new NdIndexIterator('c', arr.shape());
for (int i = 0; i < len; i++) {
f[i] = arr.getFloat(iterator.next());
}
return f;
}
示例24
@Override
public int compareTo(@NotNull Object o) {
NDArrayWritable other = (NDArrayWritable) o;
//Conventions used here for ordering NDArrays: x.compareTo(y): -ve if x < y, 0 if x == y, +ve if x > y
//Null first
//Then smallest rank first
//Then smallest length first
//Then sort by shape
//Then sort by contents
//The idea: avoid comparing contents for as long as possible
if (this.array == null) {
if (other.array == null) {
return 0;
}
return -1;
}
if (other.array == null) {
return 1;
}
if (this.array.rank() != other.array.rank()) {
return Integer.compare(array.rank(), other.array.rank());
}
if (array.length() != other.array.length()) {
return Long.compare(array.length(), other.array.length());
}
for (int i = 0; i < array.rank(); i++) {
if (Long.compare(array.size(i), other.array.size(i)) != 0) {
return Long.compare(array.size(i), other.array.size(i));
}
}
//At this point: same rank, length, shape
NdIndexIterator iter = new NdIndexIterator('c', array.shape());
while (iter.hasNext()) {
long[] nextPos = iter.next();
double d1 = array.getDouble(nextPos);
double d2 = other.array.getDouble(nextPos);
if (Double.compare(d1, d2) != 0) {
return Double.compare(d1, d2);
}
}
//Same rank, length, shape and contents: must be equal
return 0;
}
示例25
@Test
public void testVarianceSingleVsMultipleDimensions() {
// this test should always run in double
DataBuffer.Type type = Nd4j.dataType();
DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE);
Nd4j.getRandom().setSeed(12345);
//Generate C order random numbers. Strides: [500,100,10,1]
INDArray fourd = Nd4j.rand('c', new int[] {100, 5, 10, 10}).muli(10);
INDArray twod = Shape.newShapeNoCopy(fourd, new int[] {100, 5 * 10 * 10}, false);
//Population variance. These two should be identical
INDArray var4 = fourd.var(false, 1, 2, 3);
INDArray var2 = twod.var(false, 1);
//Manual calculation of population variance, not bias corrected
//https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Na.C3.AFve_algorithm
double[] sums = new double[100];
double[] sumSquares = new double[100];
NdIndexIterator iter = new NdIndexIterator(fourd.shape());
while (iter.hasNext()) {
val next = iter.next();
double d = fourd.getDouble(next);
// FIXME: int cast
sums[(int) next[0]] += d;
sumSquares[(int) next[0]] += d * d;
}
double[] manualVariance = new double[100];
val N = (fourd.length() / sums.length);
for (int i = 0; i < sums.length; i++) {
manualVariance[i] = (sumSquares[i] - (sums[i] * sums[i]) / N) / N;
}
INDArray var4bias = fourd.var(true, 1, 2, 3);
INDArray var2bias = twod.var(true, 1);
assertArrayEquals(var2.data().asDouble(), var4.data().asDouble(), 1e-5);
assertArrayEquals(manualVariance, var2.data().asDouble(), 1e-5);
assertArrayEquals(var2bias.data().asDouble(), var4bias.data().asDouble(), 1e-5);
DataTypeUtil.setDTypeForContext(type);
}
示例26
@Test
public void testBufferToIntShapeStrideMethods() {
//Specifically: Shape.shape(IntBuffer), Shape.shape(DataBuffer)
//.isRowVectorShape(DataBuffer), .isRowVectorShape(IntBuffer)
//Shape.size(DataBuffer,int), Shape.size(IntBuffer,int)
//Also: Shape.stride(IntBuffer), Shape.stride(DataBuffer)
//Shape.stride(DataBuffer,int), Shape.stride(IntBuffer,int)
List<List<Pair<INDArray, String>>> lists = new ArrayList<>();
lists.add(NDArrayCreationUtil.getAllTestMatricesWithShape(3, 4, 12345));
lists.add(NDArrayCreationUtil.getAllTestMatricesWithShape(1, 4, 12345));
lists.add(NDArrayCreationUtil.getAllTestMatricesWithShape(3, 1, 12345));
lists.add(NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 3, 4, 5));
lists.add(NDArrayCreationUtil.getAll4dTestArraysWithShape(12345, 3, 4, 5, 6));
lists.add(NDArrayCreationUtil.getAll4dTestArraysWithShape(12345, 3, 1, 5, 1));
lists.add(NDArrayCreationUtil.getAll5dTestArraysWithShape(12345, 3, 4, 5, 6, 7));
lists.add(NDArrayCreationUtil.getAll6dTestArraysWithShape(12345, 3, 4, 5, 6, 7, 8));
val shapes = new long[][] {{3, 4}, {1, 4}, {3, 1}, {3, 4, 5}, {3, 4, 5, 6}, {3, 1, 5, 1}, {3, 4, 5, 6, 7},
{3, 4, 5, 6, 7, 8}};
for (int i = 0; i < shapes.length; i++) {
List<Pair<INDArray, String>> list = lists.get(i);
val shape = shapes[i];
for (Pair<INDArray, String> p : list) {
INDArray arr = p.getFirst();
assertArrayEquals(shape, arr.shape());
val thisStride = arr.stride();
val ib = arr.shapeInfo();
DataBuffer db = arr.shapeInfoDataBuffer();
//Check shape calculation
assertEquals(shape.length, Shape.rank(ib));
assertEquals(shape.length, Shape.rank(db));
assertArrayEquals(shape, Shape.shape(ib));
assertArrayEquals(shape, Shape.shape(db));
for (int j = 0; j < shape.length; j++) {
assertEquals(shape[j], Shape.size(ib, j));
assertEquals(shape[j], Shape.size(db, j));
assertEquals(thisStride[j], Shape.stride(ib, j));
assertEquals(thisStride[j], Shape.stride(db, j));
}
//Check base offset
assertEquals(Shape.offset(ib), Shape.offset(db));
//Check offset calculation:
NdIndexIterator iter = new NdIndexIterator(shape);
while (iter.hasNext()) {
val next = iter.next();
long offset1 = Shape.getOffset(ib, next);
assertEquals(offset1, Shape.getOffset(db, next));
switch (shape.length) {
case 2:
assertEquals(offset1, Shape.getOffset(ib, next[0], next[1]));
assertEquals(offset1, Shape.getOffset(db, next[0], next[1]));
break;
case 3:
assertEquals(offset1, Shape.getOffset(ib, next[0], next[1], next[2]));
assertEquals(offset1, Shape.getOffset(db, next[0], next[1], next[2]));
break;
case 4:
assertEquals(offset1, Shape.getOffset(ib, next[0], next[1], next[2], next[3]));
assertEquals(offset1, Shape.getOffset(db, next[0], next[1], next[2], next[3]));
break;
case 5:
case 6:
//No 5 and 6d getOffset overloads
break;
default:
throw new RuntimeException();
}
}
}
}
}
示例27
@Test
public void testPutScalar() {
//Check that the various putScalar methods have the same result...
val shapes = new int[][] {{3, 4}, {1, 4}, {3, 1}, {3, 4, 5}, {1, 4, 5}, {3, 1, 5}, {3, 4, 1}, {1, 1, 5},
{3, 4, 5, 6}, {1, 4, 5, 6}, {3, 1, 5, 6}, {3, 4, 1, 6}, {3, 4, 5, 1}, {1, 1, 5, 6},
{3, 1, 1, 6}, {3, 1, 1, 1}};
for (int[] shape : shapes) {
int rank = shape.length;
NdIndexIterator iter = new NdIndexIterator(shape);
INDArray firstC = Nd4j.create(shape, 'c');
INDArray firstF = Nd4j.create(shape, 'f');
INDArray secondC = Nd4j.create(shape, 'c');
INDArray secondF = Nd4j.create(shape, 'f');
int i = 0;
while (iter.hasNext()) {
val currIdx = iter.next();
firstC.putScalar(currIdx, i);
firstF.putScalar(currIdx, i);
switch (rank) {
case 2:
secondC.putScalar(currIdx[0], currIdx[1], i);
secondF.putScalar(currIdx[0], currIdx[1], i);
break;
case 3:
secondC.putScalar(currIdx[0], currIdx[1], currIdx[2], i);
secondF.putScalar(currIdx[0], currIdx[1], currIdx[2], i);
break;
case 4:
secondC.putScalar(currIdx[0], currIdx[1], currIdx[2], currIdx[3], i);
secondF.putScalar(currIdx[0], currIdx[1], currIdx[2], currIdx[3], i);
break;
default:
throw new RuntimeException();
}
i++;
}
assertEquals(firstC, firstF);
assertEquals(firstC, secondC);
assertEquals(firstC, secondF);
}
}
示例28
@Override
public INDArray put(INDArray indices, INDArray element) {
if(indices.rank() > 2) {
throw new ND4JIllegalArgumentException("Indices must be a vector or matrix.");
}
if(indices.rows() == rank()) {
NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape());
for(int i = 0; i < indices.columns(); i++) {
int[] specifiedIndex = indices.getColumn(i).dup().data().asInt();
putScalar(specifiedIndex,element.getDouble(ndIndexIterator.next()));
}
}
else {
List<INDArray> arrList = new ArrayList<>();
if(indices.isMatrix() || indices.isColumnVector()) {
for(int i = 0; i < indices.rows(); i++) {
INDArray row = indices.getRow(i);
for(int j = 0; j < row.length(); j++) {
INDArray slice = slice(row.getInt(j));
Nd4j.getExecutioner().exec(new Assign(new INDArray[]{slice,element},new INDArray[]{slice}));
arrList.add(slice(row.getInt(j)));
}
}
}
else if(indices.isRowVector()) {
for(int i = 0; i < indices.length(); i++) {
arrList.add(slice(indices.getInt(i)));
}
}
}
return this;
}
示例29
@Override
public INDArray put(List<List<Integer>> indices, INDArray element) {
if(indices.size() == rank()) {
NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape());
INDArrayIndex[] indArrayIndices = new INDArrayIndex[indices.size()];
for(int i = 0; i < indArrayIndices.length; i++) {
indArrayIndices[i] = new SpecifiedIndex(Ints.toArray(indices.get(i)));
}
boolean hasNext = true;
Generator<List<List<Long>>> iterate = SpecifiedIndex.iterate(indArrayIndices);
while(hasNext) {
try {
List<List<Long>> next = iterate.next();
for(int i = 0; i < next.size(); i++) {
int[] curr = Ints.toArray(next.get(i));
putScalar(curr,element.getDouble(ndIndexIterator.next()));
}
}
catch(NoSuchElementException e) {
hasNext = false;
}
}
}
else {
List<INDArray> arrList = new ArrayList<>();
if(indices.size() >= 2) {
for(int i = 0; i < indices.size(); i++) {
List<Integer> row = indices.get(i);
for(int j = 0; j < row.size(); j++) {
INDArray slice = slice(row.get(j));
Nd4j.getExecutioner().exec(new Assign(new INDArray[]{slice,element},new INDArray[]{slice}));
arrList.add(slice(row.get(j)));
}
}
}
else if(indices.size() == 1) {
for(int i = 0; i < indices.size(); i++) {
arrList.add(slice(indices.get(0).get(i)));
}
}
}
return this;
}
示例30
@Override
public int compareTo(@NonNull Object o) {
NDArrayWritable other = (NDArrayWritable) o;
//Conventions used here for ordering NDArrays: x.compareTo(y): -ve if x < y, 0 if x == y, +ve if x > y
//Null first
//Then smallest rank first
//Then smallest length first
//Then sort by shape
//Then sort by contents
//The idea: avoid comparing contents for as long as possible
if (this.array == null) {
if (other.array == null) {
return 0;
}
return -1;
}
if (other.array == null) {
return 1;
}
if (this.array.rank() != other.array.rank()) {
return Integer.compare(array.rank(), other.array.rank());
}
if (array.length() != other.array.length()) {
return Long.compare(array.length(), other.array.length());
}
for (int i = 0; i < array.rank(); i++) {
if (Long.compare(array.size(i), other.array.size(i)) != 0) {
return Long.compare(array.size(i), other.array.size(i));
}
}
//At this point: same rank, length, shape
NdIndexIterator iter = new NdIndexIterator('c', array.shape());
while (iter.hasNext()) {
long[] nextPos = iter.next();
double d1 = array.getDouble(nextPos);
double d2 = other.array.getDouble(nextPos);
if (Double.compare(d1, d2) != 0) {
return Double.compare(d1, d2);
}
}
//Same rank, length, shape and contents: must be equal
return 0;
}