Java源码示例:org.tensorflow.framework.AttrValue
示例1
private static TensorShapeProto tensorFlowShape(NodeDef node) {
// Use specific shape if available...
AttrValue attrShape = node.getAttrMap().get("shape");
if (attrShape != null && attrShape.getValueCase() == AttrValue.ValueCase.SHAPE) {
return attrShape.getShape();
}
// ... else use inferred shape
AttrValue attrOutputShapes = node.getAttrMap().get("_output_shapes");
if (attrOutputShapes == null)
throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
"does not exist");
if (attrOutputShapes.getValueCase() != AttrValue.ValueCase.LIST)
throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
"is not of expected type");
return attrOutputShapes.getList().getShape(0); // support multiple outputs?
}
示例2
@Override
public Optional<List<Value>> getList(String key) {
if (attributeMap.containsKey(key)) {
AttrValue attrValue = attributeMap.get(key);
if (attrValue.getValueCase() == AttrValue.ValueCase.LIST) {
AttrValue.ListValue listValue = attrValue.getList();
if ( ! listValue.getBList().isEmpty()) {
return Optional.of(listValue.getBList().stream().map(BooleanValue::new).collect(Collectors.toList()));
}
if ( ! listValue.getIList().isEmpty()) {
return Optional.of(listValue.getIList().stream().map(DoubleValue::new).collect(Collectors.toList()));
}
if ( ! listValue.getFList().isEmpty()) {
return Optional.of(listValue.getFList().stream().map(DoubleValue::new).collect(Collectors.toList()));
}
// add the rest
}
}
return Optional.empty();
}
示例3
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
val isTransposeA = attributesForNode.get("transpose_a").getB();
val isTransposeB = attributesForNode.get("transpose_b").getB();
MMulTranspose mMulTranspose = MMulTranspose.builder()
.transposeA(isTransposeA).transposeB(isTransposeB)
.build();
this.mMulTranspose = mMulTranspose;
val args = args();
for(val arg : args) {
if(sameDiff.isPlaceHolder(arg.getVarName()) || arg.getShape() == null) {
sameDiff.addPropertyToResolve(this,arg.getVarName());
}
}
}
示例4
@Override
public Optional<Value> get(String key, OrderedTensorType type) {
if (attributeMap.containsKey(key)) {
AttrValue attrValue = attributeMap.get(key);
if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
return Optional.of(new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type)));
}
}
return get(key);
}
示例5
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
if (nodeDef.containsAttr("use_locking")) {
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
bArguments.add(true);
} else {
bArguments.add(false);
}
} else
bArguments.add(false);
}
示例6
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
String str = attributesForNode.get("message").getS().toStringUtf8();
//No "string args" support in libnd4j custom ops -> make it a constant instead
String name = nodeDef.getName();
SDVariable msg = initWith.constant(name + "/message", Nd4j.scalar(str));
List<String> newInputs = new ArrayList<>(2);
newInputs.addAll(initWith.getOps().get(name).getInputsToOp());
newInputs.add(msg.name());
initWith.getOps().get(name).setInputsToOp(newInputs);
initWith.getVariables().get(msg.name()).setInputsForOp(Collections.singletonList(getOwnName())); }
示例7
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
//permute dimensions are not specified as second input
if (nodeDef.getInputCount() < 2)
return;
NodeDef permuteDimsNode = null;
for (int i = 0; i < graph.getNodeCount(); i++) {
if (graph.getNode(i).getName().equals(nodeDef.getInput(1))) {
permuteDimsNode = graph.getNode(i);
}
}
INDArray permuteArrayOp = TFGraphMapper.getNDArrayFromTensor(permuteDimsNode);
if (permuteArrayOp != null) {
this.permuteDims = permuteArrayOp.data().asInt();
}
//handle once properly mapped
if (arg().getShape() == null || arg().getVariableType() == VariableType.PLACEHOLDER || arg().getArr() == null) {
return;
}
INDArray arr = sameDiff.getArrForVarName(arg().name());
if(permuteArrayOp != null){
addInputArgument(arr, permuteArrayOp);
} else {
addInputArgument(arr);
}
if (arr != null && permuteDims == null) {
this.permuteDims = ArrayUtil.reverseCopy(ArrayUtil.range(0, arr.rank()));
}
if (permuteDims != null && permuteDims.length < arg().getShape().length)
throw new ND4JIllegalStateException("Illegal permute found. Not all dimensions specified");
}
示例8
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
if (nodeDef.containsAttr("use_locking")) {
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
bArguments.add(true);
} else {
bArguments.add(false);
}
} else
bArguments.add(false);
}
示例9
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
val weightsName = nodeDef.getInput(1);
val variable = initWith.getVariable(weightsName);
val tmp = initWith.getArrForVarName(weightsName);
// if second argument is scalar - we should provide array of same shape
if (tmp != null) {
if (tmp.isScalar()) {
this.pow = tmp.getDouble(0);
}
}
}
示例10
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
dataType = TFGraphMapper.convertType(nodeDef.getAttrOrThrow("out_type").getType());
val dtype = DataTypeAdapter.dtypeConv(nodeDef.getAttrOrThrow("out_type").getType());
iArguments.add((long) FlatBuffersMapper.getDataTypeAsByte(dtype));
}
示例11
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
if (nodeDef.containsAttr("use_locking")) {
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
bArguments.add(true);
} else {
bArguments.add(false);
}
} else
bArguments.add(false);
}
示例12
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
configuration = LSTMConfiguration.builder()
.forgetBias(attributesForNode.get("forget_bias").getF())
.clippingCellValue(attributesForNode.get("cell_clip").getF())
.peepHole(attributesForNode.get("use_peephole").getB())
.dataFormat(RnnDataFormat.TNS) //Always time major for TF BlockLSTM
.build();
addIArgument(configuration.iArgs(true));
addTArgument(configuration.tArgs());
}
示例13
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
this.fullUV = attributesForNode.get("full_matrices").getB();
this.computeUv = attributesForNode.get("compute_uv").getB();
this.switchNum = 16;
addIArgument(ArrayUtil.fromBoolean(fullUV), ArrayUtil.fromBoolean(computeUv), switchNum);
}
示例14
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
/**
* name: "MatMul"
op: "MatMul"
input: "input"
input: "Variable/read"
attr {
key: "transpose_b"
value {
b: false
}
}
attr {
key: "transpose_a"
value {
b: false
}
}
attr {
key: "T"
value {
type: DT_FLOAT
}
}
*/
val isTransposeA = attributesForNode.get("transpose_a").getB();
val isTransposeB = attributesForNode.get("transpose_b").getB();
MMulTranspose mMulTranspose = MMulTranspose.builder()
.transposeA(isTransposeA).transposeB(isTransposeB)
.build();
this.mMulTranspose = mMulTranspose;
val args = args();
}
示例15
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
if(attributesForNode.containsKey("T")) {
outputType = TFGraphMapper.convertType(attributesForNode.get("T").getType());
}
addArgs();
}
示例16
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
boolean isTransposeA;
boolean isTransposeB;
if(nodeDef.getOp().equalsIgnoreCase("MatMul")){
isTransposeA = attributesForNode.get("transpose_a").getB();
isTransposeB = attributesForNode.get("transpose_b").getB();
} else {
//BatchMatMul, BatchMatMulV2
//In practice, BatchMatMul seems to use "adj_x" and "adj_y" instead of "transpose_a" and "transpose_b"
if(attributesForNode.containsKey("transpose_a")){
isTransposeA = attributesForNode.get("transpose_a").getB();
} else {
isTransposeA = attributesForNode.get("adj_x").getB();
}
if(attributesForNode.containsKey("transpose_b")){
isTransposeB = attributesForNode.get("transpose_b").getB();
} else {
isTransposeB = attributesForNode.get("adj_y").getB();
}
}
MMulTranspose mMulTranspose = MMulTranspose.builder()
.transposeA(isTransposeA).transposeB(isTransposeB)
.build();
this.mt = mMulTranspose;
iArguments.clear();
addIArgument(ArrayUtil.fromBoolean(mt.isTransposeA()), ArrayUtil.fromBoolean(mt.isTransposeB()));
}
示例17
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
if (nodeDef.containsAttr("use_locking")) {
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
bArguments.add(true);
} else {
bArguments.add(false);
}
} else
bArguments.add(false);
}
示例18
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
val targetNode = TFGraphMapper.getNodeWithNameFromGraph(graph, nodeDef.getInput(1));
val maxlen = TFGraphMapper.getNDArrayFromTensor(targetNode);
if (maxlen == null){
// No 2nd input
this.is_static_maxlen = true;
}
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
if (is_static_maxlen) {
addIArgument(this.maxLen);
}
}
示例19
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
if(attributesForNode.containsKey("output_type")) {
outputType = TFGraphMapper.convertType(attributesForNode.get("output_type").getType());
} else {
outputType = DataType.LONG;
}
}
示例20
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
if(attributesForNode.containsKey("narrow_range")){
this.narrowRange = attributesForNode.get("narrow_range").getB();
}
this.numBits = (int)attributesForNode.get("num_bits").getI();
this.min = attributesForNode.get("min").getF();
this.max = attributesForNode.get("max").getF();
addArgs();
}
示例21
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
val attrC = attributesForNode.get("align_corners");
val attrH = attributesForNode.get("half_pixel_centers");
this.alignCorners = attrC != null ? attrC.getB() : false;
this.halfPixelCenters = attrH != null ? attrH.getB() : false;
addArgs();
}
示例22
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
if(attributesForNode.containsKey("narrow_range")){
this.narrowRange = attributesForNode.get("narrow_range").getB();
}
if(attributesForNode.containsKey("num_bits")) {
this.numBits = (int) attributesForNode.get("num_bits").getI();
}
addIArgument(numBits);
addBArgument(narrowRange);
}
示例23
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
val inputOne = nodeDef.getInput(1);
val varFor = initWith.getVariable(inputOne);
val nodeWithIndex = TFGraphMapper.getNodeWithNameFromGraph(graph,inputOne);
val var = TFGraphMapper.getArrayFrom(nodeWithIndex,graph);
if(var != null) {
val idx = var.getInt(0);
addIArgument(idx);
}
}
示例24
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
//Switch order: TF uses [input, gamma, beta, mean, variance]; libnd4j expects [input, mean, variance, gamma, beta]
SameDiffOp op = initWith.getOps().get(this.getOwnName());
List<String> list = op.getInputsToOp();
List<String> newList = Arrays.asList(list.get(0), list.get(3), list.get(4), list.get(1), list.get(2));
op.setInputsToOp(newList);
this.applyGamma = true;
this.applyBeta = true;
this.epsilon = attributesForNode.get("epsilon").getF();
if(attributesForNode.containsKey("data_format")){
String dataFormat = attributesForNode.get("data_format").getS().toStringUtf8();
//TODO not sure if these conv1d/3d cases appear. But BN definitely uses "NCHW" or "NHWC"
if(dataFormat.equalsIgnoreCase(Conv2DConfig.NCHW) || dataFormat.equalsIgnoreCase(Conv1DConfig.NCW) || dataFormat.equalsIgnoreCase(Conv3DConfig.NCDHW)){
jaxis = new int[]{1};
} else if(dataFormat.equalsIgnoreCase(Conv2DConfig.NHWC)){
jaxis = new int[]{3};
} else if(dataFormat.equalsIgnoreCase(Conv1DConfig.NWC)){
jaxis = new int[]{2};
} else if(dataFormat.equalsIgnoreCase(Conv3DConfig.NDHWC)){
jaxis = new int[]{4};
} else {
throw new IllegalStateException("Unknown data format: \"" + dataFormat + "\"" );
}
}
addArgs();
}
示例25
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
if(attributesForNode.containsKey("output_type")) {
outputType = TFGraphMapper.convertType(attributesForNode.get("output_type").getType());
} else {
outputType = DataType.LONG;
}
}
示例26
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
val targetNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph, nodeDef.getInput(1));
val dimArr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", targetNode, graph);
if (dimArr != null) {
int axis = dimArr.data().asInt()[0];
this.axis = axis;
addIArgument(this.axis);
} else {
this.axis = Integer.MAX_VALUE;
addIArgument(this.axis);
}
}
示例27
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
val attrAxis = nodeDef.getAttrOrThrow("axis");
int axis = (int) attrAxis.getI();
this.axis = axis;
addArgs();
}
示例28
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
val numSplits = (int) attributesForNode.get("num_split").getI();
this.numSplit = numSplits;
val splitDim = TFGraphMapper.getInstance().getArrayFrom(TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph,nodeDef.getInput(0)),graph);
if(splitDim != null) {
this.splitDim = splitDim.getInt(0);
addIArgument(splitDim.getInt(0));
}
addIArgument(numSplits);
}
示例29
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
val aStrides = nodeDef.getAttrOrThrow("strides");
val tfStrides = aStrides.getList().getIList();
val sH = tfStrides.get(1);
val sW = tfStrides.get(2);
val aKernels = nodeDef.getAttrOrThrow("ksize");
val tfKernels = aKernels.getList().getIList();
val kH = tfKernels.get(1);
val kW = tfKernels.get(2);
val aPadding = nodeDef.getAttrOrThrow("padding");
val padding = aPadding.getList().getIList();
val paddingMode = aPadding.getS().toStringUtf8().replaceAll("\"","");
boolean isSameMode = paddingMode.equalsIgnoreCase("SAME");
if (!isSameMode)
log.debug("Mode: {}", paddingMode);
Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder()
.sH(sH.intValue())
.sW(sW.intValue())
.type(null)
.isSameMode(isSameMode)
.kH(kH.intValue())
.kW(kW.intValue())
.pH(padding.get(0).intValue())
.pW(padding.get(1).intValue())
.build();
this.config = pooling2DConfig;
addArgs();
log.debug("Pooling: k: [{},{}]; s: [{}, {}], padding: {}", kH, kW, sH, sW, aPadding);
}
示例30
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
addArgs();
}