• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Golang floats.Equal函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Golang中github.com/gonum/floats.Equal函数的典型用法代码示例。如果您正苦于以下问题:Golang Equal函数的具体用法?Golang Equal怎么用?Golang Equal使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了Equal函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Golang代码示例。

示例1: TestFlattenTriangular

func TestFlattenTriangular(t *testing.T) {
	for i, test := range []struct {
		a   [][]float64
		ans []float64
		ul  blas.Uplo
	}{
		{
			a: [][]float64{
				{1, 2, 3},
				{0, 4, 5},
				{0, 0, 6},
			},
			ul:  blas.Upper,
			ans: []float64{1, 2, 3, 4, 5, 6},
		},
		{
			a: [][]float64{
				{1, 0, 0},
				{2, 3, 0},
				{4, 5, 6},
			},
			ul:  blas.Lower,
			ans: []float64{1, 2, 3, 4, 5, 6},
		},
	} {
		a := flattenTriangular(test.a, test.ul)
		if !floats.Equal(a, test.ans) {
			t.Errorf("Case %v. Want %v, got %v.", i, test.ans, a)
		}
	}
}
开发者ID:gidden,项目名称:cloudlus,代码行数:31,代码来源:common_test.go


示例2: Func

func (b *BatchGradient) Func(params []float64) float64 {
	if floats.Equal(params, b.lastParams) {
		return b.lastFunc
	}
	b.lastFunc = b.funcGrad(params, b.lastGrad)
	return b.lastFunc
}
开发者ID:reggo,项目名称:reggo,代码行数:7,代码来源:combineloss_2.go


示例3: initNextLinesearch

func (ls *LinesearchMethod) initNextLinesearch(loc *Location, xNext []float64) (EvaluationType, IterationType, error) {
	copy(ls.x, loc.X)

	var stepSize float64
	if ls.first {
		stepSize = ls.NextDirectioner.InitDirection(loc, ls.dir)
		ls.first = false
	} else {
		stepSize = ls.NextDirectioner.NextDirection(loc, ls.dir)
	}

	projGrad := floats.Dot(loc.Gradient, ls.dir)
	if projGrad >= 0 {
		ls.evalType = NoEvaluation
		ls.iterType = NoIteration
		return ls.evalType, ls.iterType, ErrNonNegativeStepDirection
	}

	ls.evalType = ls.Linesearcher.Init(loc.F, projGrad, stepSize)

	floats.AddScaledTo(xNext, ls.x, stepSize, ls.dir)
	// Compare the starting point for the current iteration with the next
	// evaluation point to make sure that rounding errors do not prevent progress.
	if floats.Equal(ls.x, xNext) {
		ls.evalType = NoEvaluation
		ls.iterType = NoIteration
		return ls.evalType, ls.iterType, ErrNoProgress
	}

	ls.iterType = MinorIteration
	return ls.evalType, ls.iterType, nil
}
开发者ID:jmptrader,项目名称:optimize,代码行数:32,代码来源:linesearch.go


示例4: initNextLinesearch

// initNextLinesearch initializes the next linesearch using the previous
// complete location stored in loc. It fills loc.X and returns an evaluation
// to be performed at loc.X.
func (ls *LinesearchMethod) initNextLinesearch(loc *Location) (Operation, error) {
	copy(ls.x, loc.X)

	var step float64
	if ls.first {
		ls.first = false
		step = ls.NextDirectioner.InitDirection(loc, ls.dir)
	} else {
		step = ls.NextDirectioner.NextDirection(loc, ls.dir)
	}

	projGrad := floats.Dot(loc.Gradient, ls.dir)
	if projGrad >= 0 {
		return ls.error(ErrNonNegativeStepDirection)
	}

	op := ls.Linesearcher.Init(loc.F, projGrad, step)
	if !op.isEvaluation() {
		panic("linesearch: Linesearcher returned invalid operation")
	}

	floats.AddScaledTo(loc.X, ls.x, step, ls.dir)
	if floats.Equal(ls.x, loc.X) {
		// Step size is so small that the next evaluation point is
		// indistinguishable from the starting point for the current iteration
		// due to rounding errors.
		return ls.error(ErrNoProgress)
	}

	ls.lastStep = step
	ls.eval = NoOperation // Invalidate all fields of loc.

	ls.lastOp = op
	return ls.lastOp, nil
}
开发者ID:jacobxk,项目名称:optimize,代码行数:38,代码来源:linesearch.go


示例5: evaluate

// evaluate evaluates the problem given by p at xNext, stores the answer into
// loc and updates stats. If loc.X is not equal to xNext, then unused fields of
// loc are set to NaN.
// evaluate panics if the function does not support the requested evalType.
func evaluate(p *Problem, evalType EvaluationType, xNext []float64, loc *Location, stats *Stats) {
	if !floats.Equal(loc.X, xNext) {
		if evalType == NoEvaluation {
			// Optimizers should not request NoEvaluation at a new location.
			// The intent and therefore an appropriate action are both unclear.
			panic("optimize: no evaluation requested at new location")
		}
		invalidate(loc)
		copy(loc.X, xNext)
	}

	toEval := evalType
	if evalType&FuncEvaluation != 0 {
		loc.F = p.Func(loc.X)
		stats.FuncEvaluations++
		toEval &= ^FuncEvaluation
	}
	if evalType&GradEvaluation != 0 {
		p.Grad(loc.X, loc.Gradient)
		stats.GradEvaluations++
		toEval &= ^GradEvaluation
	}
	if evalType&HessEvaluation != 0 {
		p.Hess(loc.X, loc.Hessian)
		stats.HessEvaluations++
		toEval &= ^HessEvaluation
	}

	if toEval != NoEvaluation {
		panic(fmt.Sprintf("optimize: unknown evaluation type %v", evalType))
	}
}
开发者ID:jmptrader,项目名称:optimize,代码行数:36,代码来源:local.go


示例6: DrsclTest

func DrsclTest(t *testing.T, impl Drscler) {
	for _, test := range []struct {
		x []float64
		a float64
	}{
		{
			x: []float64{1, 2, 3, 4, 5},
			a: 4,
		},
		{
			x: []float64{1, 2, 3, 4, 5},
			a: math.MaxFloat64,
		},
		{
			x: []float64{1, 2, 3, 4, 5},
			a: 1e-307,
		},
	} {
		xcopy := make([]float64, len(test.x))
		copy(xcopy, test.x)

		// Cannot test the scaling directly because of floating point scaling issues
		// (the purpose of Drscl). Instead, check that scaling and scaling back
		// yeilds approximately x. If overflow or underflow occurs then the scaling
		// won't match.
		impl.Drscl(len(test.x), test.a, xcopy, 1)
		if floats.Equal(xcopy, test.x) {
			t.Errorf("x unchanged during call to drscl. a = %v, x = %v.", test.a, test.x)
		}
		impl.Drscl(len(test.x), 1/test.a, xcopy, 1)
		if !floats.EqualApprox(xcopy, test.x, 1e-14) {
			t.Errorf("x not equal after scaling and unscaling. a = %v, x = %v.", test.a, test.x)
		}
	}
}
开发者ID:rawlingsj,项目名称:gofabric8,代码行数:35,代码来源:drscl.go


示例7: Grad

func (b *BatchGradient) Grad(params, deriv []float64) {
	if floats.Equal(params, b.lastParams) {
		copy(deriv, b.lastGrad)
		return
	}
	b.lastFunc = b.funcGrad(params, b.lastGrad)
	copy(deriv, b.lastGrad)
}
开发者ID:reggo,项目名称:reggo,代码行数:8,代码来源:combineloss_2.go


示例8: TestQuantile

func TestQuantile(t *testing.T) {
	cumulantKinds := []CumulantKind{Empirical}
	for i, test := range []struct {
		p   []float64
		x   []float64
		w   []float64
		ans [][]float64
	}{
		{
			p:   []float64{0, 0.05, 0.1, 0.15, 0.45, 0.5, 0.55, 0.85, 0.9, 0.95, 1},
			x:   []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
			w:   nil,
			ans: [][]float64{{1, 1, 1, 2, 5, 5, 6, 9, 9, 10, 10}},
		},
		{
			p:   []float64{0, 0.05, 0.1, 0.15, 0.45, 0.5, 0.55, 0.85, 0.9, 0.95, 1},
			x:   []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
			w:   []float64{3, 3, 3, 3, 3, 3, 3, 3, 3, 3},
			ans: [][]float64{{1, 1, 1, 2, 5, 5, 6, 9, 9, 10, 10}},
		},
	} {
		copyX := make([]float64, len(test.x))
		copy(copyX, test.x)
		var copyW []float64
		if test.w != nil {
			copyW = make([]float64, len(test.w))
			copy(copyW, test.w)
		}
		for j, p := range test.p {
			for k, kind := range cumulantKinds {
				v := Quantile(p, kind, test.x, test.w)
				if !floats.Equal(copyX, test.x) {
					t.Errorf("x changed for case %d kind %d percentile %v", i, k, p)
				}
				if !floats.Equal(copyW, test.w) {
					t.Errorf("x changed for case %d kind %d percentile %v", i, k, p)
				}
				if v != test.ans[k][j] {
					t.Errorf("mismatch case %d kind %d percentile %v. Expected: %v, found: %v", i, k, p, test.ans[k][j], v)
				}
			}
		}
	}
}
开发者ID:cjslep,项目名称:stat,代码行数:44,代码来源:stat_test.go


示例9: TestCDF

func TestCDF(t *testing.T) {
	cumulantKinds := []CumulantKind{Empirical}
	for i, test := range []struct {
		q       []float64
		x       []float64
		weights []float64
		ans     [][]float64
	}{
		{},
		{
			q:   []float64{0, 0.9, 1, 1.1, 2.9, 3, 3.1, 4.9, 5, 5.1},
			x:   []float64{1, 2, 3, 4, 5},
			ans: [][]float64{{0, 0, 0.2, 0.2, 0.4, 0.6, 0.6, 0.8, 1, 1}},
		},
		{
			q:       []float64{0, 0.9, 1, 1.1, 2.9, 3, 3.1, 4.9, 5, 5.1},
			x:       []float64{1, 2, 3, 4, 5},
			weights: []float64{1, 1, 1, 1, 1},
			ans:     [][]float64{{0, 0, 0.2, 0.2, 0.4, 0.6, 0.6, 0.8, 1, 1}},
		},
	} {
		copyX := make([]float64, len(test.x))
		copy(copyX, test.x)
		var copyW []float64
		if test.weights != nil {
			copyW = make([]float64, len(test.weights))
			copy(copyW, test.weights)
		}
		for j, q := range test.q {
			for k, kind := range cumulantKinds {
				v := CDF(q, kind, test.x, test.weights)
				if !floats.Equal(copyX, test.x) {
					t.Errorf("x changed for case %d kind %d percentile %v", i, k, q)
				}
				if !floats.Equal(copyW, test.weights) {
					t.Errorf("x changed for case %d kind %d percentile %v", i, k, q)
				}
				if v != test.ans[k][j] {
					t.Errorf("mismatch case %d kind %d percentile %v. Expected: %v, found: %v", i, k, q, test.ans[k][j], v)
				}
			}
		}
	}
}
开发者ID:cjslep,项目名称:stat,代码行数:44,代码来源:stat_test.go


示例10: TestFlatten2D

func TestFlatten2D(t *testing.T) {

	s2d := [][]float64{{11, 22}, {33, 44}, {55, 66}}
	expected := []float64{11, 22, 33, 44, 55, 66}

	flatten := Flatten2D(s2d)
	if !floats.Equal(flatten, expected) {
		t.Fatalf("Flatten failed. expected %+v, got %+v", expected, flatten)
	}
}
开发者ID:henrylee2cn,项目名称:gjoa,代码行数:10,代码来源:floatx_test.go


示例11: Iterate

func (ls *LinesearchMethod) Iterate(loc *Location, xNext []float64) (EvaluationType, IterationType, error) {
	if ls.iterType == SubIteration {
		// We needed to evaluate invalid fields of Location. Now we have them
		// and can announce MajorIteration.
		copy(xNext, loc.X)
		ls.evalType = NoEvaluation
		ls.iterType = MajorIteration
		return ls.evalType, ls.iterType, nil
	}

	if ls.iterType == MajorIteration {
		// The linesearch previously signaled MajorIteration. Since we're here,
		// it means that the previous location is not good enough to converge,
		// so start the next linesearch.
		return ls.initNextLinesearch(loc, xNext)
	}

	projGrad := floats.Dot(loc.Gradient, ls.dir)
	if ls.Linesearcher.Finished(loc.F, projGrad) {
		copy(xNext, loc.X)
		// Check if the last evaluation evaluated all fields of Location.
		ls.evalType = complementEval(loc, ls.evalType)
		if ls.evalType == NoEvaluation {
			// Location is complete and MajorIteration can be announced directly.
			ls.iterType = MajorIteration
		} else {
			// Location is not complete, evaluate its invalid fields in SubIteration.
			ls.iterType = SubIteration
		}
		return ls.evalType, ls.iterType, nil
	}

	// Line search not done, just iterate.
	stepSize, evalType, err := ls.Linesearcher.Iterate(loc.F, projGrad)
	if err != nil {
		ls.evalType = NoEvaluation
		ls.iterType = NoIteration
		return ls.evalType, ls.iterType, err
	}

	floats.AddScaledTo(xNext, ls.x, stepSize, ls.dir)
	// Compare the starting point for the current iteration with the next
	// evaluation point to make sure that rounding errors do not prevent progress.
	if floats.Equal(ls.x, xNext) {
		ls.evalType = NoEvaluation
		ls.iterType = NoIteration
		return ls.evalType, ls.iterType, ErrNoProgress
	}

	ls.evalType = evalType
	ls.iterType = MinorIteration
	return ls.evalType, ls.iterType, nil
}
开发者ID:jmptrader,项目名称:optimize,代码行数:53,代码来源:linesearch.go


示例12: TestLegendreSingle

func TestLegendreSingle(t *testing.T) {
	for c, test := range []struct {
		n        int
		min, max float64
	}{
		{
			n:   100,
			min: -1,
			max: 1,
		},
		{
			n:   50,
			min: -3,
			max: -1,
		},
		{
			n:   1000,
			min: 2,
			max: 7,
		},
	} {
		l := Legendre{}
		n := test.n
		xs := make([]float64, n)
		weights := make([]float64, n)
		l.FixedLocations(xs, weights, test.min, test.max)

		xsSingle := make([]float64, n)
		weightsSingle := make([]float64, n)
		for i := range xsSingle {
			xsSingle[i], weightsSingle[i] = l.FixedLocationSingle(n, i, test.min, test.max)
		}
		if !floats.Equal(xs, xsSingle) {
			t.Errorf("Case %d: xs mismatch batch and single", c)
		}
		if !floats.Equal(weights, weightsSingle) {
			t.Errorf("Case %d: weights mismatch batch and single", c)
		}
	}
}
开发者ID:sbinet,项目名称:gonum-integrate,代码行数:40,代码来源:legendre_test.go


示例13: denseEqual

func denseEqual(a *Dense, acomp matComp) bool {
	ar2, ac2 := a.Dims()
	if ar2 != acomp.r {
		return false
	}
	if ac2 != acomp.c {
		return false
	}
	if !floats.Equal(a.mat.Data, acomp.data) {
		return false
	}
	return true
}
开发者ID:RomainVabre,项目名称:origin,代码行数:13,代码来源:mul_test.go


示例14: TestCopy2D

func TestCopy2D(t *testing.T) {

	s1 := [][]float64{{11, 22}, {33, 44}, {55, 66}}

	s2 := CopyFloat2D(s1)

	for k, _ := range s1 {
		if &s1[k] == &s2[k] {
			t.Fatalf("Slices have the same address, not a copy.")
		}
		if !floats.Equal(s1[k], s2[k]) {
			t.Fatalf("Copy failed. want: %+v, have: %+v", s1, s2)
		}
	}
}
开发者ID:henrylee2cn,项目名称:gjoa,代码行数:15,代码来源:floatx_test.go


示例15: TestPredictFeaturized

func TestPredictFeaturized(t *testing.T) {
	for _, test := range []struct {
		z              []float64
		featureWeights [][]float64
		output         []float64
		Name           string
	}{
		{
			Name: "General",
			z:    []float64{1, 2, 3},
			featureWeights: [][]float64{
				{3, 4},
				{1, 2},
				{0.5, 0.4},
			},
			output: []float64{6.5, 9.2},
		},
	} {
		zCopy := make([]float64, len(test.z))
		copy(zCopy, test.z)
		fwMat := flatten(test.featureWeights)
		fwMatCopy := &mat64.Dense{}
		fwMatCopy.Clone(fwMat)

		output := make([]float64, len(test.output))

		predictFeaturized(zCopy, fwMat, output)

		// Test that z wasn't changed
		if !floats.Equal(test.z, zCopy) {
			t.Errorf("z changed during call")
		}

		if !floats.EqualApprox(output, test.output, 1e-14) {
			t.Errorf("output doesn't match for test %v. Expected %v, found %v", test.Name, test.output, output)
		}
	}
}
开发者ID:reggo,项目名称:reggo,代码行数:38,代码来源:kitchensink_test.go


示例16: TestHistogram

func TestHistogram(t *testing.T) {
	for i, test := range []struct {
		x        []float64
		weights  []float64
		dividers []float64
		ans      []float64
	}{
		{
			x:        []float64{1, 3, 5, 6, 7, 8},
			dividers: []float64{2, 4, 6, 7},
			ans:      []float64{1, 1, 1, 1, 2},
		},
		{
			x:        []float64{1, 3, 5, 6, 7, 8},
			dividers: []float64{2, 4, 6, 7},
			weights:  []float64{1, 2, 1, 1, 1, 2},
			ans:      []float64{1, 2, 1, 1, 3},
		},
		{
			x:        []float64{1, 8},
			dividers: []float64{2, 4, 6, 7},
			weights:  []float64{1, 2},
			ans:      []float64{1, 0, 0, 0, 2},
		},
		{
			x:        []float64{1, 8},
			dividers: []float64{2, 4, 6, 7},
			ans:      []float64{1, 0, 0, 0, 1},
		},
	} {
		hist := Histogram(nil, test.dividers, test.x, test.weights)
		if !floats.Equal(hist, test.ans) {
			t.Errorf("Hist mismatch case %d. Expected %v, Found %v", i, test.ans, hist)
		}
	}
}
开发者ID:cjslep,项目名称:stat,代码行数:36,代码来源:stat_test.go


示例17: Dorml2Test


//.........这里部分代码省略.........
				{3, 5, 4, 20, 6},
				{4, 3, 5, 20, 6},
				{4, 5, 3, 20, 6},
				{5, 3, 4, 20, 6},
				{5, 4, 3, 20, 6},
			} {
				var ma, na, mc, nc int
				if side == blas.Left {
					ma = test.adim
					na = test.common
					mc = test.common
					nc = test.cdim
				} else {
					ma = test.adim
					na = test.common
					mc = test.cdim
					nc = test.common
				}
				// Generate a random matrix
				lda := test.lda
				if lda == 0 {
					lda = na
				}
				a := make([]float64, ma*lda)
				for i := range a {
					a[i] = rnd.Float64()
				}
				ldc := test.ldc
				if ldc == 0 {
					ldc = nc
				}
				// Compute random C matrix
				c := make([]float64, mc*ldc)
				for i := range c {
					c[i] = rnd.Float64()
				}

				// Compute LQ
				k := min(ma, na)
				tau := make([]float64, k)
				work := make([]float64, 1)
				impl.Dgelqf(ma, na, a, lda, tau, work, -1)
				work = make([]float64, int(work[0]))
				impl.Dgelqf(ma, na, a, lda, tau, work, len(work))

				// Build Q from result
				q := constructQ("LQ", ma, na, a, lda, tau)

				cMat := blas64.General{
					Rows:   mc,
					Cols:   nc,
					Stride: ldc,
					Data:   make([]float64, len(c)),
				}
				copy(cMat.Data, c)
				cMatCopy := blas64.General{
					Rows:   cMat.Rows,
					Cols:   cMat.Cols,
					Stride: cMat.Stride,
					Data:   make([]float64, len(cMat.Data)),
				}
				copy(cMatCopy.Data, cMat.Data)
				switch {
				default:
					panic("bad test")
				case side == blas.Left && trans == blas.NoTrans:
					blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, q, cMatCopy, 0, cMat)
				case side == blas.Left && trans == blas.Trans:
					blas64.Gemm(blas.Trans, blas.NoTrans, 1, q, cMatCopy, 0, cMat)
				case side == blas.Right && trans == blas.NoTrans:
					blas64.Gemm(blas.NoTrans, blas.NoTrans, 1, cMatCopy, q, 0, cMat)
				case side == blas.Right && trans == blas.Trans:
					blas64.Gemm(blas.NoTrans, blas.Trans, 1, cMatCopy, q, 0, cMat)
				}
				// Do Dorm2r ard compare
				if side == blas.Left {
					work = make([]float64, nc)
				} else {
					work = make([]float64, mc)
				}
				aCopy := make([]float64, len(a))
				copy(aCopy, a)
				tauCopy := make([]float64, len(tau))
				copy(tauCopy, tau)
				impl.Dorml2(side, trans, mc, nc, k, a, lda, tau, c, ldc, work)
				if !floats.Equal(a, aCopy) {
					t.Errorf("a changed in call")
				}
				if !floats.Equal(tau, tauCopy) {
					t.Errorf("tau changed in call")
				}
				if !floats.EqualApprox(cMat.Data, c, 1e-14) {
					isLeft := side == blas.Left
					isTrans := trans == blas.Trans
					t.Errorf("Multiplication mismatch. IsLeft = %v. IsTrans = %v", isLeft, isTrans)
				}
			}
		}
	}
}
开发者ID:rawlingsj,项目名称:gofabric8,代码行数:101,代码来源:dorml2.go


示例18: TestGradient

func TestGradient(t *testing.T) {
	for i, test := range []struct {
		nDim   int
		tol    float64
		method Method
	}{
		{
			nDim:   2,
			tol:    2e-4,
			method: Forward,
		},
		{
			nDim:   2,
			tol:    1e-6,
			method: Central,
		},
		{
			nDim:   40,
			tol:    2e-4,
			method: Forward,
		},
		{
			nDim:   40,
			tol:    1e-6,
			method: Central,
		},
	} {
		x := make([]float64, test.nDim)
		for i := range x {
			x[i] = rand.Float64()
		}
		xcopy := make([]float64, len(x))
		copy(xcopy, x)

		r := Rosenbrock{len(x)}
		trueGradient := make([]float64, len(x))
		r.FDf(x, trueGradient)

		settings := DefaultSettings()
		settings.Method = test.method

		// try with gradient nil
		gradient := Gradient(nil, r.F, x, settings)
		if !floats.EqualApprox(gradient, trueGradient, test.tol) {
			t.Errorf("Case %v: gradient mismatch in serial with nil. Want: %v, Got: %v.", i, trueGradient, gradient)
		}
		if !floats.Equal(x, xcopy) {
			t.Errorf("Case %v: x modified during call to gradient in serial with nil.", i)
		}

		for i := range gradient {
			gradient[i] = rand.Float64()
		}

		Gradient(gradient, r.F, x, settings)
		if !floats.EqualApprox(gradient, trueGradient, test.tol) {
			t.Errorf("Case %v: gradient mismatch in serial. Want: %v, Got: %v.", i, trueGradient, gradient)
		}
		if !floats.Equal(x, xcopy) {
			t.Errorf("Case %v: x modified during call to gradient in serial with non-nil.", i)
		}

		// Try with known value
		for i := range gradient {
			gradient[i] = rand.Float64()
		}
		settings.OriginKnown = true
		settings.OriginValue = r.F(x)
		Gradient(gradient, r.F, x, settings)
		if !floats.EqualApprox(gradient, trueGradient, test.tol) {
			t.Errorf("Case %v: gradient mismatch with known origin in serial. Want: %v, Got: %v.", i, trueGradient, gradient)
		}

		// Concurrently
		for i := range gradient {
			gradient[i] = rand.Float64()
		}
		settings.Concurrent = true
		settings.OriginKnown = false
		settings.Workers = 1000
		Gradient(gradient, r.F, x, settings)
		if !floats.EqualApprox(gradient, trueGradient, test.tol) {
			t.Errorf("Case %v: gradient mismatch with unknown origin in parallel. Want: %v, Got: %v.", i, trueGradient, gradient)
		}
		if !floats.Equal(x, xcopy) {
			t.Errorf("Case %v: x modified during call to gradient in parallel", i)
		}

		// Concurrently with origin known
		for i := range gradient {
			gradient[i] = rand.Float64()
		}
		settings.OriginKnown = true
		Gradient(gradient, r.F, x, settings)
		if !floats.EqualApprox(gradient, trueGradient, test.tol) {
			t.Errorf("Case %v: gradient mismatch with known origin in parallel. Want: %v, Got: %v.", i, trueGradient, gradient)
		}

		// With default settings
		for i := range gradient {
//.........这里部分代码省略.........
开发者ID:sbinet,项目名称:gonum-diff,代码行数:101,代码来源:gradient_test.go


示例19: Iterate

func (ls *LinesearchMethod) Iterate(loc *Location) (Operation, error) {
	switch ls.lastOp {
	case NoOperation:
		// TODO(vladimir-ch): Either Init has not been called, or the caller is
		// trying to resume the optimization run after Iterate previously
		// returned with an error. Decide what is the proper thing to do. See also #125.

	case MajorIteration:
		// The previous updated location did not converge the full
		// optimization. Initialize a new Linesearch.
		return ls.initNextLinesearch(loc)

	default:
		// Update the indicator of valid fields of loc.
		ls.eval |= ls.lastOp

		if ls.nextMajor {
			ls.nextMajor = false

			// Linesearcher previously finished, and the invalid fields of loc
			// have now been validated. Announce MajorIteration.
			ls.lastOp = MajorIteration
			return ls.lastOp, nil
		}
	}

	// Continue the linesearch.

	f := math.NaN()
	if ls.eval&FuncEvaluation != 0 {
		f = loc.F
	}
	projGrad := math.NaN()
	if ls.eval&GradEvaluation != 0 {
		projGrad = floats.Dot(loc.Gradient, ls.dir)
	}
	op, step, err := ls.Linesearcher.Iterate(f, projGrad)
	if err != nil {
		return ls.error(err)
	}

	switch op {
	case MajorIteration:
		// Linesearch has been finished.

		ls.lastOp = complementEval(loc, ls.eval)
		if ls.lastOp == NoOperation {
			// loc is complete, MajorIteration can be declared directly.
			ls.lastOp = MajorIteration
		} else {
			// Declare MajorIteration on the next call to Iterate.
			ls.nextMajor = true
		}

	case FuncEvaluation, GradEvaluation, FuncEvaluation | GradEvaluation:
		if step != ls.lastStep {
			// We are moving to a new location, and not, say, evaluating extra
			// information at the current location.

			// Compute the next evaluation point and store it in loc.X.
			floats.AddScaledTo(loc.X, ls.x, step, ls.dir)
			if floats.Equal(ls.x, loc.X) {
				// Step size has become so small that the next evaluation point is
				// indistinguishable from the starting point for the current
				// iteration due to rounding errors.
				return ls.error(ErrNoProgress)
			}
			ls.lastStep = step
			ls.eval = NoOperation // Indicate all invalid fields of loc.
		}
		ls.lastOp = op

	default:
		panic("linesearch: Linesearcher returned invalid operation")
	}

	return ls.lastOp, nil
}
开发者ID:jgcarvalho,项目名称:zdd,代码行数:78,代码来源:linesearch.go


示例20: marginalLikelihoodDerivative

func (g *GP) marginalLikelihoodDerivative(x, grad []float64, trainNoise bool, mem *margLikeMemory) {
	// d/dTheta_j log[(p|X,theta)] =
	//		1/2 * y^T * K^-1 dK/dTheta_j * K^-1 * y - 1/2 * tr(K^-1 * dK/dTheta_j)
	//		1/2 * α^T * dK/dTheta_j * α - 1/2 * tr(K^-1 dK/dTheta_j)
	// Multiply by the same -2
	//		-α^T * K^-1 * α + tr(K^-1 dK/dTheta_j)
	// This first computation is an inner product.
	n := len(g.outputs)
	nHyper := g.kernel.NumHyper()
	k := mem.k
	chol := mem.chol
	alpha := mem.alpha
	dKdTheta := mem.dKdTheta
	kInvDK := mem.kInvDK

	y := mat64.NewVector(n, g.outputs)

	var noise float64
	if trainNoise {
		noise = math.Exp(x[len(x)-1])
	} else {
		noise = g.noise
	}

	// If x is the same, then reuse what has been computed in the function.
	if !floats.Equal(mem.lastX, x) {
		copy(mem.lastX, x)
		g.kernel.SetHyper(x[:nHyper])
		g.setKernelMat(k, noise)
		//chol.Cholesky(k, false)
		chol.Factorize(k)
		alpha.SolveCholeskyVec(chol, y)
	}
	g.setKernelMatDeriv(dKdTheta, trainNoise, noise)
	for i := range dKdTheta {
		kInvDK.SolveCholesky(chol, dKdTheta[i])
		inner := mat64.Inner(alpha, dKdTheta[i], alpha)
		grad[i] = -inner + mat64.Trace(kInvDK)
	}
	floats.Scale(1/float64(n), grad)

	bounds := g.kernel.Bounds()
	if trainNoise {
		bounds = append(bounds, Bound{minLogNoise, maxLogNoise})
	}
	barrierGrad := make([]float64, len(grad))
	for i, v := range x {
		// Quadratic barrier penalty.
		if v < bounds[i].Min {
			diff := bounds[i].Min - v
			barrierGrad[i] = -(barrierPow) * math.Pow(diff, barrierPow-1)
		}
		if v > bounds[i].Max {
			diff := v - bounds[i].Max
			barrierGrad[i] = (barrierPow) * math.Pow(diff, barrierPow-1)
		}
	}
	fmt.Println("noise, minNoise", x[len(x)-1], bounds[len(x)-1].Min)
	fmt.Println("barrier Grad", barrierGrad)
	floats.Add(grad, barrierGrad)
	//copy(grad, barrierGrad)
}
开发者ID:btracey,项目名称:gaussproc,代码行数:62,代码来源:gp.go



注:本文中的github.com/gonum/floats.Equal函数示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Golang floats.EqualApprox函数代码示例发布时间:2022-05-23
下一篇:
Golang floats.Dot函数代码示例发布时间:2022-05-23
热门推荐
热门话题
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap