Java源码示例:org.nd4j.imports.graphmapper.tf.TFGraphMapper
示例1
@Test
public void testCondMapping2() throws Exception {
Nd4j.create(1);
val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream());
assertNotNull(tg);
val input = Nd4j.create(2, 2).assign(-1);
tg.associateArrayWithVariable(input, tg.getVariable("input_0"));
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0.fb"));
//log.info("{}", tg.asFlatPrint());
val array = tg.outputAll(null).get(tg.outputs().get(0));
val exp = Nd4j.create(2, 2).assign(1);
assertNotNull(array);
assertEquals(exp, array);
}
示例2
@Test
public void testWhileDualMapping1() throws Exception {
Nd4j.create(1);
val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
assertNotNull(tg);
val input0 = Nd4j.create(2, 2).assign(-4.0);
val input1 = Nd4j.scalar(1.0);
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
//tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_1.fb"));
//log.info("{}", tg.asFlatPrint());
INDArray array = tg.outputAll(null).get(tg.outputs().get(0));
val exp = Nd4j.create(2, 2).assign(-1);
assertNotNull(array);
assertEquals(exp, array);
}
示例3
@Test
@Ignore
public void importGraph1() throws Exception {
SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/max_add_2.pb.txt").getInputStream());
assertNotNull(graph);
assertEquals(2, graph.variableMap().size());
SDVariable var0 = graph.variableMap().get("zeros");
SDVariable var1 = graph.variableMap().get("ones");
assertNotNull(var0);
assertNotNull(var1);
assertNotNull(var0.getArr());
assertNotNull(var1.getArr());
assertEquals(0.0, var0.getArr().sumNumber().doubleValue(), 1e-5);
assertEquals(12.0, var1.getArr().sumNumber().doubleValue(), 1e-5);
}
示例4
@Test
public void testWhileMapping2() throws Exception {
Nd4j.create(1);
val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
assertNotNull(tg);
val input = Nd4j.scalar(4.0);
tg.associateArrayWithVariable(input, tg.getVariable("input_1"));
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0_4.fb"));
//log.info("{}", tg.asFlatPrint());
/*
val array = tg.outputAll(null).get(tg.outputs().get(0));
val exp = Nd4j.create(2, 2).assign(2);
assertNotNull(array);
assertEquals(exp, array);*/
}
示例5
@Test
public void testLenet() throws Exception {
/**
* Produced with:
* python ~/anaconda2/lib/python2.7/site-packages/tensorflow/python/tools/freeze_graph.py --input_graph=graph2.pb.txt --input_checkpoint=test3.ckpt --output_graph=graph_frozen2.pb --output_node_name=output/BiasAdd --input_binary=False
*/
Nd4j.create(1);
val rawGraph = GraphDef.parseFrom(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream());
val nodeNames = rawGraph.getNodeList().stream().map(node -> node.getName()).collect(Collectors.toList());
System.out.println(nodeNames);
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream());
val convNode = tg.getVariable("conv2d/kernel");
assertNotNull(convNode.getArr());
val shape = convNode.getShape();
System.out.println(Arrays.toString(shape));
// this is NHWC weights. will be changed soon.
assertArrayEquals(new int[]{5,5,1,32}, shape);
System.out.println(convNode);
}
示例6
@Test
public void testIntermediateLoop3() throws Exception {
Nd4j.create(1);
val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/nested_while.pb.txt").getInputStream());
assertNotNull(tg);
// now converting to FlatBuffer
val fb = tg.asFlatBuffers(true);
assertNotNull(fb);
val graph = FlatGraph.getRootAsFlatGraph(fb);
assertEquals(15, graph.variablesLength());
//assertEquals("phi/Assign", graph.nodes(0).name());
//assertEquals("alpha/Assign", graph.nodes(1).name());
assertEquals(2, graph.nodes(0).inputPairedLength());
assertEquals(2, graph.nodes(1).inputPairedLength());
// tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/nested_while.fb"));
}
示例7
@Test
@Ignore
public void testIntermediateTensorArraySimple1() throws Exception {
Nd4j.create(1);
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array.pb.txt").getInputStream());
tg.updateVariable("input_matrix",Nd4j.ones(3,2));
assertNotNull(tg);
val firstSlice = tg.getVariable("strided_slice");
val fb = tg.asFlatBuffers();
assertNotNull(fb);
val graph = FlatGraph.getRootAsFlatGraph(fb);
assertEquals(36, graph.variablesLength());
assertTrue(graph.nodesLength() > 1);
/* assertEquals("strided_slice", graph.nodes(0).name());
assertEquals("TensorArray", graph.nodes(1).name());
*/
// assertEquals(4, graph.nodes(0).inputPairedLength());
//tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/tensor_array.fb"));
}
示例8
@Test
public void testCondMapping2() throws Exception {
Nd4j.create(1);
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream());
assertNotNull(tg);
val input = Nd4j.create(2, 2).assign(-1);
tg.associateArrayWithVariable(input, tg.getVariable("input_0"));
//tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0.fb"));
//log.info("{}", tg.asFlatPrint());
val array = tg.execAndEndResult();
val exp = Nd4j.create(2, 2).assign(1);
assertNotNull(array);
assertEquals(exp, array);
}
示例9
@Test
public void testWhileMapping1() throws Exception {
Nd4j.create(1);
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
assertNotNull(tg);
val input = Nd4j.create(2, 2).assign(1);
tg.associateArrayWithVariable(input, tg.getVariable("input_0"));
//tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0_3.fb"));
//log.info("{}", tg.asFlatPrint());
val array = tg.execAndEndResult();
val exp = Nd4j.create(2, 2).assign(1);
assertNotNull(array);
assertEquals(exp, array);
}
示例10
@Test
public void testWhileMapping2() throws Exception {
Nd4j.create(1);
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
assertNotNull(tg);
val input = Nd4j.trueScalar(4.0);
tg.associateArrayWithVariable(input, tg.getVariable("input_1"));
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0_4.fb"));
//log.info("{}", tg.asFlatPrint());
/*
val array = tg.execAndEndResult();
val exp = Nd4j.create(2, 2).assign(2);
assertNotNull(array);
assertEquals(exp, array);*/
}
示例11
@Test
public void testWhileDualMapping2() throws Exception {
Nd4j.create(1);
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
assertNotNull(tg);
val input0 = Nd4j.create(2, 2).assign(-9.0);
val input1 = Nd4j.trueScalar(1.0);
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
//tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_1.fb"));
//log.info("{}", tg.asFlatPrint());
val array = tg.execAndEndResult();
val exp = Nd4j.create(2, 2).assign(-3);
assertNotNull(array);
assertEquals(exp, array);
}
示例12
@Test
public void testMixedWhileCond1() throws Exception {
Nd4j.create(1);
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_nested/frozen_model.pb").getInputStream());
assertNotNull(tg);
val input0 = Nd4j.create(2, 2).assign(1.0);
val input1 = Nd4j.create(3, 3).assign(2.0);
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
//tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_nested.fb"));
//log.info("{}", tg.asFlatPrint());
val array = tg.execAndEndResult();
//val array = tg.getVariable("output").getArr();
val exp = Nd4j.create(2, 2).assign(15.0);
assertNotNull(array);
assertEquals(exp, array);
}
示例13
@Test
@Ignore
public void testCrash_119_matrix_diag() throws Exception {
Nd4j.create(1);
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/partition_stitch_misc/frozen_model.pb").getInputStream());
assertNotNull(tg);
val input0 = Nd4j.create(2, 5, 4).assign(1.0);
val input1 = Nd4j.create(2, 3, 5, 4).assign(2.0);
val input2 = Nd4j.create(3, 1, 5, 4).assign(3.0);
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
tg.associateArrayWithVariable(input2, tg.getVariable("input_2"));
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/partition_stitch_misc.fb"));
}
示例14
@Test
public void testSimpleGraph_1() throws Exception {
val exp = Nd4j.create(new double[] {-0.95938617, -1.20301781, 1.22260064, 0.50172403, 0.59972949, 0.78568028, 0.31609724, 1.51674747, 0.68013491, -0.05227458, 0.25903158,1.13243439}, new long[]{3, 1, 4});
// configuring client
val client = new GraphInferenceGrpcClient("127.0.0.1", 40123);
val graphId = RandomUtils.nextLong(0, Long.MAX_VALUE);
// preparing and registering graph (it's optional, and graph might be embedded into Docker image
val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/expand_dim/frozen_model.pb").getInputStream());
assertNotNull(tg);
client.registerGraph(graphId, tg, ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build());
//defining input
val input0 = Nd4j.create(new double[] {0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743}, new int[] {3, 4});
val operands = new Operands().addArgument("input_0", input0);
// sending request and getting result
val result = client.output(graphId, operands);
assertEquals(exp, result.getById("output"));
}
示例15
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
// convert output data type
if(attributesForNode.containsKey("dtype")) {
outputType = TFGraphMapper.convertType(attributesForNode.get("dtype").getType());
}
// get init field
if(attributesForNode.containsKey("init")) {
initialize = attributesForNode.get("init").getB();
}
// there's no order in TF, just plain C
this.order = 'c';
addArgs();
}
示例16
@Test
public void testLenet() throws Exception {
/**
* Produced with:
* python ~/anaconda2/lib/python2.7/site-packages/tensorflow/python/tools/freeze_graph.py --input_graph=graph2.pb.txt --input_checkpoint=test3.ckpt --output_graph=graph_frozen2.pb --output_node_name=output/BiasAdd --input_binary=False
*/
Nd4j.create(1);
val rawGraph = GraphDef.parseFrom(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream());
val nodeNames = rawGraph.getNodeList().stream().map(node -> node.getName()).collect(Collectors.toList());
System.out.println(nodeNames);
val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream());
val convNode = tg.getVariable("conv2d/kernel");
assertNotNull(convNode.getArr());
val shape = convNode.getShape();
System.out.println(Arrays.toString(shape));
// this is NHWC weights. will be changed soon.
assertArrayEquals(new long[]{5,5,1,32}, shape);
System.out.println(convNode);
}
示例17
@Test
@Ignore
public void testCrash_119_matrix_diag() throws Exception {
Nd4j.create(1);
val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/partition_stitch_misc/frozen_model.pb").getInputStream());
assertNotNull(tg);
val input0 = Nd4j.create(2, 5, 4).assign(1.0);
val input1 = Nd4j.create(2, 3, 5, 4).assign(2.0);
val input2 = Nd4j.create(3, 1, 5, 4).assign(3.0);
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
tg.associateArrayWithVariable(input2, tg.getVariable("input_2"));
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/partition_stitch_misc.fb"));
}
示例18
@Test
public void testWhileDualMapping2() throws Exception {
Nd4j.create(1);
val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
assertNotNull(tg);
val input0 = Nd4j.create(2, 2).assign(-9.0);
val input1 = Nd4j.scalar(1.0);
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
//tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_1.fb"));
//log.info("{}", tg.asFlatPrint());
val array = tg.outputAll(null).get(tg.outputs().get(0));
val exp = Nd4j.create(2, 2).assign(-3);
assertNotNull(array);
assertEquals(exp, array);
}
示例19
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
val idd = nodeDef.getInput(nodeDef.getInputCount() - 1);
NodeDef iddNode = null;
for(int i = 0; i < graph.getNodeCount(); i++) {
if(graph.getNode(i).getName().equals(idd)) {
iddNode = graph.getNode(i);
}
}
val arr = TFGraphMapper.getNDArrayFromTensor(iddNode);
if (arr != null) {
int idx = arr.getInt(0);
addIArgument(idx);
}
this.tensorArrayDataType = TFGraphMapper.convertType(attributesForNode.get("dtype").getType());
}
示例20
@Test
@Ignore
public void testCrash_119_transpose() throws Exception {
Nd4j.create(1);
val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/transpose/frozen_model.pb").getInputStream());
assertNotNull(tg);
val input0 = Nd4j.create(new double[]{0.98114507, 0.96400015, 0.58669623, 0.60073098, 0.75425418, 0.44258752, 0.76373084, 0.96593234, 0.34067846}, new int[] {3, 3});
val input1 = Nd4j.create(new double[]{0.98114507, 0.60073098, 0.76373084, 0.96400015, 0.75425418, 0.96593234, 0.58669623, 0.44258752, 0.34067846}, new int[] {3, 3});
tg.associateArrayWithVariable(input0, tg.getVariable("input"));
tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/transpose.fb"));
}
示例21
@Override
public SameDiff loadModel() throws Exception {
if (ModelGuesser.isTensorflowFile(pathToModel)) {
log.debug("Loading tensorflow model from " + pathToModel.getAbsolutePath());
return TFGraphMapper.importGraph(pathToModel);
} else if (ModelGuesser.isSameDiffZipFile(pathToModel)) {
return SameDiff.load(pathToModel, true);
}
log.debug("Loading samediff model from " + pathToModel.getAbsolutePath());
return SameDiff.fromFlatFile(pathToModel);
}
示例22
@Test @Ignore
public void writeBertUI() throws Exception {
//Test used to generate graph for visualization to work out appropriate subgraph structure to replace
File f = new File("C:/Temp/TF_Graphs/mrpc_output/frozen/bert_mrpc_frozen.pb");
int minibatchSize = 4;
Map<String, TFImportOverride> m = new HashMap<>();
m.put("IteratorGetNext", (inputs, controlDepInputs, nodeDef, initWith, attributesForNode, graph) -> {
//Return 3 placeholders called "IteratorGetNext:0", "IteratorGetNext:1", "IteratorGetNext:3" instead of the training iterator
return Arrays.asList(
initWith.placeHolder("IteratorGetNext", DataType.INT, minibatchSize, 128),
initWith.placeHolder("IteratorGetNext:1", DataType.INT, minibatchSize, 128),
initWith.placeHolder("IteratorGetNext:4", DataType.INT, minibatchSize, 128)
);
});
//Skip the "IteratorV2" op - we don't want or need this
TFOpImportFilter filter = (nodeDef, initWith, attributesForNode, graph) -> {
return "IteratorV2".equals(nodeDef.getName());
};
SameDiff sd = TFGraphMapper.importGraph(f, m, filter);
LogFileWriter w = new LogFileWriter(new File("C:/Temp/BERT_UI.bin"));
long bytesWritten = w.writeGraphStructure(sd);
long bytesWritten2 = w.writeFinishStaticMarker();
}
示例23
@Test
@Ignore
public void importGraph2() throws Exception {
SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensorflow_inception_graph.pb").getInputStream());
assertNotNull(graph);
}
示例24
@Test
@Ignore
public void importGraph3() throws Exception {
SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/max_log_reg.pb.txt").getInputStream());
assertNotNull(graph);
}
示例25
@Test
@Ignore
public void testCrash_119_expand_dim() throws Exception {
Nd4j.create(1);
val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/expand_dim/frozen_model.pb").getInputStream());
assertNotNull(tg);
val input0 = Nd4j.create(new double[] {0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743}, new int[] {3, 4});
tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/expand_dim.fb"));
}
示例26
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
if(attributesForNode.containsKey("output_type")) {
outputType = TFGraphMapper.convertType(attributesForNode.get("output_type").getType());
} else {
outputType = DataType.LONG;
}
}
示例27
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
if (nodeDef.containsAttr("use_locking")) {
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
bArguments.add(true);
} else {
bArguments.add(false);
}
} else
bArguments.add(false);
}
示例28
@Test
@Ignore
public void testProfConv() throws Exception {
Nd4j.create(1);
val tg = TFGraphMapper.getInstance().importGraph(new File("/home/raver119/develop/workspace/models/profiling_conv.pb.txt"));
assertNotNull(tg);
tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/profiling_conv.fb"));
}
示例29
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
//Switch order: TF uses [logits, labels]; libnd4j expects [labels, logits]
SameDiffOp op = initWith.getOps().get(this.getOwnName());
List<String> list = op.getInputsToOp();
List<String> newList = Arrays.asList(list.get(1), list.get(0));
op.setInputsToOp(newList);
}
示例30
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
if(attributesForNode.containsKey("Tidx")){
dataType = TFGraphMapper.convertType(attributesForNode.get("Tidx").getType());
}
addDArgument(dataType);
}