本文整理汇总了Golang中github.com/sjwhitworth/golearn/filters.NewChiMergeFilter函数的典型用法代码示例。如果您正苦于以下问题:Golang NewChiMergeFilter函数的具体用法?Golang NewChiMergeFilter怎么用?Golang NewChiMergeFilter使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了NewChiMergeFilter函数的17个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Golang代码示例。
示例1: main
func main() {
var tree base.Classifier
rand.Seed(44111342)
// Load in the iris dataset
iris, err := base.ParseCSVToInstances("/home/kralli/go/src/github.com/sjwhitworth/golearn/examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
// Discretise the iris dataset with Chi-Merge
filt := filters.NewChiMergeFilter(iris, 0.999)
for _, a := range base.NonClassFloatAttributes(iris) {
filt.AddAttribute(a)
}
filt.Train()
irisf := base.NewLazilyFilteredInstances(iris, filt)
// Create a 60-40 training-test split
//testData
trainData, _ := base.InstancesTrainTestSplit(iris, 0.60)
findBestSplit(trainData)
//fmt.Println(trainData)
//fmt.Println(testData)
fmt.Println(tree)
fmt.Println(irisf)
}
开发者ID:krallistic,项目名称:go_stuff,代码行数:32,代码来源:cart_tree.go
示例2: BenchmarkBaggingRandomForestPredict
func BenchmarkBaggingRandomForestPredict(t *testing.B) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
rand.Seed(time.Now().UnixNano())
filt := filters.NewChiMergeFilter(inst, 0.90)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
rf.Fit(instf)
t.ResetTimer()
for i := 0; i < 20; i++ {
rf.Predict(instf)
}
}
开发者ID:GeekFreaker,项目名称:golearn,代码行数:25,代码来源:bagging_test.go
示例3: TestPruning
func TestPruning(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
fmt.Println(testData)
filt.Run(testData)
filt.Run(trainData)
root := NewRandomTree(2)
fittrainData, fittestData := base.InstancesTrainTestSplit(trainData, 0.6)
root.Fit(fittrainData)
root.Prune(fittestData)
fmt.Println(root)
predictions := root.Predict(testData)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testData, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
fmt.Println(eval.GetSummary(confusionMat))
}
开发者ID:hsinhoyeh,项目名称:golearn,代码行数:25,代码来源:tree_test.go
示例4: main
func main() {
var tree base.Classifier
rand.Seed(time.Now().UTC().UnixNano())
// Load in the iris dataset
iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
// Discretise the iris dataset with Chi-Merge
filt := filters.NewChiMergeFilter(iris, 0.99)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(iris)
// Create a 60-40 training-test split
insts := base.InstancesTrainTestSplit(iris, 0.60)
//
// First up, use ID3
//
tree = trees.NewID3DecisionTree(0.6)
// (Parameter controls train-prune split.)
// Train the ID3 tree
tree.Fit(insts[0])
// Generate predictions
predictions := tree.Predict(insts[1])
// Evaluate
fmt.Println("ID3 Performance")
cf := eval.GetConfusionMatrix(insts[1], predictions)
fmt.Println(eval.GetSummary(cf))
//
// Next up, Random Trees
//
// Consider two randomly-chosen attributes
tree = trees.NewRandomTree(2)
tree.Fit(insts[0])
predictions = tree.Predict(insts[1])
fmt.Println("RandomTree Performance")
cf = eval.GetConfusionMatrix(insts[1], predictions)
fmt.Println(eval.GetSummary(cf))
//
// Finally, Random Forests
//
tree = ensemble.NewRandomForest(100, 3)
tree.Fit(insts[0])
predictions = tree.Predict(insts[1])
fmt.Println("RandomForest Performance")
cf = eval.GetConfusionMatrix(insts[1], predictions)
fmt.Println(eval.GetSummary(cf))
}
开发者ID:24hours,项目名称:golearn,代码行数:60,代码来源:trees.go
示例5: TestRandomForest1
func TestRandomForest1(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
rand.Seed(time.Now().UnixNano())
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
filt := filters.NewChiMergeFilter(inst, 0.90)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
trainDataf := base.NewLazilyFilteredInstances(trainData, filt)
testDataf := base.NewLazilyFilteredInstances(testData, filt)
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
rf.Fit(trainDataf)
fmt.Println(rf)
predictions := rf.Predict(testDataf)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testDataf, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
fmt.Println(eval.GetSummary(confusionMat))
}
开发者ID:Gudym,项目名称:golearn,代码行数:29,代码来源:bagging_test.go
示例6: TestRandomForest1
func TestRandomForest1(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
rand.Seed(time.Now().UnixNano())
insts := base.InstancesTrainTestSplit(inst, 0.6)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(insts[1])
filt.Run(insts[0])
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
rf.Fit(insts[0])
fmt.Println(rf)
predictions := rf.Predict(insts[1])
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(insts[1], predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
fmt.Println(eval.GetSummary(confusionMat))
}
开发者ID:24hours,项目名称:golearn,代码行数:27,代码来源:bagging_test.go
示例7: TestRandomTree
func TestRandomTree(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(inst)
fmt.Println(inst)
r := new(RandomTreeRuleGenerator)
r.Attributes = 2
root := InferID3Tree(inst, r)
fmt.Println(root)
}
开发者ID:hsinhoyeh,项目名称:golearn,代码行数:16,代码来源:tree_test.go
示例8: TestRandomTreeClassificationAfterDiscretisation
func TestRandomTreeClassificationAfterDiscretisation(t *testing.T) {
Convey("Predictions on filtered data with a Random Tree", t, func() {
instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
trainData, testData := base.InstancesTrainTestSplit(instances, 0.6)
filter := filters.NewChiMergeFilter(instances, 0.9)
for _, a := range base.NonClassFloatAttributes(instances) {
filter.AddAttribute(a)
}
filter.Train()
filteredTrainData := base.NewLazilyFilteredInstances(trainData, filter)
filteredTestData := base.NewLazilyFilteredInstances(testData, filter)
verifyTreeClassification(filteredTrainData, filteredTestData)
})
}
开发者ID:CTLife,项目名称:golearn,代码行数:17,代码来源:tree_test.go
示例9: TestRandomForest
func TestRandomForest(t *testing.T) {
Convey("Given a valid CSV file", t, func() {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
Convey("When Chi-Merge filtering the data", func() {
filt := filters.NewChiMergeFilter(inst, 0.90)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
Convey("Splitting the data into test and training sets", func() {
trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)
Convey("Fitting and predicting with a Random Forest", func() {
rf := NewRandomForest(10, 3)
err = rf.Fit(trainData)
So(err, ShouldBeNil)
predictions, err := rf.Predict(testData)
So(err, ShouldBeNil)
confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMat), ShouldBeGreaterThan, 0.35)
})
})
})
})
Convey("Fitting with a Random Forest with too many features compared to the data", func() {
rf := NewRandomForest(10, len(base.NonClassAttributes(inst))+1)
err = rf.Fit(inst)
Convey("Should return an error", func() {
So(err, ShouldNotBeNil)
})
})
})
}
开发者ID:CTLife,项目名称:golearn,代码行数:44,代码来源:randomforest_test.go
示例10: TestRandomForest1
func TestRandomForest1(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.60)
filt := filters.NewChiMergeFilter(trainData, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(testData)
filt.Run(trainData)
rf := NewRandomForest(10, 3)
rf.Fit(trainData)
predictions := rf.Predict(testData)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testData, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetSummary(confusionMat))
}
开发者ID:hsinhoyeh,项目名称:golearn,代码行数:19,代码来源:randomforest_test.go
示例11: BenchmarkBaggingRandomForestFit
func BenchmarkBaggingRandomForestFit(testEnv *testing.B) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
rand.Seed(time.Now().UnixNano())
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(inst)
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
testEnv.ResetTimer()
for i := 0; i < 20; i++ {
rf.Fit(inst)
}
}
开发者ID:24hours,项目名称:golearn,代码行数:20,代码来源:bagging_test.go
示例12: TestBaggedModelRandomForest
func TestBaggedModelRandomForest(t *testing.T) {
Convey("Given data", t, func() {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
Convey("Splitting the data into training and test data", func() {
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
Convey("Filtering the split datasets", func() {
rand.Seed(time.Now().UnixNano())
filt := filters.NewChiMergeFilter(inst, 0.90)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
trainDataf := base.NewLazilyFilteredInstances(trainData, filt)
testDataf := base.NewLazilyFilteredInstances(testData, filt)
Convey("Fitting and Predicting with a Bagged Model of 10 Random Trees", func() {
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
rf.Fit(trainDataf)
predictions := rf.Predict(testDataf)
confusionMat, err := evaluation.GetConfusionMatrix(testDataf, predictions)
So(err, ShouldBeNil)
Convey("Predictions are somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMat), ShouldBeGreaterThan, 0.5)
})
})
})
})
})
}
开发者ID:GeekFreaker,项目名称:golearn,代码行数:38,代码来源:bagging_test.go
示例13: TestRandomForest1
func TestRandomForest1(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
filt := filters.NewChiMergeFilter(inst, 0.90)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)
rf := NewRandomForest(10, 3)
rf.Fit(trainData)
predictions := rf.Predict(testData)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testData, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetSummary(confusionMat))
}
开发者ID:JacobXie,项目名称:golearn,代码行数:23,代码来源:randomforest_test.go
示例14: TestRandomTreeClassification2
func TestRandomTreeClassification2(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
insts := base.InstancesTrainTestSplit(inst, 0.4)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
fmt.Println(insts[1])
filt.Run(insts[1])
filt.Run(insts[0])
root := NewRandomTree(2)
root.Fit(insts[0])
fmt.Println(root)
predictions := root.Predict(insts[1])
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(insts[1], predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
fmt.Println(eval.GetSummary(confusionMat))
}
开发者ID:24hours,项目名称:golearn,代码行数:23,代码来源:tree_test.go
示例15: TestRandomTreeClassification
func TestRandomTreeClassification(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(trainData)
filt.Run(testData)
fmt.Println(inst)
r := new(RandomTreeRuleGenerator)
r.Attributes = 2
root := InferID3Tree(trainData, r)
fmt.Println(root)
predictions := root.Predict(testData)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testData, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
fmt.Println(eval.GetSummary(confusionMat))
}
开发者ID:hsinhoyeh,项目名称:golearn,代码行数:24,代码来源:tree_test.go
示例16: main
func main() {
var tree base.Classifier
rand.Seed(44111342)
// Load in the iris dataset
iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
// Discretise the iris dataset with Chi-Merge
filt := filters.NewChiMergeFilter(iris, 0.999)
for _, a := range base.NonClassFloatAttributes(iris) {
filt.AddAttribute(a)
}
filt.Train()
irisf := base.NewLazilyFilteredInstances(iris, filt)
// Create a 60-40 training-test split
trainData, testData := base.InstancesTrainTestSplit(irisf, 0.60)
//
// First up, use ID3
//
tree = trees.NewID3DecisionTree(0.6)
// (Parameter controls train-prune split.)
// Train the ID3 tree
err = tree.Fit(trainData)
if err != nil {
panic(err)
}
// Generate predictions
predictions, err := tree.Predict(testData)
if err != nil {
panic(err)
}
// Evaluate
fmt.Println("ID3 Performance (information gain)")
cf, err := evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(cf))
tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.InformationGainRatioRuleGenerator))
// (Parameter controls train-prune split.)
// Train the ID3 tree
err = tree.Fit(trainData)
if err != nil {
panic(err)
}
// Generate predictions
predictions, err = tree.Predict(testData)
if err != nil {
panic(err)
}
// Evaluate
fmt.Println("ID3 Performance (information gain ratio)")
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(cf))
tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.GiniCoefficientRuleGenerator))
// (Parameter controls train-prune split.)
// Train the ID3 tree
err = tree.Fit(trainData)
if err != nil {
panic(err)
}
// Generate predictions
predictions, err = tree.Predict(testData)
if err != nil {
panic(err)
}
// Evaluate
fmt.Println("ID3 Performance (gini index generator)")
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(cf))
//
// Next up, Random Trees
//
// Consider two randomly-chosen attributes
tree = trees.NewRandomTree(2)
//.........这里部分代码省略.........
开发者ID:CTLife,项目名称:golearn,代码行数:101,代码来源:trees.go
示例17: TestRandomTreeClassification
func TestRandomTreeClassification(t *testing.T) {
Convey("Predictions on filtered data with a Random Tree", t, func() {
instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
trainData, testData := base.InstancesTrainTestSplit(instances, 0.6)
filter := filters.NewChiMergeFilter(instances, 0.9)
for _, a := range base.NonClassFloatAttributes(instances) {
filter.AddAttribute(a)
}
filter.Train()
filteredTrainData := base.NewLazilyFilteredInstances(trainData, filter)
filteredTestData := base.NewLazilyFilteredInstances(testData, filter)
Convey("Using InferID3Tree to create the tree and do the fitting", func() {
Convey("Using a RandomTreeRule", func() {
randomTreeRuleGenerator := new(RandomTreeRuleGenerator)
randomTreeRuleGenerator.Attributes = 2
root := InferID3Tree(filteredTrainData, randomTreeRuleGenerator)
Convey("Predicting with the tree", func() {
predictions, err := root.Predict(filteredTestData)
So(err, ShouldBeNil)
confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
})
})
})
Convey("Using a InformationGainRule", func() {
informationGainRuleGenerator := new(InformationGainRuleGenerator)
root := InferID3Tree(filteredTrainData, informationGainRuleGenerator)
Convey("Predicting with the tree", func() {
predictions, err := root.Predict(filteredTestData)
So(err, ShouldBeNil)
confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
})
})
})
})
Convey("Using NewRandomTree to create the tree", func() {
root := NewRandomTree(2)
Convey("Fitting with the tree", func() {
err = root.Fit(filteredTrainData)
So(err, ShouldBeNil)
Convey("Predicting with the tree, *without* pruning first", func() {
predictions, err := root.Predict(filteredTestData)
So(err, ShouldBeNil)
confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
})
})
Convey("Predicting with the tree, pruning first", func() {
root.Prune(filteredTestData)
predictions, err := root.Predict(filteredTestData)
So(err, ShouldBeNil)
confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.4)
})
})
})
})
})
}
开发者ID:GeekFreaker,项目名称:golearn,代码行数:88,代码来源:tree_test.go
注:本文中的github.com/sjwhitworth/golearn/filters.NewChiMergeFilter函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论