本文整理汇总了Java中org.apache.spark.mllib.tree.model.DecisionTreeModel类的典型用法代码示例。如果您正苦于以下问题:Java DecisionTreeModel类的具体用法?Java DecisionTreeModel怎么用?Java DecisionTreeModel使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
DecisionTreeModel类属于org.apache.spark.mllib.tree.model包,在下文中一共展示了DecisionTreeModel类的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Java代码示例。
示例1: OnlineFeatureHandler
import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public OnlineFeatureHandler(FeatureConstraint featureConstraint,
DetectionModel detectionModel,
onlineMLEventListener onlineMLEventListener,
ControllerConnector controllerConnector) {
this.featureConstraint = featureConstraint;
this.detectionModel = detectionModel;
setAthenaMLFeatureConfiguration(detectionModel.getAthenaMLFeatureConfiguration());
if (detectionModel instanceof KMeansDetectionModel) {
this.kMeansModel = (KMeansModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof GaussianMixtureDetectionModel) {
this.gaussianMixtureModel = (GaussianMixtureModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof DecisionTreeDetectionModel) {
this.decisionTreeModel = (DecisionTreeModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof NaiveBayesDetectionModel) {
this.naiveBayesModel = (NaiveBayesModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof RandomForestDetectionModel) {
this.randomForestModel = (RandomForestModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof GradientBoostedTreesDetectionModel) {
this.gradientBoostedTreesModel = (GradientBoostedTreesModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof SVMDetectionModel) {
this.svmModel = (SVMModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof LogisticRegressionDetectionModel) {
this.logisticRegressionModel = (LogisticRegressionModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof LinearRegressionDetectionModel) {
this.linearRegressionModel = (LinearRegressionModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof LassoDetectionModel) {
this.lassoModel = (LassoModel) detectionModel.getDetectionModel();
} else if (detectionModel instanceof RidgeRegressionDetectionModel) {
this.ridgeRegressionModel = (RidgeRegressionModel) detectionModel.getDetectionModel();
} else {
//not supported ML model
System.out.println("Not supported model");
}
this.eventDeliveryManager = new EventDeliveryManagerImpl(controllerConnector, new InternalAthenaFeatureEventListener());
this.eventDeliveryManager.registerOnlineAthenaFeature(null, new QueryIdentifier(QUERY_IDENTIFIER), featureConstraint);
this.onlineMLEventListener = onlineMLEventListener;
System.out.println("Install handler!");
}
开发者ID:shlee89,项目名称:athena,代码行数:41,代码来源:OnlineFeatureHandler.java
示例2: predictorExampleCounts
import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
/**
* @param trainPointData data to run down trees
* @param model random decision forest model to count on
* @return map of predictor index to the number of training examples that reached a
* node whose decision is based on that feature. The index is among predictors, not all
* features, since there are fewer predictors than features. That is, the index will
* match the one used in the {@link RandomForestModel}.
*/
private static Map<Integer,Long> predictorExampleCounts(JavaRDD<LabeledPoint> trainPointData,
RandomForestModel model) {
return trainPointData.mapPartitions(data -> {
IntLongMap featureIndexCount = HashIntLongMaps.newMutableMap();
data.forEachRemaining(datum -> {
double[] featureVector = datum.features().toArray();
for (DecisionTreeModel tree : model.trees()) {
org.apache.spark.mllib.tree.model.Node node = tree.topNode();
// This logic cloned from Node.predict:
while (!node.isLeaf()) {
Split split = node.split().get();
int featureIndex = split.feature();
// Count feature
featureIndexCount.addValue(featureIndex, 1);
node = nextNode(featureVector, node, split, featureIndex);
}
}
});
// Clone to avoid problem with Kryo serializing Koloboke
return Collections.<Map<Integer,Long>>singleton(
new HashMap<>(featureIndexCount)).iterator();
}).reduce(RDFUpdate::merge);
}
开发者ID:oncewang,项目名称:oryx2,代码行数:32,代码来源:RDFUpdate.java
示例3: generateKMeansModel
import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public DecisionTreeModel generateKMeansModel(JavaRDD<LabeledPoint> parsedData,
DecisionTreeDetectionAlgorithm decisionTreeDetectionAlgorithm,
DecisionTreeModelSummary decisionTreeModelSummary) {
DecisionTreeModel decisionTreeModel
= DecisionTree.trainClassifier(parsedData,
decisionTreeDetectionAlgorithm.getNumClasses(),
decisionTreeDetectionAlgorithm.getCategoricalFeaturesInfo(),
decisionTreeDetectionAlgorithm.getImpurity(),
decisionTreeDetectionAlgorithm.getMaxDepth(),
decisionTreeDetectionAlgorithm.getMaxBins());
decisionTreeModelSummary.setDecisionTreeDetectionAlgorithm(decisionTreeDetectionAlgorithm);
return decisionTreeModel;
}
开发者ID:shlee89,项目名称:athena,代码行数:14,代码来源:DecisionTreeDistJob.java
示例4: generateDecisionTreeWithPreprocessing
import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public DecisionTreeModel generateDecisionTreeWithPreprocessing(JavaPairRDD<Object, BSONObject> mongoRDD,
AthenaMLFeatureConfiguration athenaMLFeatureConfiguration,
DecisionTreeDetectionAlgorithm decisionTreeDetectionAlgorithm,
Marking marking,
DecisionTreeModelSummary decisionTreeModelSummary) {
return generateKMeansModel(
rddPreProcessing(mongoRDD, athenaMLFeatureConfiguration, decisionTreeModelSummary,
marking),
decisionTreeDetectionAlgorithm, decisionTreeModelSummary
);
}
开发者ID:shlee89,项目名称:athena,代码行数:13,代码来源:DecisionTreeDistJob.java
示例5: treeNodeExampleCounts
import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
/**
* @param trainPointData data to run down trees
* @param model random decision forest model to count on
* @return maps of node IDs to the count of training examples that reached that node, one
* per tree in the model
* @see #predictorExampleCounts(JavaRDD,RandomForestModel)
*/
private static List<Map<Integer,Long>> treeNodeExampleCounts(JavaRDD<LabeledPoint> trainPointData,
RandomForestModel model) {
return trainPointData.mapPartitions(data -> {
DecisionTreeModel[] trees = model.trees();
List<IntLongMap> treeNodeIDCounts = IntStream.range(0, trees.length).
mapToObj(i -> HashIntLongMaps.newMutableMap()).collect(Collectors.toList());
data.forEachRemaining(datum -> {
double[] featureVector = datum.features().toArray();
for (int i = 0; i < trees.length; i++) {
DecisionTreeModel tree = trees[i];
IntLongMap nodeIDCount = treeNodeIDCounts.get(i);
org.apache.spark.mllib.tree.model.Node node = tree.topNode();
// This logic cloned from Node.predict:
while (!node.isLeaf()) {
// Count node ID
nodeIDCount.addValue(node.id(), 1);
Split split = node.split().get();
int featureIndex = split.feature();
node = nextNode(featureVector, node, split, featureIndex);
}
nodeIDCount.addValue(node.id(), 1);
}
});
return Collections.<List<Map<Integer,Long>>>singleton(
treeNodeIDCounts.stream().map(HashMap::new).collect(Collectors.toList())).iterator();
}
).reduce((a, b) -> {
Preconditions.checkArgument(a.size() == b.size());
for (int i = 0; i < a.size(); i++) {
merge(a.get(i), b.get(i));
}
return a;
});
}
开发者ID:oncewang,项目名称:oryx2,代码行数:42,代码来源:RDFUpdate.java
示例6: main
import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public static void main(String[] args) {
Logger.getLogger("org").setLevel(Level.WARN);
SparkConf sparkConf = new SparkConf()
.setAppName("ExampleSpark")
.setMaster("local");
JavaSparkContext jsc = new JavaSparkContext(sparkConf);
//String in = "data/iris2.data";
//String out = "data/iris2outSVM.data";
//double[][] inputs = IOUtils.readMatrix(in, ",");
//double[] outputs = IOUtils.readVector(out);
IdxManager idx = IOUtils.deserialize("data/idx.ser");
IdxManager idxTest = IOUtils.deserialize("data/idx-test.ser");
double[][] inputs = idx.getData();
double[] outputs = idx.getLabelsVec();
double[][] inputsTest = idxTest.getData();
double[] outputsTest = idxTest.getLabelsVec();
inputs = HogManager.exportDataFeatures(inputs, idx.getNumOfRows(),
idx.getNumOfCols());
inputsTest = HogManager.exportDataFeatures(inputsTest, idx.getNumOfRows(),
idx.getNumOfCols());
List<LabeledPoint> pointList = new ArrayList<>();
for (int i = 0; i < outputs.length; i++) {
pointList.add(new LabeledPoint(outputs[i], Vectors.dense(inputs[i])));
}
List<LabeledPoint> pointListTest = new ArrayList<>();
for (int i = 0; i < outputsTest.length; i++) {
pointListTest.add(new LabeledPoint(outputsTest[i],
Vectors.dense(inputsTest[i])));
}
JavaRDD<LabeledPoint> trainingData = jsc.parallelize(pointList);
JavaRDD<LabeledPoint> testData = jsc.parallelize(pointListTest);
// Split the data into training and test sets (30% held out for testing)
//JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
//JavaRDD<LabeledPoint> trainingData = splits[0];
//JavaRDD<LabeledPoint> testData = splits[1];
// Set parameters.
// Empty categoricalFeaturesInfo indicates all features are continuous.
Integer numClasses = 10;
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
String impurity = "gini";
Integer maxDepth = 10;
Integer maxBins = 256;
// Train a DecisionTree model for classification.
long startTime = System.currentTimeMillis();
final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData,
numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins);
long endTime = System.currentTimeMillis();
long learnTime = endTime - startTime;
// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
testData.mapToPair(
p -> new Tuple2<>(model.predict(p.features()), p.label()));
Double testErr = 1.0 * predictionAndLabel.filter(
pl -> !pl._1().equals(pl._2())).count() / testData.count();
// results
new File("results").mkdir();
IOUtils.writeStr("results/dtree_error.data", Double.toString(testErr));
IOUtils.writeStr("results/dtree_model.data", model.toDebugString());
double[][] outFinal = new double[outputsTest.length][];
for (int i = 0; i < outputsTest.length; i++) {
outFinal[i] = valToVec(model.predict(Vectors.dense(inputsTest[i])));
}
ConfusionMatrix cm = new ConfusionMatrix(outFinal, idxTest.getLabels());
cm.writeClassErrorMatrix("results/confusion_matrix.data");
IOUtils.writeStr("results/learn_time_ms.data", Long.toString(learnTime));
}
开发者ID:lukago,项目名称:neural-algorithms,代码行数:82,代码来源:ExampleSpark.java
示例7: setDecisionTreeModel
import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public void setDecisionTreeModel(DecisionTreeModel decisionTreeModel) {
this.decisionTreeModel = decisionTreeModel;
}
开发者ID:shlee89,项目名称:athena,代码行数:4,代码来源:DecisionTreeDetectionModel.java
示例8: validate
import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public void validate(JavaPairRDD<Object, BSONObject> mongoRDD,
AthenaMLFeatureConfiguration athenaMLFeatureConfiguration,
DecisionTreeDetectionModel decisionTreeDetectionModel,
DecisionTreeValidationSummary decisionTreeValidationSummary) {
List<AthenaFeatureField> listOfTargetFeatures = athenaMLFeatureConfiguration.getListOfTargetFeatures();
Map<AthenaFeatureField, Integer> weight = athenaMLFeatureConfiguration.getWeight();
Marking marking = decisionTreeDetectionModel.getMarking();
DecisionTreeModel model = (DecisionTreeModel) decisionTreeDetectionModel.getDetectionModel();
Normalizer normalizer = new Normalizer();
int numberOfTargetValue = listOfTargetFeatures.size();
mongoRDD.foreach(new VoidFunction<Tuple2<Object, BSONObject>>() {
public void call(Tuple2<Object, BSONObject> t) throws UnknownHostException {
long start2 = System.nanoTime(); // <-- start
BSONObject feature = (BSONObject) t._2().get(AthenaFeatureField.FEATURE);
BSONObject idx = (BSONObject) t._2();
int originLabel = marking.checkClassificationMarkingElements(idx,feature);
double[] values = new double[numberOfTargetValue];
for (int j = 0; j < numberOfTargetValue; j++) {
values[j] = 0;
if (feature.containsField(listOfTargetFeatures.get(j).getValue())) {
Object obj = feature.get(listOfTargetFeatures.get(j).getValue());
if (obj instanceof Long) {
values[j] = (Long) obj;
} else if (obj instanceof Double) {
values[j] = (Double) obj;
} else if (obj instanceof Boolean) {
values[j] = (Boolean) obj ? 1 : 0;
} else {
return;
}
//check weight
if (weight.containsKey(listOfTargetFeatures.get(j))) {
values[j] *= weight.get(listOfTargetFeatures.get(j));
}
//check absolute
if (athenaMLFeatureConfiguration.isAbsolute()){
values[j] = Math.abs(values[j]);
}
}
}
Vector normedForVal;
if (athenaMLFeatureConfiguration.isNormalization()) {
normedForVal = normalizer.transform(Vectors.dense(values));
} else {
normedForVal = Vectors.dense(values);
}
LabeledPoint p = new LabeledPoint(originLabel,normedForVal);
int validatedLabel = (int) model.predict(p.features());
decisionTreeValidationSummary.updateSummary(validatedLabel,idx,feature);
long end2 = System.nanoTime();
long result2 = end2 - start2;
decisionTreeValidationSummary.addTotalNanoSeconds(result2);
}
});
decisionTreeValidationSummary.getAverageNanoSeconds();
decisionTreeValidationSummary.setDecisionTreeDetectionAlgorithm((DecisionTreeDetectionAlgorithm) decisionTreeDetectionModel.getDetectionAlgorithm());
}
开发者ID:shlee89,项目名称:athena,代码行数:69,代码来源:DecisionTreeDistJob.java
示例9: generateDecisionTreeAthenaDetectionModel
import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public DecisionTreeDetectionModel generateDecisionTreeAthenaDetectionModel(JavaSparkContext sc,
FeatureConstraint featureConstraint,
AthenaMLFeatureConfiguration athenaMLFeatureConfiguration,
DetectionAlgorithm detectionAlgorithm,
Indexing indexing,
Marking marking) {
DecisionTreeModelSummary decisionTreeModelSummary = new DecisionTreeModelSummary(
sc.sc(), indexing, marking);
long start = System.nanoTime(); // <-- start
DecisionTreeDetectionAlgorithm decisionTreeDetectionAlgorithm = (DecisionTreeDetectionAlgorithm) detectionAlgorithm;
DecisionTreeDetectionModel decisionTreeDetectionModel = new DecisionTreeDetectionModel();
decisionTreeDetectionModel.setDecisionTreeDetectionAlgorithm(decisionTreeDetectionAlgorithm);
decisionTreeModelSummary.setDecisionTreeDetectionAlgorithm(decisionTreeDetectionAlgorithm);
decisionTreeDetectionModel.setFeatureConstraint(featureConstraint);
decisionTreeDetectionModel.setAthenaMLFeatureConfiguration(athenaMLFeatureConfiguration);
decisionTreeDetectionModel.setIndexing(indexing);
decisionTreeDetectionModel.setMarking(marking);
JavaPairRDD<Object, BSONObject> mongoRDD;
mongoRDD = sc.newAPIHadoopRDD(
mongodbConfig, // Configuration
MongoInputFormat.class, // InputFormat: read from a live cluster.
Object.class, // Key class
BSONObject.class // Value class
);
DecisionTreeDistJob decisionTreeDistJob = new DecisionTreeDistJob();
DecisionTreeModel decisionTreeModel = decisionTreeDistJob.generateDecisionTreeWithPreprocessing(mongoRDD,
athenaMLFeatureConfiguration, decisionTreeDetectionAlgorithm, marking, decisionTreeModelSummary);
decisionTreeDetectionModel.setDecisionTreeModel(decisionTreeModel);
long end = System.nanoTime(); // <-- start
long time = end - start;
decisionTreeModelSummary.setTotalLearningTime(time);
decisionTreeDetectionModel.setClassificationModelSummary(decisionTreeModelSummary);
return decisionTreeDetectionModel;
}
开发者ID:shlee89,项目名称:athena,代码行数:45,代码来源:MachineLearningManagerImpl.java
示例10: rdfModelToPMML
import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
private PMML rdfModelToPMML(RandomForestModel rfModel,
CategoricalValueEncodings categoricalValueEncodings,
int maxDepth,
int maxSplitCandidates,
String impurity,
List<Map<Integer,Long>> nodeIDCounts,
Map<Integer,Long> predictorIndexCounts) {
boolean classificationTask = rfModel.algo().equals(Algo.Classification());
Preconditions.checkState(classificationTask == inputSchema.isClassification());
DecisionTreeModel[] trees = rfModel.trees();
Model model;
if (trees.length == 1) {
model = toTreeModel(trees[0], categoricalValueEncodings, nodeIDCounts.get(0));
} else {
MiningModel miningModel = new MiningModel();
model = miningModel;
Segmentation.MultipleModelMethod multipleModelMethodType = classificationTask ?
Segmentation.MultipleModelMethod.WEIGHTED_MAJORITY_VOTE :
Segmentation.MultipleModelMethod.WEIGHTED_AVERAGE;
List<Segment> segments = new ArrayList<>(trees.length);
for (int treeID = 0; treeID < trees.length; treeID++) {
TreeModel treeModel =
toTreeModel(trees[treeID], categoricalValueEncodings, nodeIDCounts.get(treeID));
segments.add(new Segment()
.setId(Integer.toString(treeID))
.setPredicate(new True())
.setModel(treeModel)
.setWeight(1.0)); // No weights in MLlib impl now
}
miningModel.setSegmentation(new Segmentation(multipleModelMethodType, segments));
}
model.setMiningFunction(classificationTask ?
MiningFunction.CLASSIFICATION :
MiningFunction.REGRESSION);
double[] importances = countsToImportances(predictorIndexCounts);
model.setMiningSchema(AppPMMLUtils.buildMiningSchema(inputSchema, importances));
DataDictionary dictionary =
AppPMMLUtils.buildDataDictionary(inputSchema, categoricalValueEncodings);
PMML pmml = PMMLUtils.buildSkeletonPMML();
pmml.setDataDictionary(dictionary);
pmml.addModels(model);
AppPMMLUtils.addExtension(pmml, "maxDepth", maxDepth);
AppPMMLUtils.addExtension(pmml, "maxSplitCandidates", maxSplitCandidates);
AppPMMLUtils.addExtension(pmml, "impurity", impurity);
return pmml;
}
开发者ID:oncewang,项目名称:oryx2,代码行数:55,代码来源:RDFUpdate.java
示例11: MLDecisionTreeModel
import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public MLDecisionTreeModel(DecisionTreeModel model) {
this.model = model;
}
开发者ID:wso2-attic,项目名称:carbon-ml,代码行数:4,代码来源:MLDecisionTreeModel.java
示例12: readExternal
import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
model = (DecisionTreeModel) in.readObject();
}
开发者ID:wso2-attic,项目名称:carbon-ml,代码行数:6,代码来源:MLDecisionTreeModel.java
示例13: getModel
import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public DecisionTreeModel getModel() {
return model;
}
开发者ID:wso2-attic,项目名称:carbon-ml,代码行数:4,代码来源:MLDecisionTreeModel.java
示例14: setModel
import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public void setModel(DecisionTreeModel model) {
this.model = model;
}
开发者ID:wso2-attic,项目名称:carbon-ml,代码行数:4,代码来源:MLDecisionTreeModel.java
示例15: buildDecisionTreeModel
import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
/**
* This method builds a decision tree model
*
* @param sparkContext JavaSparkContext initialized with the application
* @param modelID Model ID
* @param trainingData Training data as a JavaRDD of LabeledPoints
* @param testingData Testing data as a JavaRDD of LabeledPoints
* @param workflow Machine learning workflow
* @param mlModel Deployable machine learning model
* @throws MLModelBuilderException
*/
private ModelSummary buildDecisionTreeModel(JavaSparkContext sparkContext, long modelID,
JavaRDD<LabeledPoint> trainingData, JavaRDD<LabeledPoint> testingData, Workflow workflow, MLModel mlModel,
SortedMap<Integer, String> includedFeatures, Map<Integer, Integer> categoricalFeatureInfo)
throws MLModelBuilderException {
try {
Map<String, String> hyperParameters = workflow.getHyperParameters();
DecisionTree decisionTree = new DecisionTree();
DecisionTreeModel decisionTreeModel = decisionTree.train(trainingData, getNoOfClasses(mlModel),
categoricalFeatureInfo, hyperParameters.get(MLConstants.IMPURITY),
Integer.parseInt(hyperParameters.get(MLConstants.MAX_DEPTH)),
Integer.parseInt(hyperParameters.get(MLConstants.MAX_BINS)));
// remove from cache
trainingData.unpersist();
// add test data to cache
testingData.cache();
JavaPairRDD<Double, Double> predictionsAndLabels = decisionTree.test(decisionTreeModel, testingData)
.cache();
ClassClassificationAndRegressionModelSummary classClassificationAndRegressionModelSummary = SparkModelUtils
.getClassClassificationModelSummary(sparkContext, testingData, predictionsAndLabels);
// remove from cache
testingData.unpersist();
mlModel.setModel(new MLDecisionTreeModel(decisionTreeModel));
classClassificationAndRegressionModelSummary.setFeatures(includedFeatures.values().toArray(new String[0]));
classClassificationAndRegressionModelSummary.setAlgorithm(SUPERVISED_ALGORITHM.DECISION_TREE.toString());
MulticlassMetrics multiclassMetrics = getMulticlassMetrics(sparkContext, predictionsAndLabels);
predictionsAndLabels.unpersist();
classClassificationAndRegressionModelSummary.setMulticlassConfusionMatrix(getMulticlassConfusionMatrix(
multiclassMetrics, mlModel));
Double modelAccuracy = getModelAccuracy(multiclassMetrics);
classClassificationAndRegressionModelSummary.setModelAccuracy(modelAccuracy);
classClassificationAndRegressionModelSummary.setDatasetVersion(workflow.getDatasetVersion());
return classClassificationAndRegressionModelSummary;
} catch (Exception e) {
throw new MLModelBuilderException(
"An error occurred while building decision tree model: " + e.getMessage(), e);
}
}
开发者ID:wso2-attic,项目名称:carbon-ml,代码行数:59,代码来源:SupervisedSparkModelBuilder.java
示例16: loadModel
import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
@Override
protected DecisionTreeModel loadModel(SparkContext sc, String modelPath) {
return DecisionTreeModel.load(sc, modelPath);
}
开发者ID:IBMStreams,项目名称:streamsx.sparkMLLib,代码行数:5,代码来源:SparkDecisionTree.java
示例17: AnchoredPredictor
import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public AnchoredPredictor(JavaSparkContext sc) {
String dataPath = AnchoredPredictor.class.getResource("/anchoredOrMooredModel").toString();
model = DecisionTreeModel.load(sc.sc(), dataPath);
}
开发者ID:amsa-code,项目名称:risky,代码行数:5,代码来源:AnchoredPredictor.java
示例18: main
import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
public static void main(String[] args) throws IOException {
SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTree");
// just run this locally
sparkConf.setMaster("local[" + Runtime.getRuntime().availableProcessors() + "]");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
// Load and parse the data file.
String datapath = "/media/an/fixes.libsvm";
// the feature names are substituted into the model debugString later to
// make it readable
List<String> names = Arrays.asList("lat", "lon", "speedKnots", "courseHeadingDiff",
"preEffectiveSpeedKnots", "preError", "postEffectiveSpeedKnots", "postError");
List<String> classifications = Arrays.asList("other", "moored", "anchored");
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
// Split the data into training and test sets (30% held out for testing)
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[] { 0.7, 0.3 });
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];
// Set parameters.
// Empty categoricalFeaturesInfo indicates all features are continuous.
Integer numClassifications = classifications.size();
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
String impurity = "gini";
Integer maxDepth = 8;
Integer maxBins = 32;
// Train a DecisionTree model for classification.
final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData,
numClassifications, categoricalFeaturesInfo, impurity, maxDepth, maxBins);
// Evaluate model on test instances and compute test error
Double testErr = (double) testData
// pair up actual and predicted classification numerical representation
.map(toPredictionAndActual(model))
// get the ones that don't match
.filter(predictionWrong())
// count them
.count()
// divide by total count to get ratio failing test
/ testData.count();
// Save and load model to demo possible usage in prediction mode
String modelPath = "target/myModelPath";
FileUtils.deleteDirectory(new File(modelPath));
model.save(sc.sc(), modelPath);
DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), modelPath);
System.out.println("Test Error: " + testErr);
String s = useNames(model.toDebugString(), names, classifications);
System.out.println("Learned classification tree model:\n" + s);
FileOutputStream fos = new FileOutputStream("target/model.txt");
fos.write(("Test Error: " + testErr + "\n").getBytes());
fos.write(s.getBytes());
fos.close();
}
开发者ID:amsa-code,项目名称:risky,代码行数:64,代码来源:AnchoredTrainerMain.java
示例19: toPredictionAndActual
import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
private static Function<LabeledPoint, PredictionAndActual> toPredictionAndActual(
final DecisionTreeModel model) {
return p -> new PredictionAndActual(model.predict(p.features()), p.label());
}
开发者ID:amsa-code,项目名称:risky,代码行数:5,代码来源:AnchoredTrainerMain.java
示例20: trainInternal
import org.apache.spark.mllib.tree.model.DecisionTreeModel; //导入依赖的package包/类
@Override
protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
throws LensException {
DecisionTreeModel model = DecisionTree$.MODULE$.train(trainingRDD, algo, decisionTreeImpurity, maxDepth);
return new DecisionTreeClassificationModel(modelId, new SparkDecisionTreeModel(model));
}
开发者ID:apache,项目名称:lens,代码行数:7,代码来源:DecisionTreeAlgo.java
注:本文中的org.apache.spark.mllib.tree.model.DecisionTreeModel类示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论