• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Java NDArrayIndex类代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Java中org.nd4j.linalg.indexing.NDArrayIndex的典型用法代码示例。如果您正苦于以下问题:Java NDArrayIndex类的具体用法?Java NDArrayIndex怎么用?Java NDArrayIndex使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。



NDArrayIndex类属于org.nd4j.linalg.indexing包,在下文中一共展示了NDArrayIndex类的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Java代码示例。

示例1: computeCoordinateMagnitude

import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
/**
 * @param gameBuffer      A {@link List} of {@link Tuple2} containing a game state with the
 *                        corresponding formation, and a mapping from {@link Player} to a vector
 *                        representing the difference between the agent's current position and the
 *                        desired position.
 * @param player          The {@link Player} to compute the coordinate magnitude for.
 * @param coordinateIndex The index of the coordinate (0 = x, 1 = y, 2 = orientation).
 * @return The coordinate magnitude as an INDArray.
 */
private INDArray computeCoordinateMagnitude(
    final List<Tuple2<F, Map<PlayerIdentity, Tuple2<Player, INDArray>>>> gameBuffer,
    final Player player,
    final int coordinateIndex
) {
  return PSDController.apply(
      gameBuffer.get(0).getT2().get(player.getIdentity()).getT2()
          .get(NDArrayIndex.interval(coordinateIndex, coordinateIndex + 1), NDArrayIndex.all()),
      gameBuffer.stream()
          .map(Tuple2::getT2)
          .map(map -> map.get(player.getIdentity()))
          .map(error -> error.getT2()
              .get(NDArrayIndex.interval(coordinateIndex, coordinateIndex + 1),
                  NDArrayIndex.all()))
          .collect(Collectors.toList()),
      this.proportionalFactorZ,
      this.summationFactorZ,
      this.differenceFactorZ,
      (double) (gameBuffer.get(0).getT2().get(player.getIdentity()).getT1().getTimestamp()
          - gameBuffer.get(1).getT2().get(player.getIdentity()).getT1().getTimestamp()));
}
 
开发者ID:delta-leonis,项目名称:subra,代码行数:31,代码来源:PSDFormationDeducer.java


示例2: loadFeaturesFromString

import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
/**
 * Used post training to convert a String to a features INDArray that can be passed to the network output method
 *
 * @param reviewContents Contents of the review to vectorize
 * @param maxLength Maximum length (if review is longer than this: truncate to maxLength). Use Integer.MAX_VALUE to not nruncate
 * @return Features array for the given input String
 */
public INDArray loadFeaturesFromString(String reviewContents, int maxLength){
	List<String> tokens = tokenizerFactory.create(reviewContents).getTokens();
	List<String> tokensFiltered = new ArrayList<>();
	for(String t : tokens ){
		if(wordVectors.hasWord(t)) tokensFiltered.add(t);
	}
	int outputLength = Math.max(maxLength,tokensFiltered.size());

	INDArray features = Nd4j.create(1, vectorSize, outputLength);

	for( int j=0; j<tokens.size() && j<maxLength; j++ ){
		String token = tokens.get(j);
		INDArray vector = wordVectors.getWordVectorMatrix(token);
		features.put(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(j)}, vector);
	}

	return features;
}
 
开发者ID:IsaacChanghau,项目名称:NeuralNetworksLite,代码行数:26,代码来源:SentimentExampleIterator.java


示例3: testMulTensorVector

import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
/**
 * A test for implementing a multiplication like:
 *
 *      X_{b c} = \sum_{a} W_{a b c} v_{a}
 *
 * using matrix products and successive reshapes.
 */
@Test
public void testMulTensorVector() {
    /* generate random data */
    final int A = 5;
    final int B = 6;
    final int C = 7;
    final INDArray W = Nd4j.rand(new int[] {A, B, C});
    final INDArray v = Nd4j.rand(new int[] {A, 1});

    /* result using reshapes and matrix products */
    final INDArray X = W.reshape(new int[] {A, B*C}).transpose().mmul(v).reshape(new int[] {B, C});

    /* check against brute force result */
    for (int b = 0; b < B; b++) {
        for (int c = 0; c < C; c++) {
            double prod = 0;
            for (int a = 0; a < A; a++) {
                prod += W.get(NDArrayIndex.point(a), NDArrayIndex.point(b), NDArrayIndex.point(c)).getDouble(0) *
                        v.getDouble(a);
            }
            Assert.assertEquals(X.getScalar(b, c).getDouble(0), prod, EPS);
        }
    }
}
 
开发者ID:broadinstitute,项目名称:gatk-protected,代码行数:32,代码来源:Nd4jUnitTest.java


示例4: testMulTensorMatrix

import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
/**
 * A test for implementing a multiplication like:
 *
 *      X_{a} = \sum_{b, c} W_{a b c} V_{b c}
 */
@Test
public void testMulTensorMatrix() {
    /* generate random data */
    final int A = 5;
    final int B = 6;
    final int C = 7;
    final INDArray W = Nd4j.rand(new int[] {A, B, C});
    final INDArray V = Nd4j.rand(new int[] {B, C});

    /* result using reshapes and matrix products */
    final INDArray X = W.reshape(new int[] {A, B*C}).mmul(V.reshape(new int[] {B*C, 1}));

    /* check against brute force result */
    for (int a = 0; a < A; a++) {
        double prod = 0;
        for (int b = 0; b < B; b++) {
            for (int c = 0; c < C; c++) {
                prod += W.get(NDArrayIndex.point(a), NDArrayIndex.point(b), NDArrayIndex.point(c)).getDouble(0) *
                        V.get(NDArrayIndex.point(b), NDArrayIndex.point(c)).getDouble(0);
            }
        }
        Assert.assertEquals(X.getScalar(a).getDouble(0), prod, EPS);
    }
}
 
开发者ID:broadinstitute,项目名称:gatk-protected,代码行数:30,代码来源:Nd4jUnitTest.java


示例5: testINDArrayToApacheMatrix

import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Test
public void testINDArrayToApacheMatrix() {
    final INDArray rowArrCOrder = Nd4j.randn('c', new int[] {1, 4});
    final INDArray rowArrFOrder = Nd4j.randn('f', new int[] {1, 4});
    final INDArray colArrCOrder = Nd4j.randn('c', new int[] {4, 1});
    final INDArray colArrFOrder = Nd4j.randn('f', new int[] {4, 1});
    final INDArray generalCOrder = Nd4j.randn('c', new int[] {4, 5});
    final INDArray generalFOrder = Nd4j.randn('f', new int[] {4, 5});

    assertINDArrayToApacheMatrixCorrectness(rowArrCOrder);
    assertINDArrayToApacheMatrixCorrectness(rowArrFOrder);
    assertINDArrayToApacheMatrixCorrectness(colArrCOrder);
    assertINDArrayToApacheMatrixCorrectness(colArrFOrder);
    assertINDArrayToApacheMatrixCorrectness(generalCOrder);
    assertINDArrayToApacheMatrixCorrectness(generalFOrder);

    /* test on INDArray views */
    assertINDArrayToApacheMatrixCorrectness(rowArrCOrder.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 3)));
    assertINDArrayToApacheMatrixCorrectness(rowArrFOrder.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 3)));
    assertINDArrayToApacheMatrixCorrectness(colArrCOrder.get(NDArrayIndex.interval(0, 3), NDArrayIndex.all()));
    assertINDArrayToApacheMatrixCorrectness(colArrFOrder.get(NDArrayIndex.interval(0, 3), NDArrayIndex.all()));
    assertINDArrayToApacheMatrixCorrectness(generalCOrder.get(NDArrayIndex.interval(1, 4), NDArrayIndex.interval(2, 4)));
    assertINDArrayToApacheMatrixCorrectness(generalFOrder.get(NDArrayIndex.interval(1, 4), NDArrayIndex.interval(2, 4)));
}
 
开发者ID:broadinstitute,项目名称:gatk-protected,代码行数:25,代码来源:Nd4jApacheAdapterUtilsUnitTest.java


示例6: testINDArrayToApacheVector

import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Test
public void testINDArrayToApacheVector() {
    final INDArray rowArrCOrder = Nd4j.randn('c', new int[] {1, 5});
    final INDArray rowArrFOrder = Nd4j.randn('f', new int[] {1, 5});
    final INDArray colArrCOrder = Nd4j.randn('c', new int[] {5, 1});
    final INDArray colArrFOrder = Nd4j.randn('f', new int[] {5, 1});

    assertINDArrayToApacheVectorCorrectness(rowArrCOrder);
    assertINDArrayToApacheVectorCorrectness(rowArrFOrder);
    assertINDArrayToApacheVectorCorrectness(colArrCOrder);
    assertINDArrayToApacheVectorCorrectness(colArrFOrder);

    /* test on INDArray views */
    assertINDArrayToApacheVectorCorrectness(rowArrCOrder.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 4)));
    assertINDArrayToApacheVectorCorrectness(rowArrFOrder.get(NDArrayIndex.all(), NDArrayIndex.interval(2, 4)));
    assertINDArrayToApacheVectorCorrectness(colArrCOrder.get(NDArrayIndex.interval(2, 4), NDArrayIndex.all()));
    assertINDArrayToApacheVectorCorrectness(colArrFOrder.get(NDArrayIndex.interval(2, 4), NDArrayIndex.all()));
}
 
开发者ID:broadinstitute,项目名称:gatk-protected,代码行数:19,代码来源:Nd4jApacheAdapterUtilsUnitTest.java


示例7: testGetIndices

import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Test
public void testGetIndices() {
    /*[[[1.0 ,13.0],[5.0 ,17.0],[9.0 ,21.0]],[[2.0 ,14.0],[6.0 ,18.0],[10.0 ,22.0]],[[3.0 ,15.0],[7.0 ,19.0],[11.0 ,23.0]],[[4.0 ,16.0],[8.0 ,20.0],[12.0 ,24.0]]]*/
    Nd4j.factory().setOrder('f');
    INDArray test = Nd4j.linspace(1, 24, 24).reshape(new int[]{4,3,2});
    NDArrayIndex oneTwo = NDArrayIndex.interval(1, 2);
    NDArrayIndex twoToThree = NDArrayIndex.interval(1,3);
    INDArray get = test.get(oneTwo,twoToThree);
    assertTrue(Arrays.equals(new int[]{1,2,2},get.shape()));
    assertEquals(Nd4j.create(new float[]{6, 10, 18, 22}, new int[]{1, 2, 2}),get);

    INDArray anotherGet = Nd4j.create(new float[]{6, 7, 10, 11, 18, 19, 22, 23}, new int[]{2, 1, 2});
    INDArray test2 = test.get(NDArrayIndex.interval(1,3),NDArrayIndex.interval(1,2));
    assertEquals(5,test2.offset());
    //offset is off: should be 5
    assertTrue(Arrays.equals(new int[]{2,1,2},test2.shape()));
    assertEquals(test2,anotherGet);

    INDArray linear = test2.slice(0).linearView();
    assertEquals(10,linear.getFloat(1),1e-1);

    INDArray row = Nd4j.create(new float[]{7,11});
    INDArray result = test2.slice(1);
    assertEquals(row,result);
}
 
开发者ID:wlin12,项目名称:JNN,代码行数:26,代码来源:NDArrayTests.java


示例8: testGetIndices2d

import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Test
public void testGetIndices2d() {
    Nd4j.factory().setOrder('f');

    INDArray twoByTwo = Nd4j.linspace(1, 6, 6).reshape(3,2);
    INDArray firstRow = twoByTwo.getRow(0);
    INDArray secondRow = twoByTwo.getRow(1);
    INDArray firstAndSecondRow = twoByTwo.getRows(new int[]{1,2});
    INDArray firstRowViaIndexing = twoByTwo.get(NDArrayIndex.interval(0,1));
    assertEquals(firstRow,firstRowViaIndexing);
    INDArray secondRowViaIndexing = twoByTwo.get(NDArrayIndex.interval(1,2));
    assertEquals(secondRow,secondRowViaIndexing);
    INDArray individualElement = twoByTwo.get(NDArrayIndex.interval(1,2),NDArrayIndex.interval(1,2));
    individualElement.toString();
    assertEquals(Nd4j.create(new float[]{5}),individualElement);

    INDArray firstAndSecondRowTest = twoByTwo.get(NDArrayIndex.interval(1, 3));
    assertEquals(firstAndSecondRow, firstAndSecondRowTest);


}
 
开发者ID:wlin12,项目名称:JNN,代码行数:22,代码来源:NDArrayTests.java


示例9: getSubMatricesWithShape

import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
public static List<Pair<INDArray, String>> getSubMatricesWithShape(char ordering, int rows, int cols, int seed) {
    //Create 3 identical matrices. Could do get() on single original array, but in-place modifications on one
    //might mess up tests for another
    Nd4j.getRandom().setSeed(seed);
    int[] shape = new int[] {2 * rows + 4, 2 * cols + 4};
    int len = ArrayUtil.prod(shape);
    INDArray orig = Nd4j.linspace(1, len, len).reshape(ordering, shape);
    INDArray first = orig.get(NDArrayIndex.interval(0, rows), NDArrayIndex.interval(0, cols));
    Nd4j.getRandom().setSeed(seed);
    orig = Nd4j.linspace(1, len, len).reshape(shape);
    INDArray second = orig.get(NDArrayIndex.interval(3, rows + 3), NDArrayIndex.interval(3, cols + 3));
    Nd4j.getRandom().setSeed(seed);
    orig = Nd4j.linspace(1, len, len).reshape(ordering, shape);
    INDArray third = orig.get(NDArrayIndex.interval(rows, 2 * rows), NDArrayIndex.interval(cols, 2 * cols));

    String baseMsg = "getSubMatricesWithShape(" + rows + "," + cols + "," + seed + ")";
    List<Pair<INDArray, String>> list = new ArrayList<>(3);
    list.add(new Pair<>(first, baseMsg + ".get(0)"));
    list.add(new Pair<>(second, baseMsg + ".get(1)"));
    list.add(new Pair<>(third, baseMsg + ".get(2)"));
    return list;
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:23,代码来源:NDArrayCreationUtil.java


示例10: setStateViewArray

import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Override
public void setStateViewArray(INDArray viewArray, int[] gradientShape, char gradientOrder, boolean initialize) {
    if (!viewArray.isRowVector())
        throw new IllegalArgumentException("Invalid input: expect row vector input");
    if (initialize)
        viewArray.assign(0);
    int length = viewArray.length();
    this.m = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2));
    this.v = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length));

    //Reshape to match the expected shape of the input gradient arrays
    this.m = Shape.newShapeNoCopy(this.m, gradientShape, gradientOrder == 'f');
    this.v = Shape.newShapeNoCopy(this.v, gradientShape, gradientOrder == 'f');
    if (m == null || v == null)
        throw new IllegalStateException("Could not correctly reshape gradient view arrays");

    this.gradientReshapeOrder = gradientOrder;
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:19,代码来源:NadamUpdater.java


示例11: toFlattened

import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
/**
 * Returns a vector with all of the elements in every nd array
 * equal to the sum of the lengths of the ndarrays
 *
 * @param matrices the ndarrays to getFloat a flattened representation of
 * @return the flattened ndarray
 */
@Override
public INDArray toFlattened(Collection<INDArray> matrices) {
    int length = 0;
    for (INDArray m : matrices)
        length += m.length();
    INDArray ret = Nd4j.create(1, length);
    int linearIndex = 0;
    for (INDArray d : matrices) {
        ret.put(new INDArrayIndex[] {NDArrayIndex.interval(linearIndex, linearIndex + d.length())}, d);
        linearIndex += d.length();
    }

    return ret;

}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:23,代码来源:BaseNDArrayFactory.java


示例12: makeDataSetSameL

import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
public static DataSet makeDataSetSameL(int batchSize, int timesteps, float[] minorityDist, boolean twoClass) {
    INDArray features = Nd4j.rand(1, batchSize * timesteps * 2).reshape(batchSize, 2, timesteps);
    INDArray labels;
    if (twoClass) {
        labels = Nd4j.zeros(new int[] {batchSize, 2, timesteps});
    } else {
        labels = Nd4j.zeros(new int[] {batchSize, 1, timesteps});
    }
    for (int i = 0; i < batchSize; i++) {
        INDArray l;
        if (twoClass) {
            l = labels.get(NDArrayIndex.point(i), NDArrayIndex.point(1), NDArrayIndex.all());
            Nd4j.getExecutioner().exec(new BernoulliDistribution(l, minorityDist[i]));
            INDArray lOther = labels.get(NDArrayIndex.point(i), NDArrayIndex.point(0), NDArrayIndex.all());
            lOther.assign(Transforms.not(l.dup()));
        } else {
            l = labels.get(NDArrayIndex.point(i), NDArrayIndex.point(0), NDArrayIndex.all());
            Nd4j.getExecutioner().exec(new BernoulliDistribution(l, minorityDist[i]));
        }
    }
    return new DataSet(features, labels);
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:23,代码来源:UnderSamplingPreProcessorTest.java


示例13: Construct4dDataSet

import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
public Construct4dDataSet(int nExamples, int nChannels, int height, int width) {

            INDArray allImages = Nd4j.rand(new int[] {nExamples, nChannels, height, width});
            allImages.get(NDArrayIndex.all(), NDArrayIndex.point(1), NDArrayIndex.all(), NDArrayIndex.all()).muli(100)
                            .addi(200);
            allImages.get(NDArrayIndex.all(), NDArrayIndex.point(2), NDArrayIndex.all(), NDArrayIndex.all()).muli(0.001)
                            .subi(10);

            INDArray labels = Nd4j.linspace(1, nChannels, nChannels).reshape(nChannels, 1);
            sampleDataSet = new DataSet(allImages, labels);

            expectedMean = allImages.mean(0, 2, 3);
            expectedStd = allImages.std(0, 2, 3);

            expectedLabelMean = labels.mean(0);
            expectedLabelStd = labels.std(0);

            expectedMin = allImages.min(0, 2, 3);
            expectedMax = allImages.max(0, 2, 3);
        }
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:21,代码来源:PreProcessor3D4DTest.java


示例14: testGetRow

import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Test
public void testGetRow() {
    INDArray arr = Nd4j.linspace(1, 6, 6).reshape(2, 3);
    INDArray get = arr.getRow(1);
    INDArray get2 = arr.get(NDArrayIndex.point(1), NDArrayIndex.all());
    INDArray assertion = Nd4j.create(new double[] {4, 5, 6});
    assertEquals(assertion, get);
    assertEquals(get, get2);
    get2.assign(Nd4j.linspace(1, 3, 3));
    assertEquals(Nd4j.linspace(1, 3, 3), get2);

    INDArray threeByThree = Nd4j.linspace(1, 9, 9).reshape(3, 3);
    INDArray offsetTest = threeByThree.get(new SpecifiedIndex(1, 2), NDArrayIndex.all());
    INDArray threeByThreeAssertion = Nd4j.create(new double[][] {{4, 5, 6}, {7, 8, 9}});

    assertEquals(threeByThreeAssertion, offsetTest);
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:18,代码来源:SlicingTestsC.java


示例15: testResolvePointVector

import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Test
public void testResolvePointVector() {
    INDArray arr = Nd4j.linspace(1, 4, 4);
    INDArrayIndex[] getPoint = {NDArrayIndex.point(1)};
    INDArrayIndex[] resolved = NDArrayIndex.resolve(arr.shape(), getPoint);
    if (getPoint.length == resolved.length)
        assertArrayEquals(getPoint, resolved);
    else {
        assertEquals(2, resolved.length);
        assertTrue(resolved[0] instanceof PointIndex);
        assertEquals(0, resolved[0].current());
        assertTrue(resolved[1] instanceof PointIndex);
        assertEquals(1, resolved[1].current());
    }

}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:17,代码来源:NDArrayIndexResolveTests.java


示例16: testIndexPointInterval

import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Test
@Ignore
public void testIndexPointInterval() {
    INDArray zeros = Nd4j.zeros(3, 3, 3);
    INDArrayIndex x = NDArrayIndex.point(1);
    INDArrayIndex y = NDArrayIndex.interval(1, 2, true);
    INDArrayIndex z = NDArrayIndex.point(1);
    INDArray value = Nd4j.ones(1, 2);
    zeros.put(new INDArrayIndex[] {x, y, z}, value);

    String f1 = "[[[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]\n" + "  [[0,00,0,00,0,00]\n"
                    + " [0,00,1,00,0,00]\n" + " [0,00,1,00,0,00]]\n" + "  [[0,00,0,00,0,00]\n"
                    + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";

    String f2 = "[[[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]\n" + "  [[0.00,0.00,0.00]\n"
                    + " [0.00,1.00,0.00]\n" + " [0.00,1.00,0.00]]\n" + "  [[0.00,0.00,0.00]\n"
                    + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";

    if (!zeros.toString().equals(f2) && !zeros.toString().equals(f1))
        assertEquals(f2, zeros.toString());

}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:23,代码来源:ShapeResolutionTestsC.java


示例17: setStateViewArray

import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Override
public void setStateViewArray(INDArray viewArray, int[] gradientShape, char gradientOrder, boolean initialize) {
    if (!viewArray.isRowVector())
        throw new IllegalArgumentException("Invalid input: expect row vector input");
    if (initialize)
        viewArray.assign(0);
    int length = viewArray.length();
    this.m = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2));
    this.u = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length));

    //Reshape to match the expected shape of the input gradient arrays
    this.m = Shape.newShapeNoCopy(this.m, gradientShape, gradientOrder == 'f');
    this.u = Shape.newShapeNoCopy(this.u, gradientShape, gradientOrder == 'f');
    if (m == null || u == null)
        throw new IllegalStateException("Could not correctly reshape gradient view arrays");

    this.gradientReshapeOrder = gradientOrder;
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:19,代码来源:AdaMaxUpdater.java


示例18: testIndexIntervalAll

import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Test
@Ignore
public void testIndexIntervalAll() {
    INDArray zeros = Nd4j.zeros(3, 3, 3);
    INDArrayIndex x = NDArrayIndex.interval(0, 1, true);
    INDArrayIndex y = NDArrayIndex.all();
    INDArrayIndex z = NDArrayIndex.interval(1, 2, true);
    INDArray value = Nd4j.ones(2, 6);
    zeros.put(new INDArrayIndex[] {x, y, z}, value);

    String f1 = "[[[0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]]\n" + "  [[0,00,1,00,1,00]\n"
                    + " [0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]]\n" + "  [[0,00,0,00,0,00]\n"
                    + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";

    String f2 = "[[[0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]]\n" + "  [[0.00,1.00,1.00]\n"
                    + " [0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]]\n" + "  [[0.00,0.00,0.00]\n"
                    + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";

    if (!zeros.toString().equals(f1) && !zeros.toString().equals(f2))
        assertEquals(f2, zeros.toString());
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:22,代码来源:ShapeResolutionTestsC.java


示例19: testIndexPointIntervalAll

import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Test
@Ignore
public void testIndexPointIntervalAll() {
    INDArray zeros = Nd4j.zeros(3, 3, 3);
    INDArrayIndex x = NDArrayIndex.point(1);
    INDArrayIndex y = NDArrayIndex.all();
    INDArrayIndex z = NDArrayIndex.interval(1, 2, true);
    INDArray value = Nd4j.ones(3, 2);
    zeros.put(new INDArrayIndex[] {x, y, z}, value);

    String f1 = "[[[0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]\n" + "  [[0,00,1,00,1,00]\n"
                    + " [0,00,1,00,1,00]\n" + " [0,00,1,00,1,00]]\n" + "  [[0,00,0,00,0,00]\n"
                    + " [0,00,0,00,0,00]\n" + " [0,00,0,00,0,00]]]";

    String f2 = "[[[0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]\n" + "  [[0.00,1.00,1.00]\n"
                    + " [0.00,1.00,1.00]\n" + " [0.00,1.00,1.00]]\n" + "  [[0.00,0.00,0.00]\n"
                    + " [0.00,0.00,0.00]\n" + " [0.00,0.00,0.00]]]";

    if (!zeros.toString().equals(f1) && !zeros.toString().equals(f2))
        assertEquals(f2, zeros.toString());
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:22,代码来源:ShapeResolutionTestsC.java


示例20: setStateViewArray

import org.nd4j.linalg.indexing.NDArrayIndex; //导入依赖的package包/类
@Override
public void setStateViewArray(INDArray viewArray, int[] gradientShape, char gradientOrder, boolean initialize) {
    if (!viewArray.isRowVector())
        throw new IllegalArgumentException("Invalid input: expect row vector input");
    if (initialize)
        viewArray.assign(0);
    int length = viewArray.length();
    this.msg = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2));
    this.msdx = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length));

    //Reshape to match the expected shape of the input gradient arrays
    this.msg = Shape.newShapeNoCopy(this.msg, gradientShape, gradientOrder == 'f');
    this.msdx = Shape.newShapeNoCopy(this.msdx, gradientShape, gradientOrder == 'f');
    if (msg == null || msdx == null)
        throw new IllegalStateException("Could not correctly reshape gradient view arrays");
}
 
开发者ID:deeplearning4j,项目名称:nd4j,代码行数:17,代码来源:AdaDeltaUpdater.java



注:本文中的org.nd4j.linalg.indexing.NDArrayIndex类示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Java XulRunner类代码示例发布时间:2022-05-22
下一篇:
Java TreeTransformer类代码示例发布时间:2022-05-22
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap