• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Java LogisticRegressionWithSGD类代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Java中org.apache.spark.mllib.classification.LogisticRegressionWithSGD的典型用法代码示例。如果您正苦于以下问题:Java LogisticRegressionWithSGD类的具体用法?Java LogisticRegressionWithSGD怎么用?Java LogisticRegressionWithSGD使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。



LogisticRegressionWithSGD类属于org.apache.spark.mllib.classification包,在下文中一共展示了LogisticRegressionWithSGD类的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Java代码示例。

示例1: trainWithSGD

import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; //导入依赖的package包/类
@SuppressWarnings("unchecked")
public T trainWithSGD(int numIterations){    
    //Train the model
    if(modelName.equals("SVMModel")){
      SVMModel svmmodel = SVMWithSGD.train(trainingData.rdd(), numIterations);
      this.model = (T)(Object) svmmodel;
    } 
    else if(modelName.equals("LogisticRegressionModel")){
      LogisticRegressionModel lrmodel = LogisticRegressionWithSGD.train(trainingData.rdd(), numIterations);
      this.model = (T)(Object) lrmodel;
    } 

    //Evalute the trained model      
    EvaluateProcess<T> evalProcess = new EvaluateProcess<T>(model, modelName, validData, numClasses);
    evalProcess.evalute(numClasses);
  return model;
}
 
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:18,代码来源:TrainModel.java


示例2: shouldExportAndImportCorrectly

import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; //导入依赖的package包/类
@Test
public void shouldExportAndImportCorrectly() {
    String datapath = "src/test/resources/binary_classification_test.libsvm";
    JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();

    //Train model in spark
    LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(data.rdd());

    //Export this model
    byte[] exportedModel = ModelExporter.export(lrmodel);

    //Import it back
    LogisticRegressionModelInfo importedModel = (LogisticRegressionModelInfo) ModelImporter.importModelInfo(exportedModel);

    //check if they are exactly equal with respect to their fields
    //it maybe edge cases eg. order of elements in the list is changed
    assertEquals(lrmodel.intercept(), importedModel.getIntercept(), 0.01);
    assertEquals(lrmodel.numClasses(), importedModel.getNumClasses(), 0.01);
    assertEquals(lrmodel.numFeatures(), importedModel.getNumFeatures(), 0.01);
    assertEquals((double) lrmodel.getThreshold().get(), importedModel.getThreshold(), 0.01);
    for (int i = 0; i < importedModel.getNumFeatures(); i++)
        assertEquals(lrmodel.weights().toArray()[i], importedModel.getWeights()[i], 0.01);

}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:25,代码来源:LogisticRegressionExporterTest.java


示例3: shouldExportAndImportCorrectly

import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; //导入依赖的package包/类
@Test
public void shouldExportAndImportCorrectly() {
    String datapath = "src/test/resources/binary_classification_test.libsvm";
    JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();

    //Train model in spark
    LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(data.rdd());

    //Export this model
    byte[] exportedModel = ModelExporter.export(lrmodel, null);

    //Import it back
    LogisticRegressionModelInfo importedModel = (LogisticRegressionModelInfo) ModelImporter.importModelInfo(exportedModel);

    //check if they are exactly equal with respect to their fields
    //it maybe edge cases eg. order of elements in the list is changed
    assertEquals(lrmodel.intercept(), importedModel.getIntercept(), EPSILON);
    assertEquals(lrmodel.numClasses(), importedModel.getNumClasses(), EPSILON);
    assertEquals(lrmodel.numFeatures(), importedModel.getNumFeatures(), EPSILON);
    assertEquals((double) lrmodel.getThreshold().get(), importedModel.getThreshold(), EPSILON);
    for (int i = 0; i < importedModel.getNumFeatures(); i++)
        assertEquals(lrmodel.weights().toArray()[i], importedModel.getWeights()[i], EPSILON);

}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:25,代码来源:LogisticRegressionExporterTest.java


示例4: testLogisticRegression

import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; //导入依赖的package包/类
@Test
public void testLogisticRegression() {
    //prepare data
    String datapath = "src/test/resources/binary_classification_test.libsvm";
    JavaRDD<LabeledPoint> trainingData = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();

    //Train model in spark
    LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(trainingData.rdd());

    //Export this model
    byte[] exportedModel = ModelExporter.export(lrmodel);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    //validate predictions
    List<LabeledPoint> testPoints = trainingData.collect();
    for (LabeledPoint i : testPoints) {
        Vector v = i.features();
        double actual = lrmodel.predict(v);

        Map<String, Object> data = new HashMap<String, Object>();
        data.put("features", v.toArray());
        transformer.transform(data);
        double predicted = (double) data.get("prediction");

        assertEquals(actual, predicted, 0.01);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:30,代码来源:LogisticRegressionBridgeTest.java


示例5: testLogisticRegression

import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; //导入依赖的package包/类
@Test
public void testLogisticRegression() {
    //prepare data
    String datapath = "src/test/resources/binary_classification_test.libsvm";
    JavaRDD<LabeledPoint> trainingData = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();

    //Train model in spark
    LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(trainingData.rdd());

    //Export this model
    byte[] exportedModel = ModelExporter.export(lrmodel, null);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    //validate predictions
    List<LabeledPoint> testPoints = trainingData.collect();
    for (LabeledPoint i : testPoints) {
        Vector v = i.features();
        double actual = lrmodel.predict(v);

        Map<String, Object> data = new HashMap<String, Object>();
        data.put("features", v.toArray());
        transformer.transform(data);
        double predicted = (double) data.get("prediction");

        assertEquals(actual, predicted, EPSILON);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:30,代码来源:LogisticRegressionBridgeTest.java


示例6: trainWithSGD

import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; //导入依赖的package包/类
/**
 * TODO add another overloaded method to avoid Regularization.
 * This method uses stochastic gradient descent (SGD) algorithm to train a logistic regression model
 *
 * @param trainingDataset               Training dataset as a JavaRDD of labeled points
 * @param noOfIterations                No of iterations
 * @param initialLearningRate           Initial learning rate
 * @param regularizationType            Regularization type : L1 or L2
 * @param regularizationParameter       Regularization parameter
 * @param dataFractionPerSGDIteration   Data fraction per SGD iteration
 * @return                              Logistic regression model
 */
public LogisticRegressionModel trainWithSGD(JavaRDD<LabeledPoint> trainingDataset, double initialLearningRate,
        int noOfIterations, String regularizationType, double regularizationParameter,
        double dataFractionPerSGDIteration) {
    LogisticRegressionWithSGD lrSGD = new LogisticRegressionWithSGD(initialLearningRate, noOfIterations, 
            regularizationParameter, dataFractionPerSGDIteration);
    if (MLConstants.L1.equals(regularizationType)) {
        lrSGD.optimizer().setUpdater(new L1Updater());
    } else if (MLConstants.L2.equals(regularizationType)) {
        lrSGD.optimizer().setUpdater(new SquaredL2Updater());
    }
    lrSGD.setIntercept(true);
    return lrSGD.run(trainingDataset.rdd());
}
 
开发者ID:wso2-attic,项目名称:carbon-ml,代码行数:26,代码来源:LogisticRegression.java


示例7: trainInternal

import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; //导入依赖的package包/类
@Override
protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
  throws LensException {
  LogisticRegressionModel lrModel = LogisticRegressionWithSGD.train(trainingRDD, iterations, stepSize,
    minBatchFraction);
  return new LogitRegressionClassificationModel(modelId, lrModel);
}
 
开发者ID:apache,项目名称:lens,代码行数:8,代码来源:LogisticRegressionAlgo.java



注:本文中的org.apache.spark.mllib.classification.LogisticRegressionWithSGD类示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Java ProxyBuilder类代码示例发布时间:2022-05-22
下一篇:
Java LogWrapper类代码示例发布时间:2022-05-22
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap