Java源码示例:org.tensorflow.framework.AttrValue

示例1
private static TensorShapeProto tensorFlowShape(NodeDef node) {
    // Use specific shape if available...
    AttrValue attrShape = node.getAttrMap().get("shape");
    if (attrShape != null && attrShape.getValueCase() == AttrValue.ValueCase.SHAPE) {
        return attrShape.getShape();
    }

    // ... else use inferred shape
    AttrValue attrOutputShapes = node.getAttrMap().get("_output_shapes");
    if (attrOutputShapes == null)
        throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
                                           "does not exist");
    if (attrOutputShapes.getValueCase() != AttrValue.ValueCase.LIST)
        throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
                                           "is not of expected type");

    return attrOutputShapes.getList().getShape(0); // support multiple outputs?
}
 
示例2
@Override
public Optional<List<Value>> getList(String key) {
    if (attributeMap.containsKey(key)) {
        AttrValue attrValue = attributeMap.get(key);
        if (attrValue.getValueCase() == AttrValue.ValueCase.LIST) {
            AttrValue.ListValue listValue = attrValue.getList();
            if ( ! listValue.getBList().isEmpty()) {
                return Optional.of(listValue.getBList().stream().map(BooleanValue::new).collect(Collectors.toList()));
            }
            if ( ! listValue.getIList().isEmpty()) {
                return Optional.of(listValue.getIList().stream().map(DoubleValue::new).collect(Collectors.toList()));
            }
            if ( ! listValue.getFList().isEmpty()) {
                return Optional.of(listValue.getFList().stream().map(DoubleValue::new).collect(Collectors.toList()));
            }
            // add the rest
        }
    }
    return Optional.empty();
}
 
示例3
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
    val isTransposeA = attributesForNode.get("transpose_a").getB();
    val isTransposeB = attributesForNode.get("transpose_b").getB();
    MMulTranspose mMulTranspose = MMulTranspose.builder()
            .transposeA(isTransposeA).transposeB(isTransposeB)
            .build();
    this.mMulTranspose = mMulTranspose;
    val args = args();
    for(val arg : args) {
        if(sameDiff.isPlaceHolder(arg.getVarName()) || arg.getShape() == null) {
            sameDiff.addPropertyToResolve(this,arg.getVarName());
        }
    }
}
 
示例4
@Override
public Optional<Value> get(String key, OrderedTensorType type) {
    if (attributeMap.containsKey(key)) {
        AttrValue attrValue = attributeMap.get(key);
        if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
            return Optional.of(new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type)));
        }
    }
    return get(key);
}
 
示例5
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);

    if (nodeDef.containsAttr("use_locking")) {
        if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
            bArguments.add(true);
        } else {
            bArguments.add(false);
        }
    } else
        bArguments.add(false);
}
 
示例6
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    String str = attributesForNode.get("message").getS().toStringUtf8();
    //No "string args" support in libnd4j custom ops -> make it a constant instead
    String name = nodeDef.getName();
    SDVariable msg = initWith.constant(name + "/message", Nd4j.scalar(str));
    List<String> newInputs = new ArrayList<>(2);
    newInputs.addAll(initWith.getOps().get(name).getInputsToOp());
    newInputs.add(msg.name());
    initWith.getOps().get(name).setInputsToOp(newInputs);
    initWith.getVariables().get(msg.name()).setInputsForOp(Collections.singletonList(getOwnName()));    }
 
示例7
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
    //permute dimensions are not specified as second input
    if (nodeDef.getInputCount() < 2)
        return;
    NodeDef permuteDimsNode = null;
    for (int i = 0; i < graph.getNodeCount(); i++) {
        if (graph.getNode(i).getName().equals(nodeDef.getInput(1))) {
            permuteDimsNode = graph.getNode(i);
        }

    }

    INDArray permuteArrayOp = TFGraphMapper.getNDArrayFromTensor(permuteDimsNode);
    if (permuteArrayOp != null) {
        this.permuteDims = permuteArrayOp.data().asInt();
    }

    //handle once properly mapped
    if (arg().getShape() == null || arg().getVariableType() == VariableType.PLACEHOLDER || arg().getArr() == null) {
        return;
    }

    INDArray arr = sameDiff.getArrForVarName(arg().name());

    if(permuteArrayOp != null){
        addInputArgument(arr, permuteArrayOp);
    } else {
        addInputArgument(arr);
    }

    if (arr != null && permuteDims == null) {
        this.permuteDims = ArrayUtil.reverseCopy(ArrayUtil.range(0, arr.rank()));
    }

    if (permuteDims != null && permuteDims.length < arg().getShape().length)
        throw new ND4JIllegalStateException("Illegal permute found. Not all dimensions specified");
}
 
示例8
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);

    if (nodeDef.containsAttr("use_locking")) {
        if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
            bArguments.add(true);
        } else {
            bArguments.add(false);
        }
    } else
        bArguments.add(false);
}
 
示例9
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val weightsName = nodeDef.getInput(1);
    val variable = initWith.getVariable(weightsName);
    val tmp = initWith.getArrForVarName(weightsName);

    // if second argument is scalar - we should provide array of same shape
    if (tmp != null) {
        if (tmp.isScalar()) {
            this.pow = tmp.getDouble(0);
        }
    }
}
 
示例10
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);

    dataType = TFGraphMapper.convertType(nodeDef.getAttrOrThrow("out_type").getType());
    val dtype = DataTypeAdapter.dtypeConv(nodeDef.getAttrOrThrow("out_type").getType());
    iArguments.add((long) FlatBuffersMapper.getDataTypeAsByte(dtype));
}
 
示例11
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);

    if (nodeDef.containsAttr("use_locking")) {
        if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
            bArguments.add(true);
        } else {
            bArguments.add(false);
        }
    } else
        bArguments.add(false);
}
 
示例12
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    configuration = LSTMConfiguration.builder()
            .forgetBias(attributesForNode.get("forget_bias").getF())
            .clippingCellValue(attributesForNode.get("cell_clip").getF())
            .peepHole(attributesForNode.get("use_peephole").getB())
            .dataFormat(RnnDataFormat.TNS)  //Always time major for TF BlockLSTM
            .build();
    addIArgument(configuration.iArgs(true));
    addTArgument(configuration.tArgs());
}
 
示例13
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    this.fullUV = attributesForNode.get("full_matrices").getB();
    this.computeUv = attributesForNode.get("compute_uv").getB();
    this.switchNum = 16;
    addIArgument(ArrayUtil.fromBoolean(fullUV), ArrayUtil.fromBoolean(computeUv), switchNum);
}
 
示例14
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
    /**
     * name: "MatMul"
     op: "MatMul"
     input: "input"
     input: "Variable/read"
     attr {
     key: "transpose_b"
     value {
     b: false
     }
     }
     attr {
     key: "transpose_a"
     value {
     b: false
     }
     }
     attr {
     key: "T"
     value {
     type: DT_FLOAT
     }
     }

     */

    val isTransposeA = attributesForNode.get("transpose_a").getB();
    val isTransposeB = attributesForNode.get("transpose_b").getB();
    MMulTranspose mMulTranspose = MMulTranspose.builder()
            .transposeA(isTransposeA).transposeB(isTransposeB)
            .build();
    this.mMulTranspose = mMulTranspose;
    val args = args();
}
 
示例15
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    if(attributesForNode.containsKey("T")) {
        outputType = TFGraphMapper.convertType(attributesForNode.get("T").getType());
    }

    addArgs();
}
 
示例16
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);

    boolean isTransposeA;
    boolean isTransposeB;
    if(nodeDef.getOp().equalsIgnoreCase("MatMul")){
        isTransposeA = attributesForNode.get("transpose_a").getB();
        isTransposeB = attributesForNode.get("transpose_b").getB();

    } else {
        //BatchMatMul, BatchMatMulV2
        //In practice, BatchMatMul seems to use "adj_x" and "adj_y" instead of "transpose_a" and "transpose_b"
        if(attributesForNode.containsKey("transpose_a")){
            isTransposeA = attributesForNode.get("transpose_a").getB();
        } else {
            isTransposeA = attributesForNode.get("adj_x").getB();
        }
        if(attributesForNode.containsKey("transpose_b")){
            isTransposeB = attributesForNode.get("transpose_b").getB();
        } else {
            isTransposeB = attributesForNode.get("adj_y").getB();
        }
    }
    MMulTranspose mMulTranspose = MMulTranspose.builder()
            .transposeA(isTransposeA).transposeB(isTransposeB)
            .build();
    this.mt = mMulTranspose;
    iArguments.clear();
    addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()), ArrayUtil.fromBoolean(mt.isTransposeB()));
}
 
示例17
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);

    if (nodeDef.containsAttr("use_locking")) {
        if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
            bArguments.add(true);
        } else {
            bArguments.add(false);
        }
    } else
        bArguments.add(false);
}
 
示例18
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val targetNode = TFGraphMapper.getNodeWithNameFromGraph(graph, nodeDef.getInput(1));
    val maxlen = TFGraphMapper.getNDArrayFromTensor(targetNode);
    if (maxlen == null){
        // No 2nd input
        this.is_static_maxlen = true;
    }
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
    if (is_static_maxlen) {
        addIArgument(this.maxLen);
    }

}
 
示例19
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    if(attributesForNode.containsKey("output_type")) {
        outputType = TFGraphMapper.convertType(attributesForNode.get("output_type").getType());
    } else {
        outputType = DataType.LONG;
    }
}
 
示例20
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    if(attributesForNode.containsKey("narrow_range")){
        this.narrowRange = attributesForNode.get("narrow_range").getB();
    }
    this.numBits = (int)attributesForNode.get("num_bits").getI();
    this.min = attributesForNode.get("min").getF();
    this.max = attributesForNode.get("max").getF();
    addArgs();
}
 
示例21
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);

    val attrC = attributesForNode.get("align_corners");
    val attrH = attributesForNode.get("half_pixel_centers");

    this.alignCorners = attrC != null ? attrC.getB() : false;
    this.halfPixelCenters = attrH != null ? attrH.getB() : false;

    addArgs();
}
 
示例22
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    if(attributesForNode.containsKey("narrow_range")){
        this.narrowRange = attributesForNode.get("narrow_range").getB();
    }
    if(attributesForNode.containsKey("num_bits")) {
        this.numBits = (int) attributesForNode.get("num_bits").getI();
    }
    addIArgument(numBits);
    addBArgument(narrowRange);
}
 
示例23
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val inputOne = nodeDef.getInput(1);
    val varFor = initWith.getVariable(inputOne);
    val nodeWithIndex = TFGraphMapper.getNodeWithNameFromGraph(graph,inputOne);
    val var = TFGraphMapper.getArrayFrom(nodeWithIndex,graph);
    if(var != null) {
        val idx = var.getInt(0);
        addIArgument(idx);
    }
}
 
示例24
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
    //Switch order: TF uses [input, gamma, beta, mean, variance]; libnd4j expects [input, mean, variance, gamma, beta]
    SameDiffOp op = initWith.getOps().get(this.getOwnName());
    List<String> list = op.getInputsToOp();
    List<String> newList = Arrays.asList(list.get(0), list.get(3), list.get(4), list.get(1), list.get(2));
    op.setInputsToOp(newList);

    this.applyGamma = true;
    this.applyBeta = true;
    this.epsilon = attributesForNode.get("epsilon").getF();

    if(attributesForNode.containsKey("data_format")){
        String dataFormat = attributesForNode.get("data_format").getS().toStringUtf8();
        //TODO not sure if these conv1d/3d cases appear. But BN definitely uses "NCHW" or "NHWC"
        if(dataFormat.equalsIgnoreCase(Conv2DConfig.NCHW) || dataFormat.equalsIgnoreCase(Conv1DConfig.NCW) || dataFormat.equalsIgnoreCase(Conv3DConfig.NCDHW)){
            jaxis = new int[]{1};
        } else if(dataFormat.equalsIgnoreCase(Conv2DConfig.NHWC)){
            jaxis = new int[]{3};
        } else if(dataFormat.equalsIgnoreCase(Conv1DConfig.NWC)){
            jaxis = new int[]{2};
        } else if(dataFormat.equalsIgnoreCase(Conv3DConfig.NDHWC)){
            jaxis = new int[]{4};
        } else {
            throw new IllegalStateException("Unknown data format: \"" + dataFormat + "\"" );
        }
    }



    addArgs();
}
 
示例25
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    if(attributesForNode.containsKey("output_type")) {
        outputType = TFGraphMapper.convertType(attributesForNode.get("output_type").getType());
    } else {
        outputType = DataType.LONG;
    }
}
 
示例26
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val targetNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph, nodeDef.getInput(1));
    val dimArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", targetNode, graph);

    if (dimArr != null) {
        int axis = dimArr.data().asInt()[0];
        this.axis = axis;
        addIArgument(this.axis);
    } else {
        this.axis = Integer.MAX_VALUE;
        addIArgument(this.axis);
    }
}
 
示例27
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val attrAxis = nodeDef.getAttrOrThrow("axis");
    int axis = (int) attrAxis.getI();
    this.axis = axis;
    addArgs();
}
 
示例28
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val numSplits = (int) attributesForNode.get("num_split").getI();
    this.numSplit = numSplits;
    val splitDim = TFGraphMapper.getInstance().getArrayFrom(TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,nodeDef.getInput(0)),graph);
    if(splitDim != null) {
        this.splitDim = splitDim.getInt(0);
        addIArgument(splitDim.getInt(0));
    }

    addIArgument(numSplits);

}
 
示例29
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val aStrides = nodeDef.getAttrOrThrow("strides");
    val tfStrides = aStrides.getList().getIList();
    val sH = tfStrides.get(1);
    val sW = tfStrides.get(2);

    val aKernels = nodeDef.getAttrOrThrow("ksize");
    val tfKernels = aKernels.getList().getIList();

    val kH = tfKernels.get(1);
    val kW = tfKernels.get(2);

    val aPadding = nodeDef.getAttrOrThrow("padding");
    val padding = aPadding.getList().getIList();

    val paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"","");

    boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");

    if (!isSameMode)
        log.debug("Mode: {}", paddingMode);

    Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder()
            .sH(sH.intValue())
            .sW(sW.intValue())
            .type(null)
            .isSameMode(isSameMode)
            .kH(kH.intValue())
            .kW(kW.intValue())
            .pH(padding.get(0).intValue())
            .pW(padding.get(1).intValue())
            .build();
    this.config = pooling2DConfig;
    addArgs();
    log.debug("Pooling: k: [{},{}]; s: [{}, {}], padding: {}", kH, kW, sH, sW, aPadding);


}
 
示例30
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
    addArgs();
}