本文整理汇总了Java中org.nd4j.linalg.indexing.NDArrayIndex类的典型用法代码示例。如果您正苦于以下问题:Java NDArrayIndex类的具体用法?Java NDArrayIndex怎么用?Java NDArrayIndex使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
NDArrayIndex类属于org.nd4j.linalg.indexing包,在下文中一共展示了NDArrayIndex类的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Java代码示例。
示例1: computeCoordinateMagnitude
import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
/**
* @param gameBuffer A {@link List} of {@link Tuple2} containing a game state with the
* corresponding formation, and a mapping from {@link Player} to a vector
* representing the difference between the agent's current position and the
* desired position.
* @param player The {@link Player} to compute the coordinate magnitude for.
* @param coordinateIndex The index of the coordinate (0 = x, 1 = y, 2 = orientation).
* @return The coordinate magnitude as an INDArray.
*/
private INDArray computeCoordinateMagnitude(
final List<Tuple2<F, Map<PlayerIdentity, Tuple2<Player, INDArray>>>> gameBuffer,
final Player player,
final int coordinateIndex
) {
return PSDController.apply(
gameBuffer.get(0).getT2().get(player.getIdentity()).getT2()
.get(NDArrayIndex.interval(coordinateIndex, coordinateIndex + 1), NDArrayIndex.all()),
gameBuffer.stream()
.map(Tuple2::getT2)
.map(map -> map.get(player.getIdentity()))
.map(error -> error.getT2()
.get(NDArrayIndex.interval(coordinateIndex, coordinateIndex + 1),
NDArrayIndex.all()))
.collect(Collectors.toList()),
this.proportionalFactorZ,
this.summationFactorZ,
this.differenceFactorZ,
(double) (gameBuffer.get(0).getT2().get(player.getIdentity()).getT1().getTimestamp()
- gameBuffer.get(1).getT2().get(player.getIdentity()).getT1().getTimestamp()));
}
开发者ID:delta-leonis,项目名称:subra,代码行数:31,代码来源:PSDFormationDeducer.java
示例2: loadFeaturesFromString
import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
/**
* Used post training to convert a String to a features INDArray that can be passed to the network output method
*
* @param reviewContents Contents of the review to vectorize
* @param maxLength Maximum length (if review is longer than this: truncate to maxLength). Use Integer.MAX_VALUE to not nruncate
* @return Features array for the given input String
*/
public INDArray loadFeaturesFromString(String reviewContents, int maxLength){
List<String> tokens = tokenizerFactory.create(reviewContents).getTokens();
List<String> tokensFiltered = new ArrayList<>();
for(String t : tokens ){
if(wordVectors.hasWord(t)) tokensFiltered.add(t);
}
int outputLength = Math.max(maxLength,tokensFiltered.size());
INDArray features = Nd4j.create(1, vectorSize, outputLength);
for( int j=0; j<tokens.size() && j<maxLength; j++ ){
String token = tokens.get(j);
INDArray vector = wordVectors.getWordVectorMatrix(token);
features.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(j)}, vector);
}
return features;
}
开发者ID:IsaacChanghau,项目名称:NeuralNetworksLite,代码行数:26,代码来源:SentimentExampleIterator.java
示例3: testMulTensorVector
import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
/**
* A test for implementing a multiplication like:
*
* X_{b c} = \sum_{a} W_{a b c} v_{a}
*
* using matrix products and successive reshapes.
*/
@Test
public void testMulTensorVector() {
/* generate random data */
final int A = 5;
final int B = 6;
final int C = 7;
final INDArray W = Nd4j.rand(new int[] {A, B, C});
final INDArray v = Nd4j.rand(new int[] {A, 1});
/* result using reshapes and matrix products */
final INDArray X = W.reshape(new int[] {A, B*C}).transpose().mmul(v).reshape(new int[] {B, C});
/* check against brute force result */
for (int b = 0; b < B; b++) {
for (int c = 0; c < C; c++) {
double prod = 0;
for (int a = 0; a < A; a++) {
prod += W.get(NDArrayIndex.point(a), NDArrayIndex.point(b), NDArrayIndex.point(c)).getDouble(0) *
v.getDouble(a);
}
Assert.assertEquals(X.getScalar(b, c).getDouble(0), prod, EPS);
}
}
}
开发者ID:broadinstitute,项目名称:gatk-protected,代码行数:32,代码来源:Nd4jUnitTest.java
示例4: testMulTensorMatrix
import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
/**
* A test for implementing a multiplication like:
*
* X_{a} = \sum_{b, c} W_{a b c} V_{b c}
*/
@Test
public void testMulTensorMatrix() {
/* generate random data */
final int A = 5;
final int B = 6;
final int C = 7;
final INDArray W = Nd4j.rand(new int[] {A, B, C});
final INDArray V = Nd4j.rand(new int[] {B, C});
/* result using reshapes and matrix products */
final INDArray X = W.reshape(new int[] {A, B*C}).mmul(V.reshape(new int[] {B*C, 1}));
/* check against brute force result */
for (int a = 0; a < A; a++) {
double prod = 0;
for (int b = 0; b < B; b++) {
for (int c = 0; c < C; c++) {
prod += W.get(NDArrayIndex.point(a), NDArrayIndex.point(b), NDArrayIndex.point(c)).getDouble(0) *
V.get(NDArrayIndex.point(b), NDArrayIndex.point(c)).getDouble(0);
}
}
Assert.assertEquals(X.getScalar(a).getDouble(0), prod, EPS);
}
}
开发者ID:broadinstitute,项目名称:gatk-protected,代码行数:30,代码来源:Nd4jUnitTest.java
示例5: testINDArrayToApacheMatrix
import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Test
public void testINDArrayToApacheMatrix() {
final INDArray rowArrCOrder = Nd4j.randn('c', new int[] {1, 4});
final INDArray rowArrFOrder = Nd4j.randn('f', new int[] {1, 4});
final INDArray colArrCOrder = Nd4j.randn('c', new int[] {4, 1});
final INDArray colArrFOrder = Nd4j.randn('f', new int[] {4, 1});
final INDArray generalCOrder = Nd4j.randn('c', new int[] {4, 5});
final INDArray generalFOrder = Nd4j.randn('f', new int[] {4, 5});
assertINDArrayToApacheMatrixCorrectness(rowArrCOrder);
assertINDArrayToApacheMatrixCorrectness(rowArrFOrder);
assertINDArrayToApacheMatrixCorrectness(colArrCOrder);
assertINDArrayToApacheMatrixCorrectness(colArrFOrder);
assertINDArrayToApacheMatrixCorrectness(generalCOrder);
assertINDArrayToApacheMatrixCorrectness(generalFOrder);
/* test on INDArray views */
assertINDArrayToApacheMatrixCorrectness(rowArrCOrder.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 3)));
assertINDArrayToApacheMatrixCorrectness(rowArrFOrder.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 3)));
assertINDArrayToApacheMatrixCorrectness(colArrCOrder.get(NDArrayIndex.interval(0, 3), NDArrayIndex.all()));
assertINDArrayToApacheMatrixCorrectness(colArrFOrder.get(NDArrayIndex.interval(0, 3), NDArrayIndex.all()));
assertINDArrayToApacheMatrixCorrectness(generalCOrder.get(NDArrayIndex.interval(1, 4), NDArrayIndex.interval(2, 4)));
assertINDArrayToApacheMatrixCorrectness(generalFOrder.get(NDArrayIndex.interval(1, 4), NDArrayIndex.interval(2, 4)));
}
开发者ID:broadinstitute,项目名称:gatk-protected,代码行数:25,代码来源:Nd4jApacheAdapterUtilsUnitTest.java
示例6: testINDArrayToApacheVector
import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Test
public void testINDArrayToApacheVector() {
final INDArray rowArrCOrder = Nd4j.randn('c', new int[] {1, 5});
final INDArray rowArrFOrder = Nd4j.randn('f', new int[] {1, 5});
final INDArray colArrCOrder = Nd4j.randn('c', new int[] {5, 1});
final INDArray colArrFOrder = Nd4j.randn('f', new int[] {5, 1});
assertINDArrayToApacheVectorCorrectness(rowArrCOrder);
assertINDArrayToApacheVectorCorrectness(rowArrFOrder);
assertINDArrayToApacheVectorCorrectness(colArrCOrder);
assertINDArrayToApacheVectorCorrectness(colArrFOrder);
/* test on INDArray views */
assertINDArrayToApacheVectorCorrectness(rowArrCOrder.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 4)));
assertINDArrayToApacheVectorCorrectness(rowArrFOrder.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 4)));
assertINDArrayToApacheVectorCorrectness(colArrCOrder.get(NDArrayIndex.interval(2, 4), NDArrayIndex.all()));
assertINDArrayToApacheVectorCorrectness(colArrFOrder.get(NDArrayIndex.interval(2, 4), NDArrayIndex.all()));
}
开发者ID:broadinstitute,项目名称:gatk-protected,代码行数:19,代码来源:Nd4jApacheAdapterUtilsUnitTest.java
示例7: testGetIndices
import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Test
public void testGetIndices() {
/*[[[1.0 ,13.0],[5.0 ,17.0],[9.0 ,21.0]],[[2.0 ,14.0],[6.0 ,18.0],[10.0 ,22.0]],[[3.0 ,15.0],[7.0 ,19.0],[11.0 ,23.0]],[[4.0 ,16.0],[8.0 ,20.0],[12.0 ,24.0]]]*/
Nd4j.factory().setOrder('f');
INDArray test = Nd4j.linspace(1, 24, 24).reshape(new int[]{4,3,2});
NDArrayIndex oneTwo = NDArrayIndex.interval(1, 2);
NDArrayIndex twoToThree = NDArrayIndex.interval(1,3);
INDArray get = test.get(oneTwo,twoToThree);
assertTrue(Arrays.equals(new int[]{1,2,2},get.shape()));
assertEquals(Nd4j.create(new float[]{6, 10, 18, 22}, new int[]{1, 2, 2}),get);
INDArray anotherGet = Nd4j.create(new float[]{6, 7, 10, 11, 18, 19, 22, 23}, new int[]{2, 1, 2});
INDArray test2 = test.get(NDArrayIndex.interval(1,3),NDArrayIndex.interval(1,2));
assertEquals(5,test2.offset());
//offset is off: should be 5
assertTrue(Arrays.equals(new int[]{2,1,2},test2.shape()));
assertEquals(test2,anotherGet);
INDArray linear = test2.slice(0).linearView();
assertEquals(10,linear.getFloat(1),1e-1);
INDArray row = Nd4j.create(new float[]{7,11});
INDArray result = test2.slice(1);
assertEquals(row,result);
}
开发者ID:wlin12,项目名称:JNN,代码行数:26,代码来源:NDArrayTests.java
示例8: testGetIndices2d
import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Test
public void testGetIndices2d() {
Nd4j.factory().setOrder('f');
INDArray twoByTwo = Nd4j.linspace(1, 6, 6).reshape(3,2);
INDArray firstRow = twoByTwo.getRow(0);
INDArray secondRow = twoByTwo.getRow(1);
INDArray firstAndSecondRow = twoByTwo.getRows(new int[]{1,2});
INDArray firstRowViaIndexing = twoByTwo.get(NDArrayIndex.interval(0,1));
assertEquals(firstRow,firstRowViaIndexing);
INDArray secondRowViaIndexing = twoByTwo.get(NDArrayIndex.interval(1,2));
assertEquals(secondRow,secondRowViaIndexing);
INDArray individualElement = twoByTwo.get(NDArrayIndex.interval(1,2),NDArrayIndex.interval(1,2));
individualElement.toString();
assertEquals(Nd4j.create(new float[]{5}),individualElement);
INDArray firstAndSecondRowTest = twoByTwo.get(NDArrayIndex.interval(1, 3));
assertEquals(firstAndSecondRow, firstAndSecondRowTest);
}
开发者ID:wlin12,项目名称:JNN,代码行数:22,代码来源:NDArrayTests.java
示例9: getSubMatricesWithShape
import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
public static List<Pair<INDArray, String>> getSubMatricesWithShape(char ordering, int rows, int cols, int seed) {
//Create 3 identical matrices. Could do get() on single original array, but in-place modifications on one
//might mess up tests for another
Nd4j.getRandom().setSeed(seed);
int[] shape = new int[] {2 * rows + 4, 2 * cols + 4};
int len = ArrayUtil.prod(shape);
INDArray orig = Nd4j.linspace(1, len, len).reshape(ordering, shape);
INDArray first = orig.get(NDArrayIndex.interval(0, rows), NDArrayIndex.interval(0, cols));
Nd4j.getRandom().setSeed(seed);
orig = Nd4j.linspace(1, len, len).reshape(shape);
INDArray second = orig.get(NDArrayIndex.interval(3, rows + 3), NDArrayIndex.interval(3, cols + 3));
Nd4j.getRandom().setSeed(seed);
orig = Nd4j.linspace(1, len, len).reshape(ordering, shape);
INDArray third = orig.get(NDArrayIndex.interval(rows, 2 * rows), NDArrayIndex.interval(cols, 2 * cols));
String baseMsg = "getSubMatricesWithShape(" + rows + "," + cols + "," + seed + ")";
List<Pair<INDArray, String>> list = new ArrayList<>(3);
list.add(new Pair<>(first, baseMsg + ".get(0)"));
list.add(new Pair<>(second, baseMsg + ".get(1)"));
list.add(new Pair<>(third, baseMsg + ".get(2)"));
return list;
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:23,代码来源:NDArrayCreationUtil.java
示例10: setStateViewArray
import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Override
public void setStateViewArray(INDArray viewArray, int[] gradientShape, char gradientOrder, boolean initialize) {
if (!viewArray.isRowVector())
throw new IllegalArgumentException("Invalid input: expect row vector input");
if (initialize)
viewArray.assign(0);
int length = viewArray.length();
this.m = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2));
this.v = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length));
//Reshape to match the expected shape of the input gradient arrays
this.m = Shape.newShapeNoCopy(this.m, gradientShape, gradientOrder == 'f');
this.v = Shape.newShapeNoCopy(this.v, gradientShape, gradientOrder == 'f');
if (m == null || v == null)
throw new IllegalStateException("Could not correctly reshape gradient view arrays");
this.gradientReshapeOrder = gradientOrder;
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:19,代码来源:NadamUpdater.java
示例11: toFlattened
import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
/**
* Returns a vector with all of the elements in every nd array
* equal to the sum of the lengths of the ndarrays
*
* @param matrices the ndarrays to getFloat a flattened representation of
* @return the flattened ndarray
*/
@Override
public INDArray toFlattened(Collection<INDArray> matrices) {
int length = 0;
for (INDArray m : matrices)
length += m.length();
INDArray ret = Nd4j.create(1, length);
int linearIndex = 0;
for (INDArray d : matrices) {
ret.put(new INDArrayIndex[] {NDArrayIndex.interval(linearIndex, linearIndex + d.length())}, d);
linearIndex += d.length();
}
return ret;
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:23,代码来源:BaseNDArrayFactory.java
示例12: makeDataSetSameL
import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
public static DataSet makeDataSetSameL(int batchSize, int timesteps, float[] minorityDist, boolean twoClass) {
INDArray features = Nd4j.rand(1, batchSize * timesteps * 2).reshape(batchSize, 2, timesteps);
INDArray labels;
if (twoClass) {
labels = Nd4j.zeros(new int[] {batchSize, 2, timesteps});
} else {
labels = Nd4j.zeros(new int[] {batchSize, 1, timesteps});
}
for (int i = 0; i < batchSize; i++) {
INDArray l;
if (twoClass) {
l = labels.get(NDArrayIndex.point(i), NDArrayIndex.point(1), NDArrayIndex.all());
Nd4j.getExecutioner().exec(new BernoulliDistribution(l, minorityDist[i]));
INDArray lOther = labels.get(NDArrayIndex.point(i), NDArrayIndex.point(0), NDArrayIndex.all());
lOther.assign(Transforms.not(l.dup()));
} else {
l = labels.get(NDArrayIndex.point(i), NDArrayIndex.point(0), NDArrayIndex.all());
Nd4j.getExecutioner().exec(new BernoulliDistribution(l, minorityDist[i]));
}
}
return new DataSet(features, labels);
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:23,代码来源:UnderSamplingPreProcessorTest.java
示例13: Construct4dDataSet
import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
public Construct4dDataSet(int nExamples, int nChannels, int height, int width) {
INDArray allImages = Nd4j.rand(new int[] {nExamples, nChannels, height, width});
allImages.get(NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all()).muli(100)
.addi(200);
allImages.get(NDArrayIndex.all(), NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.all()).muli(0.001)
.subi(10);
INDArray labels = Nd4j.linspace(1, nChannels, nChannels).reshape(nChannels, 1);
sampleDataSet = new DataSet(allImages, labels);
expectedMean = allImages.mean(0, 2, 3);
expectedStd = allImages.std(0, 2, 3);
expectedLabelMean = labels.mean(0);
expectedLabelStd = labels.std(0);
expectedMin = allImages.min(0, 2, 3);
expectedMax = allImages.max(0, 2, 3);
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:21,代码来源:PreProcessor3D4DTest.java
示例14: testGetRow
import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Test
public void testGetRow() {
INDArray arr = Nd4j.linspace(1, 6, 6).reshape(2, 3);
INDArray get = arr.getRow(1);
INDArray get2 = arr.get(NDArrayIndex.point(1), NDArrayIndex.all());
INDArray assertion = Nd4j.create(new double[] {4, 5, 6});
assertEquals(assertion, get);
assertEquals(get, get2);
get2.assign(Nd4j.linspace(1, 3, 3));
assertEquals(Nd4j.linspace(1, 3, 3), get2);
INDArray threeByThree = Nd4j.linspace(1, 9, 9).reshape(3, 3);
INDArray offsetTest = threeByThree.get(new SpecifiedIndex(1, 2), NDArrayIndex.all());
INDArray threeByThreeAssertion = Nd4j.create(new double[][] {{4, 5, 6}, {7, 8, 9}});
assertEquals(threeByThreeAssertion, offsetTest);
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:18,代码来源:SlicingTestsC.java
示例15: testResolvePointVector
import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Test
public void testResolvePointVector() {
INDArray arr = Nd4j.linspace(1, 4, 4);
INDArrayIndex[] getPoint = {NDArrayIndex.point(1)};
INDArrayIndex[] resolved = NDArrayIndex.resolve(arr.shape(), getPoint);
if (getPoint.length == resolved.length)
assertArrayEquals(getPoint, resolved);
else {
assertEquals(2, resolved.length);
assertTrue(resolved[0] instanceof PointIndex);
assertEquals(0, resolved[0].current());
assertTrue(resolved[1] instanceof PointIndex);
assertEquals(1, resolved[1].current());
}
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:17,代码来源:NDArrayIndexResolveTests.java
示例16: testIndexPointInterval
import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Test
@Ignore
public void testIndexPointInterval() {
INDArray zeros = Nd4j.zeros(3, 3, 3);
INDArrayIndex x = NDArrayIndex.point(1);
INDArrayIndex y = NDArrayIndex.interval(1, 2, true);
INDArrayIndex z = NDArrayIndex.point(1);
INDArray value = Nd4j.ones(1, 2);
zeros.put(new INDArrayIndex[] {x, y, z}, value);
String f1 = "[[[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]\n" + " [[0,00,0,00,0,00]\n"
+ " [0,00,1,00,0,00]\n" + " [0,00,1,00,0,00]]\n" + " [[0,00,0,00,0,00]\n"
+ " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";
String f2 = "[[[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]\n" + " [[0.00,0.00,0.00]\n"
+ " [0.00,1.00,0.00]\n" + " [0.00,1.00,0.00]]\n" + " [[0.00,0.00,0.00]\n"
+ " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";
if (!zeros.toString().equals(f2) && !zeros.toString().equals(f1))
assertEquals(f2, zeros.toString());
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:23,代码来源:ShapeResolutionTestsC.java
示例17: setStateViewArray
import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Override
public void setStateViewArray(INDArray viewArray, int[] gradientShape, char gradientOrder, boolean initialize) {
if (!viewArray.isRowVector())
throw new IllegalArgumentException("Invalid input: expect row vector input");
if (initialize)
viewArray.assign(0);
int length = viewArray.length();
this.m = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2));
this.u = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length));
//Reshape to match the expected shape of the input gradient arrays
this.m = Shape.newShapeNoCopy(this.m, gradientShape, gradientOrder == 'f');
this.u = Shape.newShapeNoCopy(this.u, gradientShape, gradientOrder == 'f');
if (m == null || u == null)
throw new IllegalStateException("Could not correctly reshape gradient view arrays");
this.gradientReshapeOrder = gradientOrder;
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:19,代码来源:AdaMaxUpdater.java
示例18: testIndexIntervalAll
import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Test
@Ignore
public void testIndexIntervalAll() {
INDArray zeros = Nd4j.zeros(3, 3, 3);
INDArrayIndex x = NDArrayIndex.interval(0, 1, true);
INDArrayIndex y = NDArrayIndex.all();
INDArrayIndex z = NDArrayIndex.interval(1, 2, true);
INDArray value = Nd4j.ones(2, 6);
zeros.put(new INDArrayIndex[] {x, y, z}, value);
String f1 = "[[[0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]]\n" + " [[0,00,1,00,1,00]\n"
+ " [0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]]\n" + " [[0,00,0,00,0,00]\n"
+ " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";
String f2 = "[[[0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]]\n" + " [[0.00,1.00,1.00]\n"
+ " [0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]]\n" + " [[0.00,0.00,0.00]\n"
+ " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";
if (!zeros.toString().equals(f1) && !zeros.toString().equals(f2))
assertEquals(f2, zeros.toString());
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:22,代码来源:ShapeResolutionTestsC.java
示例19: testIndexPointIntervalAll
import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Test
@Ignore
public void testIndexPointIntervalAll() {
INDArray zeros = Nd4j.zeros(3, 3, 3);
INDArrayIndex x = NDArrayIndex.point(1);
INDArrayIndex y = NDArrayIndex.all();
INDArrayIndex z = NDArrayIndex.interval(1, 2, true);
INDArray value = Nd4j.ones(3, 2);
zeros.put(new INDArrayIndex[] {x, y, z}, value);
String f1 = "[[[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]\n" + " [[0,00,1,00,1,00]\n"
+ " [0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]]\n" + " [[0,00,0,00,0,00]\n"
+ " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";
String f2 = "[[[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]\n" + " [[0.00,1.00,1.00]\n"
+ " [0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]]\n" + " [[0.00,0.00,0.00]\n"
+ " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";
if (!zeros.toString().equals(f1) && !zeros.toString().equals(f2))
assertEquals(f2, zeros.toString());
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:22,代码来源:ShapeResolutionTestsC.java
示例20: setStateViewArray
import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Override
public void setStateViewArray(INDArray viewArray, int[] gradientShape, char gradientOrder, boolean initialize) {
if (!viewArray.isRowVector())
throw new IllegalArgumentException("Invalid input: expect row vector input");
if (initialize)
viewArray.assign(0);
int length = viewArray.length();
this.msg = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2));
this.msdx = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length));
//Reshape to match the expected shape of the input gradient arrays
this.msg = Shape.newShapeNoCopy(this.msg, gradientShape, gradientOrder == 'f');
this.msdx = Shape.newShapeNoCopy(this.msdx, gradientShape, gradientOrder == 'f');
if (msg == null || msdx == null)
throw new IllegalStateException("Could not correctly reshape gradient view arrays");
}
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:17,代码来源:AdaDeltaUpdater.java
注:本文中的org.nd4j.linalg.indexing.NDArrayIndex类示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论