Java源码示例:org.nd4j.linalg.api.rng.distribution.Distribution
示例1
@Test
public void testSample() {
int seed = 1000;
Distribution oldFunction = getOldFunction(seed);
MathProbability<Number> newFuction = getNewFunction(seed);
for (int index = 0; index < seed; index++) {
newFuction.setSeed(index);
oldFunction.reseedRandomGenerator(index);
assertSample(newFuction, oldFunction);
}
Assert.assertThat(newFuction.getMaximum().doubleValue(), CoreMatchers.equalTo(oldFunction.getSupportUpperBound()));
Assert.assertThat(newFuction.getMinimum().doubleValue(), CoreMatchers.equalTo(oldFunction.getSupportLowerBound()));
Assert.assertThat(newFuction.inverseDistribution(1D).doubleValue(), CoreMatchers.equalTo(oldFunction.getSupportUpperBound()));
Assert.assertThat(newFuction.inverseDistribution(0D).doubleValue(), CoreMatchers.equalTo(oldFunction.getSupportLowerBound()));
Assert.assertThat(newFuction.cumulativeDistribution(newFuction.getMaximum()), CoreMatchers.equalTo(oldFunction.cumulativeProbability(oldFunction.getSupportUpperBound())));
Assert.assertThat(newFuction.cumulativeDistribution(newFuction.getMinimum()), CoreMatchers.equalTo(oldFunction.cumulativeProbability(oldFunction.getSupportLowerBound())));
}
示例2
@Test
public void testNesterovs() {
int rows = 10;
int cols = 2;
NesterovsUpdater grad = new NesterovsUpdater(new Nesterovs(0.5, 0.9));
grad.setStateViewArray(Nd4j.zeros(1, rows * cols), new long[] {rows, cols}, 'c', true);
INDArray W = Nd4j.zeros(rows, cols);
Distribution dist = Nd4j.getDistributions().createNormal(1, 1);
for (int i = 0; i < W.rows(); i++)
W.putRow(i, Nd4j.create(dist.sample(W.columns())));
for (int i = 0; i < 5; i++) {
// String learningRates = String.valueOf("\nAdagrad\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
// System.out.println(learningRates);
W.addi(Nd4j.randn(rows, cols));
}
}
示例3
@Test
public void testAdaGrad() {
int rows = 10;
int cols = 2;
AdaGradUpdater grad = new AdaGradUpdater(new AdaGrad(0.1, AdaGrad.DEFAULT_ADAGRAD_EPSILON));
grad.setStateViewArray(Nd4j.zeros(1, rows * cols), new long[] {rows, cols}, 'c', true);
INDArray W = Nd4j.zeros(rows, cols);
Distribution dist = Nd4j.getDistributions().createNormal(1, 1);
for (int i = 0; i < W.rows(); i++)
W.putRow(i, Nd4j.create(dist.sample(W.columns())));
for (int i = 0; i < 5; i++) {
// String learningRates = String.valueOf("\nAdagrad\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
// System.out.println(learningRates);
W.addi(Nd4j.randn(rows, cols));
}
}
示例4
@Test
public void testAdaDelta() {
int rows = 10;
int cols = 2;
AdaDeltaUpdater grad = new AdaDeltaUpdater(new AdaDelta());
grad.setStateViewArray(Nd4j.zeros(1, 2 * rows * cols), new long[] {rows, cols}, 'c', true);
INDArray W = Nd4j.zeros(rows, cols);
Distribution dist = Nd4j.getDistributions().createNormal(1e-3, 1e-3);
for (int i = 0; i < W.rows(); i++)
W.putRow(i, Nd4j.create(dist.sample(W.columns())));
for (int i = 0; i < 5; i++) {
// String learningRates = String.valueOf("\nAdaelta\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
// System.out.println(learningRates);
W.addi(Nd4j.randn(rows, cols));
}
}
示例5
@Test
public void testAdam() {
int rows = 10;
int cols = 2;
AdamUpdater grad = new AdamUpdater(new Adam());
grad.setStateViewArray(Nd4j.zeros(1, 2 * rows * cols), new long[] {rows, cols}, 'c', true);
INDArray W = Nd4j.zeros(rows, cols);
Distribution dist = Nd4j.getDistributions().createNormal(1e-3, 1e-3);
for (int i = 0; i < W.rows(); i++)
W.putRow(i, Nd4j.create(dist.sample(W.columns())));
for (int i = 0; i < 5; i++) {
// String learningRates = String.valueOf("\nAdamUpdater\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
// System.out.println(learningRates);
W.addi(Nd4j.randn(rows, cols));
}
}
示例6
@Test
public void testNadam() {
int rows = 10;
int cols = 2;
NadamUpdater grad = new NadamUpdater(new Nadam());
grad.setStateViewArray(Nd4j.zeros(1, 2 * rows * cols), new long[] {rows, cols}, 'c', true);
INDArray W = Nd4j.zeros(rows, cols);
Distribution dist = Nd4j.getDistributions().createNormal(1e-3, 1e-3);
for (int i = 0; i < W.rows(); i++)
W.putRow(i, Nd4j.create(dist.sample(W.columns())));
for (int i = 0; i < 5; i++) {
// String learningRates = String.valueOf("\nAdamUpdater\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
// System.out.println(learningRates);
W.addi(Nd4j.randn(rows, cols));
}
}
示例7
@Test
public void testAdaMax() {
int rows = 10;
int cols = 2;
AdaMaxUpdater grad = new AdaMaxUpdater(new AdaMax());
grad.setStateViewArray(Nd4j.zeros(1, 2 * rows * cols), new long[] {rows, cols}, 'c', true);
INDArray W = Nd4j.zeros(rows, cols);
Distribution dist = Nd4j.getDistributions().createNormal(1e-3, 1e-3);
for (int i = 0; i < W.rows(); i++)
W.putRow(i, Nd4j.create(dist.sample(W.columns())));
for (int i = 0; i < 5; i++) {
// String learningRates = String.valueOf("\nAdaMax\n " + grad.getGradient(W, i)).replaceAll(";", "\n");
// System.out.println(learningRates);
W.addi(Nd4j.randn(rows, cols));
}
}
示例8
@Test
public void testNesterovs() {
int rows = 10;
int cols = 2;
NesterovsUpdater grad = new NesterovsUpdater(new Nesterovs(0.5, 0.9));
grad.setStateViewArray(Nd4j.zeros(1, rows * cols), new long[] {rows, cols}, 'c', true);
INDArray W = Nd4j.zeros(rows, cols);
Distribution dist = Nd4j.getDistributions().createNormal(1, 1);
for (int i = 0; i < W.rows(); i++)
W.putRow(i, Nd4j.create(dist.sample(W.columns())));
for (int i = 0; i < 5; i++) {
// String learningRates = String.valueOf("\nAdagrad\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
// System.out.println(learningRates);
W.addi(Nd4j.randn(rows, cols));
}
}
示例9
@Test
public void testAdaGrad() {
int rows = 10;
int cols = 2;
AdaGradUpdater grad = new AdaGradUpdater(new AdaGrad(0.1, AdaGrad.DEFAULT_ADAGRAD_EPSILON));
grad.setStateViewArray(Nd4j.zeros(1, rows * cols), new long[] {rows, cols}, 'c', true);
INDArray W = Nd4j.zeros(rows, cols);
Distribution dist = Nd4j.getDistributions().createNormal(1, 1);
for (int i = 0; i < W.rows(); i++)
W.putRow(i, Nd4j.create(dist.sample(W.columns())));
for (int i = 0; i < 5; i++) {
// String learningRates = String.valueOf("\nAdagrad\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
// System.out.println(learningRates);
W.addi(Nd4j.randn(rows, cols));
}
}
示例10
@Test
public void testAdaDelta() {
int rows = 10;
int cols = 2;
AdaDeltaUpdater grad = new AdaDeltaUpdater(new AdaDelta());
grad.setStateViewArray(Nd4j.zeros(1, 2 * rows * cols), new long[] {rows, cols}, 'c', true);
INDArray W = Nd4j.zeros(rows, cols);
Distribution dist = Nd4j.getDistributions().createNormal(1e-3, 1e-3);
for (int i = 0; i < W.rows(); i++)
W.putRow(i, Nd4j.create(dist.sample(W.columns())));
for (int i = 0; i < 5; i++) {
// String learningRates = String.valueOf("\nAdaelta\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
// System.out.println(learningRates);
W.addi(Nd4j.randn(rows, cols));
}
}
示例11
@Test
public void testAdam() {
int rows = 10;
int cols = 2;
AdamUpdater grad = new AdamUpdater(new Adam());
grad.setStateViewArray(Nd4j.zeros(1, 2 * rows * cols), new long[] {rows, cols}, 'c', true);
INDArray W = Nd4j.zeros(rows, cols);
Distribution dist = Nd4j.getDistributions().createNormal(1e-3, 1e-3);
for (int i = 0; i < W.rows(); i++)
W.putRow(i, Nd4j.create(dist.sample(W.columns())));
for (int i = 0; i < 5; i++) {
// String learningRates = String.valueOf("\nAdamUpdater\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
// System.out.println(learningRates);
W.addi(Nd4j.randn(rows, cols));
}
}
示例12
@Test
public void testNadam() {
int rows = 10;
int cols = 2;
NadamUpdater grad = new NadamUpdater(new Nadam());
grad.setStateViewArray(Nd4j.zeros(1, 2 * rows * cols), new long[] {rows, cols}, 'c', true);
INDArray W = Nd4j.zeros(rows, cols);
Distribution dist = Nd4j.getDistributions().createNormal(1e-3, 1e-3);
for (int i = 0; i < W.rows(); i++)
W.putRow(i, Nd4j.create(dist.sample(W.columns())));
for (int i = 0; i < 5; i++) {
// String learningRates = String.valueOf("\nAdamUpdater\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
// System.out.println(learningRates);
W.addi(Nd4j.randn(rows, cols));
}
}
示例13
@Test
public void testAdaMax() {
int rows = 10;
int cols = 2;
AdaMaxUpdater grad = new AdaMaxUpdater(new AdaMax());
grad.setStateViewArray(Nd4j.zeros(1, 2 * rows * cols), new long[] {rows, cols}, 'c', true);
INDArray W = Nd4j.zeros(rows, cols);
Distribution dist = Nd4j.getDistributions().createNormal(1e-3, 1e-3);
for (int i = 0; i < W.rows(); i++)
W.putRow(i, Nd4j.create(dist.sample(W.columns())));
for (int i = 0; i < 5; i++) {
// String learningRates = String.valueOf("\nAdaMax\n " + grad.getGradient(W, i)).replaceAll(";", "\n");
// System.out.println(learningRates);
W.addi(Nd4j.randn(rows, cols));
}
}
示例14
private INDArray createUserWeightMatrix(NeuralNetConfiguration conf, INDArray weightParamView, boolean initializeParameters) {
FeedForwardLayer layerConf = (FeedForwardLayer) conf.getLayer();
if (initializeParameters) {
Distribution dist = Distributions.createDistribution(layerConf.getDist());
return createWeightMatrix(numberOfUsers, layerConf.getNOut(), layerConf.getWeightInit(), dist, weightParamView, true);
} else {
return createWeightMatrix(numberOfUsers, layerConf.getNOut(), null, null, weightParamView, false);
}
}
示例15
protected INDArray createWeightMatrix(NeuralNetConfiguration configuration, INDArray view, boolean initialize) {
FeedForwardLayer layerConfiguration = (FeedForwardLayer) configuration.getLayer();
if (initialize) {
Distribution distribution = Distributions.createDistribution(layerConfiguration.getDist());
return super.createWeightMatrix(numberOfFeatures, layerConfiguration.getNOut(), layerConfiguration.getWeightInit(), distribution, view, true);
} else {
return super.createWeightMatrix(numberOfFeatures, layerConfiguration.getNOut(), null, null, view, false);
}
}
示例16
@Test
public void testBinomial() {
Distribution distribution = Nd4j.getDistributions().createBinomial(3, Nd4j.create(10).putScalar(1, 0.00001));
for (int x = 0; x < 10000; x++) {
INDArray z = distribution.sample(new int[]{1, 10});
System.out.println();
MatchCondition condition = new MatchCondition(z, Conditions.equals(0.0));
int match = Nd4j.getExecutioner().exec(condition, Integer.MAX_VALUE).getInt(0);
assertEquals(z.length(), match);
}
}
示例17
/**
* Uses a test of Gaussianity for testing the values out of GaussianDistribution
* See https://en.wikipedia.org/wiki/Anderson%E2%80%93Darling_test
*
* @throws Exception
*/
@Test
public void testAndersonDarling() throws Exception {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
INDArray z1 = Nd4j.create(1000);
GaussianDistribution op1 = new GaussianDistribution(z1, 0.0, 1.0);
Nd4j.getExecutioner().exec(op1, random1);
val n = z1.length();
//using this just for the cdf
Distribution nd = new NormalDistribution(random1, 0.0, 1.0);
Nd4j.sort(z1, true);
System.out.println("Data for Anderson-Darling: " + z1);
for (int i = 0; i < n; i++) {
Double res = nd.cumulativeProbability(z1.getDouble(i));
assertTrue (res >= 0.0);
assertTrue (res <= 1.0);
// avoid overflow when taking log later.
if (res == 0) res = 0.0000001;
if (res == 1) res = 0.9999999;
z1.putScalar(i, res);
}
double A = 0.0;
for (int i = 0; i < n; i++) {
A -= (2*i+1) * (Math.log(z1.getDouble(i)) + Math.log(1-z1.getDouble(n - i - 1)));
}
A = A / n - n;
A *= (1 + 4.0/n - 25.0/(n*n));
assertTrue("Critical (max) value for 1000 points and confidence α = 0.0001 is 1.8692, received: "+ A, A < 1.8692);
}
示例18
/**
* Uses a test of Gaussianity for testing the values out of GaussianDistribution
* See https://en.wikipedia.org/wiki/Anderson%E2%80%93Darling_test
*
* @throws Exception
*/
@Test
public void testAndersonDarling() {
Random random1 = Nd4j.getRandomFactory().getNewRandomInstance(119);
INDArray z1 = Nd4j.create(1000);
GaussianDistribution op1 = new GaussianDistribution(z1, 0.0, 1.0);
Nd4j.getExecutioner().exec(op1, random1);
val n = z1.length();
//using this just for the cdf
Distribution nd = new NormalDistribution(random1, 0.0, 1.0);
Nd4j.sort(z1, true);
// System.out.println("Data for Anderson-Darling: " + z1);
for (int i = 0; i < n; i++) {
Double res = nd.cumulativeProbability(z1.getDouble(i));
assertTrue (res >= 0.0);
assertTrue (res <= 1.0);
// avoid overflow when taking log later.
if (res == 0) res = 0.0000001;
if (res == 1) res = 0.9999999;
z1.putScalar(i, res);
}
double A = 0.0;
for (int i = 0; i < n; i++) {
A -= (2*i+1) * (Math.log(z1.getDouble(i)) + Math.log(1-z1.getDouble(n - i - 1)));
}
A = A / n - n;
A *= (1 + 4.0/n - 25.0/(n*n));
assertTrue("Critical (max) value for 1000 points and confidence α = 0.0001 is 1.8692, received: "+ A, A < 1.8692);
}
示例19
@Override
protected Distribution getOldFunction(int seed) {
Random random = new DefaultRandom(seed);
Distribution distribution = new org.nd4j.linalg.api.rng.distribution.impl.UniformDistribution(random, 0.4D, 4D);
return distribution;
}
示例20
@Override
protected void assertSample(MathProbability newFuction, Distribution oldFunction) {
Number newSample = newFuction.sample();
Number oldSample = oldFunction.sample();
Assert.assertThat(newSample, CoreMatchers.equalTo(oldSample));
}
示例21
@Override
protected Distribution getOldFunction(int seed) {
Random random = new DefaultRandom(seed);
Distribution distribution = new org.nd4j.linalg.api.rng.distribution.impl.BinomialDistribution(random, 10, 0.5D);
return distribution;
}
示例22
@Override
protected void assertSample(MathProbability newFuction, Distribution oldFunction) {
Number newSample = newFuction.sample().doubleValue();
Number oldSample = Math.ceil(oldFunction.sample());
Assert.assertThat(newSample, CoreMatchers.equalTo(oldSample));
}
示例23
@Override
protected Distribution getOldFunction(int seed) {
Random random = new DefaultRandom(seed);
Distribution distribution = new org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution(random, 1D, 5D);
return distribution;
}
示例24
@Override
protected void assertSample(MathProbability newFuction, Distribution oldFunction) {
Number newSample = newFuction.sample();
Number oldSample = oldFunction.sample();
Assert.assertThat(newSample, CoreMatchers.equalTo(oldSample));
}
示例25
@Builder
public DistributionInitScheme(char order, Distribution distribution) {
super(order);
this.distribution = distribution;
}
示例26
@Override
public Distribution createBinomial(int n, INDArray p) {
return new BinomialDistribution(n, p);
}
示例27
@Override
public Distribution createBinomial(int n, double p) {
return new BinomialDistribution(n, p);
}
示例28
@Override
public Distribution createNormal(INDArray mean, double std) {
return new NormalDistribution(mean, std);
}
示例29
@Override
public Distribution createNormal(double mean, double std) {
return new NormalDistribution(mean, std);
}
示例30
@Override
public Distribution createLogNormal(double mean, double std) {
return new LogNormalDistribution(mean, std);
}