本文整理汇总了Java中org.nd4j.linalg.api.buffer.DataBuffer类的典型用法代码示例。如果您正苦于以下问题:Java DataBuffer类的具体用法?Java DataBuffer怎么用?Java DataBuffer使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
DataBuffer类属于org.nd4j.linalg.api.buffer包,在下文中一共展示了DataBuffer类的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Java代码示例。
示例1: copy
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* Compute y <- x (copy a matrix)
*/
@Override
public INDArray copy(INDArray x, INDArray y) {
DataTypeValidation.assertSameDataType(x,y);
if(x.data().dataType().equals(DataBuffer.DOUBLE))
JavaBlas.rcopy(
x.length(),
x.data().asDouble(),
x.offset(),
x.secondaryStride(),
y.data().asDouble(),
y.offset(),
y.secondaryStride());
else
JavaBlas.rcopy(
x.length(),
x.data().asFloat(),
x.offset(),
x.secondaryStride(),
y.data().asFloat(),
y.offset(),
y.secondaryStride());
return y;
}
开发者ID:wlin12,项目名称:JNN,代码行数:28,代码来源:BlasWrapper.java
示例2: applyTransformToOrigin
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* Apply the transformation at from[i]
*
* @param i the index of the element to apply the transform to
*/
@Override
public void applyTransformToOrigin(INDArray origin,int i) {
if(origin instanceof IComplexNumber) {
IComplexNDArray c2 = (IComplexNDArray) origin;
IComplexNumber transformed = apply(origin,getFromOrigin(origin,i),i);
c2.putScalar(i,transformed);
}
else {
Number f = apply(origin,getFromOrigin(origin,i),i);
double val = f.doubleValue();
if(Double.isNaN(val) || Double.isInfinite(val))
val = Nd4j.EPS_THRESHOLD;
if(origin.data().dataType().equals(DataBuffer.FLOAT))
origin.putScalar(i, val);
else
origin.putScalar(i, val);
}
}
开发者ID:wlin12,项目名称:JNN,代码行数:25,代码来源:BaseElementWiseOp.java
示例3: cumsum
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
public static Function<INDArray,INDArray> cumsum() {
return new Function<INDArray, INDArray>() {
@Override
public INDArray apply(INDArray input) {
double s = 0.0;
for (int i = 0; i < input.length(); i++) {
if(input.data().dataType().equals(DataBuffer.FLOAT))
s += input.getDouble(i);
else
s+= input.getDouble(i);
input.putScalar(i, s);
}
return input;
}
};
}
开发者ID:wlin12,项目名称:JNN,代码行数:18,代码来源:DimensionFunctions.java
示例4: nrm2
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
@Override
public double nrm2(IComplexNDArray x) {
if(x.data().dataType().equals(DataBuffer.FLOAT))
return NativeBlas.scnrm2(
x.length(),
x.data().asFloat(),
x.offset(),
x.secondaryStride());
else if(x.data().dataType().equals(DataBuffer.DOUBLE))
return NativeBlas.dznrm2(
x.length(),
x.data().asDouble(),
x.offset(),
x.secondaryStride());
throw new IllegalStateException("Illegal data type");
}
开发者ID:wlin12,项目名称:JNN,代码行数:20,代码来源:BlasWrapper.java
示例5: dotc
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* Compute x^T * y (dot product)
*/
@Override
public IComplexNumber dotc(IComplexNDArray x, IComplexNDArray y) {
DataTypeValidation.assertSameDataType(x,y);
if(x.data().dataType().equals(DataBuffer.FLOAT))
return new ComplexFloat(NativeBlas.cdotc(
x.length(),
x.data().asFloat(),
x.blasOffset(),
x.secondaryStride(),
y.data().asFloat(),
y.blasOffset(),
y.secondaryStride()));
else if(x.data().dataType().equals(DataBuffer.DOUBLE))
return new ComplexDouble(
NativeBlas.zdotc(
x.length(),
x.data().asDouble(),
x.blasOffset(),
x.secondaryStride(),
y.data().asDouble(),
y.blasOffset(),
y.secondaryStride()));
throw new IllegalStateException("Illegal data type");
}
开发者ID:wlin12,项目名称:JNN,代码行数:28,代码来源:BlasWrapper.java
示例6: swap
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* Compute x <-> y (swap two matrices)
*/
@Override
public INDArray swap(INDArray x, INDArray y) {
//NativeBlas.dswap(x.length(), x.data(), 0, 1, y.data(), 0, 1);
DataTypeValidation.assertSameDataType(x,y);
if(x.data().dataType().equals(DataBuffer.FLOAT))
JavaBlas.rswap(
x.length(),
x.data().asFloat(),
x.offset(),
x.secondaryStride(),
y.data().asFloat(),
y.offset(),
y.secondaryStride());
else
JavaBlas.rswap(
x.length(),
x.data().asDouble(),
x.offset(),
x.secondaryStride(),
y.data().asDouble(),
y.offset(),
y.secondaryStride());
return y;
}
开发者ID:wlin12,项目名称:JNN,代码行数:28,代码来源:BlasWrapper.java
示例7: iamax
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* Compute index of element with largest absolute value (index of absolute
* value maximum)
*/
@Override
public int iamax(INDArray x) {
if(x.data().dataType().equals(DataBuffer.FLOAT))
return NativeBlas.isamax(
x.length(),
x.data().asFloat(),
x.offset(),
x.secondaryStride()) - 1;
else if(x.data().dataType().equals(DataBuffer.DOUBLE)) {
return NativeBlas.idamax(
x.length(),
x.data().asDouble(),
x.offset(),
x.secondaryStride()) - 1;
}
throw new IllegalStateException("Illegal data type");
}
开发者ID:wlin12,项目名称:JNN,代码行数:24,代码来源:BlasWrapper.java
示例8: asum
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
@Override
public double asum(IComplexNDArray x) {
if(x.data().dataType().equals(DataBuffer.FLOAT)) {
return NativeBlas.scasum(
x.length(),
x.data().asFloat(),
x.offset() / 2,
x.secondaryStride());
}
else if(x.data().dataType().equals(DataBuffer.DOUBLE)) {
return NativeBlas.dzasum(
x.length(),
x.data().asDouble(),
x.offset() / 2,
x.secondaryStride());
}
throw new IllegalStateException("Illegal data type");
}
开发者ID:wlin12,项目名称:JNN,代码行数:23,代码来源:BlasWrapper.java
示例9: posv
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
@Override
public void posv(char uplo, INDArray A, INDArray B) {
int n = A.rows();
int nrhs = B.columns();
int info = -1;
DataTypeValidation.assertSameDataType(A,B);
if(A.data().dataType().equals(DataBuffer.FLOAT))
info = NativeBlas.sposv(
uplo,
n,
nrhs,
A.data().asFloat(),
A.offset(),
A.rows(),
B.data().asFloat(),
B.offset(),
B.rows());
else if(A.data().dataType().equals(DataBuffer.DOUBLE)) {
info = NativeBlas.dposv(
uplo,
n,
nrhs,
A.data().asDouble(),
A.offset(),
A.rows(),
B.data().asDouble(),
B.offset(),
B.rows());
}
checkInfo("DPOSV", info);
if (info > 0)
throw new LapackArgumentException("DPOSV",
"Leading minor of order i of A is not positive definite.");
}
开发者ID:wlin12,项目名称:JNN,代码行数:35,代码来源:BlasWrapper.java
示例10: syev
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
@Override
public int syev(char jobz, char uplo, INDArray a, INDArray w) {
int info = -1;
DataTypeValidation.assertSameDataType(a,w);
if(a.data().dataType().equals(DataBuffer.FLOAT)) {
info = NativeBlas.ssyev(
jobz,
uplo,
a.rows(),
a.data().asFloat(),
a.offset(),
a.rows(),
w.data().asFloat(),
w.offset());
}
else {
info = NativeBlas.dsyev(
jobz,
uplo,
a.rows(),
a.data().asDouble(),
a.offset(),
a.rows(),
w.data().asDouble(),
w.offset());
}
if (info > 0)
throw new LapackConvergenceException("SYEV",
"Eigenvalues could not be computed " + info
+ " off-diagonal elements did not converge");
return info;
}
开发者ID:wlin12,项目名称:JNN,代码行数:39,代码来源:BlasWrapper.java
示例11: alloc
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* Allocate and return a pointer
* based on the length of the ndarray
* @param ndarray the ndarray to allocate
* @return the allocated pointer
*/
public static Pointer alloc(JCublasNDArray ndarray) {
Pointer ret = new Pointer();
//allocate memory for the pointer
Pointer toData =null;
if(ndarray.data().dataType().equals(DataBuffer.FLOAT))
toData = Pointer.to(ndarray.data().asFloat()).withByteOffset(ndarray.offset() * size(ndarray));
else
toData = Pointer.to(ndarray.data().asDouble()).withByteOffset(ndarray.offset() * size(ndarray));
JCublas.cublasAlloc(
ndarray.length(),
size(ndarray)
, ret);
/* Copy from data to pointer at majorStride() (you want to stride through the data properly) incrementing by 1 for the pointer on the GPU.
* This allows us to copy only what we need. */
if(ndarray.length() == ndarray.data().length())
JCublas.cublasSetVector(
ndarray.length(),
size(ndarray),
toData,
1,
ret,
1);
else
JCublas.cublasSetVector(
ndarray.length(),
size(ndarray),
toData,
ndarray.majorStride(),
ret,
1);
return ret;
}
开发者ID:wlin12,项目名称:JNN,代码行数:46,代码来源:SimpleJCublas.java
示例12: copy
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* Copy x to y
* @param x the origin
* @param y the destination
*/
public static void copy(IComplexNDArray x, IComplexNDArray y) {
DataTypeValidation.assertSameDataType(x,y);
JCublas.cublasInit();
JCublasComplexNDArray xC = (JCublasComplexNDArray) x;
JCublasComplexNDArray yC = (JCublasComplexNDArray) y;
Pointer xCPointer = alloc(xC);
Pointer yCPointer = alloc(yC);
if(xC.data().dataType().equals(DataBuffer.FLOAT)) {
JCublas.cublasScopy(
x.length(),
xCPointer,
1,
yCPointer,
1);
getData(yC,yCPointer,Pointer.to(yC.data().asFloat()));
}
else {
JCublas.cublasDcopy(
x.length(),
xCPointer,
1,
yCPointer,
1);
getData(yC,yCPointer,Pointer.to(yC.data().asDouble()));
}
free(xCPointer,yCPointer);
}
开发者ID:wlin12,项目名称:JNN,代码行数:45,代码来源:SimpleJCublas.java
示例13: cumsumi
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* Cumulative sum along a dimension
*
* @param dimension the dimension to perform cumulative sum along
* @return the cumulative sum along the specified dimension
*/
@Override
public INDArray cumsumi(int dimension) {
if(isVector()) {
double s = 0.0;
for (int i = 0; i < length; i++) {
if(data.dataType().equals(DataBuffer.FLOAT))
s += getDouble(i);
else
s+= getDouble(i);
putScalar(i, s);
}
}
else if(dimension == Integer.MAX_VALUE || dimension == shape.length - 1) {
INDArray flattened = ravel().dup();
double prevVal = flattened.getDouble(0);
for(int i = 1; i < flattened.length(); i++) {
double d = prevVal + flattened.getDouble(i);
flattened.putScalar(i,d);
prevVal = d;
}
return flattened;
}
else {
for(int i = 0; i < vectorsAlongDimension(dimension); i++) {
INDArray vec = vectorAlongDimension(i,dimension);
vec.cumsumi(0);
}
}
return this;
}
开发者ID:wlin12,项目名称:JNN,代码行数:45,代码来源:BaseNDArray.java
示例14: dot
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* Compute x^T * y (dot product)
*/
@Override
public double dot(INDArray x, INDArray y) {
//return NativeBlas.ddot(x.length(), x.data(), 0, 1, y.data(), 0, 1);
DataTypeValidation.assertSameDataType(x,y);
if(x.data().dataType().equals(DataBuffer.FLOAT))
return JavaBlas.rdot(
x.length(),
x.data().asFloat(),
x.offset(),
x.secondaryStride(),
y.data().asFloat(),
y.offset(),
y.secondaryStride());
else if(x.data().dataType().equals(DataBuffer.DOUBLE)) {
return JavaBlas.rdot(
x.length(),
x.data().asDouble(),
x.offset(),
x.secondaryStride(),
y.data().asDouble(),
y.offset(),
y.secondaryStride());
}
throw new IllegalStateException("Illegal data type");
}
开发者ID:wlin12,项目名称:JNN,代码行数:31,代码来源:BlasWrapper.java
示例15: addi
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* in place addition of two matrices
*
* @param other the second ndarray to add
* @param result the result ndarray
* @return the result of the addition
*/
@Override
public INDArray addi(INDArray other, INDArray result) {
if (other.isScalar()) {
return result.addi(other.getDouble(0),result);
}
if (isScalar()) {
return other.addi(getDouble(0), result);
}
if (result == this) {
if(data.dataType().equals(DataBuffer.DOUBLE))
Nd4j.getBlasWrapper().axpy(1.0, other, result);
else
Nd4j.getBlasWrapper().axpy(1.0f,other,result);
}
else if (result == other) {
if(data.dataType().equals(DataBuffer.DOUBLE))
Nd4j.getBlasWrapper().axpy(1.0, this, result);
else
Nd4j.getBlasWrapper().axpy(1.0f,this,result);
}
else {
INDArray resultLinear = result.linearView();
INDArray otherLinear = other.linearView();
INDArray linear = linearView();
for(int i = 0; i < resultLinear.length(); i++) {
resultLinear.putScalar(i,otherLinear.getDouble(i) + linear.getDouble(i));
}
}
return result;
}
开发者ID:wlin12,项目名称:JNN,代码行数:47,代码来源:BaseNDArray.java
示例16: element
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* Returns a scalar (individual element)
* of a scalar ndarray
*
* @return the individual item in this ndarray
*/
@Override
public Object element() {
if(!isScalar())
throw new IllegalStateException("Unable to retrieve element from non scalar matrix");
if(data.dataType().equals(DataBuffer.FLOAT))
return data.getFloat(offset);
return data.getDouble(offset);
}
开发者ID:wlin12,项目名称:JNN,代码行数:15,代码来源:BaseNDArray.java
示例17: BaseComplexNDArray
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
public BaseComplexNDArray(DataBuffer data, int[] shape, int[] stride, int offset) {
this.data = data;
this.stride = stride;
this.offset = offset;
this.ordering = Nd4j.order();
initShape(shape);
}
开发者ID:wlin12,项目名称:JNN,代码行数:9,代码来源:BaseComplexNDArray.java
示例18: subi
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
/**
* in place subtraction of two matrices
*
* @param other the second ndarray to subtract
* @param result the result ndarray
* @return the result of the subtraction
*/
@Override
public IComplexNDArray subi(INDArray other, INDArray result) {
IComplexNDArray cOther = (IComplexNDArray) other;
IComplexNDArray cResult = (IComplexNDArray) result;
if (other.isScalar())
return subi(cOther.getComplex(0), result);
if (result == this)
Nd4j.getBlasWrapper().axpy(Nd4j.NEG_UNIT, cOther, cResult);
else if (result == other) {
if(data.dataType().equals(DataBuffer.DOUBLE)) {
Nd4j.getBlasWrapper().scal(Nd4j.NEG_UNIT.asDouble(), cResult);
Nd4j.getBlasWrapper().axpy(Nd4j.UNIT, this, cResult);
}
else {
Nd4j.getBlasWrapper().scal(Nd4j.NEG_UNIT.asFloat(), cResult);
Nd4j.getBlasWrapper().axpy(Nd4j.UNIT, this, cResult);
}
}
else {
Nd4j.getBlasWrapper().copy(this, result);
Nd4j.getBlasWrapper().axpy(Nd4j.NEG_UNIT, cOther, cResult);
}
return cResult;
}
开发者ID:wlin12,项目名称:JNN,代码行数:36,代码来源:BaseComplexNDArray.java
示例19: testRowsColumns
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
@Test
public void testRowsColumns() {
DataBuffer data = Nd4j.linspace(1, 6, 6).data();
INDArray rows = Nd4j.create(data, new int[]{2, 3});
assertEquals(2,rows.rows());
assertEquals(3,rows.columns());
INDArray columnVector = Nd4j.create(data, new int[]{6, 1});
assertEquals(6,columnVector.rows());
assertEquals(1,columnVector.columns());
INDArray rowVector = Nd4j.create(data, new int[]{6});
assertEquals(1,rowVector.rows());
assertEquals(6,rowVector.columns());
}
开发者ID:wlin12,项目名称:JNN,代码行数:15,代码来源:NDArrayTests.java
示例20: testVectorInit
import org.nd4j.linalg.api.buffer.DataBuffer; //导入依赖的package包/类
@Test
public void testVectorInit() {
DataBuffer data = Nd4j.linspace(1, 4, 4).data();
IComplexNDArray arr = Nd4j.createComplex(data, new int[]{4});
assertEquals(true,arr.isRowVector());
IComplexNDArray arr2 = Nd4j.createComplex(data, new int[]{1, 4});
assertEquals(true,arr2.isRowVector());
IComplexNDArray columnVector = Nd4j.createComplex(data, new int[]{4, 1});
assertEquals(true,columnVector.isColumnVector());
}
开发者ID:wlin12,项目名称:JNN,代码行数:12,代码来源:ComplexNDArrayTests.java
注:本文中的org.nd4j.linalg.api.buffer.DataBuffer类示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论