本文整理汇总了Java中org.apache.spark.mllib.classification.LogisticRegressionWithSGD类的典型用法代码示例。如果您正苦于以下问题:Java LogisticRegressionWithSGD类的具体用法?Java LogisticRegressionWithSGD怎么用?Java LogisticRegressionWithSGD使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
LogisticRegressionWithSGD类属于org.apache.spark.mllib.classification包,在下文中一共展示了LogisticRegressionWithSGD类的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Java代码示例。
示例1: trainWithSGD
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; //导入依赖的package包/类
@SuppressWarnings("unchecked")
public T trainWithSGD(int numIterations){
//Train the model
if(modelName.equals("SVMModel")){
SVMModel svmmodel = SVMWithSGD.train(trainingData.rdd(), numIterations);
this.model = (T)(Object) svmmodel;
}
else if(modelName.equals("LogisticRegressionModel")){
LogisticRegressionModel lrmodel = LogisticRegressionWithSGD.train(trainingData.rdd(), numIterations);
this.model = (T)(Object) lrmodel;
}
//Evalute the trained model
EvaluateProcess<T> evalProcess = new EvaluateProcess<T>(model, modelName, validData, numClasses);
evalProcess.evalute(numClasses);
return model;
}
开发者ID:Chih-Ling-Hsu,项目名称:Spark-Machine-Learning-Modules,代码行数:18,代码来源:TrainModel.java
示例2: shouldExportAndImportCorrectly
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; //导入依赖的package包/类
@Test
public void shouldExportAndImportCorrectly() {
String datapath = "src/test/resources/binary_classification_test.libsvm";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
//Train model in spark
LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(data.rdd());
//Export this model
byte[] exportedModel = ModelExporter.export(lrmodel);
//Import it back
LogisticRegressionModelInfo importedModel = (LogisticRegressionModelInfo) ModelImporter.importModelInfo(exportedModel);
//check if they are exactly equal with respect to their fields
//it maybe edge cases eg. order of elements in the list is changed
assertEquals(lrmodel.intercept(), importedModel.getIntercept(), 0.01);
assertEquals(lrmodel.numClasses(), importedModel.getNumClasses(), 0.01);
assertEquals(lrmodel.numFeatures(), importedModel.getNumFeatures(), 0.01);
assertEquals((double) lrmodel.getThreshold().get(), importedModel.getThreshold(), 0.01);
for (int i = 0; i < importedModel.getNumFeatures(); i++)
assertEquals(lrmodel.weights().toArray()[i], importedModel.getWeights()[i], 0.01);
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:25,代码来源:LogisticRegressionExporterTest.java
示例3: shouldExportAndImportCorrectly
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; //导入依赖的package包/类
@Test
public void shouldExportAndImportCorrectly() {
String datapath = "src/test/resources/binary_classification_test.libsvm";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
//Train model in spark
LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(data.rdd());
//Export this model
byte[] exportedModel = ModelExporter.export(lrmodel, null);
//Import it back
LogisticRegressionModelInfo importedModel = (LogisticRegressionModelInfo) ModelImporter.importModelInfo(exportedModel);
//check if they are exactly equal with respect to their fields
//it maybe edge cases eg. order of elements in the list is changed
assertEquals(lrmodel.intercept(), importedModel.getIntercept(), EPSILON);
assertEquals(lrmodel.numClasses(), importedModel.getNumClasses(), EPSILON);
assertEquals(lrmodel.numFeatures(), importedModel.getNumFeatures(), EPSILON);
assertEquals((double) lrmodel.getThreshold().get(), importedModel.getThreshold(), EPSILON);
for (int i = 0; i < importedModel.getNumFeatures(); i++)
assertEquals(lrmodel.weights().toArray()[i], importedModel.getWeights()[i], EPSILON);
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:25,代码来源:LogisticRegressionExporterTest.java
示例4: testLogisticRegression
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; //导入依赖的package包/类
@Test
public void testLogisticRegression() {
//prepare data
String datapath = "src/test/resources/binary_classification_test.libsvm";
JavaRDD<LabeledPoint> trainingData = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
//Train model in spark
LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(trainingData.rdd());
//Export this model
byte[] exportedModel = ModelExporter.export(lrmodel);
//Import and get Transformer
Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);
//validate predictions
List<LabeledPoint> testPoints = trainingData.collect();
for (LabeledPoint i : testPoints) {
Vector v = i.features();
double actual = lrmodel.predict(v);
Map<String, Object> data = new HashMap<String, Object>();
data.put("features", v.toArray());
transformer.transform(data);
double predicted = (double) data.get("prediction");
assertEquals(actual, predicted, 0.01);
}
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:30,代码来源:LogisticRegressionBridgeTest.java
示例5: testLogisticRegression
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; //导入依赖的package包/类
@Test
public void testLogisticRegression() {
//prepare data
String datapath = "src/test/resources/binary_classification_test.libsvm";
JavaRDD<LabeledPoint> trainingData = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
//Train model in spark
LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(trainingData.rdd());
//Export this model
byte[] exportedModel = ModelExporter.export(lrmodel, null);
//Import and get Transformer
Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);
//validate predictions
List<LabeledPoint> testPoints = trainingData.collect();
for (LabeledPoint i : testPoints) {
Vector v = i.features();
double actual = lrmodel.predict(v);
Map<String, Object> data = new HashMap<String, Object>();
data.put("features", v.toArray());
transformer.transform(data);
double predicted = (double) data.get("prediction");
assertEquals(actual, predicted, EPSILON);
}
}
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:30,代码来源:LogisticRegressionBridgeTest.java
示例6: trainWithSGD
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; //导入依赖的package包/类
/**
* TODO add another overloaded method to avoid Regularization.
* This method uses stochastic gradient descent (SGD) algorithm to train a logistic regression model
*
* @param trainingDataset Training dataset as a JavaRDD of labeled points
* @param noOfIterations No of iterations
* @param initialLearningRate Initial learning rate
* @param regularizationType Regularization type : L1 or L2
* @param regularizationParameter Regularization parameter
* @param dataFractionPerSGDIteration Data fraction per SGD iteration
* @return Logistic regression model
*/
public LogisticRegressionModel trainWithSGD(JavaRDD<LabeledPoint> trainingDataset, double initialLearningRate,
int noOfIterations, String regularizationType, double regularizationParameter,
double dataFractionPerSGDIteration) {
LogisticRegressionWithSGD lrSGD = new LogisticRegressionWithSGD(initialLearningRate, noOfIterations,
regularizationParameter, dataFractionPerSGDIteration);
if (MLConstants.L1.equals(regularizationType)) {
lrSGD.optimizer().setUpdater(new L1Updater());
} else if (MLConstants.L2.equals(regularizationType)) {
lrSGD.optimizer().setUpdater(new SquaredL2Updater());
}
lrSGD.setIntercept(true);
return lrSGD.run(trainingDataset.rdd());
}
开发者ID:wso2-attic,项目名称:carbon-ml,代码行数:26,代码来源:LogisticRegression.java
示例7: trainInternal
import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; //导入依赖的package包/类
@Override
protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
throws LensException {
LogisticRegressionModel lrModel = LogisticRegressionWithSGD.train(trainingRDD, iterations, stepSize,
minBatchFraction);
return new LogitRegressionClassificationModel(modelId, lrModel);
}
开发者ID:apache,项目名称:lens,代码行数:8,代码来源:LogisticRegressionAlgo.java
注:本文中的org.apache.spark.mllib.classification.LogisticRegressionWithSGD类示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论