Java源码示例:org.nd4j.linalg.api.iter.NdIndexIterator

示例1
void testSvd(int M, int N, char matrixOrder) {
    INDArray A = Nd4j.rand(M, N, matrixOrder);
    INDArray Aorig = A.dup();
    INDArray U = Nd4j.create(M, M, matrixOrder);
    INDArray S = Nd4j.create(N, matrixOrder);
    INDArray VT = Nd4j.create(N, N, matrixOrder);

    Nd4j.getBlasWrapper().lapack().gesvd(A, S, U, VT);

    INDArray SS = Nd4j.create(M, N);
    for (int i = 0; i < Math.min(M, N); i++) {
        SS.put(i, i, S.getDouble(i));
    }

    INDArray AA = U.mmul(SS).mmul(VT);
    NdIndexIterator iter = new NdIndexIterator(AA.shape());
    while(iter.hasNext()){
        int[] pos = iter.next();
        assertEquals("SVD did not factorize properly", AA.getDouble(pos), Aorig.getDouble(pos), 1e-5);
    }
}
 
示例2
@Override
public INDArray sample(INDArray ret) {
    if (random.getStatePointer() != null) {
        if (means != null) {
            return Nd4j.getExecutioner().exec(new GaussianDistribution(
                    ret, means, standardDeviation), random);
        } else {
            return Nd4j.getExecutioner().exec(new GaussianDistribution(
                    ret, mean, standardDeviation), random);
        }
    } else {
        Iterator<long[]> idxIter = new NdIndexIterator(ret.shape()); //For consistent values irrespective of c vs. fortran ordering
        long len = ret.length();
        if (means != null) {
            for (int i = 0; i < len; i++) {
                long[] idx = idxIter.next();
                ret.putScalar(idx, standardDeviation * random.nextGaussian() + means.getDouble(idx));
            }
        } else {
            for (int i = 0; i < len; i++) {
                ret.putScalar(idxIter.next(), standardDeviation * random.nextGaussian() + mean);
            }
        }
        return ret;
    }
}
 
示例3
@Override
public String apply(INDArray actual) {
    //TODO switch to binary relative error ops
    if(!Arrays.equals(expected.shape(), actual.shape())){
        throw new IllegalStateException("Shapes differ! " + Arrays.toString(expected.shape()) + " vs " + Arrays.toString(actual.shape()));
    }

    NdIndexIterator iter = new NdIndexIterator(expected.shape());
    while(iter.hasNext()){
        long[] next = iter.next();
        double d1 = expected.getDouble(next);
        double d2 = actual.getDouble(next);
        if(d1 == 0.0 && d2 == 0){
            continue;
        }
        if(Math.abs(d1-d2) < minAbsoluteError){
            continue;
        }
        double re = Math.abs(d1-d2) / (Math.abs(d1) + Math.abs(d2));
        if(re > maxRelativeError){
            return "Failed on relative error at position " + Arrays.toString(next) + ": relativeError=" + re + ", maxRE=" + maxRelativeError + ", absError=" +
                    Math.abs(d1-d2) + ", minAbsError=" + minAbsoluteError + " - values (" + d1 + "," + d2 + ")";
        }
    }
    return null;
}
 
示例4
@Override
public INDArray sample(INDArray ret) {
    if (random.getStatePointer() != null) {
        if (means != null) {
            return Nd4j.getExecutioner().exec(new GaussianDistribution(
                    ret, means, standardDeviation), random);
        } else {
            return Nd4j.getExecutioner().exec(new GaussianDistribution(
                    ret, mean, standardDeviation), random);
        }
    } else {
        Iterator<long[]> idxIter = new NdIndexIterator(ret.shape()); //For consistent values irrespective of c vs. fortran ordering
        long len = ret.length();
        if (means != null) {
            for (int i = 0; i < len; i++) {
                long[] idx = idxIter.next();
                ret.putScalar(idx, standardDeviation * random.nextGaussian() + means.getDouble(idx));
            }
        } else {
            for (int i = 0; i < len; i++) {
                ret.putScalar(idxIter.next(), standardDeviation * random.nextGaussian() + mean);
            }
        }
        return ret;
    }
}
 
示例5
public static void main(String[] args) {
    INDArray nd = Nd4j.create(new float[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, new int[]{2, 6});
    System.out.println("打印原有数组");
    System.out.println(nd);

    /*
        获取指定索引的值
     */
    System.out.println("获取数组下标为0, 3的值");
    double value = nd.getDouble(0, 3);
    System.out.println(value);

    /*
        修改指定索引的值
     */
    System.out.println("修改数组下标为0, 3的值");
    //scalar 标量
    nd.putScalar(0, 3, 100);

    System.out.println(nd);
    /*
        使用索引迭代器遍历ndarray,使用c order
     */
    System.out.println("使用索引迭代器遍历ndarray");
    NdIndexIterator iter = new NdIndexIterator(2, 6);
    while (iter.hasNext()) {
        long[] nextIndex = iter.next();
        double nextVal = nd.getDouble(nextIndex);

        System.out.println(nextVal);
    }
}
 
示例6
private static boolean allZeros(INDArray array) {
    NdIndexIterator iter = new NdIndexIterator(array.shape());
    while (iter.hasNext()) {
        double nextVal = array.getDouble(iter.next());
        if (nextVal != 0) {
            return false;
        }
    }
    return true;
}
 
示例7
void testEv(int N, char matrixOrder) {
    INDArray A = Nd4j.rand(N, N, matrixOrder);
    for (int r = 1; r < N; r++) {
        for (int c = 0; c < r; c++) {
            double v = A.getDouble(r, c);
            A.putScalar(c, r, v);
        }
    }

    INDArray Aorig = A.dup();
    INDArray V = Nd4j.create(N);

    Nd4j.getBlasWrapper().lapack().syev('V', 'U', A, V);

    INDArray VV = Nd4j.create(N, N);
    for (int i = 0; i < N; i++) {
        VV.put(i, i, V.getDouble(i));
    }

    INDArray L = Aorig.mmul(A);
    INDArray R = A.mmul(VV);

    NdIndexIterator iter = new NdIndexIterator(L.shape());
    while(iter.hasNext()){
        int[] pos = iter.next();
        assertEquals("SVD did not factorize properly", L.getDouble(pos), R.getDouble(pos), 1e-5);
    }
}
 
示例8
@Test
public void testIterate() {
    val shapeIter = new NdIndexIterator(2, 2);
    val possibleSolutions = new long[][] {{0, 0}, {0, 1}, {1, 0}, {1, 1},};

    for (int i = 0; i < 4; i++) {
        assertArrayEquals(possibleSolutions[i], shapeIter.next());
    }


}
 
示例9
@Override
public INDArray sample(INDArray target) {
    Iterator<long[]> idxIter = new NdIndexIterator(target.shape()); //For consistent values irrespective of c vs. fortran ordering
    long len = target.length();
    for (long i = 0; i < len; i++) {
        target.putScalar(idxIter.next(), sample());
    }
    return target;
}
 
示例10
@Override
public INDArray sample(INDArray ret) {
    if (random.getStatePointer() != null) {
        return Nd4j.getExecutioner().exec(new org.nd4j.linalg.api.ops.random.impl.UniformDistribution(
                ret, lower, upper), random);
    } else {
        val idxIter = new NdIndexIterator(ret.shape()); //For consistent values irrespective of c vs. fortran ordering
        long len = ret.length();
        for (int i = 0; i < len; i++) {
            ret.putScalar(idxIter.next(), sample());
        }
        return ret;
    }
}
 
示例11
@Test
public void testEvaluationCalibration3d() {
    INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
    INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);


    List<INDArray> rowsP = new ArrayList<>();
    List<INDArray> rowsL = new ArrayList<>();
    NdIndexIterator iter = new NdIndexIterator(2, 10);
    while (iter.hasNext()) {
        long[] idx = iter.next();
        INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])};
        rowsP.add(prediction.get(idxs));
        rowsL.add(label.get(idxs));
    }

    INDArray p2d = Nd4j.vstack(rowsP);
    INDArray l2d = Nd4j.vstack(rowsL);

    EvaluationCalibration e3d = new EvaluationCalibration();
    EvaluationCalibration e2d = new EvaluationCalibration();

    e3d.eval(label, prediction);
    e2d.eval(l2d, p2d);

    System.out.println(e2d.stats());

    assertEquals(e2d, e3d);

    assertEquals(e2d.stats(), e3d.stats());
}
 
示例12
@Test
public void testEvaluationCalibration3dMasking() {
    INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10);
    INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10);

    List<INDArray> rowsP = new ArrayList<>();
    List<INDArray> rowsL = new ArrayList<>();

    //Check "DL4J-style" 2d per timestep masking [minibatch, seqLength] mask shape
    INDArray mask2d = Nd4j.randomBernoulli(0.5, 2, 10);
    NdIndexIterator iter = new NdIndexIterator(2, 10);
    while (iter.hasNext()) {
        long[] idx = iter.next();
        if(mask2d.getDouble(idx[0], idx[1]) != 0.0) {
            INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])};
            rowsP.add(prediction.get(idxs));
            rowsL.add(label.get(idxs));
        }
    }
    INDArray p2d = Nd4j.vstack(rowsP);
    INDArray l2d = Nd4j.vstack(rowsL);

    EvaluationCalibration e3d_m2d = new EvaluationCalibration();
    EvaluationCalibration e2d_m2d = new EvaluationCalibration();
    e3d_m2d.eval(label, prediction, mask2d);
    e2d_m2d.eval(l2d, p2d);

    assertEquals(e3d_m2d, e2d_m2d);
}
 
示例13
@Test
public void testRegressionEval3d() {
    INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
    INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);


    List<INDArray> rowsP = new ArrayList<>();
    List<INDArray> rowsL = new ArrayList<>();
    NdIndexIterator iter = new NdIndexIterator(2, 10);
    while (iter.hasNext()) {
        long[] idx = iter.next();
        INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])};
        rowsP.add(prediction.get(idxs));
        rowsL.add(label.get(idxs));
    }

    INDArray p2d = Nd4j.vstack(rowsP);
    INDArray l2d = Nd4j.vstack(rowsL);

    RegressionEvaluation e3d = new RegressionEvaluation();
    RegressionEvaluation e2d = new RegressionEvaluation();

    e3d.eval(label, prediction);
    e2d.eval(l2d, p2d);

    for (Metric m : Metric.values()) {
        double d1 = e3d.scoreForMetric(m);
        double d2 = e2d.scoreForMetric(m);
        assertEquals(m.toString(), d2, d1, 1e-6);
    }
}
 
示例14
@Test
public void testRegressionEval4d() {
    INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
    INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);


    List<INDArray> rowsP = new ArrayList<>();
    List<INDArray> rowsL = new ArrayList<>();
    NdIndexIterator iter = new NdIndexIterator(2, 10, 10);
    while (iter.hasNext()) {
        long[] idx = iter.next();
        INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1]), NDArrayIndex.point(idx[2])};
        rowsP.add(prediction.get(idxs));
        rowsL.add(label.get(idxs));
    }

    INDArray p2d = Nd4j.vstack(rowsP);
    INDArray l2d = Nd4j.vstack(rowsL);

    RegressionEvaluation e4d = new RegressionEvaluation();
    RegressionEvaluation e2d = new RegressionEvaluation();

    e4d.eval(label, prediction);
    e2d.eval(l2d, p2d);

    for (Metric m : Metric.values()) {
        double d1 = e4d.scoreForMetric(m);
        double d2 = e2d.scoreForMetric(m);
        assertEquals(m.toString(), d2, d1, 1e-5);
    }
}
 
示例15
@Test
public void testEvaluationBinary3d() {
    INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
    INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);


    List<INDArray> rowsP = new ArrayList<>();
    List<INDArray> rowsL = new ArrayList<>();
    NdIndexIterator iter = new NdIndexIterator(2, 10);
    while (iter.hasNext()) {
        long[] idx = iter.next();
        INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])};
        rowsP.add(prediction.get(idxs));
        rowsL.add(label.get(idxs));
    }

    INDArray p2d = Nd4j.vstack(rowsP);
    INDArray l2d = Nd4j.vstack(rowsL);

    EvaluationBinary e3d = new EvaluationBinary();
    EvaluationBinary e2d = new EvaluationBinary();

    e3d.eval(label, prediction);
    e2d.eval(l2d, p2d);

    for (EvaluationBinary.Metric m : EvaluationBinary.Metric.values()) {
        for( int i=0; i<5; i++ ) {
            double d1 = e3d.scoreForMetric(m, i);
            double d2 = e2d.scoreForMetric(m, i);
            assertEquals(m.toString(), d2, d1, 1e-6);
        }
    }
}
 
示例16
@Test
public void testEvaluationBinary4d() {
    INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
    INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);


    List<INDArray> rowsP = new ArrayList<>();
    List<INDArray> rowsL = new ArrayList<>();
    NdIndexIterator iter = new NdIndexIterator(2, 10, 10);
    while (iter.hasNext()) {
        long[] idx = iter.next();
        INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1]), NDArrayIndex.point(idx[2])};
        rowsP.add(prediction.get(idxs));
        rowsL.add(label.get(idxs));
    }

    INDArray p2d = Nd4j.vstack(rowsP);
    INDArray l2d = Nd4j.vstack(rowsL);

    EvaluationBinary e4d = new EvaluationBinary();
    EvaluationBinary e2d = new EvaluationBinary();

    e4d.eval(label, prediction);
    e2d.eval(l2d, p2d);

    for (EvaluationBinary.Metric m : EvaluationBinary.Metric.values()) {
        for( int i=0; i<3; i++ ) {
            double d1 = e4d.scoreForMetric(m, i);
            double d2 = e2d.scoreForMetric(m, i);
            assertEquals(m.toString(), d2, d1, 1e-6);
        }
    }
}
 
示例17
@Test
public void testROCBinary3d() {
    INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 5, 10);
    INDArray label = Nd4j.rand(DataType.FLOAT, 2, 5, 10);


    List<INDArray> rowsP = new ArrayList<>();
    List<INDArray> rowsL = new ArrayList<>();
    NdIndexIterator iter = new NdIndexIterator(2, 10);
    while (iter.hasNext()) {
        long[] idx = iter.next();
        INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1])};
        rowsP.add(prediction.get(idxs));
        rowsL.add(label.get(idxs));
    }

    INDArray p2d = Nd4j.vstack(rowsP);
    INDArray l2d = Nd4j.vstack(rowsL);

    ROCBinary e3d = new ROCBinary();
    ROCBinary e2d = new ROCBinary();

    e3d.eval(label, prediction);
    e2d.eval(l2d, p2d);

    for (ROCBinary.Metric m : ROCBinary.Metric.values()) {
        for( int i=0; i<5; i++ ) {
            double d1 = e3d.scoreForMetric(m, i);
            double d2 = e2d.scoreForMetric(m, i);
            assertEquals(m.toString(), d2, d1, 1e-6);
        }
    }
}
 
示例18
@Test
public void testROCBinary4d() {
    INDArray prediction = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);
    INDArray label = Nd4j.rand(DataType.FLOAT, 2, 3, 10, 10);


    List<INDArray> rowsP = new ArrayList<>();
    List<INDArray> rowsL = new ArrayList<>();
    NdIndexIterator iter = new NdIndexIterator(2, 10, 10);
    while (iter.hasNext()) {
        long[] idx = iter.next();
        INDArrayIndex[] idxs = new INDArrayIndex[]{NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[1]), NDArrayIndex.point(idx[2])};
        rowsP.add(prediction.get(idxs));
        rowsL.add(label.get(idxs));
    }

    INDArray p2d = Nd4j.vstack(rowsP);
    INDArray l2d = Nd4j.vstack(rowsL);

    ROCBinary e4d = new ROCBinary();
    ROCBinary e2d = new ROCBinary();

    e4d.eval(label, prediction);
    e2d.eval(l2d, p2d);

    for (ROCBinary.Metric m : ROCBinary.Metric.values()) {
        for( int i=0; i<3; i++ ) {
            double d1 = e4d.scoreForMetric(m, i);
            double d2 = e2d.scoreForMetric(m, i);
            assertEquals(m.toString(), d2, d1, 1e-6);
        }
    }
}
 
示例19
@Test
public void testIterate() {
    val shapeIter = new NdIndexIterator(2, 2);
    val possibleSolutions = new long[][] {{0, 0}, {0, 1}, {1, 0}, {1, 1},};

    for (int i = 0; i < 4; i++) {
        assertArrayEquals(possibleSolutions[i], shapeIter.next());
    }


}
 
示例20
@Override
public INDArray sample(INDArray target) {
    Iterator<long[]> idxIter = new NdIndexIterator(target.shape()); //For consistent values irrespective of c vs. fortran ordering
    long len = target.length();
    for (long i = 0; i < len; i++) {
        target.putScalar(idxIter.next(), sample());
    }
    return target;
}
 
示例21
@Override
public INDArray sample(INDArray ret) {
    if (random.getStatePointer() != null) {
        return Nd4j.getExecutioner().exec(new org.nd4j.linalg.api.ops.random.impl.UniformDistribution(
                ret, lower, upper), random);
    } else {
        val idxIter = new NdIndexIterator(ret.shape()); //For consistent values irrespective of c vs. fortran ordering
        long len = ret.length();
        for (int i = 0; i < len; i++) {
            ret.putScalar(idxIter.next(), sample());
        }
        return ret;
    }
}
 
示例22
@Override
public INDArray put(INDArray indices, INDArray element) {
    if(indices.rank() > 2) {
        throw new ND4JIllegalArgumentException("Indices must be a vector or matrix.");
    }

    if(indices.rows() == rank()) {
        NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape());
        for(int i = 0; i < indices.columns(); i++) {
            int[] specifiedIndex = indices.getColumn(i).dup().data().asInt();
            putScalar(specifiedIndex,element.getDouble(ndIndexIterator.next()));
        }
    }
    else {
        List<INDArray> arrList = new ArrayList<>();

        if(indices.isMatrix() || indices.isColumnVector()) {
            for(int i = 0; i < indices.rows(); i++) {
                INDArray row = indices.getRow(i);
                for(int j = 0; j < row.length(); j++) {
                    INDArray slice = slice(row.getInt(j));
                    Nd4j.getExecutioner().execAndReturn(new Assign(new INDArray[]{slice,element},new INDArray[]{slice}));
                    arrList.add(slice(row.getInt(j)));
                }
            }
        }
        else if(indices.isRowVector()) {
            for(int i = 0; i < indices.length(); i++) {
                arrList.add(slice(indices.getInt(i)));
            }
        }
    }
    return this;
}
 
示例23
public static float[] asFloat(INDArray arr) {
    long len = arr.length();
    if (len > Integer.MAX_VALUE)
        throw new ND4JArraySizeException();
    float[] f = new float[(int) len];
    NdIndexIterator iterator = new NdIndexIterator('c', arr.shape());
    for (int i = 0; i < len; i++) {
        f[i] = arr.getFloat(iterator.next());
    }
    return f;
}
 
示例24
@Override
public int compareTo(@NotNull Object o) {
    NDArrayWritable other = (NDArrayWritable) o;

    //Conventions used here for ordering NDArrays: x.compareTo(y): -ve if x < y, 0 if x == y, +ve if x > y
    //Null first
    //Then smallest rank first
    //Then smallest length first
    //Then sort by shape
    //Then sort by contents
    //The idea: avoid comparing contents for as long as possible

    if (this.array == null) {
        if (other.array == null) {
            return 0;
        }
        return -1;
    }
    if (other.array == null) {
        return 1;
    }

    if (this.array.rank() != other.array.rank()) {
        return Integer.compare(array.rank(), other.array.rank());
    }

    if (array.length() != other.array.length()) {
        return Long.compare(array.length(), other.array.length());
    }

    for (int i = 0; i < array.rank(); i++) {
        if (Long.compare(array.size(i), other.array.size(i)) != 0) {
            return Long.compare(array.size(i), other.array.size(i));
        }
    }

    //At this point: same rank, length, shape
    NdIndexIterator iter = new NdIndexIterator('c', array.shape());
    while (iter.hasNext()) {
        long[] nextPos = iter.next();
        double d1 = array.getDouble(nextPos);
        double d2 = other.array.getDouble(nextPos);

        if (Double.compare(d1, d2) != 0) {
            return Double.compare(d1, d2);
        }
    }

    //Same rank, length, shape and contents: must be equal
    return 0;
}
 
示例25
@Test
public void testVarianceSingleVsMultipleDimensions() {
    // this test should always run in double
    DataBuffer.Type type = Nd4j.dataType();
    DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE);
    Nd4j.getRandom().setSeed(12345);

    //Generate C order random numbers. Strides: [500,100,10,1]
    INDArray fourd = Nd4j.rand('c', new int[] {100, 5, 10, 10}).muli(10);
    INDArray twod = Shape.newShapeNoCopy(fourd, new int[] {100, 5 * 10 * 10}, false);

    //Population variance. These two should be identical
    INDArray var4 = fourd.var(false, 1, 2, 3);
    INDArray var2 = twod.var(false, 1);

    //Manual calculation of population variance, not bias corrected
    //https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Na.C3.AFve_algorithm
    double[] sums = new double[100];
    double[] sumSquares = new double[100];
    NdIndexIterator iter = new NdIndexIterator(fourd.shape());
    while (iter.hasNext()) {
        val next = iter.next();
        double d = fourd.getDouble(next);

        // FIXME: int cast
        sums[(int) next[0]] += d;
        sumSquares[(int) next[0]] += d * d;
    }

    double[] manualVariance = new double[100];
    val N = (fourd.length() / sums.length);
    for (int i = 0; i < sums.length; i++) {
        manualVariance[i] = (sumSquares[i] - (sums[i] * sums[i]) / N) / N;
    }

    INDArray var4bias = fourd.var(true, 1, 2, 3);
    INDArray var2bias = twod.var(true, 1);

    assertArrayEquals(var2.data().asDouble(), var4.data().asDouble(), 1e-5);
    assertArrayEquals(manualVariance, var2.data().asDouble(), 1e-5);
    assertArrayEquals(var2bias.data().asDouble(), var4bias.data().asDouble(), 1e-5);

    DataTypeUtil.setDTypeForContext(type);
}
 
示例26
@Test
public void testBufferToIntShapeStrideMethods() {
    //Specifically: Shape.shape(IntBuffer), Shape.shape(DataBuffer)
    //.isRowVectorShape(DataBuffer), .isRowVectorShape(IntBuffer)
    //Shape.size(DataBuffer,int), Shape.size(IntBuffer,int)
    //Also: Shape.stride(IntBuffer), Shape.stride(DataBuffer)
    //Shape.stride(DataBuffer,int), Shape.stride(IntBuffer,int)

    List<List<Pair<INDArray, String>>> lists = new ArrayList<>();
    lists.add(NDArrayCreationUtil.getAllTestMatricesWithShape(3, 4, 12345));
    lists.add(NDArrayCreationUtil.getAllTestMatricesWithShape(1, 4, 12345));
    lists.add(NDArrayCreationUtil.getAllTestMatricesWithShape(3, 1, 12345));
    lists.add(NDArrayCreationUtil.getAll3dTestArraysWithShape(12345, 3, 4, 5));
    lists.add(NDArrayCreationUtil.getAll4dTestArraysWithShape(12345, 3, 4, 5, 6));
    lists.add(NDArrayCreationUtil.getAll4dTestArraysWithShape(12345, 3, 1, 5, 1));
    lists.add(NDArrayCreationUtil.getAll5dTestArraysWithShape(12345, 3, 4, 5, 6, 7));
    lists.add(NDArrayCreationUtil.getAll6dTestArraysWithShape(12345, 3, 4, 5, 6, 7, 8));

    val shapes = new long[][] {{3, 4}, {1, 4}, {3, 1}, {3, 4, 5}, {3, 4, 5, 6}, {3, 1, 5, 1}, {3, 4, 5, 6, 7},
                    {3, 4, 5, 6, 7, 8}};

    for (int i = 0; i < shapes.length; i++) {
        List<Pair<INDArray, String>> list = lists.get(i);
        val shape = shapes[i];

        for (Pair<INDArray, String> p : list) {
            INDArray arr = p.getFirst();

            assertArrayEquals(shape, arr.shape());

            val thisStride = arr.stride();

            val ib = arr.shapeInfo();
            DataBuffer db = arr.shapeInfoDataBuffer();

            //Check shape calculation
            assertEquals(shape.length, Shape.rank(ib));
            assertEquals(shape.length, Shape.rank(db));

            assertArrayEquals(shape, Shape.shape(ib));
            assertArrayEquals(shape, Shape.shape(db));

            for (int j = 0; j < shape.length; j++) {
                assertEquals(shape[j], Shape.size(ib, j));
                assertEquals(shape[j], Shape.size(db, j));

                assertEquals(thisStride[j], Shape.stride(ib, j));
                assertEquals(thisStride[j], Shape.stride(db, j));
            }

            //Check base offset
            assertEquals(Shape.offset(ib), Shape.offset(db));

            //Check offset calculation:
            NdIndexIterator iter = new NdIndexIterator(shape);
            while (iter.hasNext()) {
                val next = iter.next();
                long offset1 = Shape.getOffset(ib, next);

                assertEquals(offset1, Shape.getOffset(db, next));

                switch (shape.length) {
                    case 2:
                        assertEquals(offset1, Shape.getOffset(ib, next[0], next[1]));
                        assertEquals(offset1, Shape.getOffset(db, next[0], next[1]));
                        break;
                    case 3:
                        assertEquals(offset1, Shape.getOffset(ib, next[0], next[1], next[2]));
                        assertEquals(offset1, Shape.getOffset(db, next[0], next[1], next[2]));
                        break;
                    case 4:
                        assertEquals(offset1, Shape.getOffset(ib, next[0], next[1], next[2], next[3]));
                        assertEquals(offset1, Shape.getOffset(db, next[0], next[1], next[2], next[3]));
                        break;
                    case 5:
                    case 6:
                        //No 5 and 6d getOffset overloads
                        break;
                    default:
                        throw new RuntimeException();
                }
            }
        }
    }
}
 
示例27
@Test
public void testPutScalar() {
    //Check that the various putScalar methods have the same result...
    val shapes = new int[][] {{3, 4}, {1, 4}, {3, 1}, {3, 4, 5}, {1, 4, 5}, {3, 1, 5}, {3, 4, 1}, {1, 1, 5},
                    {3, 4, 5, 6}, {1, 4, 5, 6}, {3, 1, 5, 6}, {3, 4, 1, 6}, {3, 4, 5, 1}, {1, 1, 5, 6},
                    {3, 1, 1, 6}, {3, 1, 1, 1}};

    for (int[] shape : shapes) {
        int rank = shape.length;
        NdIndexIterator iter = new NdIndexIterator(shape);
        INDArray firstC = Nd4j.create(shape, 'c');
        INDArray firstF = Nd4j.create(shape, 'f');
        INDArray secondC = Nd4j.create(shape, 'c');
        INDArray secondF = Nd4j.create(shape, 'f');

        int i = 0;
        while (iter.hasNext()) {
            val currIdx = iter.next();
            firstC.putScalar(currIdx, i);
            firstF.putScalar(currIdx, i);

            switch (rank) {
                case 2:
                    secondC.putScalar(currIdx[0], currIdx[1], i);
                    secondF.putScalar(currIdx[0], currIdx[1], i);
                    break;
                case 3:
                    secondC.putScalar(currIdx[0], currIdx[1], currIdx[2], i);
                    secondF.putScalar(currIdx[0], currIdx[1], currIdx[2], i);
                    break;
                case 4:
                    secondC.putScalar(currIdx[0], currIdx[1], currIdx[2], currIdx[3], i);
                    secondF.putScalar(currIdx[0], currIdx[1], currIdx[2], currIdx[3], i);
                    break;
                default:
                    throw new RuntimeException();
            }
            i++;
        }
        assertEquals(firstC, firstF);
        assertEquals(firstC, secondC);
        assertEquals(firstC, secondF);
    }
}
 
示例28
@Override
public INDArray put(INDArray indices, INDArray element) {
    if(indices.rank() > 2) {
        throw new ND4JIllegalArgumentException("Indices must be a vector or matrix.");
    }

    if(indices.rows() == rank()) {
        NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape());
        for(int i = 0; i < indices.columns(); i++) {
            int[] specifiedIndex = indices.getColumn(i).dup().data().asInt();
            putScalar(specifiedIndex,element.getDouble(ndIndexIterator.next()));
        }

    }
    else {
        List<INDArray> arrList = new ArrayList<>();

        if(indices.isMatrix() || indices.isColumnVector()) {
            for(int i = 0; i < indices.rows(); i++) {
                INDArray row = indices.getRow(i);
                for(int j = 0; j < row.length(); j++) {
                    INDArray slice = slice(row.getInt(j));
                    Nd4j.getExecutioner().exec(new Assign(new INDArray[]{slice,element},new INDArray[]{slice}));
                    arrList.add(slice(row.getInt(j)));
                }


            }
        }
        else if(indices.isRowVector()) {
            for(int i = 0; i < indices.length(); i++) {
                arrList.add(slice(indices.getInt(i)));
            }
        }

    }


    return this;

}
 
示例29
@Override
public INDArray put(List<List<Integer>> indices, INDArray element) {
    if(indices.size() == rank()) {
        NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape());
        INDArrayIndex[] indArrayIndices = new INDArrayIndex[indices.size()];
        for(int i = 0; i < indArrayIndices.length; i++) {
            indArrayIndices[i] = new SpecifiedIndex(Ints.toArray(indices.get(i)));
        }
        boolean hasNext = true;
        Generator<List<List<Long>>> iterate = SpecifiedIndex.iterate(indArrayIndices);
        while(hasNext) {
            try {
                List<List<Long>> next = iterate.next();
                for(int i = 0; i < next.size(); i++) {
                    int[] curr = Ints.toArray(next.get(i));
                    putScalar(curr,element.getDouble(ndIndexIterator.next()));
                }
            }
            catch(NoSuchElementException e) {
                hasNext = false;
            }
        }

    }
    else {
        List<INDArray> arrList = new ArrayList<>();

        if(indices.size() >= 2) {
            for(int i = 0; i < indices.size(); i++) {
                List<Integer> row = indices.get(i);
                for(int j = 0; j < row.size(); j++) {
                    INDArray slice = slice(row.get(j));
                    Nd4j.getExecutioner().exec(new Assign(new INDArray[]{slice,element},new INDArray[]{slice}));
                    arrList.add(slice(row.get(j)));
                }


            }
        }
        else if(indices.size() == 1) {
            for(int i = 0; i < indices.size(); i++) {
                arrList.add(slice(indices.get(0).get(i)));
            }
        }

    }


    return this;
}
 
示例30
@Override
public int compareTo(@NonNull Object o) {
    NDArrayWritable other = (NDArrayWritable) o;

    //Conventions used here for ordering NDArrays: x.compareTo(y): -ve if x < y, 0 if x == y, +ve if x > y
    //Null first
    //Then smallest rank first
    //Then smallest length first
    //Then sort by shape
    //Then sort by contents
    //The idea: avoid comparing contents for as long as possible

    if (this.array == null) {
        if (other.array == null) {
            return 0;
        }
        return -1;
    }
    if (other.array == null) {
        return 1;
    }

    if (this.array.rank() != other.array.rank()) {
        return Integer.compare(array.rank(), other.array.rank());
    }

    if (array.length() != other.array.length()) {
        return Long.compare(array.length(), other.array.length());
    }

    for (int i = 0; i < array.rank(); i++) {
        if (Long.compare(array.size(i), other.array.size(i)) != 0) {
            return Long.compare(array.size(i), other.array.size(i));
        }
    }

    //At this point: same rank, length, shape
    NdIndexIterator iter = new NdIndexIterator('c', array.shape());
    while (iter.hasNext()) {
        long[] nextPos = iter.next();
        double d1 = array.getDouble(nextPos);
        double d2 = other.array.getDouble(nextPos);

        if (Double.compare(d1, d2) != 0) {
            return Double.compare(d1, d2);
        }
    }

    //Same rank, length, shape and contents: must be equal
    return 0;
}