本文整理汇总了Golang中github.com/sjwhitworth/golearn/base.InstancesTrainTestSplit函数的典型用法代码示例。如果您正苦于以下问题:Golang InstancesTrainTestSplit函数的具体用法?Golang InstancesTrainTestSplit怎么用?Golang InstancesTrainTestSplit使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了InstancesTrainTestSplit函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Golang代码示例。
示例1: TestPruning
func TestPruning(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
fmt.Println(testData)
filt.Run(testData)
filt.Run(trainData)
root := NewRandomTree(2)
fittrainData, fittestData := base.InstancesTrainTestSplit(trainData, 0.6)
root.Fit(fittrainData)
root.Prune(fittestData)
fmt.Println(root)
predictions := root.Predict(testData)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testData, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
fmt.Println(eval.GetSummary(confusionMat))
}
开发者ID:hsinhoyeh,项目名称:golearn,代码行数:25,代码来源:tree_test.go
示例2: main
func main() {
// Load in a dataset, with headers. Header attributes will be stored.
// Think of instances as a Data Frame structure in R or Pandas.
// You can also create instances from scratch.
rawData, err := base.ParseCSVToInstances("datasets/iris.csv", false)
if err != nil {
panic(err)
}
// Print a pleasant summary of your data.
fmt.Println(rawData)
//Initialises a new KNN classifier
cls := knn.NewKnnClassifier("euclidean", 2)
//Do a training-test split
trainData, testData := base.InstancesTrainTestSplit(rawData, 0.50)
cls.Fit(trainData)
//Calculates the Euclidean distance and returns the most popular label
predictions := cls.Predict(testData)
fmt.Println(predictions)
// Prints precision/recall metrics
confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(confusionMat))
}
开发者ID:raghavkgarg,项目名称:gotutorial,代码行数:30,代码来源:ml1.go
示例3: TestRandomForest1
func TestRandomForest1(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
rand.Seed(time.Now().UnixNano())
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
filt := filters.NewChiMergeFilter(inst, 0.90)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
trainDataf := base.NewLazilyFilteredInstances(trainData, filt)
testDataf := base.NewLazilyFilteredInstances(testData, filt)
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
rf.Fit(trainDataf)
fmt.Println(rf)
predictions := rf.Predict(testDataf)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testDataf, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
fmt.Println(eval.GetSummary(confusionMat))
}
开发者ID:Gudym,项目名称:golearn,代码行数:29,代码来源:bagging_test.go
示例4: main
func main() {
var tree base.Classifier
rand.Seed(time.Now().UTC().UnixNano())
// Load in the iris dataset
iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
// Discretise the iris dataset with Chi-Merge
filt := filters.NewChiMergeFilter(iris, 0.99)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(iris)
// Create a 60-40 training-test split
insts := base.InstancesTrainTestSplit(iris, 0.60)
//
// First up, use ID3
//
tree = trees.NewID3DecisionTree(0.6)
// (Parameter controls train-prune split.)
// Train the ID3 tree
tree.Fit(insts[0])
// Generate predictions
predictions := tree.Predict(insts[1])
// Evaluate
fmt.Println("ID3 Performance")
cf := eval.GetConfusionMatrix(insts[1], predictions)
fmt.Println(eval.GetSummary(cf))
//
// Next up, Random Trees
//
// Consider two randomly-chosen attributes
tree = trees.NewRandomTree(2)
tree.Fit(insts[0])
predictions = tree.Predict(insts[1])
fmt.Println("RandomTree Performance")
cf = eval.GetConfusionMatrix(insts[1], predictions)
fmt.Println(eval.GetSummary(cf))
//
// Finally, Random Forests
//
tree = ensemble.NewRandomForest(100, 3)
tree.Fit(insts[0])
predictions = tree.Predict(insts[1])
fmt.Println("RandomForest Performance")
cf = eval.GetConfusionMatrix(insts[1], predictions)
fmt.Println(eval.GetSummary(cf))
}
开发者ID:24hours,项目名称:golearn,代码行数:60,代码来源:trees.go
示例5: TestPredict
func TestPredict(t *testing.T) {
a := NewAveragePerceptron(10, 1.2, 0.5, 0.3)
if a == nil {
t.Errorf("Unable to create average perceptron")
}
absPath, _ := filepath.Abs("../examples/datasets/house-votes-84.csv")
rawData, err := base.ParseCSVToInstances(absPath, true)
if err != nil {
t.Fail()
}
trainData, testData := base.InstancesTrainTestSplit(rawData, 0.5)
a.Fit(trainData)
if a.trained == false {
t.Errorf("Perceptron was not trained")
}
predictions := a.Predict(testData)
cf, err := evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
t.Errorf("Couldn't get confusion matrix: %s", err)
t.Fail()
}
fmt.Println(evaluation.GetSummary(cf))
fmt.Println(trainData)
fmt.Println(testData)
if evaluation.GetAccuracy(cf) < 0.65 {
t.Errorf("Perceptron not trained correctly")
}
}
开发者ID:CTLife,项目名称:golearn,代码行数:35,代码来源:average_test.go
示例6: main
func main() {
var tree base.Classifier
rand.Seed(44111342)
// Load in the iris dataset
iris, err := base.ParseCSVToInstances("/home/kralli/go/src/github.com/sjwhitworth/golearn/examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
// Discretise the iris dataset with Chi-Merge
filt := filters.NewChiMergeFilter(iris, 0.999)
for _, a := range base.NonClassFloatAttributes(iris) {
filt.AddAttribute(a)
}
filt.Train()
irisf := base.NewLazilyFilteredInstances(iris, filt)
// Create a 60-40 training-test split
//testData
trainData, _ := base.InstancesTrainTestSplit(iris, 0.60)
findBestSplit(trainData)
//fmt.Println(trainData)
//fmt.Println(testData)
fmt.Println(tree)
fmt.Println(irisf)
}
开发者ID:krallistic,项目名称:go_stuff,代码行数:32,代码来源:cart_tree.go
示例7: TestRandomForest1
func TestRandomForest1(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
rand.Seed(time.Now().UnixNano())
insts := base.InstancesTrainTestSplit(inst, 0.6)
filt := filters.NewChiMergeFilter(inst, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(insts[1])
filt.Run(insts[0])
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
rf.Fit(insts[0])
fmt.Println(rf)
predictions := rf.Predict(insts[1])
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(insts[1], predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
fmt.Println(eval.GetSummary(confusionMat))
}
开发者ID:24hours,项目名称:golearn,代码行数:27,代码来源:bagging_test.go
示例8: TestLinearRegression
func TestLinearRegression(t *testing.T) {
lr := NewLinearRegression()
rawData, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
if err != nil {
t.Fatal(err)
}
trainData, testData := base.InstancesTrainTestSplit(rawData, 0.1)
err = lr.Fit(trainData)
if err != nil {
t.Fatal(err)
}
predictions, err := lr.Predict(testData)
if err != nil {
t.Fatal(err)
}
_, rows := predictions.Size()
for i := 0; i < rows; i++ {
fmt.Printf("Expected: %s || Predicted: %s\n", base.GetClass(testData, i), base.GetClass(predictions, i))
}
}
开发者ID:JacobXie,项目名称:golearn,代码行数:25,代码来源:linear_regression_test.go
示例9: TestLinearRegression
func TestLinearRegression(t *testing.T) {
Convey("Doing a linear regression", t, func() {
lr := NewLinearRegression()
Convey("With no training data", func() {
Convey("Predicting", func() {
testData, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
So(err, ShouldBeNil)
_, err = lr.Predict(testData)
Convey("Should result in a NoTrainingDataError", func() {
So(err, ShouldEqual, NoTrainingDataError)
})
})
})
Convey("With not enough training data", func() {
trainingDatum, err := base.ParseCSVToInstances("../examples/datasets/exam.csv", true)
So(err, ShouldBeNil)
Convey("Fitting", func() {
err = lr.Fit(trainingDatum)
Convey("Should result in a NotEnoughDataError", func() {
So(err, ShouldEqual, NotEnoughDataError)
})
})
})
Convey("With sufficient training data", func() {
instances, err := base.ParseCSVToInstances("../examples/datasets/exams.csv", true)
So(err, ShouldBeNil)
trainData, testData := base.InstancesTrainTestSplit(instances, 0.1)
Convey("Fitting and Predicting", func() {
err := lr.Fit(trainData)
So(err, ShouldBeNil)
predictions, err := lr.Predict(testData)
So(err, ShouldBeNil)
Convey("It makes reasonable predictions", func() {
_, rows := predictions.Size()
for i := 0; i < rows; i++ {
actualValue, _ := strconv.ParseFloat(base.GetClass(testData, i), 64)
expectedValue, _ := strconv.ParseFloat(base.GetClass(predictions, i), 64)
So(actualValue, ShouldAlmostEqual, expectedValue, actualValue*0.05)
}
})
})
})
})
}
开发者ID:CTLife,项目名称:golearn,代码行数:57,代码来源:linear_regression_test.go
示例10: Fit
// Fit builds the ID3 decision tree
func (t *ID3DecisionTree) Fit(on base.FixedDataGrid) error {
if t.PruneSplit > 0.001 {
trainData, testData := base.InstancesTrainTestSplit(on, t.PruneSplit)
t.Root = InferID3Tree(trainData, t.Rule)
t.Root.Prune(testData)
} else {
t.Root = InferID3Tree(on, t.Rule)
}
return nil
}
开发者ID:tanduong,项目名称:golearn,代码行数:11,代码来源:id3.go
示例11: TestRandomTreeClassificationWithoutDiscretisation
func TestRandomTreeClassificationWithoutDiscretisation(t *testing.T) {
Convey("Predictions on filtered data with a Random Tree", t, func() {
instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
trainData, testData := base.InstancesTrainTestSplit(instances, 0.6)
verifyTreeClassification(trainData, testData)
})
}
开发者ID:CTLife,项目名称:golearn,代码行数:10,代码来源:tree_test.go
示例12: Fit
// Fit builds the ID3 decision tree
func (t *ID3DecisionTree) Fit(on *base.Instances) {
rule := new(InformationGainRuleGenerator)
if t.PruneSplit > 0.001 {
insts := base.InstancesTrainTestSplit(on, t.PruneSplit)
t.Root = InferID3Tree(insts[0], rule)
t.Root.Prune(insts[1])
} else {
t.Root = InferID3Tree(on, rule)
}
}
开发者ID:24hours,项目名称:golearn,代码行数:11,代码来源:id3.go
示例13: Fit
// Fit builds the ID3 decision tree
func (t *ID3DecisionTree) Fit(on base.FixedDataGrid) {
rule := new(InformationGainRuleGenerator)
if t.PruneSplit > 0.001 {
trainData, testData := base.InstancesTrainTestSplit(on, t.PruneSplit)
t.Root = InferID3Tree(trainData, rule)
t.Root.Prune(testData)
} else {
t.Root = InferID3Tree(on, rule)
}
}
开发者ID:JacobXie,项目名称:golearn,代码行数:11,代码来源:id3.go
示例14: BenchmarkFit
func BenchmarkFit(b *testing.B) {
a := NewAveragePerceptron(10, 1.2, 0.5, 0.3)
absPath, _ := filepath.Abs("../examples/datasets/house-votes-84.csv")
rawData, _ := base.ParseCSVToInstances(absPath, true)
trainData, _ := base.InstancesTrainTestSplit(rawData, 0.5)
b.ResetTimer()
for i := 0; i < b.N; i++ {
a.Fit(trainData)
}
}
开发者ID:CTLife,项目名称:golearn,代码行数:11,代码来源:average_test.go
示例15: main
func main() {
data, err := base.ParseCSVToInstances("iris_headers.csv", true)
if err != nil {
panic(err)
}
cls := knn.NewKnnClassifier("euclidean", 2)
trainData, testData := base.InstancesTrainTestSplit(data, 0.8)
cls.Fit(trainData)
predictions := cls.Predict(testData)
fmt.Println(predictions)
confusionMat := evaluation.GetConfusionMatrix(testData, predictions)
fmt.Println(evaluation.GetSummary(confusionMat))
}
开发者ID:vkarthi46,项目名称:ml-algorithms-simple,代码行数:17,代码来源:golearn_sample.go
示例16: TestRandomTreeClassificationAfterDiscretisation
func TestRandomTreeClassificationAfterDiscretisation(t *testing.T) {
Convey("Predictions on filtered data with a Random Tree", t, func() {
instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
trainData, testData := base.InstancesTrainTestSplit(instances, 0.6)
filter := filters.NewChiMergeFilter(instances, 0.9)
for _, a := range base.NonClassFloatAttributes(instances) {
filter.AddAttribute(a)
}
filter.Train()
filteredTrainData := base.NewLazilyFilteredInstances(trainData, filter)
filteredTestData := base.NewLazilyFilteredInstances(testData, filter)
verifyTreeClassification(filteredTrainData, filteredTestData)
})
}
开发者ID:CTLife,项目名称:golearn,代码行数:17,代码来源:tree_test.go
示例17: TestMultiSVMUnweighted
func TestMultiSVMUnweighted(t *testing.T) {
Convey("Loading data...", t, func() {
inst, err := base.ParseCSVToInstances("../examples/datasets/articles.csv", false)
So(err, ShouldBeNil)
X, Y := base.InstancesTrainTestSplit(inst, 0.4)
m := NewMultiLinearSVC("l1", "l2", true, 1.0, 1e-4, nil)
m.Fit(X)
Convey("Predictions should work...", func() {
predictions, err := m.Predict(Y)
cf, err := evaluation.GetConfusionMatrix(Y, predictions)
So(err, ShouldEqual, nil)
So(evaluation.GetAccuracy(cf), ShouldBeGreaterThan, 0.70)
})
})
}
开发者ID:CTLife,项目名称:golearn,代码行数:17,代码来源:multisvc_test.go
示例18: NewTestTrial
func NewTestTrial(filename string, split float64) bool {
cls := knn.NewKnnClassifier("euclidean", 2)
data := CSVtoKNNData(filename)
train, test := base.InstancesTrainTestSplit(data, split)
cls.Fit(train)
//Calculates the Euclidean distance and returns the most popular label
predictions := cls.Predict(test)
fmt.Println(predictions)
confusionMat, err := evaluation.GetConfusionMatrix(test, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(confusionMat))
return true
}
开发者ID:postfix,项目名称:education,代码行数:18,代码来源:knn.go
示例19: TestRandomForest
func TestRandomForest(t *testing.T) {
Convey("Given a valid CSV file", t, func() {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
Convey("When Chi-Merge filtering the data", func() {
filt := filters.NewChiMergeFilter(inst, 0.90)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
Convey("Splitting the data into test and training sets", func() {
trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)
Convey("Fitting and predicting with a Random Forest", func() {
rf := NewRandomForest(10, 3)
err = rf.Fit(trainData)
So(err, ShouldBeNil)
predictions, err := rf.Predict(testData)
So(err, ShouldBeNil)
confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMat), ShouldBeGreaterThan, 0.35)
})
})
})
})
Convey("Fitting with a Random Forest with too many features compared to the data", func() {
rf := NewRandomForest(10, len(base.NonClassAttributes(inst))+1)
err = rf.Fit(inst)
Convey("Should return an error", func() {
So(err, ShouldNotBeNil)
})
})
})
}
开发者ID:CTLife,项目名称:golearn,代码行数:44,代码来源:randomforest_test.go
示例20: TestRandomForest1
func TestRandomForest1(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
trainData, testData := base.InstancesTrainTestSplit(inst, 0.60)
filt := filters.NewChiMergeFilter(trainData, 0.90)
filt.AddAllNumericAttributes()
filt.Build()
filt.Run(testData)
filt.Run(trainData)
rf := NewRandomForest(10, 3)
rf.Fit(trainData)
predictions := rf.Predict(testData)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testData, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetSummary(confusionMat))
}
开发者ID:hsinhoyeh,项目名称:golearn,代码行数:19,代码来源:randomforest_test.go
注:本文中的github.com/sjwhitworth/golearn/base.InstancesTrainTestSplit函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论