• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Golang base.GetClass函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Golang中github.com/sjwhitworth/golearn/base.GetClass函数的典型用法代码示例。如果您正苦于以下问题:Golang GetClass函数的具体用法?Golang GetClass怎么用?Golang GetClass使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了GetClass函数的16个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Golang代码示例。

示例1: TestLinearRegression

func TestLinearRegression(t *testing.T) {
	lr := NewLinearRegression()

	rawData, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
	if err != nil {
		t.Fatal(err)
	}

	trainData, testData := base.InstancesTrainTestSplit(rawData, 0.1)
	err = lr.Fit(trainData)
	if err != nil {
		t.Fatal(err)
	}

	predictions, err := lr.Predict(testData)
	if err != nil {
		t.Fatal(err)
	}

	_, rows := predictions.Size()

	for i := 0; i < rows; i++ {
		fmt.Printf("Expected: %s || Predicted: %s\n", base.GetClass(testData, i), base.GetClass(predictions, i))
	}
}
开发者ID:JacobXie,项目名称:golearn,代码行数:25,代码来源:linear_regression_test.go


示例2: TestKnnClassifier

func TestKnnClassifier(t *testing.T) {
	Convey("Given labels, a classifier and data", t, func() {
		trainingData, err := base.ParseCSVToInstances("knn_train.csv", false)
		So(err, ShouldBeNil)

		testingData, err := base.ParseCSVToInstances("knn_test.csv", false)
		So(err, ShouldBeNil)

		cls := NewKnnClassifier("euclidean", 2)
		cls.Fit(trainingData)
		predictions := cls.Predict(testingData)
		So(predictions, ShouldNotEqual, nil)

		Convey("When predicting the label for our first vector", func() {
			result := base.GetClass(predictions, 0)
			Convey("The result should be 'blue", func() {
				So(result, ShouldEqual, "blue")
			})
		})

		Convey("When predicting the label for our second vector", func() {
			result2 := base.GetClass(predictions, 1)
			Convey("The result should be 'red", func() {
				So(result2, ShouldEqual, "red")
			})
		})
	})
}
开发者ID:GeekFreaker,项目名称:golearn,代码行数:28,代码来源:knn_test.go


示例3: TestLinearRegression

func TestLinearRegression(t *testing.T) {
	Convey("Doing a  linear regression", t, func() {
		lr := NewLinearRegression()

		Convey("With no training data", func() {
			Convey("Predicting", func() {
				testData, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
				So(err, ShouldBeNil)

				_, err = lr.Predict(testData)

				Convey("Should result in a NoTrainingDataError", func() {
					So(err, ShouldEqual, NoTrainingDataError)
				})

			})
		})

		Convey("With not enough training data", func() {
			trainingDatum, err := base.ParseCSVToInstances("../examples/datasets/exam.csv", true)
			So(err, ShouldBeNil)

			Convey("Fitting", func() {
				err = lr.Fit(trainingDatum)

				Convey("Should result in a NotEnoughDataError", func() {
					So(err, ShouldEqual, NotEnoughDataError)
				})
			})
		})

		Convey("With sufficient training data", func() {
			instances, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
			So(err, ShouldBeNil)
			trainData, testData := base.InstancesTrainTestSplit(instances, 0.1)

			Convey("Fitting and Predicting", func() {
				err := lr.Fit(trainData)
				So(err, ShouldBeNil)

				predictions, err := lr.Predict(testData)
				So(err, ShouldBeNil)

				Convey("It makes reasonable predictions", func() {
					_, rows := predictions.Size()

					for i := 0; i < rows; i++ {
						actualValue, _ := strconv.ParseFloat(base.GetClass(testData, i), 64)
						expectedValue, _ := strconv.ParseFloat(base.GetClass(predictions, i), 64)

						So(actualValue, ShouldAlmostEqual, expectedValue, actualValue*0.05)
					}
				})
			})
		})
	})
}
开发者ID:CTLife,项目名称:golearn,代码行数:57,代码来源:linear_regression_test.go


示例4: TestLayeredXORInline

func TestLayeredXORInline(t *testing.T) {

	Convey("Given an inline XOR dataset...", t, func() {

		data := mat64.NewDense(4, 3, []float64{
			1, 0, 1,
			0, 1, 1,
			0, 0, 0,
			1, 1, 0,
		})

		XORData := base.InstancesFromMat64(4, 3, data)
		classAttr := base.GetAttributeByName(XORData, "2")
		XORData.AddClassAttribute(classAttr)

		net := NewMultiLayerNet([]int{3})
		net.MaxIterations = 20000
		net.Fit(XORData)

		Convey("After running for 20000 iterations, should have some predictive power...", func() {

			Convey("The right nodes should be connected in the network...", func() {
				So(net.network.GetWeight(1, 1), ShouldAlmostEqual, 1.000)
				So(net.network.GetWeight(2, 2), ShouldAlmostEqual, 1.000)

				for i := 1; i <= 6; i++ {
					So(net.network.GetWeight(6, i), ShouldAlmostEqual, 0.000)
				}

			})
			out := mat64.NewDense(6, 1, []float64{1.0, 0.0, 0.0, 0.0, 0.0, 0.0})
			net.network.Activate(out, 2)
			So(out.At(5, 0), ShouldAlmostEqual, 1.0, 0.1)

			Convey("And Predict() should do OK too...", func() {

				pred := net.Predict(XORData)

				for _, a := range pred.AllAttributes() {
					af, ok := a.(*base.FloatAttribute)
					So(ok, ShouldBeTrue)

					af.Precision = 1
				}

				So(base.GetClass(pred, 0), ShouldEqual, "1.0")
				So(base.GetClass(pred, 1), ShouldEqual, "1.0")
				So(base.GetClass(pred, 2), ShouldEqual, "0.0")
				So(base.GetClass(pred, 3), ShouldEqual, "0.0")

			})
		})

	})

}
开发者ID:thedadams,项目名称:golearn,代码行数:56,代码来源:layered_test.go


示例5: TestLayeredXOR

func TestLayeredXOR(t *testing.T) {

	Convey("Given an XOR dataset...", t, func() {

		XORData, err := base.ParseCSVToInstances("xor.csv", false)
		So(err, ShouldEqual, nil)

		fmt.Println(XORData)
		net := NewMultiLayerNet([]int{3})
		net.MaxIterations = 20000
		net.Fit(XORData)

		Convey("After running for 20000 iterations, should have some predictive power...", func() {

			Convey("The right nodes should be connected in the network...", func() {

				fmt.Println(net.network)
				So(net.network.GetWeight(1, 1), ShouldAlmostEqual, 1.000)
				So(net.network.GetWeight(2, 2), ShouldAlmostEqual, 1.000)

				for i := 1; i <= 6; i++ {
					So(net.network.GetWeight(6, i), ShouldAlmostEqual, 0.000)
				}

			})
			out := mat64.NewDense(6, 1, []float64{1.0, 0.0, 0.0, 0.0, 0.0, 0.0})
			net.network.Activate(out, 2)
			fmt.Println(out)
			So(out.At(5, 0), ShouldAlmostEqual, 1.0, 0.1)

			Convey("And Predict() should do OK too...", func() {

				pred := net.Predict(XORData)

				for _, a := range pred.AllAttributes() {
					af, ok := a.(*base.FloatAttribute)
					if !ok {
						panic("All of these should be FloatAttributes!")
					}
					af.Precision = 1
				}

				So(base.GetClass(pred, 0), ShouldEqual, "0.0")
				So(base.GetClass(pred, 1), ShouldEqual, "1.0")
				So(base.GetClass(pred, 2), ShouldEqual, "1.0")
				So(base.GetClass(pred, 3), ShouldEqual, "0.0")

			})
		})

	})

}
开发者ID:JacobXie,项目名称:golearn,代码行数:53,代码来源:layered_test.go


示例6: ChiMBuildFrequencyTable

func ChiMBuildFrequencyTable(attr base.Attribute, inst base.FixedDataGrid) []*FrequencyTableEntry {
	ret := make([]*FrequencyTableEntry, 0)
	attribute := attr.(*base.FloatAttribute)

	attrSpec, err := inst.GetAttribute(attr)
	if err != nil {
		panic(err)
	}
	attrSpecs := []base.AttributeSpec{attrSpec}

	err = inst.MapOverRows(attrSpecs, func(row [][]byte, rowNo int) (bool, error) {
		value := row[0]
		valueConv := attribute.GetFloatFromSysVal(value)
		class := base.GetClass(inst, rowNo)
		// Search the frequency table for the value
		found := false
		for _, entry := range ret {
			if entry.Value == valueConv {
				found = true
				entry.Frequency[class] += 1
			}
		}
		if !found {
			newEntry := &FrequencyTableEntry{
				valueConv,
				make(map[string]int),
			}
			newEntry.Frequency[class] = 1
			ret = append(ret, newEntry)
		}
		return true, nil
	})

	return ret
}
开发者ID:Gudym,项目名称:golearn,代码行数:35,代码来源:chimerge_funcs.go


示例7: vote

func (KNN *KNNClassifier) vote(maxmap map[string]int, values []int) string {
	// Reset maxMap
	for a := range maxmap {
		maxmap[a] = 0
	}

	// Refresh maxMap
	for _, elem := range values {
		label := base.GetClass(KNN.TrainingData, elem)
		if _, ok := maxmap[label]; ok {
			maxmap[label]++
		} else {
			maxmap[label] = 1
		}
	}

	// Sort the maxMap
	var maxClass string
	maxVal := -1
	for a := range maxmap {
		if maxmap[a] > maxVal {
			maxVal = maxmap[a]
			maxClass = a
		}
	}
	return maxClass
}
开发者ID:nickpoorman,项目名称:golearn,代码行数:27,代码来源:knn.go


示例8: getNumericAttributeEntropy

func getNumericAttributeEntropy(f base.FixedDataGrid, attr *base.FloatAttribute) (float64, float64) {

	// Resolve Attribute
	attrSpec, err := f.GetAttribute(attr)
	if err != nil {
		panic(err)
	}

	// Build sortable vector
	_, rows := f.Size()
	refs := make([]numericSplitRef, rows)
	f.MapOverRows([]base.AttributeSpec{attrSpec}, func(val [][]byte, row int) (bool, error) {
		cls := base.GetClass(f, row)
		v := base.UnpackBytesToFloat(val[0])
		refs[row] = numericSplitRef{v, cls}
		return true, nil
	})

	// Sort
	sort.Sort(splitVec(refs))

	generateCandidateSplitDistribution := func(val float64) map[string]map[string]int {
		presplit := make(map[string]int)
		postplit := make(map[string]int)
		for _, i := range refs {
			if i.val < val {
				presplit[i.class]++
			} else {
				postplit[i.class]++
			}
		}
		ret := make(map[string]map[string]int)
		ret["0"] = presplit
		ret["1"] = postplit
		return ret
	}

	minSplitEntropy := math.Inf(1)
	minSplitVal := math.Inf(1)
	// Consider each possible function
	for i := 0; i < len(refs)-1; i++ {
		val := refs[i].val + refs[i+1].val
		val /= 2
		splitDist := generateCandidateSplitDistribution(val)
		splitEntropy := getSplitEntropy(splitDist)
		if splitEntropy < minSplitEntropy {
			minSplitEntropy = splitEntropy
			minSplitVal = val
		}
	}

	return minSplitEntropy, minSplitVal
}
开发者ID:CTLife,项目名称:golearn,代码行数:53,代码来源:entropy.go


示例9: GetConfusionMatrix

// GetConfusionMatrix builds a ConfusionMatrix from a set of reference (`ref')
// and generate (`gen') Instances.
func GetConfusionMatrix(ref base.FixedDataGrid, gen base.FixedDataGrid) (map[string]map[string]int, error) {
	_, refRows := ref.Size()
	_, genRows := gen.Size()

	if refRows != genRows {
		return nil, errors.New(fmt.Sprintf("Row count mismatch: ref has %d rows, gen has %d rows", refRows, genRows))
	}

	ret := make(map[string]map[string]int)

	for i := 0; i < int(refRows); i++ {
		referenceClass := base.GetClass(ref, i)
		predictedClass := base.GetClass(gen, i)
		if _, ok := ret[referenceClass]; ok {
			ret[referenceClass][predictedClass] += 1
		} else {
			ret[referenceClass] = make(map[string]int)
			ret[referenceClass][predictedClass] = 1
		}
	}
	return ret, nil
}
开发者ID:CTLife,项目名称:golearn,代码行数:24,代码来源:confusion.go


示例10: GetConfusionMatrix

// GetConfusionMatrix builds a ConfusionMatrix from a set of reference (`ref')
// and generate (`gen') Instances.
func GetConfusionMatrix(ref base.FixedDataGrid, gen base.FixedDataGrid) map[string]map[string]int {

	_, refRows := ref.Size()
	_, genRows := gen.Size()

	if refRows != genRows {
		panic("Row counts should match")
	}

	ret := make(map[string]map[string]int)

	for i := 0; i < int(refRows); i++ {
		referenceClass := base.GetClass(ref, i)
		predictedClass := base.GetClass(gen, i)
		if _, ok := ret[referenceClass]; ok {
			ret[referenceClass][predictedClass] += 1
		} else {
			ret[referenceClass] = make(map[string]int)
			ret[referenceClass][predictedClass] = 1
		}
	}
	return ret
}
开发者ID:Gudym,项目名称:golearn,代码行数:25,代码来源:confusion.go


示例11: main

func main() {
	xOne, err := strconv.ParseInt(os.Args[1], 10, 64)
	xTwo, err := strconv.ParseInt(os.Args[2], 10, 64)
	fmt.Println("")

	if err != nil {
		panic(err)
	}

	XNORData, _ := base.ParseCSVToInstances("xnor.csv", false)

	net := neural.NewMultiLayerNet([]int{3})
	net.MaxIterations = 20000
	net.Fit(XNORData)

	pred := net.Predict(XNORData)

	var inputVectorIndex int
	if xOne == 0 && xTwo == 0 {
		inputVectorIndex = 0
		fmt.Println(base.GetClass(pred, inputVectorIndex))
	} else if xOne == 0 && xTwo == 1 {
		inputVectorIndex = 1
		fmt.Println(base.GetClass(pred, inputVectorIndex))
	} else if xOne == 1 && xTwo == 0 {
		inputVectorIndex = 2
		fmt.Println(base.GetClass(pred, inputVectorIndex))
	} else if xOne == 1 && xTwo == 1 {
		inputVectorIndex = 3
		fmt.Println(base.GetClass(pred, inputVectorIndex))
	} else {
		fmt.Println("Your input is incorrect. Quitting. ")
	}

	fmt.Println("")

}
开发者ID:nick11roberts,项目名称:neural-xnor,代码行数:37,代码来源:main.go


示例12: processData

func processData(x base.FixedDataGrid) instances {
	_, rows := x.Size()

	result := make(instances, rows)

	// Retrieve numeric non-class Attributes
	numericAttrs := base.NonClassFloatAttributes(x)
	numericAttrSpecs := base.ResolveAttributes(x, numericAttrs)

	// Retrieve class Attributes
	classAttrs := x.AllClassAttributes()
	if len(classAttrs) != 1 {
		panic("Only one classAttribute supported!")
	}

	// Check that the class Attribute is categorical
	// (with two values) or binary
	classAttr := classAttrs[0]
	if attr, ok := classAttr.(*base.CategoricalAttribute); ok {
		if len(attr.GetValues()) != 2 {
			panic("To many values for Attribute!")
		}
	} else if _, ok := classAttr.(*base.BinaryAttribute); ok {
	} else {
		panic("Wrong class Attribute type!")
	}

	// Convert each row
	x.MapOverRows(numericAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
		// Allocate a new row
		probRow := make([]float64, len(numericAttrSpecs))

		// Read out the row
		for i, _ := range numericAttrSpecs {
			probRow[i] = base.UnpackBytesToFloat(row[i])
		}

		// Get the class for the values
		class := base.GetClass(x, rowNo)
		instance := instance{class, probRow}
		result[rowNo] = instance
		return true, nil
	})
	return result
}
开发者ID:CTLife,项目名称:golearn,代码行数:45,代码来源:average.go


示例13: Predict

// Predict returns a classification for the vector, based on a vector input, using the KNN algorithm.
func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {

	// Check what distance function we are using
	var distanceFunc pairwise.PairwiseDistanceFunc
	switch KNN.DistanceFunc {
	case "euclidean":
		distanceFunc = pairwise.NewEuclidean()
	case "manhattan":
		distanceFunc = pairwise.NewManhattan()
	default:
		panic("unsupported distance function")

	}
	// Check Compatibility
	allAttrs := base.CheckCompatible(what, KNN.TrainingData)
	if allAttrs == nil {
		// Don't have the same Attributes
		return nil
	}

	// Remove the Attributes which aren't numeric
	allNumericAttrs := make([]base.Attribute, 0)
	for _, a := range allAttrs {
		if fAttr, ok := a.(*base.FloatAttribute); ok {
			allNumericAttrs = append(allNumericAttrs, fAttr)
		}
	}

	// Generate return vector
	ret := base.GeneratePredictionVector(what)

	// Resolve Attribute specifications for both
	whatAttrSpecs := base.ResolveAttributes(what, allNumericAttrs)
	trainAttrSpecs := base.ResolveAttributes(KNN.TrainingData, allNumericAttrs)

	// Reserve storage for most the most similar items
	distances := make(map[int]float64)

	// Reserve storage for voting map
	maxmap := make(map[string]int)

	// Reserve storage for row computations
	trainRowBuf := make([]float64, len(allNumericAttrs))
	predRowBuf := make([]float64, len(allNumericAttrs))

	// Iterate over all outer rows
	what.MapOverRows(whatAttrSpecs, func(predRow [][]byte, predRowNo int) (bool, error) {
		// Read the float values out
		for i, _ := range allNumericAttrs {
			predRowBuf[i] = base.UnpackBytesToFloat(predRow[i])
		}

		predMat := utilities.FloatsToMatrix(predRowBuf)

		// Find the closest match in the training data
		KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) {

			// Read the float values out
			for i, _ := range allNumericAttrs {
				trainRowBuf[i] = base.UnpackBytesToFloat(trainRow[i])
			}

			// Compute the distance
			trainMat := utilities.FloatsToMatrix(trainRowBuf)
			distances[srcRowNo] = distanceFunc.Distance(predMat, trainMat)
			return true, nil
		})

		sorted := utilities.SortIntMap(distances)
		values := sorted[:KNN.NearestNeighbours]

		// Reset maxMap
		for a := range maxmap {
			maxmap[a] = 0
		}

		// Refresh maxMap
		for _, elem := range values {
			label := base.GetClass(KNN.TrainingData, elem)
			if _, ok := maxmap[label]; ok {
				maxmap[label]++
			} else {
				maxmap[label] = 1
			}
		}

		// Sort the maxMap
		var maxClass string
		maxVal := -1
		for a := range maxmap {
			if maxmap[a] > maxVal {
				maxVal = maxmap[a]
				maxClass = a
			}
		}

		base.SetClass(ret, predRowNo, maxClass)
		return true, nil

//.........这里部分代码省略.........
开发者ID:hpxro7,项目名称:golearn,代码行数:101,代码来源:knn.go


示例14: TestSimple

func TestSimple(t *testing.T) {
	Convey("Given a simple training data", t, func() {
		trainingData, err := base.ParseCSVToInstances("test/simple_train.csv", false)
		So(err, ShouldBeNil)

		nb := NewBernoulliNBClassifier()
		nb.Fit(convertToBinary(trainingData))

		Convey("Check if Fit is working as expected", func() {
			Convey("All data needed for prior should be correctly calculated", func() {
				So(nb.classInstances["blue"], ShouldEqual, 2)
				So(nb.classInstances["red"], ShouldEqual, 2)
				So(nb.trainingInstances, ShouldEqual, 4)
			})

			Convey("'red' conditional probabilities should be correct", func() {
				logCondProbTok0 := nb.condProb["red"][0]
				logCondProbTok1 := nb.condProb["red"][1]
				logCondProbTok2 := nb.condProb["red"][2]

				So(logCondProbTok0, ShouldAlmostEqual, 1.0)
				So(logCondProbTok1, ShouldAlmostEqual, 1.0/3.0)
				So(logCondProbTok2, ShouldAlmostEqual, 1.0)
			})

			Convey("'blue' conditional probabilities should be correct", func() {
				logCondProbTok0 := nb.condProb["blue"][0]
				logCondProbTok1 := nb.condProb["blue"][1]
				logCondProbTok2 := nb.condProb["blue"][2]

				So(logCondProbTok0, ShouldAlmostEqual, 1.0)
				So(logCondProbTok1, ShouldAlmostEqual, 1.0)
				So(logCondProbTok2, ShouldAlmostEqual, 1.0/3.0)
			})
		})

		Convey("PredictOne should work as expected", func() {
			Convey("Using a document with different number of cols should panic", func() {
				testDoc := [][]byte{[]byte{0}, []byte{2}}
				So(func() { nb.PredictOne(testDoc) }, ShouldPanic)
			})

			Convey("Token 1 should be a good predictor of the blue class", func() {
				testDoc := [][]byte{[]byte{0}, []byte{1}, []byte{0}}
				So(nb.PredictOne(testDoc), ShouldEqual, "blue")

				testDoc = [][]byte{[]byte{1}, []byte{1}, []byte{0}}
				So(nb.PredictOne(testDoc), ShouldEqual, "blue")
			})

			Convey("Token 2 should be a good predictor of the red class", func() {
				testDoc := [][]byte{[]byte{0}, []byte{0}, []byte{1}}
				So(nb.PredictOne(testDoc), ShouldEqual, "red")
				testDoc = [][]byte{[]byte{1}, []byte{0}, []byte{1}}
				So(nb.PredictOne(testDoc), ShouldEqual, "red")
			})
		})

		Convey("Predict should work as expected", func() {
			testData, err := base.ParseCSVToInstances("test/simple_test.csv", false)
			So(err, ShouldBeNil)

			predictions := nb.Predict(convertToBinary(testData))

			Convey("All simple predicitions should be correct", func() {
				So(base.GetClass(predictions, 0), ShouldEqual, "blue")
				So(base.GetClass(predictions, 1), ShouldEqual, "red")
				So(base.GetClass(predictions, 2), ShouldEqual, "blue")
				So(base.GetClass(predictions, 3), ShouldEqual, "red")
			})
		})
	})
}
开发者ID:CTLife,项目名称:golearn,代码行数:73,代码来源:bernoulli_nb_test.go


示例15: Fit

// Fill data matrix with Bernoulli Naive Bayes model. All values
// necessary for calculating prior probability and p(f_i)
func (nb *BernoulliNBClassifier) Fit(X base.FixedDataGrid) {

	// Check that all Attributes are binary
	classAttrs := X.AllClassAttributes()
	allAttrs := X.AllAttributes()
	featAttrs := base.AttributeDifference(allAttrs, classAttrs)
	for i := range featAttrs {
		if _, ok := featAttrs[i].(*base.BinaryAttribute); !ok {
			panic(fmt.Sprintf("%v: Should be BinaryAttribute", featAttrs[i]))
		}
	}
	featAttrSpecs := base.ResolveAttributes(X, featAttrs)

	// Check that only one classAttribute is defined
	if len(classAttrs) != 1 {
		panic("Only one class Attribute can be used")
	}

	// Number of features and instances in this training set
	_, nb.trainingInstances = X.Size()
	nb.attrs = featAttrs
	nb.features = len(featAttrs)

	// Number of instances in class
	nb.classInstances = make(map[string]int)

	// Number of documents with given term (by class)
	docsContainingTerm := make(map[string][]int)

	// This algorithm could be vectorized after binarizing the data
	// matrix. Since mat64 doesn't have this function, a iterative
	// version is used.
	X.MapOverRows(featAttrSpecs, func(docVector [][]byte, r int) (bool, error) {
		class := base.GetClass(X, r)

		// increment number of instances in class
		t, ok := nb.classInstances[class]
		if !ok {
			t = 0
		}
		nb.classInstances[class] = t + 1

		for feat := 0; feat < len(docVector); feat++ {
			v := docVector[feat]
			// In Bernoulli Naive Bayes the presence and absence of
			// features are considered. All non-zero values are
			// treated as presence.
			if v[0] > 0 {
				// Update number of times this feature appeared within
				// given label.
				t, ok := docsContainingTerm[class]
				if !ok {
					t = make([]int, nb.features)
					docsContainingTerm[class] = t
				}
				t[feat] += 1
			}
		}
		return true, nil
	})

	// Pre-calculate conditional probabilities for each class
	for c, _ := range nb.classInstances {
		nb.condProb[c] = make([]float64, nb.features)
		for feat := 0; feat < nb.features; feat++ {
			classTerms, _ := docsContainingTerm[c]
			numDocs := classTerms[feat]
			docsInClass, _ := nb.classInstances[c]

			classCondProb, _ := nb.condProb[c]
			// Calculate conditional probability with laplace smoothing
			classCondProb[feat] = float64(numDocs+1) / float64(docsInClass+1)
		}
	}
}
开发者ID:JacobXie,项目名称:golearn,代码行数:77,代码来源:bernoulli_nb.go


示例16: Predict

// Predict gathers predictions from all the classifiers
// and outputs the most common (majority) class
//
// IMPORTANT: in the event of a tie, the first class which
// achieved the tie value is output.
func (b *BaggedModel) Predict(from base.FixedDataGrid) base.FixedDataGrid {
	n := runtime.NumCPU()
	// Channel to receive the results as they come in
	votes := make(chan base.DataGrid, n)
	// Count the votes for each class
	voting := make(map[int](map[string]int))

	// Create a goroutine to collect the votes
	var votingwait sync.WaitGroup
	votingwait.Add(1)
	go func() {
		for { // Need to resolve the voting problem
			incoming, ok := <-votes
			if ok {
				cSpecs := base.ResolveAttributes(incoming, incoming.AllClassAttributes())
				incoming.MapOverRows(cSpecs, func(row [][]byte, predRow int) (bool, error) {
					// Check if we've seen this class before...
					if _, ok := voting[predRow]; !ok {
						// If we haven't, create an entry
						voting[predRow] = make(map[string]int)
						// Continue on the current row
					}
					voting[predRow][base.GetClass(incoming, predRow)]++
					return true, nil
				})
			} else {
				votingwait.Done()
				break
			}
		}
	}()

	// Create workers to process the predictions
	processpipe := make(chan int, n)
	var processwait sync.WaitGroup
	for i := 0; i < n; i++ {
		processwait.Add(1)
		go func() {
			for {
				if i, ok := <-processpipe; ok {
					c := b.Models[i]
					l := b.generatePredictionInstances(i, from)
					votes <- c.Predict(l)
				} else {
					processwait.Done()
					break
				}
			}
		}()
	}

	// Send all the models to the workers for prediction
	for i := range b.Models {
		processpipe <- i
	}
	close(processpipe) // Finished sending models to be predicted
	processwait.Wait() // Predictors all finished processing
	close(votes)       // Close the vote channel and allow it to drain
	votingwait.Wait()  // All the votes are in

	// Generate the overall consensus
	ret := base.GeneratePredictionVector(from)
	for i := range voting {
		maxClass := ""
		maxCount := 0
		// Find the most popular class
		for c := range voting[i] {
			votes := voting[i][c]
			if votes > maxCount {
				maxClass = c
				maxCount = votes
			}
		}
		base.SetClass(ret, i, maxClass)
	}
	return ret
}
开发者ID:JacobXie,项目名称:golearn,代码行数:82,代码来源:bagging.go



注:本文中的github.com/sjwhitworth/golearn/base.GetClass函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Golang base.InstancesTrainTestSplit函数代码示例发布时间:2022-05-28
下一篇:
Golang logger.Println函数代码示例发布时间:2022-05-28
热门推荐
热门话题
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap