本文整理汇总了Golang中github.com/huichen/mlf/data.Dataset类的典型用法代码示例。如果您正苦于以下问题:Golang Dataset类的具体用法?Golang Dataset怎么用?Golang Dataset使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
在下文中一共展示了Dataset类的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Golang代码示例。
示例1: SaveLibSVMDataset
func SaveLibSVMDataset(path string, set data.Dataset) {
log.Print("保存数据集到libsvm格式文件", path)
f, err := os.Create(path)
defer f.Close()
if err != nil {
log.Fatalf("无法打开文件\"%v\",错误提示:%v\n", path, err)
}
w := bufio.NewWriter(f)
defer w.Flush()
iter := set.CreateIterator()
iter.Start()
for !iter.End() {
instance := iter.GetInstance()
if instance.Output.LabelString == "" {
fmt.Fprintf(w, "%d ", instance.Output.Label)
} else {
fmt.Fprintf(w, "%s ", instance.Output.LabelString)
}
for _, k := range instance.Features.Keys() {
// 跳过第0个特征,因为它始终是1
if k == 0 {
continue
}
if instance.Features.Get(k) != 0 {
// libsvm格式的特征从1开始
fmt.Fprintf(w, "%d:%s ", k, strconv.FormatFloat(instance.Features.Get(k), 'f', -1, 64))
}
}
fmt.Fprint(w, "\n")
iter.Next()
}
}
开发者ID:reginald1787,项目名称:mlf,代码行数:35,代码来源:libsvm_dataset_saver.go
示例2: feeder
func (rbm *RBM) feeder(set data.Dataset, ch chan *data.Instance) {
iter := set.CreateIterator()
iter.Start()
for it := 0; it < set.NumInstances(); it++ {
instance := iter.GetInstance()
ch <- instance
iter.Next()
}
}
开发者ID:sguzwf,项目名称:mlf,代码行数:9,代码来源:rbm.go
示例3: Evaluate
// 输出的度量名字为 "confusion:M/N" 其中M为真实标注,N为预测标注
func (e *ConfusionMatrixEvaluator) Evaluate(m supervised.Model, set data.Dataset) (result Evaluation) {
result.Metrics = make(map[string]float64)
iter := set.CreateIterator()
iter.Start()
for !iter.End() {
instance := iter.GetInstance()
out := m.Predict(instance)
name := fmt.Sprintf("confusion:%d/%d", instance.Output.Label, out.Label)
result.Metrics[name]++
iter.Next()
}
return
}
开发者ID:reginald1787,项目名称:mlf,代码行数:14,代码来源:confusion_matrix.go
示例4: Evaluate
func (e *PREvaluator) Evaluate(m supervised.Model, set data.Dataset) (result Evaluation) {
tp := 0 // true-positive
tn := 0 // true-negative
fp := 0 // false-positive
fn := 0 // false-negative
iter := set.CreateIterator()
iter.Start()
for !iter.End() {
instance := iter.GetInstance()
if instance.Output.Label > 2 {
log.Fatal("调用PREvaluator但不是二分类问题")
}
out := m.Predict(instance)
if out.Label == 0 {
if instance.Output.Label == 0 {
tn++
} else {
fn++
}
} else {
if instance.Output.Label == 0 {
fp++
} else {
tp++
}
}
iter.Next()
}
result.Metrics = make(map[string]float64)
result.Metrics["precision"] = float64(tp) / float64(tp+fp)
result.Metrics["recall"] = float64(tp) / float64(tp+fn)
result.Metrics["tp"] = float64(tp)
result.Metrics["fp"] = float64(fp)
result.Metrics["tn"] = float64(tn)
result.Metrics["fn"] = float64(fn)
result.Metrics["fscore"] =
2 * result.Metrics["precision"] * result.Metrics["recall"] / (result.Metrics["precision"] + result.Metrics["recall"])
return
}
开发者ID:reginald1787,项目名称:mlf,代码行数:43,代码来源:precision_recall.go
示例5: Evaluate
func (e *AccuracyEvaluator) Evaluate(m supervised.Model, set data.Dataset) (result Evaluation) {
correctPrediction := 0
totalPrediction := 0
iter := set.CreateIterator()
iter.Start()
for !iter.End() {
instance := iter.GetInstance()
out := m.Predict(instance)
if instance.Output.Label == out.Label {
correctPrediction++
}
totalPrediction++
iter.Next()
}
result.Metrics = make(map[string]float64)
result.Metrics["accuracy"] = float64(correctPrediction) / float64(totalPrediction)
return
}
开发者ID:hycxa,项目名称:mlf,代码行数:21,代码来源:accuracy.go
示例6: OptimizeWeights
func (opt *gdOptimizer) OptimizeWeights(
weights *util.Matrix, derivative_func ComputeInstanceDerivativeFunc, set data.Dataset) {
// 偏导数向量
derivative := weights.Populate()
// 学习率计算器
learningRate := NewLearningRate(opt.options)
// 优化循环
iterator := set.CreateIterator()
step := 0
var learning_rate float64
convergingSteps := 0
oldWeights := weights.Populate()
weightsDelta := weights.Populate()
instanceDerivative := weights.Populate()
log.Print("开始梯度递降优化")
for {
if opt.options.MaxIterations > 0 && step >= opt.options.MaxIterations {
break
}
step++
// 每次遍历样本前对偏导数向量清零
derivative.Clear()
// 遍历所有样本,计算偏导数向量并累加
iterator.Start()
instancesProcessed := 0
for !iterator.End() {
instance := iterator.GetInstance()
derivative_func(weights, instance, instanceDerivative)
derivative.Increment(instanceDerivative, 1.0/float64(set.NumInstances()))
iterator.Next()
instancesProcessed++
if opt.options.GDBatchSize > 0 && instancesProcessed >= opt.options.GDBatchSize {
// 添加正则化项
derivative.Increment(ComputeRegularization(weights, opt.options),
float64(instancesProcessed)/(float64(set.NumInstances())*float64(set.NumInstances())))
// 计算特征权重的增量
delta := opt.GetDeltaX(weights, derivative)
// 根据学习率更新权重
learning_rate = learningRate.ComputeLearningRate(delta)
weights.Increment(delta, learning_rate)
// 重置
derivative.Clear()
instancesProcessed = 0
}
}
if instancesProcessed > 0 {
// 处理剩余的样本
derivative.Increment(ComputeRegularization(weights, opt.options),
float64(instancesProcessed)/(float64(set.NumInstances())*float64(set.NumInstances())))
delta := opt.GetDeltaX(weights, derivative)
learning_rate = learningRate.ComputeLearningRate(delta)
weights.Increment(delta, learning_rate)
}
weightsDelta.WeightedSum(weights, oldWeights, 1, -1)
oldWeights.DeepCopy(weights)
weightsNorm := weights.Norm()
weightsDeltaNorm := weightsDelta.Norm()
log.Printf("#%d |w|=%1.3g |dw|/|w|=%1.3g lr=%1.3g", step, weightsNorm, weightsDeltaNorm/weightsNorm, learning_rate)
// 判断是否溢出
if math.IsNaN(weightsNorm) {
log.Fatal("优化失败:不收敛")
}
// 判断是否收敛
if weightsDelta.Norm()/weights.Norm() < opt.options.ConvergingDeltaWeight {
convergingSteps++
if convergingSteps > opt.options.ConvergingSteps {
log.Printf("收敛")
break
}
}
}
}
开发者ID:hycxa,项目名称:mlf,代码行数:84,代码来源:gd.go
示例7: Train
func (trainer *MaxEntClassifierTrainer) Train(set data.Dataset) Model {
// 检查训练数据是否是分类问题
if !set.GetOptions().IsSupervisedLearning {
log.Fatal("训练数据不是分类问题数据")
}
// 建立新的优化器
optimizer := optimizer.NewOptimizer(trainer.options.Optimizer)
// 建立特征权重向量
featureDimension := set.GetOptions().FeatureDimension
numLabels := set.GetOptions().NumLabels
var weights *util.Matrix
if set.GetOptions().FeatureIsSparse {
weights = util.NewSparseMatrix(numLabels)
} else {
weights = util.NewMatrix(numLabels, featureDimension)
}
// 得到优化的特征权重向量
optimizer.OptimizeWeights(weights, MaxEntComputeInstanceDerivative, set)
classifier := new(MaxEntClassifier)
classifier.Weights = weights
classifier.NumLabels = numLabels
classifier.FeatureDimension = featureDimension
classifier.FeatureDictionary = set.GetFeatureDictionary()
classifier.LabelDictionary = set.GetLabelDictionary()
return classifier
}
开发者ID:numb3r3,项目名称:mlf,代码行数:30,代码来源:maxent_classifier_trainer.go
示例8: Train
func (rbm *RBM) Train(set data.Dataset) {
featureDimension := set.GetOptions().FeatureDimension
visibleDim := featureDimension
hiddenDim := rbm.options.NumHiddenUnits + 1
log.Printf("#visible = %d, #hidden = %d", featureDimension-1, hiddenDim-1)
// 随机化 weights
rbm.lock.Lock()
rbm.lock.weights = util.NewMatrix(hiddenDim, visibleDim)
oldWeights := util.NewMatrix(hiddenDim, visibleDim)
batchDerivative := util.NewMatrix(hiddenDim, visibleDim)
for i := 0; i < hiddenDim; i++ {
for j := 0; j < visibleDim; j++ {
value := (rand.Float64()*2 - 1) * 0.01
rbm.lock.weights.Set(i, j, value)
}
}
rbm.lock.Unlock()
// 启动工作协程
ch := make(chan *data.Instance, rbm.options.Worker)
out := make(chan *util.Matrix, rbm.options.Worker)
for iWorker := 0; iWorker < rbm.options.Worker; iWorker++ {
go rbm.derivativeWorker(ch, out, visibleDim, hiddenDim)
}
iteration := 0
delta := 1.0
for (rbm.options.MaxIter == 0 || iteration < rbm.options.MaxIter) &&
(rbm.options.Delta == 0 || delta > rbm.options.Delta) {
iteration++
go rbm.feeder(set, ch)
iBatch := 0
batchDerivative.Clear()
numInstances := set.NumInstances()
for it := 0; it < numInstances; it++ {
// 乱序读入
derivative := <-out
batchDerivative.Increment(derivative, rbm.options.LearningRate)
iBatch++
if iBatch == rbm.options.BatchSize || it == numInstances-1 {
rbm.lock.Lock()
rbm.lock.weights.Increment(batchDerivative, 1.0)
rbm.lock.Unlock()
iBatch = 0
batchDerivative.Clear()
}
}
// 统计delta和|weight|
rbm.lock.RLock()
weightsNorm := rbm.lock.weights.Norm()
batchDerivative.DeepCopy(rbm.lock.weights)
batchDerivative.Increment(oldWeights, -1.0)
derivativeNorm := batchDerivative.Norm()
delta = derivativeNorm / weightsNorm
log.Printf("iter = %d, delta = %f, |weight| = %f",
iteration, delta, weightsNorm)
oldWeights.DeepCopy(rbm.lock.weights)
rbm.lock.RUnlock()
}
}
开发者ID:sguzwf,项目名称:mlf,代码行数:64,代码来源:rbm.go
示例9: OptimizeWeights
func (opt *lbfgsOptimizer) OptimizeWeights(
weights *util.Matrix, derivative_func ComputeInstanceDerivativeFunc, set data.Dataset) {
// 学习率计算器
learningRate := NewLearningRate(opt.options)
// 偏导数向量
derivative := weights.Populate()
// 优化循环
step := 0
convergingSteps := 0
oldWeights := weights.Populate()
weightsDelta := weights.Populate()
// 为各个工作协程开辟临时资源
numLbfgsThreads := *lbfgs_threads
if numLbfgsThreads == 0 {
numLbfgsThreads = runtime.NumCPU()
}
workerSet := make([]data.Dataset, numLbfgsThreads)
workerDerivative := make([]*util.Matrix, numLbfgsThreads)
workerInstanceDerivative := make([]*util.Matrix, numLbfgsThreads)
for iWorker := 0; iWorker < numLbfgsThreads; iWorker++ {
workerBuckets := []data.SkipBucket{
{true, iWorker},
{false, 1},
{true, numLbfgsThreads - 1 - iWorker},
}
workerSet[iWorker] = data.NewSkipDataset(set, workerBuckets)
workerDerivative[iWorker] = weights.Populate()
workerInstanceDerivative[iWorker] = weights.Populate()
}
log.Print("开始L-BFGS优化")
for {
if opt.options.MaxIterations > 0 && step >= opt.options.MaxIterations {
break
}
step++
// 开始工作协程
workerChannel := make(chan int, numLbfgsThreads)
for iWorker := 0; iWorker < numLbfgsThreads; iWorker++ {
go func(iw int) {
workerDerivative[iw].Clear()
iterator := workerSet[iw].CreateIterator()
iterator.Start()
for !iterator.End() {
instance := iterator.GetInstance()
derivative_func(
weights, instance, workerInstanceDerivative[iw])
// log.Print(workerInstanceDerivative[iw].GetValues(0))
workerDerivative[iw].Increment(
workerInstanceDerivative[iw], float64(1)/float64(set.NumInstances()))
iterator.Next()
}
workerChannel <- iw
}(iWorker)
}
derivative.Clear()
// 等待工作协程结束
for iWorker := 0; iWorker < numLbfgsThreads; iWorker++ {
<-workerChannel
}
for iWorker := 0; iWorker < numLbfgsThreads; iWorker++ {
derivative.Increment(workerDerivative[iWorker], 1)
}
// 添加正则化项
derivative.Increment(ComputeRegularization(weights, opt.options), 1.0/float64(set.NumInstances()))
// 计算特征权重的增量
delta := opt.GetDeltaX(weights, derivative)
// 根据学习率更新权重
learning_rate := learningRate.ComputeLearningRate(delta)
weights.Increment(delta, learning_rate)
weightsDelta.WeightedSum(weights, oldWeights, 1, -1)
oldWeights.DeepCopy(weights)
weightsNorm := weights.Norm()
weightsDeltaNorm := weightsDelta.Norm()
log.Printf("#%d |dw|/|w|=%f |w|=%f lr=%1.3g", step, weightsDeltaNorm/weightsNorm, weightsNorm, learning_rate)
// 判断是否溢出
if math.IsNaN(weightsNorm) {
log.Fatal("优化失败:不收敛")
}
// 判断是否收敛
if weightsDeltaNorm/weightsNorm < opt.options.ConvergingDeltaWeight {
convergingSteps++
if convergingSteps > opt.options.ConvergingSteps {
log.Printf("收敛")
break
}
} else {
//.........这里部分代码省略.........
开发者ID:sguzwf,项目名称:mlf,代码行数:101,代码来源:lbfgs.go
注:本文中的github.com/huichen/mlf/data.Dataset类示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论