Java源码示例:org.dmg.pmml.regression.RegressionTable

示例1
/**
 * @param path Path.
 */
public static LogisticRegressionModel load(String path) {
    try (InputStream is = new FileInputStream(new File(path))) {
        PMML pmml = PMMLUtil.unmarshal(is);

        RegressionModel logRegMdl = (RegressionModel)pmml.getModels().get(0);

        RegressionTable regTbl = logRegMdl.getRegressionTables().get(0);

        Vector coefficients = new DenseVector(regTbl.getNumericPredictors().size());

        for (int i = 0; i < regTbl.getNumericPredictors().size(); i++)
            coefficients.set(i, regTbl.getNumericPredictors().get(i).getCoefficient());

        double interceptor = regTbl.getIntercept();

        return new LogisticRegressionModel(coefficients, interceptor);
    }
    catch (IOException | JAXBException | SAXException e) {
        e.printStackTrace();
    }

    return null;
}
 
示例2
@Test
public void marshal() throws Exception {
	PMML pmml = new PMML(Version.PMML_4_4.getVersion(), new Header(), new DataDictionary());

	RegressionModel regressionModel = new RegressionModel()
		.addRegressionTables(new RegressionTable());

	pmml.addModels(regressionModel);

	JAXBContext context = JAXBContextFactory.createContext(new Class[]{org.dmg.pmml.ObjectFactory.class, org.dmg.pmml.regression.ObjectFactory.class}, null);

	Marshaller marshaller = context.createMarshaller();

	String string;

	try(ByteArrayOutputStream os = new ByteArrayOutputStream()){
		marshaller.marshal(pmml, os);

		string = os.toString("UTF-8");
	}

	assertTrue(string.contains("<PMML xmlns=\"http://www.dmg.org/PMML-4_4\""));
	assertTrue(string.contains(" version=\"4.4\">"));
	assertTrue(string.contains("<RegressionModel>"));
	assertTrue(string.contains("</RegressionModel>"));
	assertTrue(string.contains("</PMML>"));
}
 
示例3
static
private boolean isDefault(RegressionTable regressionTable){

	if(regressionTable.hasNumericPredictors() || regressionTable.hasCategoricalPredictors() || regressionTable.hasPredictorTerms()){
		return false;
	}

	Number intercept = regressionTable.getIntercept();
	if(intercept != null && intercept.doubleValue() != 0d){
		return false;
	}

	return true;
}
 
示例4
@Override
public RegressionModel encodeModel(TensorFlowEncoder encoder){
	DataField dataField = encoder.createDataField(FieldName.create("_target"), OpType.CATEGORICAL, DataType.INTEGER);

	RegressionModel regressionModel = encodeRegressionModel(encoder);

	List<RegressionTable> regressionTables = regressionModel.getRegressionTables();

	List<String> categories;

	if(regressionTables.size() == 1){
		categories = Arrays.asList("0", "1");

		RegressionTable activeRegressionTable = regressionTables.get(0)
			.setTargetCategory(categories.get(1));

		RegressionTable passiveRegressionTable = new RegressionTable(0)
			.setTargetCategory(categories.get(0));

		regressionModel.addRegressionTables(passiveRegressionTable);
	} else

	if(regressionTables.size() > 2){
		categories = new ArrayList<>();

		for(int i = 0; i < regressionTables.size(); i++){
			RegressionTable regressionTable = regressionTables.get(i);
			String category = String.valueOf(i);

			regressionTable.setTargetCategory(category);

			categories.add(category);
		}
	} else

	{
		throw new IllegalArgumentException();
	}

	dataField = encoder.toCategorical(dataField.getName(), categories);

	CategoricalLabel categoricalLabel = new CategoricalLabel(dataField);

	regressionModel
		.setMiningFunction(MiningFunction.CLASSIFICATION)
		.setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX)
		.setMiningSchema(ModelUtil.createMiningSchema(categoricalLabel))
		.setOutput(ModelUtil.createProbabilityOutput(DataType.FLOAT, categoricalLabel));

	return regressionModel;
}
 
示例5
static
public MiningModel createClassification(List<? extends Model> models, RegressionModel.NormalizationMethod normalizationMethod, boolean hasProbabilityDistribution, Schema schema){
    CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();

    // modified here
    if(categoricalLabel.size() != models.size()){
        throw new IllegalArgumentException();
    } // End if

    if(normalizationMethod != null){

        switch(normalizationMethod){
            case NONE:
            case SIMPLEMAX:
            case SOFTMAX:
                break;
            default:
                throw new IllegalArgumentException();
        }
    }

    MathContext mathContext = null;

    List<RegressionTable> regressionTables = new ArrayList<>();

    for(int i = 0; i < categoricalLabel.size(); i++){
        Model model = models.get(i);

        MathContext modelMathContext = model.getMathContext();
        if(modelMathContext == null){
            modelMathContext = MathContext.DOUBLE;
        } // End if

        if(mathContext == null){
            mathContext = modelMathContext;
        } else

        {
            if(!Objects.equals(mathContext, modelMathContext)){
                throw new IllegalArgumentException();
            }
        }

        Feature feature = MODEL_PREDICTION.apply(model);

        RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(Collections.singletonList(feature), Collections.singletonList(1d), null)
                .setTargetCategory(categoricalLabel.getValue(i));

        regressionTables.add(regressionTable);
    }

    RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), regressionTables)
            .setNormalizationMethod(normalizationMethod)
            .setMathContext(ModelUtil.simplifyMathContext(mathContext))
            .setOutput(hasProbabilityDistribution ? ModelUtil.createProbabilityOutput(mathContext, categoricalLabel) : null);

    List<Model> segmentationModels = new ArrayList<>(models);
    segmentationModels.add(regressionModel);

    return createModelChain(segmentationModels, schema);
}
 
示例6
@Test
public void apply(){
	RegressionTable regressionTable = new RegressionTable()
		.addNumericPredictors(new NumericPredictor(FieldName.create("x1"), 1d));

	FieldReferenceFinder fieldReferenceFinder = new FieldReferenceFinder();
	fieldReferenceFinder.applyTo(regressionTable);

	assertEquals(Collections.singleton(FieldName.create("x1")), fieldReferenceFinder.getFieldNames());

	fieldReferenceFinder.reset();

	assertEquals(Collections.emptySet(), fieldReferenceFinder.getFieldNames());

	regressionTable
		.addNumericPredictors(new NumericPredictor(FieldName.create("x2"), -1d));

	fieldReferenceFinder.applyTo(regressionTable);

	assertEquals(new HashSet<>(Arrays.asList(FieldName.create("x1"), FieldName.create("x2"))), fieldReferenceFinder.getFieldNames());

	fieldReferenceFinder.reset();

	assertEquals(Collections.emptySet(), fieldReferenceFinder.getFieldNames());
}
 
示例7
@Override
public VisitorAction visit(RegressionTable regressionTable){
	regressionTable.setTargetCategory(parseTargetValue(regressionTable.getTargetCategory()));

	return super.visit(regressionTable);
}
 
示例8
@Override
protected <V extends Number> Map<FieldName, ?> evaluateRegression(ValueFactory<V> valueFactory, EvaluationContext context){
	RegressionModel regressionModel = getModel();

	TargetField targetField = getTargetField();

	FieldName targetName = regressionModel.getTargetField();
	if(targetName != null && !Objects.equals(targetField.getFieldName(), targetName)){
		throw new InvalidAttributeException(regressionModel, PMMLAttributes.REGRESSIONMODEL_TARGETFIELD, targetName);
	}

	List<RegressionTable> regressionTables = regressionModel.getRegressionTables();
	if(regressionTables.size() != 1){
		throw new InvalidElementListException(regressionTables);
	}

	RegressionTable regressionTable = regressionTables.get(0);

	Value<V> result = evaluateRegressionTable(valueFactory, regressionTable, context);
	if(result == null){
		return TargetUtil.evaluateRegressionDefault(valueFactory, targetField);
	}

	RegressionModel.NormalizationMethod normalizationMethod = regressionModel.getNormalizationMethod();
	switch(normalizationMethod){
		case NONE:
		case SOFTMAX:
		case LOGIT:
		case EXP:
		case PROBIT:
		case CLOGLOG:
		case LOGLOG:
		case CAUCHIT:
			RegressionModelUtil.normalizeRegressionResult(normalizationMethod, result);
			break;
		case SIMPLEMAX:
			throw new InvalidAttributeException(regressionModel, normalizationMethod);
		default:
			throw new UnsupportedAttributeException(regressionModel, normalizationMethod);
	}

	return TargetUtil.evaluateRegression(targetField, result);
}
 
示例9
@Test
public void parseRegressionModel(){
	Value falseValue = new Value("false");
	Value trueValue = new Value("true");
	Value invalidValue = new Value("N/A");

	DataField dataField = new DataField(FieldName.create("x1"), OpType.CATEGORICAL, DataType.STRING)
		.addValues(falseValue, trueValue, invalidValue);

	DataDictionary dataDictionary = new DataDictionary()
		.addDataFields(dataField);

	CategoricalPredictor falseTerm = new CategoricalPredictor(dataField.getName(), "false", -1d);
	CategoricalPredictor trueTerm = new CategoricalPredictor(dataField.getName(), "true", 1d);

	RegressionTable regressionTable = new RegressionTable()
		.addCategoricalPredictors(falseTerm, trueTerm);

	MiningField miningField = new MiningField(dataField.getName())
		.setMissingValueReplacement("false")
		.setInvalidValueReplacement("N/A");

	MiningSchema miningSchema = new MiningSchema()
		.addMiningFields(miningField);

	RegressionModel regressionModel = new RegressionModel(MiningFunction.REGRESSION, miningSchema, null)
		.addRegressionTables(regressionTable);

	PMML pmml = new PMML(Version.PMML_4_3.getVersion(), new Header(), dataDictionary)
		.addModels(regressionModel);

	List<DataField> dataFields = dataDictionary.getDataFields();

	ValueParser parser = new ValueParser(ValueParser.Mode.STRICT);
	parser.applyTo(pmml);

	dataField = dataFields.get(0);

	assertEquals("false", falseValue.getValue());
	assertEquals("true", trueValue.getValue());
	assertEquals("N/A", invalidValue.getValue());

	assertEquals("false", falseTerm.getValue());
	assertEquals("true", trueTerm.getValue());

	assertEquals("false", miningField.getMissingValueReplacement());
	assertEquals("N/A", miningField.getInvalidValueReplacement());

	dataField.setDataType(DataType.BOOLEAN);

	parser.applyTo(pmml);

	assertEquals(Boolean.FALSE, falseValue.getValue());
	assertEquals(Boolean.TRUE, trueValue.getValue());
	assertEquals("N/A", invalidValue.getValue());

	assertEquals(Boolean.FALSE, falseTerm.getValue());
	assertEquals(Boolean.TRUE, trueTerm.getValue());

	assertEquals(Boolean.FALSE, miningField.getMissingValueReplacement());
	assertEquals("N/A", miningField.getInvalidValueReplacement());
}
 
示例10
static
public <C extends ModelConverter<?> & HasRegressionTableOptions> Model createSoftmaxClassification(C converter, Matrix coefficients, Vector intercepts, Schema schema){
	CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();

	MatrixUtil.checkRows(categoricalLabel.size(), coefficients);

	List<RegressionTable> regressionTables = new ArrayList<>();

	for(int i = 0; i < categoricalLabel.size(); i++){
		Object targetCategory = categoricalLabel.getValue(i);

		List<Feature> features = new ArrayList<>(schema.getFeatures());
		List<Double> featureCoefficients = new ArrayList<>(MatrixUtil.getRow(coefficients, i));

		RegressionTableUtil.simplify(converter, targetCategory, features, featureCoefficients);

		double intercept = intercepts.apply(i);

		RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(features, featureCoefficients, intercept)
			.setTargetCategory(targetCategory);

		regressionTables.add(regressionTable);
	}

	RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), regressionTables)
		.setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX);

	return regressionModel;
}