本文整理汇总了Golang中github.com/gonum/floats.Equal函数的典型用法代码示例。如果您正苦于以下问题:Golang Equal函数的具体用法?Golang Equal怎么用?Golang Equal使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了Equal函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Golang代码示例。
示例1: TestFlattenTriangular
func TestFlattenTriangular(t *testing.T) {
for i, test := range []struct {
a [][]float64
ans []float64
ul blas.Uplo
}{
{
a: [][]float64{
{1, 2, 3},
{0, 4, 5},
{0, 0, 6},
},
ul: blas.Upper,
ans: []float64{1, 2, 3, 4, 5, 6},
},
{
a: [][]float64{
{1, 0, 0},
{2, 3, 0},
{4, 5, 6},
},
ul: blas.Lower,
ans: []float64{1, 2, 3, 4, 5, 6},
},
} {
a := flattenTriangular(test.a, test.ul)
if !floats.Equal(a, test.ans) {
t.Errorf("Case %v. Want %v, got %v.", i, test.ans, a)
}
}
}
开发者ID:gidden,项目名称:cloudlus,代码行数:31,代码来源:common_test.go
示例2: Func
func (b *BatchGradient) Func(params []float64) float64 {
if floats.Equal(params, b.lastParams) {
return b.lastFunc
}
b.lastFunc = b.funcGrad(params, b.lastGrad)
return b.lastFunc
}
开发者ID:reggo,项目名称:reggo,代码行数:7,代码来源:combineloss_2.go
示例3: initNextLinesearch
func (ls *LinesearchMethod) initNextLinesearch(loc *Location, xNext []float64) (EvaluationType, IterationType, error) {
copy(ls.x, loc.X)
var stepSize float64
if ls.first {
stepSize = ls.NextDirectioner.InitDirection(loc, ls.dir)
ls.first = false
} else {
stepSize = ls.NextDirectioner.NextDirection(loc, ls.dir)
}
projGrad := floats.Dot(loc.Gradient, ls.dir)
if projGrad >= 0 {
ls.evalType = NoEvaluation
ls.iterType = NoIteration
return ls.evalType, ls.iterType, ErrNonNegativeStepDirection
}
ls.evalType = ls.Linesearcher.Init(loc.F, projGrad, stepSize)
floats.AddScaledTo(xNext, ls.x, stepSize, ls.dir)
// Compare the starting point for the current iteration with the next
// evaluation point to make sure that rounding errors do not prevent progress.
if floats.Equal(ls.x, xNext) {
ls.evalType = NoEvaluation
ls.iterType = NoIteration
return ls.evalType, ls.iterType, ErrNoProgress
}
ls.iterType = MinorIteration
return ls.evalType, ls.iterType, nil
}
开发者ID:jmptrader,项目名称:optimize,代码行数:32,代码来源:linesearch.go
示例4: initNextLinesearch
// initNextLinesearch initializes the next linesearch using the previous
// complete location stored in loc. It fills loc.X and returns an evaluation
// to be performed at loc.X.
func (ls *LinesearchMethod) initNextLinesearch(loc *Location) (Operation, error) {
copy(ls.x, loc.X)
var step float64
if ls.first {
ls.first = false
step = ls.NextDirectioner.InitDirection(loc, ls.dir)
} else {
step = ls.NextDirectioner.NextDirection(loc, ls.dir)
}
projGrad := floats.Dot(loc.Gradient, ls.dir)
if projGrad >= 0 {
return ls.error(ErrNonNegativeStepDirection)
}
op := ls.Linesearcher.Init(loc.F, projGrad, step)
if !op.isEvaluation() {
panic("linesearch: Linesearcher returned invalid operation")
}
floats.AddScaledTo(loc.X, ls.x, step, ls.dir)
if floats.Equal(ls.x, loc.X) {
// Step size is so small that the next evaluation point is
// indistinguishable from the starting point for the current iteration
// due to rounding errors.
return ls.error(ErrNoProgress)
}
ls.lastStep = step
ls.eval = NoOperation // Invalidate all fields of loc.
ls.lastOp = op
return ls.lastOp, nil
}
开发者ID:jacobxk,项目名称:optimize,代码行数:38,代码来源:linesearch.go
示例5: evaluate
// evaluate evaluates the problem given by p at xNext, stores the answer into
// loc and updates stats. If loc.X is not equal to xNext, then unused fields of
// loc are set to NaN.
// evaluate panics if the function does not support the requested evalType.
func evaluate(p *Problem, evalType EvaluationType, xNext []float64, loc *Location, stats *Stats) {
if !floats.Equal(loc.X, xNext) {
if evalType == NoEvaluation {
// Optimizers should not request NoEvaluation at a new location.
// The intent and therefore an appropriate action are both unclear.
panic("optimize: no evaluation requested at new location")
}
invalidate(loc)
copy(loc.X, xNext)
}
toEval := evalType
if evalType&FuncEvaluation != 0 {
loc.F = p.Func(loc.X)
stats.FuncEvaluations++
toEval &= ^FuncEvaluation
}
if evalType&GradEvaluation != 0 {
p.Grad(loc.X, loc.Gradient)
stats.GradEvaluations++
toEval &= ^GradEvaluation
}
if evalType&HessEvaluation != 0 {
p.Hess(loc.X, loc.Hessian)
stats.HessEvaluations++
toEval &= ^HessEvaluation
}
if toEval != NoEvaluation {
panic(fmt.Sprintf("optimize: unknown evaluation type %v", evalType))
}
}
开发者ID:jmptrader,项目名称:optimize,代码行数:36,代码来源:local.go
示例6: DrsclTest
func DrsclTest(t *testing.T, impl Drscler) {
for _, test := range []struct {
x []float64
a float64
}{
{
x: []float64{1, 2, 3, 4, 5},
a: 4,
},
{
x: []float64{1, 2, 3, 4, 5},
a: math.MaxFloat64,
},
{
x: []float64{1, 2, 3, 4, 5},
a: 1e-307,
},
} {
xcopy := make([]float64, len(test.x))
copy(xcopy, test.x)
// Cannot test the scaling directly because of floating point scaling issues
// (the purpose of Drscl). Instead, check that scaling and scaling back
// yeilds approximately x. If overflow or underflow occurs then the scaling
// won't match.
impl.Drscl(len(test.x), test.a, xcopy, 1)
if floats.Equal(xcopy, test.x) {
t.Errorf("x unchanged during call to drscl. a = %v, x = %v.", test.a, test.x)
}
impl.Drscl(len(test.x), 1/test.a, xcopy, 1)
if !floats.EqualApprox(xcopy, test.x, 1e-14) {
t.Errorf("x not equal after scaling and unscaling. a = %v, x = %v.", test.a, test.x)
}
}
}
开发者ID:rawlingsj,项目名称:gofabric8,代码行数:35,代码来源:drscl.go
示例7: Grad
func (b *BatchGradient) Grad(params, deriv []float64) {
if floats.Equal(params, b.lastParams) {
copy(deriv, b.lastGrad)
return
}
b.lastFunc = b.funcGrad(params, b.lastGrad)
copy(deriv, b.lastGrad)
}
开发者ID:reggo,项目名称:reggo,代码行数:8,代码来源:combineloss_2.go
示例8: TestQuantile
func TestQuantile(t *testing.T) {
cumulantKinds := []CumulantKind{Empirical}
for i, test := range []struct {
p []float64
x []float64
w []float64
ans [][]float64
}{
{
p: []float64{0, 0.05, 0.1, 0.15, 0.45, 0.5, 0.55, 0.85, 0.9, 0.95, 1},
x: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
w: nil,
ans: [][]float64{{1, 1, 1, 2, 5, 5, 6, 9, 9, 10, 10}},
},
{
p: []float64{0, 0.05, 0.1, 0.15, 0.45, 0.5, 0.55, 0.85, 0.9, 0.95, 1},
x: []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
w: []float64{3, 3, 3, 3, 3, 3, 3, 3, 3, 3},
ans: [][]float64{{1, 1, 1, 2, 5, 5, 6, 9, 9, 10, 10}},
},
} {
copyX := make([]float64, len(test.x))
copy(copyX, test.x)
var copyW []float64
if test.w != nil {
copyW = make([]float64, len(test.w))
copy(copyW, test.w)
}
for j, p := range test.p {
for k, kind := range cumulantKinds {
v := Quantile(p, kind, test.x, test.w)
if !floats.Equal(copyX, test.x) {
t.Errorf("x changed for case %d kind %d percentile %v", i, k, p)
}
if !floats.Equal(copyW, test.w) {
t.Errorf("x changed for case %d kind %d percentile %v", i, k, p)
}
if v != test.ans[k][j] {
t.Errorf("mismatch case %d kind %d percentile %v. Expected: %v, found: %v", i, k, p, test.ans[k][j], v)
}
}
}
}
}
开发者ID:cjslep,项目名称:stat,代码行数:44,代码来源:stat_test.go
示例9: TestCDF
func TestCDF(t *testing.T) {
cumulantKinds := []CumulantKind{Empirical}
for i, test := range []struct {
q []float64
x []float64
weights []float64
ans [][]float64
}{
{},
{
q: []float64{0, 0.9, 1, 1.1, 2.9, 3, 3.1, 4.9, 5, 5.1},
x: []float64{1, 2, 3, 4, 5},
ans: [][]float64{{0, 0, 0.2, 0.2, 0.4, 0.6, 0.6, 0.8, 1, 1}},
},
{
q: []float64{0, 0.9, 1, 1.1, 2.9, 3, 3.1, 4.9, 5, 5.1},
x: []float64{1, 2, 3, 4, 5},
weights: []float64{1, 1, 1, 1, 1},
ans: [][]float64{{0, 0, 0.2, 0.2, 0.4, 0.6, 0.6, 0.8, 1, 1}},
},
} {
copyX := make([]float64, len(test.x))
copy(copyX, test.x)
var copyW []float64
if test.weights != nil {
copyW = make([]float64, len(test.weights))
copy(copyW, test.weights)
}
for j, q := range test.q {
for k, kind := range cumulantKinds {
v := CDF(q, kind, test.x, test.weights)
if !floats.Equal(copyX, test.x) {
t.Errorf("x changed for case %d kind %d percentile %v", i, k, q)
}
if !floats.Equal(copyW, test.weights) {
t.Errorf("x changed for case %d kind %d percentile %v", i, k, q)
}
if v != test.ans[k][j] {
t.Errorf("mismatch case %d kind %d percentile %v. Expected: %v, found: %v", i, k, q, test.ans[k][j], v)
}
}
}
}
}
开发者ID:cjslep,项目名称:stat,代码行数:44,代码来源:stat_test.go
示例10: TestFlatten2D
func TestFlatten2D(t *testing.T) {
s2d := [][]float64{{11, 22}, {33, 44}, {55, 66}}
expected := []float64{11, 22, 33, 44, 55, 66}
flatten := Flatten2D(s2d)
if !floats.Equal(flatten, expected) {
t.Fatalf("Flatten failed. expected %+v, got %+v", expected, flatten)
}
}
开发者ID:henrylee2cn,项目名称:gjoa,代码行数:10,代码来源:floatx_test.go
示例11: Iterate
func (ls *LinesearchMethod) Iterate(loc *Location, xNext []float64) (EvaluationType, IterationType, error) {
if ls.iterType == SubIteration {
// We needed to evaluate invalid fields of Location. Now we have them
// and can announce MajorIteration.
copy(xNext, loc.X)
ls.evalType = NoEvaluation
ls.iterType = MajorIteration
return ls.evalType, ls.iterType, nil
}
if ls.iterType == MajorIteration {
// The linesearch previously signaled MajorIteration. Since we're here,
// it means that the previous location is not good enough to converge,
// so start the next linesearch.
return ls.initNextLinesearch(loc, xNext)
}
projGrad := floats.Dot(loc.Gradient, ls.dir)
if ls.Linesearcher.Finished(loc.F, projGrad) {
copy(xNext, loc.X)
// Check if the last evaluation evaluated all fields of Location.
ls.evalType = complementEval(loc, ls.evalType)
if ls.evalType == NoEvaluation {
// Location is complete and MajorIteration can be announced directly.
ls.iterType = MajorIteration
} else {
// Location is not complete, evaluate its invalid fields in SubIteration.
ls.iterType = SubIteration
}
return ls.evalType, ls.iterType, nil
}
// Line search not done, just iterate.
stepSize, evalType, err := ls.Linesearcher.Iterate(loc.F, projGrad)
if err != nil {
ls.evalType = NoEvaluation
ls.iterType = NoIteration
return ls.evalType, ls.iterType, err
}
floats.AddScaledTo(xNext, ls.x, stepSize, ls.dir)
// Compare the starting point for the current iteration with the next
// evaluation point to make sure that rounding errors do not prevent progress.
if floats.Equal(ls.x, xNext) {
ls.evalType = NoEvaluation
ls.iterType = NoIteration
return ls.evalType, ls.iterType, ErrNoProgress
}
ls.evalType = evalType
ls.iterType = MinorIteration
return ls.evalType, ls.iterType, nil
}
开发者ID:jmptrader,项目名称:optimize,代码行数:53,代码来源:linesearch.go
示例12: TestLegendreSingle
func TestLegendreSingle(t *testing.T) {
for c, test := range []struct {
n int
min, max float64
}{
{
n: 100,
min: -1,
max: 1,
},
{
n: 50,
min: -3,
max: -1,
},
{
n: 1000,
min: 2,
max: 7,
},
} {
l := Legendre{}
n := test.n
xs := make([]float64, n)
weights := make([]float64, n)
l.FixedLocations(xs, weights, test.min, test.max)
xsSingle := make([]float64, n)
weightsSingle := make([]float64, n)
for i := range xsSingle {
xsSingle[i], weightsSingle[i] = l.FixedLocationSingle(n, i, test.min, test.max)
}
if !floats.Equal(xs, xsSingle) {
t.Errorf("Case %d: xs mismatch batch and single", c)
}
if !floats.Equal(weights, weightsSingle) {
t.Errorf("Case %d: weights mismatch batch and single", c)
}
}
}
开发者ID:sbinet,项目名称:gonum-integrate,代码行数:40,代码来源:legendre_test.go
示例13: denseEqual
func denseEqual(a *Dense, acomp matComp) bool {
ar2, ac2 := a.Dims()
if ar2 != acomp.r {
return false
}
if ac2 != acomp.c {
return false
}
if !floats.Equal(a.mat.Data, acomp.data) {
return false
}
return true
}
开发者ID:RomainVabre,项目名称:origin,代码行数:13,代码来源:mul_test.go
示例14: TestCopy2D
func TestCopy2D(t *testing.T) {
s1 := [][]float64{{11, 22}, {33, 44}, {55, 66}}
s2 := CopyFloat2D(s1)
for k, _ := range s1 {
if &s1[k] == &s2[k] {
t.Fatalf("Slices have the same address, not a copy.")
}
if !floats.Equal(s1[k], s2[k]) {
t.Fatalf("Copy failed. want: %+v, have: %+v", s1, s2)
}
}
}
开发者ID:henrylee2cn,项目名称:gjoa,代码行数:15,代码来源:floatx_test.go
示例15: TestPredictFeaturized
func TestPredictFeaturized(t *testing.T) {
for _, test := range []struct {
z []float64
featureWeights [][]float64
output []float64
Name string
}{
{
Name: "General",
z: []float64{1, 2, 3},
featureWeights: [][]float64{
{3, 4},
{1, 2},
{0.5, 0.4},
},
output: []float64{6.5, 9.2},
},
} {
zCopy := make([]float64, len(test.z))
copy(zCopy, test.z)
fwMat := flatten(test.featureWeights)
fwMatCopy := &mat64.Dense{}
fwMatCopy.Clone(fwMat)
output := make([]float64, len(test.output))
predictFeaturized(zCopy, fwMat, output)
// Test that z wasn't changed
if !floats.Equal(test.z, zCopy) {
t.Errorf("z changed during call")
}
if !floats.EqualApprox(output, test.output, 1e-14) {
t.Errorf("output doesn't match for test %v. Expected %v, found %v", test.Name, test.output, output)
}
}
}
开发者ID:reggo,项目名称:reggo,代码行数:38,代码来源:kitchensink_test.go
示例16: TestHistogram
func TestHistogram(t *testing.T) {
for i, test := range []struct {
x []float64
weights []float64
dividers []float64
ans []float64
}{
{
x: []float64{1, 3, 5, 6, 7, 8},
dividers: []float64{2, 4, 6, 7},
ans: []float64{1, 1, 1, 1, 2},
},
{
x: []float64{1, 3, 5, 6, 7, 8},
dividers: []float64{2, 4, 6, 7},
weights: []float64{1, 2, 1, 1, 1, 2},
ans: []float64{1, 2, 1, 1, 3},
},
{
x: []float64{1, 8},
dividers: []float64{2, 4, 6, 7},
weights: []float64{1, 2},
ans: []float64{1, 0, 0, 0, 2},
},
{
x: []float64{1, 8},
dividers: []float64{2, 4, 6, 7},
ans: []float64{1, 0, 0, 0, 1},
},
} {
hist := Histogram(nil, test.dividers, test.x, test.weights)
if !floats.Equal(hist, test.ans) {
t.Errorf("Hist mismatch case %d. Expected %v, Found %v", i, test.ans, hist)
}
}
}
开发者ID:cjslep,项目名称:stat,代码行数:36,代码来源:stat_test.go
示例17: Dorml2Test
//.........这里部分代码省略.........
{3, 5, 4, 20, 6},
{4, 3, 5, 20, 6},
{4, 5, 3, 20, 6},
{5, 3, 4, 20, 6},
{5, 4, 3, 20, 6},
} {
var ma, na, mc, nc int
if side == blas.Left {
ma = test.adim
na = test.common
mc = test.common
nc = test.cdim
} else {
ma = test.adim
na = test.common
mc = test.cdim
nc = test.common
}
// Generate a random matrix
lda := test.lda
if lda == 0 {
lda = na
}
a := make([]float64, ma*lda)
for i := range a {
a[i] = rnd.Float64()
}
ldc := test.ldc
if ldc == 0 {
ldc = nc
}
// Compute random C matrix
c := make([]float64, mc*ldc)
for i := range c {
c[i] = rnd.Float64()
}
// Compute LQ
k := min(ma, na)
tau := make([]float64, k)
work := make([]float64, 1)
impl.Dgelqf(ma, na, a, lda, tau, work, -1)
work = make([]float64, int(work[0]))
impl.Dgelqf(ma, na, a, lda, tau, work, len(work))
// Build Q from result
q := constructQ("LQ", ma, na, a, lda, tau)
cMat := blas64.General{
Rows: mc,
Cols: nc,
Stride: ldc,
Data: make([]float64, len(c)),
}
copy(cMat.Data, c)
cMatCopy := blas64.General{
Rows: cMat.Rows,
Cols: cMat.Cols,
Stride: cMat.Stride,
Data: make([]float64, len(cMat.Data)),
}
copy(cMatCopy.Data, cMat.Data)
switch {
default:
panic("bad test")
case side == blas.Left && trans == blas.NoTrans:
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, cMatCopy, 0, cMat)
case side == blas.Left && trans == blas.Trans:
blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, cMatCopy, 0, cMat)
case side == blas.Right && trans == blas.NoTrans:
blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMatCopy, q, 0, cMat)
case side == blas.Right && trans == blas.Trans:
blas64.Gemm(blas.NoTrans, blas.Trans, 1, cMatCopy, q, 0, cMat)
}
// Do Dorm2r ard compare
if side == blas.Left {
work = make([]float64, nc)
} else {
work = make([]float64, mc)
}
aCopy := make([]float64, len(a))
copy(aCopy, a)
tauCopy := make([]float64, len(tau))
copy(tauCopy, tau)
impl.Dorml2(side, trans, mc, nc, k, a, lda, tau, c, ldc, work)
if !floats.Equal(a, aCopy) {
t.Errorf("a changed in call")
}
if !floats.Equal(tau, tauCopy) {
t.Errorf("tau changed in call")
}
if !floats.EqualApprox(cMat.Data, c, 1e-14) {
isLeft := side == blas.Left
isTrans := trans == blas.Trans
t.Errorf("Multiplication mismatch. IsLeft = %v. IsTrans = %v", isLeft, isTrans)
}
}
}
}
}
开发者ID:rawlingsj,项目名称:gofabric8,代码行数:101,代码来源:dorml2.go
示例18: TestGradient
func TestGradient(t *testing.T) {
for i, test := range []struct {
nDim int
tol float64
method Method
}{
{
nDim: 2,
tol: 2e-4,
method: Forward,
},
{
nDim: 2,
tol: 1e-6,
method: Central,
},
{
nDim: 40,
tol: 2e-4,
method: Forward,
},
{
nDim: 40,
tol: 1e-6,
method: Central,
},
} {
x := make([]float64, test.nDim)
for i := range x {
x[i] = rand.Float64()
}
xcopy := make([]float64, len(x))
copy(xcopy, x)
r := Rosenbrock{len(x)}
trueGradient := make([]float64, len(x))
r.FDf(x, trueGradient)
settings := DefaultSettings()
settings.Method = test.method
// try with gradient nil
gradient := Gradient(nil, r.F, x, settings)
if !floats.EqualApprox(gradient, trueGradient, test.tol) {
t.Errorf("Case %v: gradient mismatch in serial with nil. Want: %v, Got: %v.", i, trueGradient, gradient)
}
if !floats.Equal(x, xcopy) {
t.Errorf("Case %v: x modified during call to gradient in serial with nil.", i)
}
for i := range gradient {
gradient[i] = rand.Float64()
}
Gradient(gradient, r.F, x, settings)
if !floats.EqualApprox(gradient, trueGradient, test.tol) {
t.Errorf("Case %v: gradient mismatch in serial. Want: %v, Got: %v.", i, trueGradient, gradient)
}
if !floats.Equal(x, xcopy) {
t.Errorf("Case %v: x modified during call to gradient in serial with non-nil.", i)
}
// Try with known value
for i := range gradient {
gradient[i] = rand.Float64()
}
settings.OriginKnown = true
settings.OriginValue = r.F(x)
Gradient(gradient, r.F, x, settings)
if !floats.EqualApprox(gradient, trueGradient, test.tol) {
t.Errorf("Case %v: gradient mismatch with known origin in serial. Want: %v, Got: %v.", i, trueGradient, gradient)
}
// Concurrently
for i := range gradient {
gradient[i] = rand.Float64()
}
settings.Concurrent = true
settings.OriginKnown = false
settings.Workers = 1000
Gradient(gradient, r.F, x, settings)
if !floats.EqualApprox(gradient, trueGradient, test.tol) {
t.Errorf("Case %v: gradient mismatch with unknown origin in parallel. Want: %v, Got: %v.", i, trueGradient, gradient)
}
if !floats.Equal(x, xcopy) {
t.Errorf("Case %v: x modified during call to gradient in parallel", i)
}
// Concurrently with origin known
for i := range gradient {
gradient[i] = rand.Float64()
}
settings.OriginKnown = true
Gradient(gradient, r.F, x, settings)
if !floats.EqualApprox(gradient, trueGradient, test.tol) {
t.Errorf("Case %v: gradient mismatch with known origin in parallel. Want: %v, Got: %v.", i, trueGradient, gradient)
}
// With default settings
for i := range gradient {
//.........这里部分代码省略.........
开发者ID:sbinet,项目名称:gonum-diff,代码行数:101,代码来源:gradient_test.go
示例19: Iterate
func (ls *LinesearchMethod) Iterate(loc *Location) (Operation, error) {
switch ls.lastOp {
case NoOperation:
// TODO(vladimir-ch): Either Init has not been called, or the caller is
// trying to resume the optimization run after Iterate previously
// returned with an error. Decide what is the proper thing to do. See also #125.
case MajorIteration:
// The previous updated location did not converge the full
// optimization. Initialize a new Linesearch.
return ls.initNextLinesearch(loc)
default:
// Update the indicator of valid fields of loc.
ls.eval |= ls.lastOp
if ls.nextMajor {
ls.nextMajor = false
// Linesearcher previously finished, and the invalid fields of loc
// have now been validated. Announce MajorIteration.
ls.lastOp = MajorIteration
return ls.lastOp, nil
}
}
// Continue the linesearch.
f := math.NaN()
if ls.eval&FuncEvaluation != 0 {
f = loc.F
}
projGrad := math.NaN()
if ls.eval&GradEvaluation != 0 {
projGrad = floats.Dot(loc.Gradient, ls.dir)
}
op, step, err := ls.Linesearcher.Iterate(f, projGrad)
if err != nil {
return ls.error(err)
}
switch op {
case MajorIteration:
// Linesearch has been finished.
ls.lastOp = complementEval(loc, ls.eval)
if ls.lastOp == NoOperation {
// loc is complete, MajorIteration can be declared directly.
ls.lastOp = MajorIteration
} else {
// Declare MajorIteration on the next call to Iterate.
ls.nextMajor = true
}
case FuncEvaluation, GradEvaluation, FuncEvaluation | GradEvaluation:
if step != ls.lastStep {
// We are moving to a new location, and not, say, evaluating extra
// information at the current location.
// Compute the next evaluation point and store it in loc.X.
floats.AddScaledTo(loc.X, ls.x, step, ls.dir)
if floats.Equal(ls.x, loc.X) {
// Step size has become so small that the next evaluation point is
// indistinguishable from the starting point for the current
// iteration due to rounding errors.
return ls.error(ErrNoProgress)
}
ls.lastStep = step
ls.eval = NoOperation // Indicate all invalid fields of loc.
}
ls.lastOp = op
default:
panic("linesearch: Linesearcher returned invalid operation")
}
return ls.lastOp, nil
}
开发者ID:jgcarvalho,项目名称:zdd,代码行数:78,代码来源:linesearch.go
示例20: marginalLikelihoodDerivative
func (g *GP) marginalLikelihoodDerivative(x, grad []float64, trainNoise bool, mem *margLikeMemory) {
// d/dTheta_j log[(p|X,theta)] =
// 1/2 * y^T * K^-1 dK/dTheta_j * K^-1 * y - 1/2 * tr(K^-1 * dK/dTheta_j)
// 1/2 * α^T * dK/dTheta_j * α - 1/2 * tr(K^-1 dK/dTheta_j)
// Multiply by the same -2
// -α^T * K^-1 * α + tr(K^-1 dK/dTheta_j)
// This first computation is an inner product.
n := len(g.outputs)
nHyper := g.kernel.NumHyper()
k := mem.k
chol := mem.chol
alpha := mem.alpha
dKdTheta := mem.dKdTheta
kInvDK := mem.kInvDK
y := mat64.NewVector(n, g.outputs)
var noise float64
if trainNoise {
noise = math.Exp(x[len(x)-1])
} else {
noise = g.noise
}
// If x is the same, then reuse what has been computed in the function.
if !floats.Equal(mem.lastX, x) {
copy(mem.lastX, x)
g.kernel.SetHyper(x[:nHyper])
g.setKernelMat(k, noise)
//chol.Cholesky(k, false)
chol.Factorize(k)
alpha.SolveCholeskyVec(chol, y)
}
g.setKernelMatDeriv(dKdTheta, trainNoise, noise)
for i := range dKdTheta {
kInvDK.SolveCholesky(chol, dKdTheta[i])
inner := mat64.Inner(alpha, dKdTheta[i], alpha)
grad[i] = -inner + mat64.Trace(kInvDK)
}
floats.Scale(1/float64(n), grad)
bounds := g.kernel.Bounds()
if trainNoise {
bounds = append(bounds, Bound{minLogNoise, maxLogNoise})
}
barrierGrad := make([]float64, len(grad))
for i, v := range x {
// Quadratic barrier penalty.
if v < bounds[i].Min {
diff := bounds[i].Min - v
barrierGrad[i] = -(barrierPow) * math.Pow(diff, barrierPow-1)
}
if v > bounds[i].Max {
diff := v - bounds[i].Max
barrierGrad[i] = (barrierPow) * math.Pow(diff, barrierPow-1)
}
}
fmt.Println("noise, minNoise", x[len(x)-1], bounds[len(x)-1].Min)
fmt.Println("barrier Grad", barrierGrad)
floats.Add(grad, barrierGrad)
//copy(grad, barrierGrad)
}
开发者ID:btracey,项目名称:gaussproc,代码行数:62,代码来源:gp.go
注:本文中的github.com/gonum/floats.Equal函数示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论