Java源码示例:org.dmg.pmml.regression.RegressionTable
示例1
/**
* @param path Path.
*/
public static LogisticRegressionModel load(String path) {
try (InputStream is = new FileInputStream(new File(path))) {
PMML pmml = PMMLUtil.unmarshal(is);
RegressionModel logRegMdl = (RegressionModel)pmml.getModels().get(0);
RegressionTable regTbl = logRegMdl.getRegressionTables().get(0);
Vector coefficients = new DenseVector(regTbl.getNumericPredictors().size());
for (int i = 0; i < regTbl.getNumericPredictors().size(); i++)
coefficients.set(i, regTbl.getNumericPredictors().get(i).getCoefficient());
double interceptor = regTbl.getIntercept();
return new LogisticRegressionModel(coefficients, interceptor);
}
catch (IOException | JAXBException | SAXException e) {
e.printStackTrace();
}
return null;
}
示例2
@Test
public void marshal() throws Exception {
PMML pmml = new PMML(Version.PMML_4_4.getVersion(), new Header(), new DataDictionary());
RegressionModel regressionModel = new RegressionModel()
.addRegressionTables(new RegressionTable());
pmml.addModels(regressionModel);
JAXBContext context = JAXBContextFactory.createContext(new Class[]{org.dmg.pmml.ObjectFactory.class, org.dmg.pmml.regression.ObjectFactory.class}, null);
Marshaller marshaller = context.createMarshaller();
String string;
try(ByteArrayOutputStream os = new ByteArrayOutputStream()){
marshaller.marshal(pmml, os);
string = os.toString("UTF-8");
}
assertTrue(string.contains("<PMML xmlns=\"http://www.dmg.org/PMML-4_4\""));
assertTrue(string.contains(" version=\"4.4\">"));
assertTrue(string.contains("<RegressionModel>"));
assertTrue(string.contains("</RegressionModel>"));
assertTrue(string.contains("</PMML>"));
}
示例3
static
private boolean isDefault(RegressionTable regressionTable){
if(regressionTable.hasNumericPredictors() || regressionTable.hasCategoricalPredictors() || regressionTable.hasPredictorTerms()){
return false;
}
Number intercept = regressionTable.getIntercept();
if(intercept != null && intercept.doubleValue() != 0d){
return false;
}
return true;
}
示例4
@Override
public RegressionModel encodeModel(TensorFlowEncoder encoder){
DataField dataField = encoder.createDataField(FieldName.create("_target"), OpType.CATEGORICAL, DataType.INTEGER);
RegressionModel regressionModel = encodeRegressionModel(encoder);
List<RegressionTable> regressionTables = regressionModel.getRegressionTables();
List<String> categories;
if(regressionTables.size() == 1){
categories = Arrays.asList("0", "1");
RegressionTable activeRegressionTable = regressionTables.get(0)
.setTargetCategory(categories.get(1));
RegressionTable passiveRegressionTable = new RegressionTable(0)
.setTargetCategory(categories.get(0));
regressionModel.addRegressionTables(passiveRegressionTable);
} else
if(regressionTables.size() > 2){
categories = new ArrayList<>();
for(int i = 0; i < regressionTables.size(); i++){
RegressionTable regressionTable = regressionTables.get(i);
String category = String.valueOf(i);
regressionTable.setTargetCategory(category);
categories.add(category);
}
} else
{
throw new IllegalArgumentException();
}
dataField = encoder.toCategorical(dataField.getName(), categories);
CategoricalLabel categoricalLabel = new CategoricalLabel(dataField);
regressionModel
.setMiningFunction(MiningFunction.CLASSIFICATION)
.setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX)
.setMiningSchema(ModelUtil.createMiningSchema(categoricalLabel))
.setOutput(ModelUtil.createProbabilityOutput(DataType.FLOAT, categoricalLabel));
return regressionModel;
}
示例5
static
public MiningModel createClassification(List<? extends Model> models, RegressionModel.NormalizationMethod normalizationMethod, boolean hasProbabilityDistribution, Schema schema){
CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
// modified here
if(categoricalLabel.size() != models.size()){
throw new IllegalArgumentException();
} // End if
if(normalizationMethod != null){
switch(normalizationMethod){
case NONE:
case SIMPLEMAX:
case SOFTMAX:
break;
default:
throw new IllegalArgumentException();
}
}
MathContext mathContext = null;
List<RegressionTable> regressionTables = new ArrayList<>();
for(int i = 0; i < categoricalLabel.size(); i++){
Model model = models.get(i);
MathContext modelMathContext = model.getMathContext();
if(modelMathContext == null){
modelMathContext = MathContext.DOUBLE;
} // End if
if(mathContext == null){
mathContext = modelMathContext;
} else
{
if(!Objects.equals(mathContext, modelMathContext)){
throw new IllegalArgumentException();
}
}
Feature feature = MODEL_PREDICTION.apply(model);
RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(Collections.singletonList(feature), Collections.singletonList(1d), null)
.setTargetCategory(categoricalLabel.getValue(i));
regressionTables.add(regressionTable);
}
RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), regressionTables)
.setNormalizationMethod(normalizationMethod)
.setMathContext(ModelUtil.simplifyMathContext(mathContext))
.setOutput(hasProbabilityDistribution ? ModelUtil.createProbabilityOutput(mathContext, categoricalLabel) : null);
List<Model> segmentationModels = new ArrayList<>(models);
segmentationModels.add(regressionModel);
return createModelChain(segmentationModels, schema);
}
示例6
@Test
public void apply(){
RegressionTable regressionTable = new RegressionTable()
.addNumericPredictors(new NumericPredictor(FieldName.create("x1"), 1d));
FieldReferenceFinder fieldReferenceFinder = new FieldReferenceFinder();
fieldReferenceFinder.applyTo(regressionTable);
assertEquals(Collections.singleton(FieldName.create("x1")), fieldReferenceFinder.getFieldNames());
fieldReferenceFinder.reset();
assertEquals(Collections.emptySet(), fieldReferenceFinder.getFieldNames());
regressionTable
.addNumericPredictors(new NumericPredictor(FieldName.create("x2"), -1d));
fieldReferenceFinder.applyTo(regressionTable);
assertEquals(new HashSet<>(Arrays.asList(FieldName.create("x1"), FieldName.create("x2"))), fieldReferenceFinder.getFieldNames());
fieldReferenceFinder.reset();
assertEquals(Collections.emptySet(), fieldReferenceFinder.getFieldNames());
}
示例7
@Override
public VisitorAction visit(RegressionTable regressionTable){
regressionTable.setTargetCategory(parseTargetValue(regressionTable.getTargetCategory()));
return super.visit(regressionTable);
}
示例8
@Override
protected <V extends Number> Map<FieldName, ?> evaluateRegression(ValueFactory<V> valueFactory, EvaluationContext context){
RegressionModel regressionModel = getModel();
TargetField targetField = getTargetField();
FieldName targetName = regressionModel.getTargetField();
if(targetName != null && !Objects.equals(targetField.getFieldName(), targetName)){
throw new InvalidAttributeException(regressionModel, PMMLAttributes.REGRESSIONMODEL_TARGETFIELD, targetName);
}
List<RegressionTable> regressionTables = regressionModel.getRegressionTables();
if(regressionTables.size() != 1){
throw new InvalidElementListException(regressionTables);
}
RegressionTable regressionTable = regressionTables.get(0);
Value<V> result = evaluateRegressionTable(valueFactory, regressionTable, context);
if(result == null){
return TargetUtil.evaluateRegressionDefault(valueFactory, targetField);
}
RegressionModel.NormalizationMethod normalizationMethod = regressionModel.getNormalizationMethod();
switch(normalizationMethod){
case NONE:
case SOFTMAX:
case LOGIT:
case EXP:
case PROBIT:
case CLOGLOG:
case LOGLOG:
case CAUCHIT:
RegressionModelUtil.normalizeRegressionResult(normalizationMethod, result);
break;
case SIMPLEMAX:
throw new InvalidAttributeException(regressionModel, normalizationMethod);
default:
throw new UnsupportedAttributeException(regressionModel, normalizationMethod);
}
return TargetUtil.evaluateRegression(targetField, result);
}
示例9
@Test
public void parseRegressionModel(){
Value falseValue = new Value("false");
Value trueValue = new Value("true");
Value invalidValue = new Value("N/A");
DataField dataField = new DataField(FieldName.create("x1"), OpType.CATEGORICAL, DataType.STRING)
.addValues(falseValue, trueValue, invalidValue);
DataDictionary dataDictionary = new DataDictionary()
.addDataFields(dataField);
CategoricalPredictor falseTerm = new CategoricalPredictor(dataField.getName(), "false", -1d);
CategoricalPredictor trueTerm = new CategoricalPredictor(dataField.getName(), "true", 1d);
RegressionTable regressionTable = new RegressionTable()
.addCategoricalPredictors(falseTerm, trueTerm);
MiningField miningField = new MiningField(dataField.getName())
.setMissingValueReplacement("false")
.setInvalidValueReplacement("N/A");
MiningSchema miningSchema = new MiningSchema()
.addMiningFields(miningField);
RegressionModel regressionModel = new RegressionModel(MiningFunction.REGRESSION, miningSchema, null)
.addRegressionTables(regressionTable);
PMML pmml = new PMML(Version.PMML_4_3.getVersion(), new Header(), dataDictionary)
.addModels(regressionModel);
List<DataField> dataFields = dataDictionary.getDataFields();
ValueParser parser = new ValueParser(ValueParser.Mode.STRICT);
parser.applyTo(pmml);
dataField = dataFields.get(0);
assertEquals("false", falseValue.getValue());
assertEquals("true", trueValue.getValue());
assertEquals("N/A", invalidValue.getValue());
assertEquals("false", falseTerm.getValue());
assertEquals("true", trueTerm.getValue());
assertEquals("false", miningField.getMissingValueReplacement());
assertEquals("N/A", miningField.getInvalidValueReplacement());
dataField.setDataType(DataType.BOOLEAN);
parser.applyTo(pmml);
assertEquals(Boolean.FALSE, falseValue.getValue());
assertEquals(Boolean.TRUE, trueValue.getValue());
assertEquals("N/A", invalidValue.getValue());
assertEquals(Boolean.FALSE, falseTerm.getValue());
assertEquals(Boolean.TRUE, trueTerm.getValue());
assertEquals(Boolean.FALSE, miningField.getMissingValueReplacement());
assertEquals("N/A", miningField.getInvalidValueReplacement());
}
示例10
static
public <C extends ModelConverter<?> & HasRegressionTableOptions> Model createSoftmaxClassification(C converter, Matrix coefficients, Vector intercepts, Schema schema){
CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();
MatrixUtil.checkRows(categoricalLabel.size(), coefficients);
List<RegressionTable> regressionTables = new ArrayList<>();
for(int i = 0; i < categoricalLabel.size(); i++){
Object targetCategory = categoricalLabel.getValue(i);
List<Feature> features = new ArrayList<>(schema.getFeatures());
List<Double> featureCoefficients = new ArrayList<>(MatrixUtil.getRow(coefficients, i));
RegressionTableUtil.simplify(converter, targetCategory, features, featureCoefficients);
double intercept = intercepts.apply(i);
RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(features, featureCoefficients, intercept)
.setTargetCategory(targetCategory);
regressionTables.add(regressionTable);
}
RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), regressionTables)
.setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX);
return regressionModel;
}