• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Java PipelineModel类代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Java中org.apache.spark.ml.PipelineModel的典型用法代码示例。如果您正苦于以下问题:Java PipelineModel类的具体用法?Java PipelineModel怎么用?Java PipelineModel使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。



PipelineModel类属于org.apache.spark.ml包,在下文中一共展示了PipelineModel类的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Java代码示例。

示例1: registerFeatures

import org.apache.spark.ml.PipelineModel; //导入依赖的package包/类
@Override
public void registerFeatures(SparkMLEncoder encoder){
	RFormulaModel transformer = getTransformer();

	ResolvedRFormula resolvedFormula = transformer.resolvedFormula();

	String targetCol = resolvedFormula.label();

	String labelCol = transformer.getLabelCol();
	if(!(targetCol).equals(labelCol)){
		List<Feature> features = encoder.getFeatures(targetCol);

		encoder.putFeatures(labelCol, features);
	}

	PipelineModel pipelineModel = transformer.pipelineModel();

	Transformer[] stages = pipelineModel.stages();
	for(Transformer stage : stages){
		FeatureConverter<?> featureConverter = ConverterUtil.createFeatureConverter(stage);

		featureConverter.registerFeatures(encoder);
	}
}
 
开发者ID:jpmml,项目名称:jpmml-sparkml,代码行数:25,代码来源:RFormulaModelConverter.java


示例2: toPMMLByteArray

import org.apache.spark.ml.PipelineModel; //导入依赖的package包/类
static
public byte[] toPMMLByteArray(StructType schema, PipelineModel pipelineModel){
	PMML pmml = toPMML(schema, pipelineModel);

	ByteArrayOutputStream os = new ByteArrayOutputStream(1024 * 1024);

	try {
		MetroJAXBUtil.marshalPMML(pmml, os);
	} catch(JAXBException je){
		throw new RuntimeException(je);
	}

	return os.toByteArray();
}
 
开发者ID:jpmml,项目名称:jpmml-sparkml,代码行数:15,代码来源:ConverterUtil.java


示例3: run

import org.apache.spark.ml.PipelineModel; //导入依赖的package包/类
private void run() throws Exception {
	SparkConf sparkConf = new SparkConf();

	try(JavaSparkContext sparkContext = new JavaSparkContext(sparkConf)){
		SQLContext sqlContext = new SQLContext(sparkContext);

		DataFrameReader reader = sqlContext.read()
			.format("com.databricks.spark.csv")
			.option("header", "true")
			.option("inferSchema", "true");

		DataFrame dataFrame = reader.load(this.csvInput.getAbsolutePath());

		StructType schema = dataFrame.schema();
		System.out.println(schema.treeString());

		Pipeline pipeline = createPipeline(this.function, this.formula);

		PipelineModel pipelineModel = pipeline.fit(dataFrame);

		PMML pmml = ConverterUtil.toPMML(schema, pipelineModel);

		try(OutputStream os = new FileOutputStream(this.pmmlOutput.getAbsolutePath())){
			MetroJAXBUtil.marshalPMML(pmml, os);
		}
	}
}
 
开发者ID:jpmml,项目名称:jpmml-sparkml-bootstrap,代码行数:28,代码来源:Main.java


示例4: CMMModel

import org.apache.spark.ml.PipelineModel; //导入依赖的package包/类
/**
 * Creates a conditional Markov model.
 * @param pipelineModel
 * @param weights
 * @param markovOrder
 */
public CMMModel(PipelineModel pipelineModel, Vector weights, MarkovOrder markovOrder, Map<String, Set<Integer>> tagDictionary) {
	this.pipelineModel = pipelineModel;
	this.contextExtractor = new ContextExtractor(markovOrder, Constants.REGEXP_FILE);
	this.weights = weights;
	this.tags = ((StringIndexerModel)(pipelineModel.stages()[2])).labels();
	String[] features = ((CountVectorizerModel)(pipelineModel.stages()[1])).vocabulary();
	featureMap = new HashMap<String, Integer>();
	for (int j = 0; j < features.length; j++) {
		featureMap.put(features[j], j);
	}
	this.tagDictionary = tagDictionary;
}
 
开发者ID:phuonglh,项目名称:vn.vitk,代码行数:19,代码来源:CMMModel.java


示例5: load

import org.apache.spark.ml.PipelineModel; //导入依赖的package包/类
@Override
public CMMModel load(String path) {
	org.apache.spark.ml.util.DefaultParamsReader.Metadata metadata = DefaultParamsReader.loadMetadata(path, sc(), CMMModel.class.getName());
	String pipelinePath = new Path(path, "pipelineModel").toString();
	PipelineModel pipelineModel = PipelineModel.load(pipelinePath);
	String dataPath = new Path(path, "data").toString();
	DataFrame df = sqlContext().read().format("parquet").load(dataPath);
	Row row = df.select("markovOrder", "weights", "tagDictionary").head();
	// load the Markov order
	MarkovOrder order = MarkovOrder.values()[row.getInt(0)-1];
	// load the weight vector
	Vector w = row.getAs(1);
	// load the tag dictionary
	@SuppressWarnings("unchecked")
	scala.collection.immutable.HashMap<String, WrappedArray<Integer>> td = (scala.collection.immutable.HashMap<String, WrappedArray<Integer>>)row.get(2);
	Map<String, Set<Integer>> tagDict = new HashMap<String, Set<Integer>>();
	Iterator<Tuple2<String, WrappedArray<Integer>>> iterator = td.iterator();
	while (iterator.hasNext()) {
		Tuple2<String, WrappedArray<Integer>> tuple = iterator.next();
		Set<Integer> labels = new HashSet<Integer>();
		scala.collection.immutable.List<Integer> list = tuple._2().toList();
		for (int i = 0; i < list.size(); i++)
			labels.add(list.apply(i));
		tagDict.put(tuple._1(), labels);
	}
	// build a CMM model
	CMMModel model = new CMMModel(pipelineModel, w, order, tagDict);
	DefaultParamsReader.getAndSetParams(model, metadata);
	return model;
}
 
开发者ID:phuonglh,项目名称:vn.vitk,代码行数:31,代码来源:CMMModel.java


示例6: TransitionBasedParserMLP

import org.apache.spark.ml.PipelineModel; //导入依赖的package包/类
/**
 * Creates a transition-based parser using a MLP transition classifier.
 * @param jsc
 * @param classifierFileName
 * @param featureFrame
 */
public TransitionBasedParserMLP(JavaSparkContext jsc, String classifierFileName, FeatureFrame featureFrame) {
	this.featureFrame = featureFrame;
	this.classifier = TransitionClassifier.load(jsc, new Path(classifierFileName, "data").toString());
	this.pipelineModel = PipelineModel.load(new Path(classifierFileName, "pipelineModel").toString());
	this.transitionName = ((StringIndexerModel)pipelineModel.stages()[2]).labels();
	String[] features = ((CountVectorizerModel)(pipelineModel.stages()[1])).vocabulary();
	this.featureMap = new HashMap<String, Integer>();
	for (int j = 0; j < features.length; j++) {
		this.featureMap.put(features[j], j);
	}
	
}
 
开发者ID:phuonglh,项目名称:vn.vitk,代码行数:19,代码来源:TransitionBasedParserMLP.java


示例7: getModelInfo

import org.apache.spark.ml.PipelineModel; //导入依赖的package包/类
@Override
public PipelineModelInfo getModelInfo(final PipelineModel from) {
    final PipelineModelInfo modelInfo = new PipelineModelInfo();
    final ModelInfo stages[] = new ModelInfo[from.stages().length];
    for (int i = 0; i < from.stages().length; i++) {
        Transformer sparkModel = from.stages()[i];
        stages[i] = ModelInfoAdapterFactory.getAdapter(sparkModel.getClass()).adapt(sparkModel);
    }
    modelInfo.setStages(stages);
    return modelInfo;
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:12,代码来源:PipelineModelInfoAdapter.java


示例8: getModelInfo

import org.apache.spark.ml.PipelineModel; //导入依赖的package包/类
@Override
public PipelineModelInfo getModelInfo(final PipelineModel from, final DataFrame df) {
    final PipelineModelInfo modelInfo = new PipelineModelInfo();
    final ModelInfo stages[] = new ModelInfo[from.stages().length];
    for (int i = 0; i < from.stages().length; i++) {
        Transformer sparkModel = from.stages()[i];
        stages[i] = ModelInfoAdapterFactory.getAdapter(sparkModel.getClass()).adapt(sparkModel, df);
    }
    modelInfo.setStages(stages);
    return modelInfo;
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:12,代码来源:PipelineModelInfoAdapter.java


示例9: build

import org.apache.spark.ml.PipelineModel; //导入依赖的package包/类
public Transformer build(){
	Evaluator evaluator = getEvaluator();

	PMMLTransformer pmmlTransformer = new PMMLTransformer(evaluator, this.columnProducers);

	if(this.exploded){
		ColumnExploder columnExploder = new ColumnExploder(pmmlTransformer.getOutputCol());

		ColumnPruner columnPruner = new ColumnPruner(ScalaUtil.singletonSet(pmmlTransformer.getOutputCol()));

		PipelineModel pipelineModel = new PipelineModel(null, new Transformer[]{pmmlTransformer, columnExploder, columnPruner});

		return pipelineModel;
	}

	return pmmlTransformer;
}
 
开发者ID:jeremyore,项目名称:spark-pmml-import,代码行数:18,代码来源:TransformerBuilder.java


示例10: getSource

import org.apache.spark.ml.PipelineModel; //导入依赖的package包/类
@Override
public Class<PipelineModel> getSource() {
    return PipelineModel.class;
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:5,代码来源:PipelineModelInfoAdapter.java


示例11: testDecisionTreeRegressionPrediction

import org.apache.spark.ml.PipelineModel; //导入依赖的package包/类
@Test
  public void testDecisionTreeRegressionPrediction() {
      // Load the data stored in LIBSVM format as a DataFrame.
  	String datapath = "src/test/resources/regression_test.libsvm";
  	
  	Dataset<Row> data = spark.read().format("libsvm").load(datapath);


      // Split the data into training and test sets (30% held out for testing)
      Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});
      Dataset<Row> trainingData = splits[0];
      Dataset<Row> testData = splits[1];

      StringIndexer indexer = new StringIndexer()
              .setInputCol("label")
              .setOutputCol("labelIndex").setHandleInvalid("skip");
      
DecisionTreeRegressor regressionModel =
        new DecisionTreeRegressor().setLabelCol("labelIndex").setFeaturesCol("features");

Pipeline pipeline = new Pipeline()
              .setStages(new PipelineStage[]{indexer, regressionModel});

PipelineModel sparkPipeline = pipeline.fit(trainingData);

      byte[] exportedModel = ModelExporter.export(sparkPipeline);

      Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);
      List<Row> output = sparkPipeline.transform(testData).select("features", "prediction", "label").collectAsList();

      //compare predictions
      for (Row row : output) {
      	Map<String, Object> data_ = new HashMap<>();
          data_.put("features", ((SparseVector) row.get(0)).toArray());
          data_.put("label", (row.get(2)).toString());
          transformer.transform(data_);
          System.out.println(data_);
          System.out.println(data_.get("prediction"));
          assertEquals((double)data_.get("prediction"), (double)row.get(1), EPSILON);
      }
  }
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:42,代码来源:DecisionTreeRegressionModelBridgePipelineTest.java


示例12: testGradientBoostClassification

import org.apache.spark.ml.PipelineModel; //导入依赖的package包/类
@Test
public void testGradientBoostClassification() {
	// Load the data stored in LIBSVM format as a DataFrame.
	String datapath = "src/test/resources/binary_classification_test.libsvm";

	Dataset<Row> data = spark.read().format("libsvm").load(datapath);
	StringIndexer indexer = new StringIndexer()
               .setInputCol("label")
               .setOutputCol("labelIndex");
	// Split the data into training and test sets (30% held out for testing)
	Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});
	Dataset<Row> trainingData = splits[0];
	Dataset<Row> testData = splits[1];

	// Train a RandomForest model.
	GBTClassifier classificationModel = new GBTClassifier().setLabelCol("labelIndex")
               .setFeaturesCol("features");;

        Pipeline pipeline = new Pipeline()
                .setStages(new PipelineStage[]{indexer, classificationModel});


	 PipelineModel sparkPipeline = pipeline.fit(trainingData);

	// Export this model
	byte[] exportedModel = ModelExporter.export(sparkPipeline);

	// Import and get Transformer
	Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

	List<Row> sparkOutput = sparkPipeline.transform(testData).select("features", "prediction", "label").collectAsList();
	
	// compare predictions
	for (Row row : sparkOutput) {
		Map<String, Object> data_ = new HashMap<>();
		data_.put("features", ((SparseVector) row.get(0)).toArray());
		data_.put("label", (row.get(2)).toString());
		transformer.transform(data_);
		System.out.println(data_);
		System.out.println(data_.get("prediction")+" ,"+row.get(1));
		assertEquals((double) data_.get("prediction"), (double) row.get(1), EPSILON);
	}

}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:45,代码来源:GradientBoostClassificationModelPipelineTest.java


示例13: testDecisionTreeClassificationWithPipeline

import org.apache.spark.ml.PipelineModel; //导入依赖的package包/类
@Test
public void testDecisionTreeClassificationWithPipeline() {
	

    // Load the data stored in LIBSVM format as a DataFrame.
	String datapath = "src/test/resources/classification_test.libsvm";
	Dataset<Row> data = spark.read().format("libsvm").load(datapath);



    // Split the data into training and test sets (30% held out for testing)
    Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});        

    Dataset<Row> trainingData = splits[0];
    Dataset<Row> testData = splits[1];

    StringIndexer indexer = new StringIndexer()
            .setInputCol("label")
            .setOutputCol("labelIndex");

    // Train a DecisionTree model.
    DecisionTreeClassifier classificationModel = new DecisionTreeClassifier()
            .setLabelCol("labelIndex")
            .setFeaturesCol("features");

    Pipeline pipeline = new Pipeline()
            .setStages(new PipelineStage[]{indexer, classificationModel});


    // Train model.  This also runs the indexer.
    PipelineModel sparkPipeline = pipeline.fit(trainingData);

    //Export this model
    byte[] exportedModel = ModelExporter.export(sparkPipeline);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    List<Row> output = sparkPipeline.transform(testData).select("features", "label","prediction","rawPrediction").collectAsList();

    //compare predictions
    for (Row row : output) {
    	Map<String, Object> data_ = new HashMap<>();
    	double [] actualRawPrediction = ((DenseVector) row.get(3)).toArray();
        data_.put("features", ((SparseVector) row.get(0)).toArray());
        data_.put("label", (row.get(1)).toString());
        transformer.transform(data_);
        System.out.println(data_);
        System.out.println(data_.get("prediction"));
        assertEquals((double)data_.get("prediction"), (double)row.get(2), EPSILON);
        assertArrayEquals((double[]) data_.get("rawPrediction"), actualRawPrediction, EPSILON);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:54,代码来源:DecisionTreeClassificationModelBridgePipelineTest.java


示例14: testPipeline

import org.apache.spark.ml.PipelineModel; //导入依赖的package包/类
@Test
public void testPipeline() {
    // Prepare training documents, which are labeled.
    StructType schema = createStructType(new StructField[]{
            createStructField("id", LongType, false),
            createStructField("text", StringType, false),
            createStructField("label", DoubleType, false)
    });
    Dataset<Row> trainingData = spark.createDataFrame(Arrays.asList(
            cr(0L, "a b c d e spark", 1.0),
            cr(1L, "b d", 0.0),
            cr(2L, "spark f g h", 1.0),
            cr(3L, "hadoop mapreduce", 0.0)
    ), schema);

    // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and LogisticRegression.
    RegexTokenizer tokenizer = new RegexTokenizer()
            .setInputCol("text")
            .setOutputCol("words")
            .setPattern("\\s")
            .setGaps(true)
            .setToLowercase(false);

    HashingTF hashingTF = new HashingTF()
            .setNumFeatures(1000)
            .setInputCol(tokenizer.getOutputCol())
            .setOutputCol("features");
    LogisticRegression lr = new LogisticRegression()
            .setMaxIter(10)
            .setRegParam(0.01);
    Pipeline pipeline = new Pipeline()
            .setStages(new PipelineStage[]{tokenizer, hashingTF, lr});

    // Fit the pipeline to training documents.
    PipelineModel sparkPipelineModel = pipeline.fit(trainingData);


    //Export this model
    byte[] exportedModel = ModelExporter.export(sparkPipelineModel);
    System.out.println(new String(exportedModel));

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    //prepare test data
    StructType testSchema = createStructType(new StructField[]{
            createStructField("id", LongType, false),
            createStructField("text", StringType, false),
    });
    Dataset<Row> testData = spark.createDataFrame(Arrays.asList(
            cr(4L, "spark i j k"),
            cr(5L, "l m n"),
            cr(6L, "mapreduce spark"),
            cr(7L, "apache hadoop")
    ), testSchema);

    //verify that predictions for spark pipeline and exported pipeline are the same
    List<Row> predictions = sparkPipelineModel.transform(testData).select("id", "text", "probability", "prediction").collectAsList();
    for (Row r : predictions) {
        System.out.println(r);
        double sparkPipelineOp = r.getDouble(3);
        Map<String, Object> data = new HashMap<String, Object>();
        data.put("text", r.getString(1));
        transformer.transform(data);
        double exportedPipelineOp = (double) data.get("prediction");
        double exportedPipelineProb = (double) data.get("probability");
        assertEquals(sparkPipelineOp, exportedPipelineOp, 0.01);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:70,代码来源:PipelineBridgeTest.java


示例15: testRandomForestRegressionWithPipeline

import org.apache.spark.ml.PipelineModel; //导入依赖的package包/类
@Test
public void testRandomForestRegressionWithPipeline() {
    // Load the data stored in LIBSVM format as a DataFrame.
    DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/regression_test.libsvm");

    // Split the data into training and test sets (30% held out for testing)
    DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
    DataFrame trainingData = splits[0];
    DataFrame testData = splits[1];

    // Train a RandomForest model.
    RandomForestRegressionModel regressionModel = new RandomForestRegressor()
            .setFeaturesCol("features").fit(trainingData);

    Pipeline pipeline = new Pipeline()
            .setStages(new PipelineStage[]{regressionModel});

    // Train model.  This also runs the indexer.
    PipelineModel sparkPipeline = pipeline.fit(trainingData);

    //Export this model
    byte[] exportedModel = ModelExporter.export(sparkPipeline, null);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    Row[] sparkOutput = sparkPipeline.transform(testData).select("features", "prediction").collect();

    //compare predictions
    for (Row row : sparkOutput) {
        Vector v = (Vector) row.get(0);
        double actual = row.getDouble(1);

        Map<String, Object> inputData = new HashMap<String, Object>();
        inputData.put(transformer.getInputKeys().iterator().next(), v.toArray());
        transformer.transform(inputData);
        double predicted = (double) inputData.get(transformer.getOutputKeys().iterator().next());

        assertEquals(actual, predicted, EPSILON);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:42,代码来源:RandomForestRegressionModelInfoAdapterBridgeTest.java


示例16: testDecisionTreeRegressionWithPipeline

import org.apache.spark.ml.PipelineModel; //导入依赖的package包/类
@Test
public void testDecisionTreeRegressionWithPipeline() {
    // Load the data stored in LIBSVM format as a DataFrame.
    DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/regression_test.libsvm");

    // Split the data into training and test sets (30% held out for testing)
    DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
    DataFrame trainingData = splits[0];
    DataFrame testData = splits[1];

    // Train a DecisionTree model.
    DecisionTreeRegressor dt = new DecisionTreeRegressor()
            .setFeaturesCol("features");

    Pipeline pipeline = new Pipeline()
            .setStages(new PipelineStage[]{dt});

    // Train model.  This also runs the indexer.
    PipelineModel sparkPipeline = pipeline.fit(trainingData);

    //Export this model
    byte[] exportedModel = ModelExporter.export(sparkPipeline, null);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    Row[] sparkOutput = sparkPipeline.transform(testData).select("features", "prediction").collect();

    //compare predictions
    for (Row row : sparkOutput) {
        Vector v = (Vector) row.get(0);
        double actual = row.getDouble(1);

        Map<String, Object> inputData = new HashMap<String, Object>();
        inputData.put(transformer.getInputKeys().iterator().next(), v.toArray());
        transformer.transform(inputData);
        double predicted = (double) inputData.get(transformer.getOutputKeys().iterator().next());

        assertEquals(actual, predicted, EPSILON);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:42,代码来源:DecisionTreeRegressionModelBridgeTest.java


示例17: testPipeline

import org.apache.spark.ml.PipelineModel; //导入依赖的package包/类
@Test
public void testPipeline() {
    // Prepare training documents, which are labeled.
    StructType schema = createStructType(new StructField[]{
            createStructField("id", LongType, false),
            createStructField("text", StringType, false),
            createStructField("label", DoubleType, false)
    });
    DataFrame trainingData = sqlContext.createDataFrame(Arrays.asList(
            cr(0L, "a b c d e spark", 1.0),
            cr(1L, "b d", 0.0),
            cr(2L, "spark f g h", 1.0),
            cr(3L, "hadoop mapreduce", 0.0)
    ), schema);

    // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and LogisticRegression.
    RegexTokenizer tokenizer = new RegexTokenizer()
            .setInputCol("text")
            .setOutputCol("words")
            .setPattern("\\s")
            .setGaps(true)
            .setToLowercase(false);

    HashingTF hashingTF = new HashingTF()
            .setNumFeatures(1000)
            .setInputCol(tokenizer.getOutputCol())
            .setOutputCol("features");
    LogisticRegression lr = new LogisticRegression()
            .setMaxIter(10)
            .setRegParam(0.01);
    Pipeline pipeline = new Pipeline()
            .setStages(new PipelineStage[]{tokenizer, hashingTF, lr});

    // Fit the pipeline to training documents.
    PipelineModel sparkPipelineModel = pipeline.fit(trainingData);


    //Export this model
    byte[] exportedModel = ModelExporter.export(sparkPipelineModel, trainingData);
    System.out.println(new String(exportedModel));

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    //prepare test data
    StructType testSchema = createStructType(new StructField[]{
            createStructField("id", LongType, false),
            createStructField("text", StringType, false),
    });
    DataFrame testData = sqlContext.createDataFrame(Arrays.asList(
            cr(4L, "spark i j k"),
            cr(5L, "l m n"),
            cr(6L, "mapreduce spark"),
            cr(7L, "apache hadoop")
    ), testSchema);

    //verify that predictions for spark pipeline and exported pipeline are the same
    Row[] predictions = sparkPipelineModel.transform(testData).select("id", "text", "probability", "prediction").collect();
    for (Row r : predictions) {
        System.out.println(r);
        double sparkPipelineOp = r.getDouble(3);
        Map<String, Object> data = new HashMap<String, Object>();
        data.put("text", r.getString(1));
        transformer.transform(data);
        double exportedPipelineOp = (double) data.get("prediction");
        double exportedPipelineProb = (double) data.get("probability");
        assertEquals(sparkPipelineOp, exportedPipelineOp, EPSILON);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:70,代码来源:PipelineBridgeTest.java


示例18: testRandomForestClassificationWithPipeline

import org.apache.spark.ml.PipelineModel; //导入依赖的package包/类
@Test
public void testRandomForestClassificationWithPipeline() {
    // Load the data stored in LIBSVM format as a DataFrame.
    DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/classification_test.libsvm");

    // Split the data into training and test sets (30% held out for testing)
    DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
    DataFrame trainingData = splits[0];
    DataFrame testData = splits[1];

    StringIndexer indexer = new StringIndexer()
            .setInputCol("label")
            .setOutputCol("labelIndex");

    // Train a DecisionTree model.
    RandomForestClassifier classifier = new RandomForestClassifier()
            .setLabelCol("labelIndex")
            .setFeaturesCol("features")
            .setPredictionCol("prediction")
            .setRawPredictionCol("rawPrediction")
            .setProbabilityCol("probability");


    Pipeline pipeline = new Pipeline()
            .setStages(new PipelineStage[]{indexer, classifier});

    // Train model.  This also runs the indexer.
    PipelineModel sparkPipeline = pipeline.fit(trainingData);

    //Export this model
    byte[] exportedModel = ModelExporter.export(sparkPipeline, null);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    Row[] sparkOutput = sparkPipeline.transform(testData).select("label", "features", "prediction", "rawPrediction", "probability").collect();

    //compare predictions
    for (Row row : sparkOutput) {
        Vector v = (Vector) row.get(1);
        double actual = row.getDouble(2);
        double [] actualProbability = ((Vector) row.get(4)).toArray();
        double[] actualRaw = ((Vector) row.get(3)).toArray();

        Map<String, Object> inputData = new HashMap<String, Object>();
        inputData.put("features", v.toArray());
        inputData.put("label", row.get(0).toString());
        transformer.transform(inputData);
        double predicted = (double) inputData.get("prediction");
        double[] probability = (double[]) inputData.get("probability");
        double[] rawPrediction = (double[]) inputData.get("rawPrediction");

        assertEquals(actual, predicted, EPSILON);
        assertArrayEquals(actualProbability, probability, EPSILON);
        assertArrayEquals(actualRaw, rawPrediction, EPSILON);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:58,代码来源:RandomForestClassificationModelInfoAdapterBridgeTest.java


示例19: testDecisionTreeClassificationWithPipeline

import org.apache.spark.ml.PipelineModel; //导入依赖的package包/类
@Test
public void testDecisionTreeClassificationWithPipeline() {
    // Load the data stored in LIBSVM format as a DataFrame.
    DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/classification_test.libsvm");

    // Split the data into training and test sets (30% held out for testing)
    DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
    DataFrame trainingData = splits[0];
    DataFrame testData = splits[1];

    StringIndexer indexer = new StringIndexer()
            .setInputCol("label")
            .setOutputCol("labelIndex");

    // Train a DecisionTree model.
    DecisionTreeClassifier classificationModel = new DecisionTreeClassifier()
            .setLabelCol("labelIndex")
            .setFeaturesCol("features");

    Pipeline pipeline = new Pipeline()
            .setStages(new PipelineStage[]{indexer, classificationModel});

    // Train model.  This also runs the indexer.
    PipelineModel sparkPipeline = pipeline.fit(trainingData);

    //Export this model
    byte[] exportedModel = ModelExporter.export(sparkPipeline, null);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    Row[] sparkOutput = sparkPipeline.transform(testData).select("label", "features", "prediction").collect();

    //compare predictions
    for (Row row : sparkOutput) {
        Vector v = (Vector) row.get(1);
        double actual = row.getDouble(2);

        Map<String, Object> inputData = new HashMap<String, Object>();
        inputData.put("features", v.toArray());
        inputData.put("label", row.get(0).toString());
        transformer.transform(inputData);
        double predicted = (double) inputData.get("prediction");

        assertEquals(actual, predicted, EPSILON);
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:48,代码来源:DecisionTreeClassificationModelBridgeTest.java


示例20: shouldWorkCorrectlyWithPipeline

import org.apache.spark.ml.PipelineModel; //导入依赖的package包/类
@Test
public void shouldWorkCorrectlyWithPipeline() {

    //Prepare test data
    DataFrame df = getDataFrame();
    Row[] originalData = df.orderBy("id").select("id", "a", "b", "c", "d").collect();

    //prepare transformation pipeline
    FillNAValuesTransformer fillNAValuesTransformer = new FillNAValuesTransformer();
    fillNAValuesTransformer.setNAValueMap( getFillNAMap() );
    Pipeline pipeline = new Pipeline();
    pipeline.setStages(new PipelineStage[]{fillNAValuesTransformer});
    PipelineModel model = pipeline.fit(df);

    //predict
    Row[] sparkOutput = model.transform(df).orderBy("id").select("id", "a", "b", "c", "d").collect();

    //export
    byte[] exportedModel = ModelExporter.export(model, df);
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    //verify correctness
    assertTrue(transformer.getInputKeys().size() == 4);
    assertTrue(transformer.getInputKeys().containsAll(Arrays.asList("a", "b", "c", "d")));
    assertTrue(transformer.getOutputKeys().size() == 4);
    assertTrue(transformer.getOutputKeys().containsAll(Arrays.asList("a", "b", "c", "d")));
    for( int i=0; i < originalData.length; i++) {
        Map<String, Object> input = new HashMap<String, Object>();
        input.put("a", originalData[i].get(1));
        input.put("b", originalData[i].get(2));
        input.put("c", originalData[i].get(3));
        input.put("d", originalData[i].get(4));

        transformer.transform(input);

        assertEquals(sparkOutput[i].get(1), input.get("a"));
        assertEquals(sparkOutput[i].get(2), input.get("b"));
        assertEquals(sparkOutput[i].get(3), input.get("c"));
        assertEquals(sparkOutput[i].get(4), input.get("d"));
    }
}
 
开发者ID:flipkart-incubator,项目名称:spark-transformers,代码行数:42,代码来源:FillNAValuesTransformerBridgeTest.java



注:本文中的org.apache.spark.ml.PipelineModel类示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Java Listbox类代码示例发布时间:2022-05-22
下一篇:
Java DataflowPipelineOptions类代码示例发布时间:2022-05-22
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap