本文整理汇总了Golang中github.com/sjwhitworth/golearn/base.NewLazilyFilteredInstances函数的典型用法代码示例。如果您正苦于以下问题:Golang NewLazilyFilteredInstances函数的具体用法?Golang NewLazilyFilteredInstances怎么用?Golang NewLazilyFilteredInstances使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了NewLazilyFilteredInstances函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Golang代码示例。
示例1: TestRandomForest1
func TestRandomForest1(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
rand.Seed(time.Now().UnixNano())
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
filt := filters.NewChiMergeFilter(inst, 0.90)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
trainDataf := base.NewLazilyFilteredInstances(trainData, filt)
testDataf := base.NewLazilyFilteredInstances(testData, filt)
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
rf.Fit(trainDataf)
fmt.Println(rf)
predictions := rf.Predict(testDataf)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testDataf, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetMacroPrecision(confusionMat))
fmt.Println(eval.GetMacroRecall(confusionMat))
fmt.Println(eval.GetSummary(confusionMat))
}
开发者ID:Gudym,项目名称:golearn,代码行数:29,代码来源:bagging_test.go
示例2: TestChiMergeFilter
func TestChiMergeFilter(t *testing.T) {
Convey("Chi-Merge Filter", t, func() {
// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
// Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
Convey("Create and train the filter", func() {
filter := NewChiMergeFilter(instances, 0.90)
filter.AddAttribute(instances.AllAttributes()[0])
filter.AddAttribute(instances.AllAttributes()[1])
filter.Train()
Convey("Filter the dataset", func() {
filteredInstances := base.NewLazilyFilteredInstances(instances, filter)
classAttributes := filteredInstances.AllClassAttributes()
Convey("There should only be one class attribute", func() {
So(len(classAttributes), ShouldEqual, 1)
})
expectedClassAttribute := "Species"
Convey(fmt.Sprintf("The class attribute should be %s", expectedClassAttribute), func() {
So(classAttributes[0].GetName(), ShouldEqual, expectedClassAttribute)
})
})
})
})
}
开发者ID:CTLife,项目名称:golearn,代码行数:31,代码来源:chimerge_test.go
示例3: main
func main() {
var tree base.Classifier
rand.Seed(44111342)
// Load in the iris dataset
iris, err := base.ParseCSVToInstances("/home/kralli/go/src/github.com/sjwhitworth/golearn/examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
// Discretise the iris dataset with Chi-Merge
filt := filters.NewChiMergeFilter(iris, 0.999)
for _, a := range base.NonClassFloatAttributes(iris) {
filt.AddAttribute(a)
}
filt.Train()
irisf := base.NewLazilyFilteredInstances(iris, filt)
// Create a 60-40 training-test split
//testData
trainData, _ := base.InstancesTrainTestSplit(iris, 0.60)
findBestSplit(trainData)
//fmt.Println(trainData)
//fmt.Println(testData)
fmt.Println(tree)
fmt.Println(irisf)
}
开发者ID:krallistic,项目名称:go_stuff,代码行数:32,代码来源:cart_tree.go
示例4: Predict
// Predict issues predictions. Each class-specific classifier is expected
// to output a value between 0 (indicating that a given instance is not
// a given class) and 1 (indicating that the given instance is definitely
// that class). For each instance, the class with the highest value is chosen.
// The result is undefined if several underlying models output the same value.
func (m *OneVsAllModel) Predict(what base.FixedDataGrid) (base.FixedDataGrid, error) {
ret := base.GeneratePredictionVector(what)
vecs := make([]base.FixedDataGrid, m.maxClassVal+1)
specs := make([]base.AttributeSpec, m.maxClassVal+1)
for i := uint64(0); i <= m.maxClassVal; i++ {
f := m.filters[i]
c := base.NewLazilyFilteredInstances(what, f)
p, err := m.classifiers[i].Predict(c)
if err != nil {
return nil, err
}
vecs[i] = p
specs[i] = base.ResolveAttributes(p, p.AllClassAttributes())[0]
}
_, rows := ret.Size()
spec := base.ResolveAttributes(ret, ret.AllClassAttributes())[0]
for i := 0; i < rows; i++ {
class := uint64(0)
best := 0.0
for j := uint64(0); j <= m.maxClassVal; j++ {
val := base.UnpackBytesToFloat(vecs[j].Get(specs[j], i))
if val > best {
class = j
best = val
}
}
ret.Set(spec, i, base.PackU64ToBytes(class))
}
return ret, nil
}
开发者ID:CTLife,项目名称:golearn,代码行数:35,代码来源:one_v_all.go
示例5: TestBinaryFilterClassPreservation
func TestBinaryFilterClassPreservation(t *testing.T) {
Convey("Given a contrived dataset...", t, func() {
// Read the contrived dataset
inst, err := base.ParseCSVToInstances("./binary_test.csv", true)
So(err, ShouldEqual, nil)
// Add all Attributes to the filter
bFilt := NewBinaryConvertFilter()
bAttrs := inst.AllAttributes()
for _, a := range bAttrs {
bFilt.AddAttribute(a)
}
bFilt.Train()
// Construct a LazilyFilteredInstances to handle it
instF := base.NewLazilyFilteredInstances(inst, bFilt)
Convey("All the expected class Attributes should be present if discretised...", func() {
attrMap := make(map[string]bool)
attrMap["arbitraryClass_hi"] = false
attrMap["arbitraryClass_there"] = false
attrMap["arbitraryClass_world"] = false
for _, a := range instF.AllClassAttributes() {
attrMap[a.GetName()] = true
}
So(attrMap["arbitraryClass_hi"], ShouldEqual, true)
So(attrMap["arbitraryClass_there"], ShouldEqual, true)
So(attrMap["arbitraryClass_world"], ShouldEqual, true)
})
})
}
开发者ID:CTLife,项目名称:golearn,代码行数:33,代码来源:binary_test.go
示例6: BenchmarkBaggingRandomForestPredict
func BenchmarkBaggingRandomForestPredict(t *testing.B) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
t.Fatal("Unable to parse CSV to instances: %s", err.Error())
}
rand.Seed(time.Now().UnixNano())
filt := filters.NewChiMergeFilter(inst, 0.90)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
rf.Fit(instf)
t.ResetTimer()
for i := 0; i < 20; i++ {
rf.Predict(instf)
}
}
开发者ID:GeekFreaker,项目名称:golearn,代码行数:25,代码来源:bagging_test.go
示例7: TestRandomTreeClassificationAfterDiscretisation
func TestRandomTreeClassificationAfterDiscretisation(t *testing.T) {
Convey("Predictions on filtered data with a Random Tree", t, func() {
instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
trainData, testData := base.InstancesTrainTestSplit(instances, 0.6)
filter := filters.NewChiMergeFilter(instances, 0.9)
for _, a := range base.NonClassFloatAttributes(instances) {
filter.AddAttribute(a)
}
filter.Train()
filteredTrainData := base.NewLazilyFilteredInstances(trainData, filter)
filteredTestData := base.NewLazilyFilteredInstances(testData, filter)
verifyTreeClassification(filteredTrainData, filteredTestData)
})
}
开发者ID:CTLife,项目名称:golearn,代码行数:17,代码来源:tree_test.go
示例8: convertToFloatInsts
func (m *MultiLayerNet) convertToFloatInsts(X base.FixedDataGrid) base.FixedDataGrid {
// Make sure everything's a FloatAttribute
fFilt := filters.NewFloatConvertFilter()
for _, a := range X.AllAttributes() {
fFilt.AddAttribute(a)
}
fFilt.Train()
insts := base.NewLazilyFilteredInstances(X, fFilt)
return insts
}
开发者ID:nickpoorman,项目名称:golearn,代码行数:11,代码来源:layered.go
示例9: convertToBinary
func convertToBinary(src base.FixedDataGrid) base.FixedDataGrid {
// Convert to binary
b := filters.NewBinaryConvertFilter()
attrs := base.NonClassAttributes(src)
for _, a := range attrs {
b.AddAttribute(a)
}
b.Train()
ret := base.NewLazilyFilteredInstances(src, b)
return ret
}
开发者ID:CTLife,项目名称:golearn,代码行数:11,代码来源:bernoulli_nb_test.go
示例10: TestBaggedModelRandomForest
func TestBaggedModelRandomForest(t *testing.T) {
Convey("Given data", t, func() {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
Convey("Splitting the data into training and test data", func() {
trainData, testData := base.InstancesTrainTestSplit(inst, 0.6)
Convey("Filtering the split datasets", func() {
rand.Seed(time.Now().UnixNano())
filt := filters.NewChiMergeFilter(inst, 0.90)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
trainDataf := base.NewLazilyFilteredInstances(trainData, filt)
testDataf := base.NewLazilyFilteredInstances(testData, filt)
Convey("Fitting and Predicting with a Bagged Model of 10 Random Trees", func() {
rf := new(BaggedModel)
for i := 0; i < 10; i++ {
rf.AddModel(trees.NewRandomTree(2))
}
rf.Fit(trainDataf)
predictions := rf.Predict(testDataf)
confusionMat, err := evaluation.GetConfusionMatrix(testDataf, predictions)
So(err, ShouldBeNil)
Convey("Predictions are somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMat), ShouldBeGreaterThan, 0.5)
})
})
})
})
})
}
开发者ID:GeekFreaker,项目名称:golearn,代码行数:38,代码来源:bagging_test.go
示例11: TestBinning
func TestBinning(t *testing.T) {
Convey("Given some data and a reference", t, func() {
// Read the data
inst1, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
inst2, err := base.ParseCSVToInstances("../examples/datasets/iris_binned.csv", true)
if err != nil {
panic(err)
}
//
// Construct the binning filter
binAttr := inst1.AllAttributes()[0]
filt := NewBinningFilter(inst1, 10)
filt.AddAttribute(binAttr)
filt.Train()
inst1f := base.NewLazilyFilteredInstances(inst1, filt)
// Retrieve the categorical version of the original Attribute
var cAttr base.Attribute
for _, a := range inst1f.AllAttributes() {
if a.GetName() == binAttr.GetName() {
cAttr = a
}
}
cAttrSpec, err := inst1f.GetAttribute(cAttr)
So(err, ShouldEqual, nil)
binAttrSpec, err := inst2.GetAttribute(binAttr)
So(err, ShouldEqual, nil)
//
// Create the LazilyFilteredInstances
// and check the values
Convey("Discretized version should match reference", func() {
_, rows := inst1.Size()
for i := 0; i < rows; i++ {
val1 := inst1f.Get(cAttrSpec, i)
val2 := inst2.Get(binAttrSpec, i)
val1s := cAttr.GetStringFromSysVal(val1)
val2s := binAttr.GetStringFromSysVal(val2)
So(val1s, ShouldEqual, val2s)
}
})
})
}
开发者ID:JacobXie,项目名称:golearn,代码行数:48,代码来源:binning_test.go
示例12: TestRandomForest
func TestRandomForest(t *testing.T) {
Convey("Given a valid CSV file", t, func() {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
Convey("When Chi-Merge filtering the data", func() {
filt := filters.NewChiMergeFilter(inst, 0.90)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
Convey("Splitting the data into test and training sets", func() {
trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)
Convey("Fitting and predicting with a Random Forest", func() {
rf := NewRandomForest(10, 3)
err = rf.Fit(trainData)
So(err, ShouldBeNil)
predictions, err := rf.Predict(testData)
So(err, ShouldBeNil)
confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMat), ShouldBeGreaterThan, 0.35)
})
})
})
})
Convey("Fitting with a Random Forest with too many features compared to the data", func() {
rf := NewRandomForest(10, len(base.NonClassAttributes(inst))+1)
err = rf.Fit(inst)
Convey("Should return an error", func() {
So(err, ShouldNotBeNil)
})
})
})
}
开发者ID:CTLife,项目名称:golearn,代码行数:44,代码来源:randomforest_test.go
示例13: Fit
// Fit creates n filtered datasets (where n is the number of values
// a CategoricalAttribute can take) and uses them to train the
// underlying classifiers.
func (m *OneVsAllModel) Fit(using base.FixedDataGrid) {
var classAttr *base.CategoricalAttribute
// Do some validation
classAttrs := using.AllClassAttributes()
for _, a := range classAttrs {
if c, ok := a.(*base.CategoricalAttribute); !ok {
panic("Unsupported ClassAttribute type")
} else {
classAttr = c
}
}
attrs := m.generateAttributes(using)
// Find the highest stored value
val := uint64(0)
classVals := classAttr.GetValues()
for _, s := range classVals {
cur := base.UnpackBytesToU64(classAttr.GetSysValFromString(s))
if cur > val {
val = cur
}
}
if val == 0 {
panic("Must have more than one class!")
}
m.maxClassVal = val
// Create individual filtered instances for training
filters := make([]*oneVsAllFilter, val+1)
classifiers := make([]base.Classifier, val+1)
for i := uint64(0); i <= val; i++ {
f := &oneVsAllFilter{
attrs,
classAttr,
i,
}
filters[i] = f
classifiers[i] = m.NewClassifierFunction(classVals[int(i)])
classifiers[i].Fit(base.NewLazilyFilteredInstances(using, f))
}
m.filters = filters
m.classifiers = classifiers
}
开发者ID:CTLife,项目名称:golearn,代码行数:47,代码来源:one_v_all.go
示例14: TestChiMerge4
func TestChiMerge4(testEnv *testing.T) {
// See http://sci2s.ugr.es/keel/pdf/algorithm/congreso/1992-Kerber-ChimErge-AAAI92.pdf
// Randy Kerber, ChiMerge: Discretisation of Numeric Attributes, 1992
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
filt := NewChiMergeFilter(inst, 0.90)
filt.AddAttribute(inst.AllAttributes()[0])
filt.AddAttribute(inst.AllAttributes()[1])
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
fmt.Println(instf)
fmt.Println(instf.String())
clsAttrs := instf.AllClassAttributes()
if len(clsAttrs) != 1 {
panic(fmt.Sprintf("%d != %d", len(clsAttrs), 1))
}
if clsAttrs[0].GetName() != "Species" {
panic("Class Attribute wrong!")
}
}
开发者ID:Gudym,项目名称:golearn,代码行数:23,代码来源:chimerge_test.go
示例15: TestRandomForest1
func TestRandomForest1(testEnv *testing.T) {
inst, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
filt := filters.NewChiMergeFilter(inst, 0.90)
for _, a := range base.NonClassFloatAttributes(inst) {
filt.AddAttribute(a)
}
filt.Train()
instf := base.NewLazilyFilteredInstances(inst, filt)
trainData, testData := base.InstancesTrainTestSplit(instf, 0.60)
rf := NewRandomForest(10, 3)
rf.Fit(trainData)
predictions := rf.Predict(testData)
fmt.Println(predictions)
confusionMat := eval.GetConfusionMatrix(testData, predictions)
fmt.Println(confusionMat)
fmt.Println(eval.GetSummary(confusionMat))
}
开发者ID:JacobXie,项目名称:golearn,代码行数:23,代码来源:randomforest_test.go
示例16: TestRandomTreeClassification
func TestRandomTreeClassification(t *testing.T) {
Convey("Predictions on filtered data with a Random Tree", t, func() {
instances, err := base.ParseCSVToInstances("../examples/datasets/iris_headers.csv", true)
So(err, ShouldBeNil)
trainData, testData := base.InstancesTrainTestSplit(instances, 0.6)
filter := filters.NewChiMergeFilter(instances, 0.9)
for _, a := range base.NonClassFloatAttributes(instances) {
filter.AddAttribute(a)
}
filter.Train()
filteredTrainData := base.NewLazilyFilteredInstances(trainData, filter)
filteredTestData := base.NewLazilyFilteredInstances(testData, filter)
Convey("Using InferID3Tree to create the tree and do the fitting", func() {
Convey("Using a RandomTreeRule", func() {
randomTreeRuleGenerator := new(RandomTreeRuleGenerator)
randomTreeRuleGenerator.Attributes = 2
root := InferID3Tree(filteredTrainData, randomTreeRuleGenerator)
Convey("Predicting with the tree", func() {
predictions, err := root.Predict(filteredTestData)
So(err, ShouldBeNil)
confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
})
})
})
Convey("Using a InformationGainRule", func() {
informationGainRuleGenerator := new(InformationGainRuleGenerator)
root := InferID3Tree(filteredTrainData, informationGainRuleGenerator)
Convey("Predicting with the tree", func() {
predictions, err := root.Predict(filteredTestData)
So(err, ShouldBeNil)
confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
})
})
})
})
Convey("Using NewRandomTree to create the tree", func() {
root := NewRandomTree(2)
Convey("Fitting with the tree", func() {
err = root.Fit(filteredTrainData)
So(err, ShouldBeNil)
Convey("Predicting with the tree, *without* pruning first", func() {
predictions, err := root.Predict(filteredTestData)
So(err, ShouldBeNil)
confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.5)
})
})
Convey("Predicting with the tree, pruning first", func() {
root.Prune(filteredTestData)
predictions, err := root.Predict(filteredTestData)
So(err, ShouldBeNil)
confusionMatrix, err := evaluation.GetConfusionMatrix(filteredTestData, predictions)
So(err, ShouldBeNil)
Convey("Predictions should be somewhat accurate", func() {
So(evaluation.GetAccuracy(confusionMatrix), ShouldBeGreaterThan, 0.4)
})
})
})
})
})
}
开发者ID:GeekFreaker,项目名称:golearn,代码行数:88,代码来源:tree_test.go
示例17: TestFloatFilter
func TestFloatFilter(t *testing.T) {
Convey("Given a contrived dataset...", t, func() {
// Read the contrived dataset
inst, err := base.ParseCSVToInstances("./binary_test.csv", true)
So(err, ShouldEqual, nil)
// Add Attributes to the filter
bFilt := NewFloatConvertFilter()
bAttrs := base.NonClassAttributes(inst)
for _, a := range bAttrs {
bFilt.AddAttribute(a)
}
bFilt.Train()
// Construct a LazilyFilteredInstances to handle it
instF := base.NewLazilyFilteredInstances(inst, bFilt)
Convey("All the non-class Attributes should be floats...", func() {
// Check that all the Attributes are the right type
for _, a := range base.NonClassAttributes(instF) {
_, ok := a.(*base.FloatAttribute)
So(ok, ShouldEqual, true)
}
})
// Check that all the class Attributes made it
Convey("All the class Attributes should have survived...", func() {
origClassAttrs := inst.AllClassAttributes()
newClassAttrs := instF.AllClassAttributes()
intersectClassAttrs := base.AttributeIntersect(origClassAttrs, newClassAttrs)
So(len(intersectClassAttrs), ShouldEqual, len(origClassAttrs))
})
// Check that the Attributes have the right names
Convey("Attribute names should be correct...", func() {
origNames := []string{"floatAttr", "shouldBe1Binary",
"shouldBe3Binary_stoicism", "shouldBe3Binary_heroism",
"shouldBe3Binary_romanticism", "arbitraryClass"}
origMap := make(map[string]bool)
for _, a := range origNames {
origMap[a] = false
}
for _, a := range instF.AllAttributes() {
name := a.GetName()
_, ok := origMap[name]
So(ok, ShouldBeTrue)
origMap[name] = true
}
for a := range origMap {
So(origMap[a], ShouldEqual, true)
}
})
Convey("All Attributes should be the correct type...", func() {
for _, a := range instF.AllAttributes() {
if a.GetName() == "arbitraryClass" {
_, ok := a.(*base.CategoricalAttribute)
So(ok, ShouldEqual, true)
} else {
_, ok := a.(*base.FloatAttribute)
So(ok, ShouldEqual, true)
}
}
})
// Check that the Attributes have been discretised correctly
Convey("FloatConversion should have worked", func() {
// Build Attribute map
attrMap := make(map[string]base.Attribute)
for _, a := range instF.AllAttributes() {
attrMap[a.GetName()] = a
}
// For each attribute
for name := range attrMap {
So(name, ShouldBeIn, []string{
"floatAttr",
"shouldBe1Binary",
"shouldBe3Binary_stoicism",
"shouldBe3Binary_heroism",
"shouldBe3Binary_romanticism",
"arbitraryClass",
})
attr := attrMap[name]
as, err := instF.GetAttribute(attr)
So(err, ShouldEqual, nil)
if name == "floatAttr" {
So(instF.Get(as, 0), ShouldResemble, base.PackFloatToBytes(1.0))
So(instF.Get(as, 1), ShouldResemble, base.PackFloatToBytes(1.0))
So(instF.Get(as, 2), ShouldResemble, base.PackFloatToBytes(0.0))
} else if name == "shouldBe1Binary" {
So(instF.Get(as, 0), ShouldResemble, base.PackFloatToBytes(0.0))
So(instF.Get(as, 1), ShouldResemble, base.PackFloatToBytes(1.0))
So(instF.Get(as, 2), ShouldResemble, base.PackFloatToBytes(1.0))
} else if name == "shouldBe3Binary_stoicism" {
So(instF.Get(as, 0), ShouldResemble, base.PackFloatToBytes(1.0))
So(instF.Get(as, 1), ShouldResemble, base.PackFloatToBytes(0.0))
//.........这里部分代码省略.........
开发者ID:CTLife,项目名称:golearn,代码行数:101,代码来源:float_test.go
示例18: TestBinaryFilter
func TestBinaryFilter(t *testing.T) {
Convey("Given a contrived dataset...", t, func() {
// Read the contrived dataset
inst, err := base.ParseCSVToInstances("./binary_test.csv", true)
So(err, ShouldEqual, nil)
// Add Attributes to the filter
bFilt := NewBinaryConvertFilter()
bAttrs := base.NonClassAttributes(inst)
for _, a := range bAttrs {
bFilt.AddAttribute(a)
}
bFilt.Train()
// Construct a LazilyFilteredInstances to handle it
instF := base.NewLazilyFilteredInstances(inst, bFilt)
Convey("All the non-class Attributes should be binary...", func() {
// Check that all the Attributes are the right type
for _, a := range base.NonClassAttributes(instF) {
_, ok := a.(*base.BinaryAttribute)
So(ok, ShouldEqual, true)
}
})
// Check that all the class Attributes made it
Convey("All the class Attributes should have survived...", func() {
origClassAttrs := inst.AllClassAttributes()
newClassAttrs := instF.AllClassAttributes()
intersectClassAttrs := base.AttributeIntersect(origClassAttrs, newClassAttrs)
So(len(intersectClassAttrs), ShouldEqual, len(origClassAttrs))
})
// Check that the Attributes have the right names
Convey("Attribute names should be correct...", func() {
origNames := []string{"floatAttr", "shouldBe1Binary",
"shouldBe3Binary_stoicism", "shouldBe3Binary_heroism",
"shouldBe3Binary_romanticism", "arbitraryClass"}
origMap := make(map[string]bool)
for _, a := range origNames {
origMap[a] = false
}
for _, a := range instF.AllAttributes() {
name := a.GetName()
_, ok := origMap[name]
if !ok {
t.Error(fmt.Sprintf("Weird: %s", name))
}
origMap[name] = true
}
for a := range origMap {
So(origMap[a], ShouldEqual, true)
}
})
// Check that the Attributes have been discretised correctly
Convey("Discretisation should have worked", func() {
// Build Attribute map
attrMap := make(map[string]base.Attribute)
for _, a := range instF.AllAttributes() {
attrMap[a.GetName()] = a
}
// For each attribute
for name := range attrMap {
attr := attrMap[name]
// Retrieve AttributeSpec
as, err := instF.GetAttribute(attr)
So(err, ShouldEqual, nil)
if name == "floatAttr" {
So(instF.Get(as, 0), ShouldResemble, []byte{1})
So(instF.Get(as, 1), ShouldResemble, []byte{1})
So(instF.Get(as, 2), ShouldResemble, []byte{0})
} else if name == "shouldBe1Binary" {
So(instF.Get(as, 0), ShouldResemble, []byte{0})
So(instF.Get(as, 1), ShouldResemble, []byte{1})
So(instF.Get(as, 2), ShouldResemble, []byte{1})
} else if name == "shouldBe3Binary_stoicism" {
So(instF.Get(as, 0), ShouldResemble, []byte{1})
So(instF.Get(as, 1), ShouldResemble, []byte{0})
So(instF.Get(as, 2), ShouldResemble, []byte{0})
} else if name == "shouldBe3Binary_heroism" {
So(instF.Get(as, 0), ShouldResemble, []byte{0})
So(instF.Get(as, 1), ShouldResemble, []byte{1})
So(instF.Get(as, 2), ShouldResemble, []byte{0})
} else if name == "shouldBe3Binary_romanticism" {
So(instF.Get(as, 0), ShouldResemble, []byte{0})
So(instF.Get(as, 1), ShouldResemble, []byte{0})
So(instF.Get(as, 2), ShouldResemble, []byte{1})
} else if name == "arbitraryClass" {
} else {
t.Error("Shouldn't have %s", name)
}
}
})
})
}
开发者ID:JacobXie,项目名称:golearn,代码行数:100,代码来源:binary_test.go
示例19: main
func main() {
var tree base.Classifier
rand.Seed(time.Now().UTC().UnixNano())
// Load in the iris dataset
iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
// Discretise the iris dataset with Chi-Merge
filt := filters.NewChiMergeFilter(iris, 0.99)
for _, a := range base.NonClassFloatAttributes(iris) {
filt.AddAttribute(a)
}
filt.Train()
irisf := base.NewLazilyFilteredInstances(iris, filt)
// Create a 60-40 training-test split
trainData, testData := base.InstancesTrainTestSplit(irisf, 0.60)
//
// First up, use ID3
//
tree = trees.NewID3DecisionTree(0.6)
// (Parameter controls train-prune split.)
// Train the ID3 tree
tree.Fit(trainData)
// Generate predictions
predictions := tree.Predict(testData)
// Evaluate
fmt.Println("ID3 Performance")
cf := eval.GetConfusionMatrix(testData, predictions)
fmt.Println(eval.GetSummary(cf))
//
// Next up, Random Trees
//
// Consider two randomly-chosen attributes
tree = trees.NewRandomTree(2)
tree.Fit(testData)
predictions = tree.Predict(testData)
fmt.Println("RandomTree Performance")
cf = eval.GetConfusionMatrix(testData, predictions)
fmt.Println(eval.GetSummary(cf))
//
// Finally, Random Forests
//
tree = ensemble.NewRandomForest(100, 3)
tree.Fit(trainData)
predictions = tree.Predict(testData)
fmt.Println("RandomForest Performance")
cf = eval.GetConfusionMatrix(testData, predictions)
fmt.Println(eval.GetSummary(cf))
}
开发者ID:JacobXie,项目名称:golearn,代码行数:62,代码来源:trees.go
示例20: main
func main() {
var tree base.Classifier
rand.Seed(44111342)
// Load in the iris dataset
iris, err := base.ParseCSVToInstances("../datasets/iris_headers.csv", true)
if err != nil {
panic(err)
}
// Discretise the iris dataset with Chi-Merge
filt := filters.NewChiMergeFilter(iris, 0.999)
for _, a := range base.NonClassFloatAttributes(iris) {
filt.AddAttribute(a)
}
filt.Train()
irisf := base.NewLazilyFilteredInstances(iris, filt)
// Create a 60-40 training-test split
trainData, testData := base.InstancesTrainTestSplit(irisf, 0.60)
//
// First up, use ID3
//
tree = trees.NewID3DecisionTree(0.6)
// (Parameter controls train-prune split.)
// Train the ID3 tree
err = tree.Fit(trainData)
if err != nil {
panic(err)
}
// Generate predictions
predictions, err := tree.Predict(testData)
if err != nil {
panic(err)
}
// Evaluate
fmt.Println("ID3 Performance (information gain)")
cf, err := evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(cf))
tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.InformationGainRatioRuleGenerator))
// (Parameter controls train-prune split.)
// Train the ID3 tree
err = tree.Fit(trainData)
if err != nil {
panic(err)
}
// Generate predictions
predictions, err = tree.Predict(testData)
if err != nil {
panic(err)
}
// Evaluate
fmt.Println("ID3 Performance (information gain ratio)")
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(cf))
tree = trees.NewID3DecisionTreeFromRule(0.6, new(trees.GiniCoefficientRuleGenerator))
// (Parameter controls train-prune split.)
// Train the ID3 tree
err = tree.Fit(trainData)
if err != nil {
panic(err)
}
// Generate predictions
predictions, err = tree.Predict(testData)
if err != nil {
panic(err)
}
// Evaluate
fmt.Println("ID3 Performance (gini index generator)")
cf, err = evaluation.GetConfusionMatrix(testData, predictions)
if err != nil {
panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
}
fmt.Println(evaluation.GetSummary(cf))
//
// Next up, Random Trees
//
// Consider two randomly-chosen attributes
tree = trees.NewRandomTree(2)
//.........这里部分代码省略.........
开发者ID:CTLife,项目名称:golearn,代码行数:101,代码来源:trees.go
注:本文中的github.com/sjwhitworth/golearn/base.NewLazilyFilteredInstances函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论