本文整理汇总了Golang中github.com/sjwhitworth/golearn/base.GetClass函数的典型用法代码示例。如果您正苦于以下问题:Golang GetClass函数的具体用法?Golang GetClass怎么用?Golang GetClass使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了GetClass函数的16个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Golang代码示例。
示例1: TestLinearRegression
func TestLinearRegression(t *testing.T) {
lr := NewLinearRegression()
rawData, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
if err != nil {
t.Fatal(err)
}
trainData, testData := base.InstancesTrainTestSplit(rawData, 0.1)
err = lr.Fit(trainData)
if err != nil {
t.Fatal(err)
}
predictions, err := lr.Predict(testData)
if err != nil {
t.Fatal(err)
}
_, rows := predictions.Size()
for i := 0; i < rows; i++ {
fmt.Printf("Expected: %s || Predicted: %s\n", base.GetClass(testData, i), base.GetClass(predictions, i))
}
}
开发者ID:JacobXie,项目名称:golearn,代码行数:25,代码来源:linear_regression_test.go
示例2: TestKnnClassifier
func TestKnnClassifier(t *testing.T) {
Convey("Given labels, a classifier and data", t, func() {
trainingData, err := base.ParseCSVToInstances("knn_train.csv", false)
So(err, ShouldBeNil)
testingData, err := base.ParseCSVToInstances("knn_test.csv", false)
So(err, ShouldBeNil)
cls := NewKnnClassifier("euclidean", 2)
cls.Fit(trainingData)
predictions := cls.Predict(testingData)
So(predictions, ShouldNotEqual, nil)
Convey("When predicting the label for our first vector", func() {
result := base.GetClass(predictions, 0)
Convey("The result should be 'blue", func() {
So(result, ShouldEqual, "blue")
})
})
Convey("When predicting the label for our second vector", func() {
result2 := base.GetClass(predictions, 1)
Convey("The result should be 'red", func() {
So(result2, ShouldEqual, "red")
})
})
})
}
开发者ID:GeekFreaker,项目名称:golearn,代码行数:28,代码来源:knn_test.go
示例3: TestLinearRegression
func TestLinearRegression(t *testing.T) {
Convey("Doing a linear regression", t, func() {
lr := NewLinearRegression()
Convey("With no training data", func() {
Convey("Predicting", func() {
testData, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
So(err, ShouldBeNil)
_, err = lr.Predict(testData)
Convey("Should result in a NoTrainingDataError", func() {
So(err, ShouldEqual, NoTrainingDataError)
})
})
})
Convey("With not enough training data", func() {
trainingDatum, err := base.ParseCSVToInstances("../examples/datasets/exam.csv", true)
So(err, ShouldBeNil)
Convey("Fitting", func() {
err = lr.Fit(trainingDatum)
Convey("Should result in a NotEnoughDataError", func() {
So(err, ShouldEqual, NotEnoughDataError)
})
})
})
Convey("With sufficient training data", func() {
instances, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
So(err, ShouldBeNil)
trainData, testData := base.InstancesTrainTestSplit(instances, 0.1)
Convey("Fitting and Predicting", func() {
err := lr.Fit(trainData)
So(err, ShouldBeNil)
predictions, err := lr.Predict(testData)
So(err, ShouldBeNil)
Convey("It makes reasonable predictions", func() {
_, rows := predictions.Size()
for i := 0; i < rows; i++ {
actualValue, _ := strconv.ParseFloat(base.GetClass(testData, i), 64)
expectedValue, _ := strconv.ParseFloat(base.GetClass(predictions, i), 64)
So(actualValue, ShouldAlmostEqual, expectedValue, actualValue*0.05)
}
})
})
})
})
}
开发者ID:CTLife,项目名称:golearn,代码行数:57,代码来源:linear_regression_test.go
示例4: TestLayeredXORInline
func TestLayeredXORInline(t *testing.T) {
Convey("Given an inline XOR dataset...", t, func() {
data := mat64.NewDense(4, 3, []float64{
1, 0, 1,
0, 1, 1,
0, 0, 0,
1, 1, 0,
})
XORData := base.InstancesFromMat64(4, 3, data)
classAttr := base.GetAttributeByName(XORData, "2")
XORData.AddClassAttribute(classAttr)
net := NewMultiLayerNet([]int{3})
net.MaxIterations = 20000
net.Fit(XORData)
Convey("After running for 20000 iterations, should have some predictive power...", func() {
Convey("The right nodes should be connected in the network...", func() {
So(net.network.GetWeight(1, 1), ShouldAlmostEqual, 1.000)
So(net.network.GetWeight(2, 2), ShouldAlmostEqual, 1.000)
for i := 1; i <= 6; i++ {
So(net.network.GetWeight(6, i), ShouldAlmostEqual, 0.000)
}
})
out := mat64.NewDense(6, 1, []float64{1.0, 0.0, 0.0, 0.0, 0.0, 0.0})
net.network.Activate(out, 2)
So(out.At(5, 0), ShouldAlmostEqual, 1.0, 0.1)
Convey("And Predict() should do OK too...", func() {
pred := net.Predict(XORData)
for _, a := range pred.AllAttributes() {
af, ok := a.(*base.FloatAttribute)
So(ok, ShouldBeTrue)
af.Precision = 1
}
So(base.GetClass(pred, 0), ShouldEqual, "1.0")
So(base.GetClass(pred, 1), ShouldEqual, "1.0")
So(base.GetClass(pred, 2), ShouldEqual, "0.0")
So(base.GetClass(pred, 3), ShouldEqual, "0.0")
})
})
})
}
开发者ID:thedadams,项目名称:golearn,代码行数:56,代码来源:layered_test.go
示例5: TestLayeredXOR
func TestLayeredXOR(t *testing.T) {
Convey("Given an XOR dataset...", t, func() {
XORData, err := base.ParseCSVToInstances("xor.csv", false)
So(err, ShouldEqual, nil)
fmt.Println(XORData)
net := NewMultiLayerNet([]int{3})
net.MaxIterations = 20000
net.Fit(XORData)
Convey("After running for 20000 iterations, should have some predictive power...", func() {
Convey("The right nodes should be connected in the network...", func() {
fmt.Println(net.network)
So(net.network.GetWeight(1, 1), ShouldAlmostEqual, 1.000)
So(net.network.GetWeight(2, 2), ShouldAlmostEqual, 1.000)
for i := 1; i <= 6; i++ {
So(net.network.GetWeight(6, i), ShouldAlmostEqual, 0.000)
}
})
out := mat64.NewDense(6, 1, []float64{1.0, 0.0, 0.0, 0.0, 0.0, 0.0})
net.network.Activate(out, 2)
fmt.Println(out)
So(out.At(5, 0), ShouldAlmostEqual, 1.0, 0.1)
Convey("And Predict() should do OK too...", func() {
pred := net.Predict(XORData)
for _, a := range pred.AllAttributes() {
af, ok := a.(*base.FloatAttribute)
if !ok {
panic("All of these should be FloatAttributes!")
}
af.Precision = 1
}
So(base.GetClass(pred, 0), ShouldEqual, "0.0")
So(base.GetClass(pred, 1), ShouldEqual, "1.0")
So(base.GetClass(pred, 2), ShouldEqual, "1.0")
So(base.GetClass(pred, 3), ShouldEqual, "0.0")
})
})
})
}
开发者ID:JacobXie,项目名称:golearn,代码行数:53,代码来源:layered_test.go
示例6: ChiMBuildFrequencyTable
func ChiMBuildFrequencyTable(attr base.Attribute, inst base.FixedDataGrid) []*FrequencyTableEntry {
ret := make([]*FrequencyTableEntry, 0)
attribute := attr.(*base.FloatAttribute)
attrSpec, err := inst.GetAttribute(attr)
if err != nil {
panic(err)
}
attrSpecs := []base.AttributeSpec{attrSpec}
err = inst.MapOverRows(attrSpecs, func(row [][]byte, rowNo int) (bool, error) {
value := row[0]
valueConv := attribute.GetFloatFromSysVal(value)
class := base.GetClass(inst, rowNo)
// Search the frequency table for the value
found := false
for _, entry := range ret {
if entry.Value == valueConv {
found = true
entry.Frequency[class] += 1
}
}
if !found {
newEntry := &FrequencyTableEntry{
valueConv,
make(map[string]int),
}
newEntry.Frequency[class] = 1
ret = append(ret, newEntry)
}
return true, nil
})
return ret
}
开发者ID:Gudym,项目名称:golearn,代码行数:35,代码来源:chimerge_funcs.go
示例7: vote
func (KNN *KNNClassifier) vote(maxmap map[string]int, values []int) string {
// Reset maxMap
for a := range maxmap {
maxmap[a] = 0
}
// Refresh maxMap
for _, elem := range values {
label := base.GetClass(KNN.TrainingData, elem)
if _, ok := maxmap[label]; ok {
maxmap[label]++
} else {
maxmap[label] = 1
}
}
// Sort the maxMap
var maxClass string
maxVal := -1
for a := range maxmap {
if maxmap[a] > maxVal {
maxVal = maxmap[a]
maxClass = a
}
}
return maxClass
}
开发者ID:nickpoorman,项目名称:golearn,代码行数:27,代码来源:knn.go
示例8: getNumericAttributeEntropy
func getNumericAttributeEntropy(f base.FixedDataGrid, attr *base.FloatAttribute) (float64, float64) {
// Resolve Attribute
attrSpec, err := f.GetAttribute(attr)
if err != nil {
panic(err)
}
// Build sortable vector
_, rows := f.Size()
refs := make([]numericSplitRef, rows)
f.MapOverRows([]base.AttributeSpec{attrSpec}, func(val [][]byte, row int) (bool, error) {
cls := base.GetClass(f, row)
v := base.UnpackBytesToFloat(val[0])
refs[row] = numericSplitRef{v, cls}
return true, nil
})
// Sort
sort.Sort(splitVec(refs))
generateCandidateSplitDistribution := func(val float64) map[string]map[string]int {
presplit := make(map[string]int)
postplit := make(map[string]int)
for _, i := range refs {
if i.val < val {
presplit[i.class]++
} else {
postplit[i.class]++
}
}
ret := make(map[string]map[string]int)
ret["0"] = presplit
ret["1"] = postplit
return ret
}
minSplitEntropy := math.Inf(1)
minSplitVal := math.Inf(1)
// Consider each possible function
for i := 0; i < len(refs)-1; i++ {
val := refs[i].val + refs[i+1].val
val /= 2
splitDist := generateCandidateSplitDistribution(val)
splitEntropy := getSplitEntropy(splitDist)
if splitEntropy < minSplitEntropy {
minSplitEntropy = splitEntropy
minSplitVal = val
}
}
return minSplitEntropy, minSplitVal
}
开发者ID:CTLife,项目名称:golearn,代码行数:53,代码来源:entropy.go
示例9: GetConfusionMatrix
// GetConfusionMatrix builds a ConfusionMatrix from a set of reference (`ref')
// and generate (`gen') Instances.
func GetConfusionMatrix(ref base.FixedDataGrid, gen base.FixedDataGrid) (map[string]map[string]int, error) {
_, refRows := ref.Size()
_, genRows := gen.Size()
if refRows != genRows {
return nil, errors.New(fmt.Sprintf("Row count mismatch: ref has %d rows, gen has %d rows", refRows, genRows))
}
ret := make(map[string]map[string]int)
for i := 0; i < int(refRows); i++ {
referenceClass := base.GetClass(ref, i)
predictedClass := base.GetClass(gen, i)
if _, ok := ret[referenceClass]; ok {
ret[referenceClass][predictedClass] += 1
} else {
ret[referenceClass] = make(map[string]int)
ret[referenceClass][predictedClass] = 1
}
}
return ret, nil
}
开发者ID:CTLife,项目名称:golearn,代码行数:24,代码来源:confusion.go
示例10: GetConfusionMatrix
// GetConfusionMatrix builds a ConfusionMatrix from a set of reference (`ref')
// and generate (`gen') Instances.
func GetConfusionMatrix(ref base.FixedDataGrid, gen base.FixedDataGrid) map[string]map[string]int {
_, refRows := ref.Size()
_, genRows := gen.Size()
if refRows != genRows {
panic("Row counts should match")
}
ret := make(map[string]map[string]int)
for i := 0; i < int(refRows); i++ {
referenceClass := base.GetClass(ref, i)
predictedClass := base.GetClass(gen, i)
if _, ok := ret[referenceClass]; ok {
ret[referenceClass][predictedClass] += 1
} else {
ret[referenceClass] = make(map[string]int)
ret[referenceClass][predictedClass] = 1
}
}
return ret
}
开发者ID:Gudym,项目名称:golearn,代码行数:25,代码来源:confusion.go
示例11: main
func main() {
xOne, err := strconv.ParseInt(os.Args[1], 10, 64)
xTwo, err := strconv.ParseInt(os.Args[2], 10, 64)
fmt.Println("")
if err != nil {
panic(err)
}
XNORData, _ := base.ParseCSVToInstances("xnor.csv", false)
net := neural.NewMultiLayerNet([]int{3})
net.MaxIterations = 20000
net.Fit(XNORData)
pred := net.Predict(XNORData)
var inputVectorIndex int
if xOne == 0 && xTwo == 0 {
inputVectorIndex = 0
fmt.Println(base.GetClass(pred, inputVectorIndex))
} else if xOne == 0 && xTwo == 1 {
inputVectorIndex = 1
fmt.Println(base.GetClass(pred, inputVectorIndex))
} else if xOne == 1 && xTwo == 0 {
inputVectorIndex = 2
fmt.Println(base.GetClass(pred, inputVectorIndex))
} else if xOne == 1 && xTwo == 1 {
inputVectorIndex = 3
fmt.Println(base.GetClass(pred, inputVectorIndex))
} else {
fmt.Println("Your input is incorrect. Quitting. ")
}
fmt.Println("")
}
开发者ID:nick11roberts,项目名称:neural-xnor,代码行数:37,代码来源:main.go
示例12: processData
func processData(x base.FixedDataGrid) instances {
_, rows := x.Size()
result := make(instances, rows)
// Retrieve numeric non-class Attributes
numericAttrs := base.NonClassFloatAttributes(x)
numericAttrSpecs := base.ResolveAttributes(x, numericAttrs)
// Retrieve class Attributes
classAttrs := x.AllClassAttributes()
if len(classAttrs) != 1 {
panic("Only one classAttribute supported!")
}
// Check that the class Attribute is categorical
// (with two values) or binary
classAttr := classAttrs[0]
if attr, ok := classAttr.(*base.CategoricalAttribute); ok {
if len(attr.GetValues()) != 2 {
panic("To many values for Attribute!")
}
} else if _, ok := classAttr.(*base.BinaryAttribute); ok {
} else {
panic("Wrong class Attribute type!")
}
// Convert each row
x.MapOverRows(numericAttrSpecs, func(row [][]byte, rowNo int) (bool, error) {
// Allocate a new row
probRow := make([]float64, len(numericAttrSpecs))
// Read out the row
for i, _ := range numericAttrSpecs {
probRow[i] = base.UnpackBytesToFloat(row[i])
}
// Get the class for the values
class := base.GetClass(x, rowNo)
instance := instance{class, probRow}
result[rowNo] = instance
return true, nil
})
return result
}
开发者ID:CTLife,项目名称:golearn,代码行数:45,代码来源:average.go
示例13: Predict
// Predict returns a classification for the vector, based on a vector input, using the KNN algorithm.
func (KNN *KNNClassifier) Predict(what base.FixedDataGrid) base.FixedDataGrid {
// Check what distance function we are using
var distanceFunc pairwise.PairwiseDistanceFunc
switch KNN.DistanceFunc {
case "euclidean":
distanceFunc = pairwise.NewEuclidean()
case "manhattan":
distanceFunc = pairwise.NewManhattan()
default:
panic("unsupported distance function")
}
// Check Compatibility
allAttrs := base.CheckCompatible(what, KNN.TrainingData)
if allAttrs == nil {
// Don't have the same Attributes
return nil
}
// Remove the Attributes which aren't numeric
allNumericAttrs := make([]base.Attribute, 0)
for _, a := range allAttrs {
if fAttr, ok := a.(*base.FloatAttribute); ok {
allNumericAttrs = append(allNumericAttrs, fAttr)
}
}
// Generate return vector
ret := base.GeneratePredictionVector(what)
// Resolve Attribute specifications for both
whatAttrSpecs := base.ResolveAttributes(what, allNumericAttrs)
trainAttrSpecs := base.ResolveAttributes(KNN.TrainingData, allNumericAttrs)
// Reserve storage for most the most similar items
distances := make(map[int]float64)
// Reserve storage for voting map
maxmap := make(map[string]int)
// Reserve storage for row computations
trainRowBuf := make([]float64, len(allNumericAttrs))
predRowBuf := make([]float64, len(allNumericAttrs))
// Iterate over all outer rows
what.MapOverRows(whatAttrSpecs, func(predRow [][]byte, predRowNo int) (bool, error) {
// Read the float values out
for i, _ := range allNumericAttrs {
predRowBuf[i] = base.UnpackBytesToFloat(predRow[i])
}
predMat := utilities.FloatsToMatrix(predRowBuf)
// Find the closest match in the training data
KNN.TrainingData.MapOverRows(trainAttrSpecs, func(trainRow [][]byte, srcRowNo int) (bool, error) {
// Read the float values out
for i, _ := range allNumericAttrs {
trainRowBuf[i] = base.UnpackBytesToFloat(trainRow[i])
}
// Compute the distance
trainMat := utilities.FloatsToMatrix(trainRowBuf)
distances[srcRowNo] = distanceFunc.Distance(predMat, trainMat)
return true, nil
})
sorted := utilities.SortIntMap(distances)
values := sorted[:KNN.NearestNeighbours]
// Reset maxMap
for a := range maxmap {
maxmap[a] = 0
}
// Refresh maxMap
for _, elem := range values {
label := base.GetClass(KNN.TrainingData, elem)
if _, ok := maxmap[label]; ok {
maxmap[label]++
} else {
maxmap[label] = 1
}
}
// Sort the maxMap
var maxClass string
maxVal := -1
for a := range maxmap {
if maxmap[a] > maxVal {
maxVal = maxmap[a]
maxClass = a
}
}
base.SetClass(ret, predRowNo, maxClass)
return true, nil
//.........这里部分代码省略.........
开发者ID:hpxro7,项目名称:golearn,代码行数:101,代码来源:knn.go
示例14: TestSimple
func TestSimple(t *testing.T) {
Convey("Given a simple training data", t, func() {
trainingData, err := base.ParseCSVToInstances("test/simple_train.csv", false)
So(err, ShouldBeNil)
nb := NewBernoulliNBClassifier()
nb.Fit(convertToBinary(trainingData))
Convey("Check if Fit is working as expected", func() {
Convey("All data needed for prior should be correctly calculated", func() {
So(nb.classInstances["blue"], ShouldEqual, 2)
So(nb.classInstances["red"], ShouldEqual, 2)
So(nb.trainingInstances, ShouldEqual, 4)
})
Convey("'red' conditional probabilities should be correct", func() {
logCondProbTok0 := nb.condProb["red"][0]
logCondProbTok1 := nb.condProb["red"][1]
logCondProbTok2 := nb.condProb["red"][2]
So(logCondProbTok0, ShouldAlmostEqual, 1.0)
So(logCondProbTok1, ShouldAlmostEqual, 1.0/3.0)
So(logCondProbTok2, ShouldAlmostEqual, 1.0)
})
Convey("'blue' conditional probabilities should be correct", func() {
logCondProbTok0 := nb.condProb["blue"][0]
logCondProbTok1 := nb.condProb["blue"][1]
logCondProbTok2 := nb.condProb["blue"][2]
So(logCondProbTok0, ShouldAlmostEqual, 1.0)
So(logCondProbTok1, ShouldAlmostEqual, 1.0)
So(logCondProbTok2, ShouldAlmostEqual, 1.0/3.0)
})
})
Convey("PredictOne should work as expected", func() {
Convey("Using a document with different number of cols should panic", func() {
testDoc := [][]byte{[]byte{0}, []byte{2}}
So(func() { nb.PredictOne(testDoc) }, ShouldPanic)
})
Convey("Token 1 should be a good predictor of the blue class", func() {
testDoc := [][]byte{[]byte{0}, []byte{1}, []byte{0}}
So(nb.PredictOne(testDoc), ShouldEqual, "blue")
testDoc = [][]byte{[]byte{1}, []byte{1}, []byte{0}}
So(nb.PredictOne(testDoc), ShouldEqual, "blue")
})
Convey("Token 2 should be a good predictor of the red class", func() {
testDoc := [][]byte{[]byte{0}, []byte{0}, []byte{1}}
So(nb.PredictOne(testDoc), ShouldEqual, "red")
testDoc = [][]byte{[]byte{1}, []byte{0}, []byte{1}}
So(nb.PredictOne(testDoc), ShouldEqual, "red")
})
})
Convey("Predict should work as expected", func() {
testData, err := base.ParseCSVToInstances("test/simple_test.csv", false)
So(err, ShouldBeNil)
predictions := nb.Predict(convertToBinary(testData))
Convey("All simple predicitions should be correct", func() {
So(base.GetClass(predictions, 0), ShouldEqual, "blue")
So(base.GetClass(predictions, 1), ShouldEqual, "red")
So(base.GetClass(predictions, 2), ShouldEqual, "blue")
So(base.GetClass(predictions, 3), ShouldEqual, "red")
})
})
})
}
开发者ID:CTLife,项目名称:golearn,代码行数:73,代码来源:bernoulli_nb_test.go
示例15: Fit
// Fill data matrix with Bernoulli Naive Bayes model. All values
// necessary for calculating prior probability and p(f_i)
func (nb *BernoulliNBClassifier) Fit(X base.FixedDataGrid) {
// Check that all Attributes are binary
classAttrs := X.AllClassAttributes()
allAttrs := X.AllAttributes()
featAttrs := base.AttributeDifference(allAttrs, classAttrs)
for i := range featAttrs {
if _, ok := featAttrs[i].(*base.BinaryAttribute); !ok {
panic(fmt.Sprintf("%v: Should be BinaryAttribute", featAttrs[i]))
}
}
featAttrSpecs := base.ResolveAttributes(X, featAttrs)
// Check that only one classAttribute is defined
if len(classAttrs) != 1 {
panic("Only one class Attribute can be used")
}
// Number of features and instances in this training set
_, nb.trainingInstances = X.Size()
nb.attrs = featAttrs
nb.features = len(featAttrs)
// Number of instances in class
nb.classInstances = make(map[string]int)
// Number of documents with given term (by class)
docsContainingTerm := make(map[string][]int)
// This algorithm could be vectorized after binarizing the data
// matrix. Since mat64 doesn't have this function, a iterative
// version is used.
X.MapOverRows(featAttrSpecs, func(docVector [][]byte, r int) (bool, error) {
class := base.GetClass(X, r)
// increment number of instances in class
t, ok := nb.classInstances[class]
if !ok {
t = 0
}
nb.classInstances[class] = t + 1
for feat := 0; feat < len(docVector); feat++ {
v := docVector[feat]
// In Bernoulli Naive Bayes the presence and absence of
// features are considered. All non-zero values are
// treated as presence.
if v[0] > 0 {
// Update number of times this feature appeared within
// given label.
t, ok := docsContainingTerm[class]
if !ok {
t = make([]int, nb.features)
docsContainingTerm[class] = t
}
t[feat] += 1
}
}
return true, nil
})
// Pre-calculate conditional probabilities for each class
for c, _ := range nb.classInstances {
nb.condProb[c] = make([]float64, nb.features)
for feat := 0; feat < nb.features; feat++ {
classTerms, _ := docsContainingTerm[c]
numDocs := classTerms[feat]
docsInClass, _ := nb.classInstances[c]
classCondProb, _ := nb.condProb[c]
// Calculate conditional probability with laplace smoothing
classCondProb[feat] = float64(numDocs+1) / float64(docsInClass+1)
}
}
}
开发者ID:JacobXie,项目名称:golearn,代码行数:77,代码来源:bernoulli_nb.go
示例16: Predict
// Predict gathers predictions from all the classifiers
// and outputs the most common (majority) class
//
// IMPORTANT: in the event of a tie, the first class which
// achieved the tie value is output.
func (b *BaggedModel) Predict(from base.FixedDataGrid) base.FixedDataGrid {
n := runtime.NumCPU()
// Channel to receive the results as they come in
votes := make(chan base.DataGrid, n)
// Count the votes for each class
voting := make(map[int](map[string]int))
// Create a goroutine to collect the votes
var votingwait sync.WaitGroup
votingwait.Add(1)
go func() {
for { // Need to resolve the voting problem
incoming, ok := <-votes
if ok {
cSpecs := base.ResolveAttributes(incoming, incoming.AllClassAttributes())
incoming.MapOverRows(cSpecs, func(row [][]byte, predRow int) (bool, error) {
// Check if we've seen this class before...
if _, ok := voting[predRow]; !ok {
// If we haven't, create an entry
voting[predRow] = make(map[string]int)
// Continue on the current row
}
voting[predRow][base.GetClass(incoming, predRow)]++
return true, nil
})
} else {
votingwait.Done()
break
}
}
}()
// Create workers to process the predictions
processpipe := make(chan int, n)
var processwait sync.WaitGroup
for i := 0; i < n; i++ {
processwait.Add(1)
go func() {
for {
if i, ok := <-processpipe; ok {
c := b.Models[i]
l := b.generatePredictionInstances(i, from)
votes <- c.Predict(l)
} else {
processwait.Done()
break
}
}
}()
}
// Send all the models to the workers for prediction
for i := range b.Models {
processpipe <- i
}
close(processpipe) // Finished sending models to be predicted
processwait.Wait() // Predictors all finished processing
close(votes) // Close the vote channel and allow it to drain
votingwait.Wait() // All the votes are in
// Generate the overall consensus
ret := base.GeneratePredictionVector(from)
for i := range voting {
maxClass := ""
maxCount := 0
// Find the most popular class
for c := range voting[i] {
votes := voting[i][c]
if votes > maxCount {
maxClass = c
maxCount = votes
}
}
base.SetClass(ret, i, maxClass)
}
return ret
}
开发者ID:JacobXie,项目名称:golearn,代码行数:82,代码来源:bagging.go
注:本文中的github.com/sjwhitworth/golearn/base.GetClass函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论