Java源码示例:org.nd4j.imports.graphmapper.tf.TFGraphMapper

示例1
@Test
public void testCondMapping2() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream());
    assertNotNull(tg);

    val input = Nd4j.create(2, 2).assign(-1);
    tg.associateArrayWithVariable(input, tg.getVariable("input_0"));
    tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0.fb"));

    //log.info("{}", tg.asFlatPrint());
    val array = tg.outputAll(null).get(tg.outputs().get(0));
    val exp = Nd4j.create(2, 2).assign(1);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
示例2
@Test
public void testWhileDualMapping1() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
    assertNotNull(tg);
    val input0 = Nd4j.create(2, 2).assign(-4.0);
    val input1 = Nd4j.scalar(1.0);
    tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));

    //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_1.fb"));

    //log.info("{}", tg.asFlatPrint());

    INDArray array = tg.outputAll(null).get(tg.outputs().get(0));
    val exp = Nd4j.create(2, 2).assign(-1);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
示例3
@Test
@Ignore
public void importGraph1() throws Exception {
    SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/max_add_2.pb.txt").getInputStream());

    assertNotNull(graph);

    assertEquals(2, graph.variableMap().size());

    SDVariable var0 = graph.variableMap().get("zeros");
    SDVariable var1 = graph.variableMap().get("ones");

    assertNotNull(var0);
    assertNotNull(var1);

    assertNotNull(var0.getArr());
    assertNotNull(var1.getArr());

    assertEquals(0.0, var0.getArr().sumNumber().doubleValue(), 1e-5);
    assertEquals(12.0, var1.getArr().sumNumber().doubleValue(), 1e-5);
}
 
示例4
@Test
    public void testWhileMapping2() throws Exception {
        Nd4j.create(1);
        val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
        assertNotNull(tg);
        val input = Nd4j.scalar(4.0);
        tg.associateArrayWithVariable(input, tg.getVariable("input_1"));

        tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0_4.fb"));

        //log.info("{}", tg.asFlatPrint());
/*
        val array = tg.outputAll(null).get(tg.outputs().get(0));
        val exp = Nd4j.create(2, 2).assign(2);
        assertNotNull(array);
        assertEquals(exp, array);*/
    }
 
示例5
@Test
public void testLenet() throws Exception {
    /**
     * Produced with:
     * python  ~/anaconda2/lib/python2.7/site-packages/tensorflow/python/tools/freeze_graph.py  --input_graph=graph2.pb.txt  --input_checkpoint=test3.ckpt  --output_graph=graph_frozen2.pb  --output_node_name=output/BiasAdd --input_binary=False

     */

    Nd4j.create(1);
    val rawGraph = GraphDef.parseFrom(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream());
    val nodeNames = rawGraph.getNodeList().stream().map(node -> node.getName()).collect(Collectors.toList());
    System.out.println(nodeNames);
    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream());


    val convNode = tg.getVariable("conv2d/kernel");
    assertNotNull(convNode.getArr());
    val shape = convNode.getShape();
    System.out.println(Arrays.toString(shape));

    // this is NHWC weights. will be changed soon.
    assertArrayEquals(new int[]{5,5,1,32}, shape);
    System.out.println(convNode);
}
 
示例6
@Test
public void testIntermediateLoop3() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/nested_while.pb.txt").getInputStream());

    assertNotNull(tg);

    // now converting to FlatBuffer
    val fb = tg.asFlatBuffers(true);
    assertNotNull(fb);

    val graph = FlatGraph.getRootAsFlatGraph(fb);
    assertEquals(15, graph.variablesLength());

    //assertEquals("phi/Assign", graph.nodes(0).name());
    //assertEquals("alpha/Assign", graph.nodes(1).name());

    assertEquals(2, graph.nodes(0).inputPairedLength());
    assertEquals(2, graph.nodes(1).inputPairedLength());

    //   tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/nested_while.fb"));
}
 
示例7
@Test
    @Ignore
    public void testIntermediateTensorArraySimple1() throws Exception {
        Nd4j.create(1);
        val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensor_array.pb.txt").getInputStream());
        tg.updateVariable("input_matrix",Nd4j.ones(3,2));

        assertNotNull(tg);

        val firstSlice = tg.getVariable("strided_slice");


        val fb = tg.asFlatBuffers();
        assertNotNull(fb);

        val graph = FlatGraph.getRootAsFlatGraph(fb);
        assertEquals(36, graph.variablesLength());

        assertTrue(graph.nodesLength() > 1);
     /*   assertEquals("strided_slice", graph.nodes(0).name());
        assertEquals("TensorArray", graph.nodes(1).name());
*/
        //   assertEquals(4, graph.nodes(0).inputPairedLength());

        //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/tensor_array.fb"));
    }
 
示例8
@Test
public void testCondMapping2() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream());
    assertNotNull(tg);

    val input = Nd4j.create(2, 2).assign(-1);
    tg.associateArrayWithVariable(input, tg.getVariable("input_0"));
    //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0.fb"));

    //log.info("{}", tg.asFlatPrint());
    val array = tg.execAndEndResult();
    val exp = Nd4j.create(2, 2).assign(1);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
示例9
@Test
public void testWhileMapping1() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
    assertNotNull(tg);
    val input = Nd4j.create(2, 2).assign(1);
    tg.associateArrayWithVariable(input, tg.getVariable("input_0"));

    //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0_3.fb"));

    //log.info("{}", tg.asFlatPrint());


    val array = tg.execAndEndResult();
    val exp = Nd4j.create(2, 2).assign(1);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
示例10
@Test
    public void testWhileMapping2() throws Exception {
        Nd4j.create(1);
        val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
        assertNotNull(tg);
        val input = Nd4j.trueScalar(4.0);
        tg.associateArrayWithVariable(input, tg.getVariable("input_1"));

        tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0_4.fb"));

        //log.info("{}", tg.asFlatPrint());
/*
        val array = tg.execAndEndResult();
        val exp = Nd4j.create(2, 2).assign(2);
        assertNotNull(array);
        assertEquals(exp, array);*/
    }
 
示例11
@Test
public void testWhileDualMapping2() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
    assertNotNull(tg);
    val input0 = Nd4j.create(2, 2).assign(-9.0);
    val input1 = Nd4j.trueScalar(1.0);
    tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));

    //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_1.fb"));

    //log.info("{}", tg.asFlatPrint());

    val array = tg.execAndEndResult();
    val exp = Nd4j.create(2, 2).assign(-3);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
示例12
@Test
public void testMixedWhileCond1() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_nested/frozen_model.pb").getInputStream());
    assertNotNull(tg);
    val input0 = Nd4j.create(2, 2).assign(1.0);
    val input1 = Nd4j.create(3, 3).assign(2.0);
    tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));

    //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_nested.fb"));


    //log.info("{}", tg.asFlatPrint());

    val array = tg.execAndEndResult();
    //val array = tg.getVariable("output").getArr();
    val exp = Nd4j.create(2, 2).assign(15.0);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
示例13
@Test
@Ignore
public void testCrash_119_matrix_diag() throws Exception {
    Nd4j.create(1);

    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/partition_stitch_misc/frozen_model.pb").getInputStream());
    assertNotNull(tg);

    val input0 = Nd4j.create(2, 5, 4).assign(1.0);
    val input1 = Nd4j.create(2, 3, 5, 4).assign(2.0);
    val input2 = Nd4j.create(3, 1, 5, 4).assign(3.0);
    tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
    tg.associateArrayWithVariable(input2, tg.getVariable("input_2"));


    tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/partition_stitch_misc.fb"));
}
 
示例14
@Test
public void testSimpleGraph_1() throws Exception {
    val exp = Nd4j.create(new double[] {-0.95938617, -1.20301781, 1.22260064, 0.50172403, 0.59972949, 0.78568028, 0.31609724, 1.51674747, 0.68013491, -0.05227458, 0.25903158,1.13243439}, new long[]{3, 1, 4});

    // configuring client
    val client = new GraphInferenceGrpcClient("127.0.0.1", 40123);

    val graphId = RandomUtils.nextLong(0, Long.MAX_VALUE);

    // preparing and registering graph (it's optional, and graph might be embedded into Docker image
    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/expand_dim/frozen_model.pb").getInputStream());
    assertNotNull(tg);
    client.registerGraph(graphId, tg, ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).build());

    //defining input
    val input0 = Nd4j.create(new double[] {0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743}, new int[] {3, 4});
    val operands = new Operands().addArgument("input_0", input0);

    // sending request and getting result
    val result = client.output(graphId, operands);
    assertEquals(exp, result.getById("output"));
}
 
示例15
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    // convert output data type
    if(attributesForNode.containsKey("dtype")) {
        outputType = TFGraphMapper.convertType(attributesForNode.get("dtype").getType());
    }

    // get init field
    if(attributesForNode.containsKey("init")) {
        initialize = attributesForNode.get("init").getB();
    }

    // there's no order in TF, just plain C
    this.order = 'c';
    addArgs();
}
 
示例16
@Test
public void testLenet() throws Exception {
    /**
     * Produced with:
     * python  ~/anaconda2/lib/python2.7/site-packages/tensorflow/python/tools/freeze_graph.py  --input_graph=graph2.pb.txt  --input_checkpoint=test3.ckpt  --output_graph=graph_frozen2.pb  --output_node_name=output/BiasAdd --input_binary=False

     */

    Nd4j.create(1);
    val rawGraph = GraphDef.parseFrom(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream());
    val nodeNames = rawGraph.getNodeList().stream().map(node -> node.getName()).collect(Collectors.toList());
    System.out.println(nodeNames);
    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/lenet_cnn.pb").getInputStream());


    val convNode = tg.getVariable("conv2d/kernel");
    assertNotNull(convNode.getArr());
    val shape = convNode.getShape();
    System.out.println(Arrays.toString(shape));

    // this is NHWC weights. will be changed soon.
    assertArrayEquals(new long[]{5,5,1,32}, shape);
    System.out.println(convNode);
}
 
示例17
@Test
@Ignore
public void testCrash_119_matrix_diag() throws Exception {
    Nd4j.create(1);

    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/partition_stitch_misc/frozen_model.pb").getInputStream());
    assertNotNull(tg);

    val input0 = Nd4j.create(2, 5, 4).assign(1.0);
    val input1 = Nd4j.create(2, 3, 5, 4).assign(2.0);
    val input2 = Nd4j.create(3, 1, 5, 4).assign(3.0);
    tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));
    tg.associateArrayWithVariable(input2, tg.getVariable("input_2"));


    tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/partition_stitch_misc.fb"));
}
 
示例18
@Test
public void testWhileDualMapping2() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_1/frozen_model.pb").getInputStream());
    assertNotNull(tg);
    val input0 = Nd4j.create(2, 2).assign(-9.0);
    val input1 = Nd4j.scalar(1.0);
    tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));

    //tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_1.fb"));

    //log.info("{}", tg.asFlatPrint());

    val array = tg.outputAll(null).get(tg.outputs().get(0));
    val exp = Nd4j.create(2, 2).assign(-3);
    assertNotNull(array);
    assertEquals(exp, array);
}
 
示例19
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    val idd = nodeDef.getInput(nodeDef.getInputCount() - 1);
    NodeDef iddNode = null;
    for(int i = 0; i < graph.getNodeCount(); i++) {
        if(graph.getNode(i).getName().equals(idd)) {
            iddNode = graph.getNode(i);
        }
    }

    val arr = TFGraphMapper.getNDArrayFromTensor(iddNode);

    if (arr != null) {
        int idx = arr.getInt(0);
        addIArgument(idx);
    }

    this.tensorArrayDataType = TFGraphMapper.convertType(attributesForNode.get("dtype").getType());
}
 
示例20
@Test
@Ignore
public void testCrash_119_transpose() throws Exception {
    Nd4j.create(1);

    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/transpose/frozen_model.pb").getInputStream());
    assertNotNull(tg);

    val input0 = Nd4j.create(new double[]{0.98114507, 0.96400015, 0.58669623, 0.60073098, 0.75425418, 0.44258752, 0.76373084, 0.96593234, 0.34067846}, new int[] {3, 3});
    val input1 = Nd4j.create(new double[]{0.98114507, 0.60073098, 0.76373084, 0.96400015, 0.75425418, 0.96593234, 0.58669623, 0.44258752, 0.34067846}, new int[] {3, 3});

    tg.associateArrayWithVariable(input0, tg.getVariable("input"));
    tg.associateArrayWithVariable(input1, tg.getVariable("input_1"));

    tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/transpose.fb"));
}
 
示例21
@Override
public SameDiff loadModel() throws Exception {
    if (ModelGuesser.isTensorflowFile(pathToModel)) {
        log.debug("Loading tensorflow model from " + pathToModel.getAbsolutePath());
        return TFGraphMapper.importGraph(pathToModel);
    } else if (ModelGuesser.isSameDiffZipFile(pathToModel)) {
        return SameDiff.load(pathToModel, true);
    }

    log.debug("Loading samediff model from " + pathToModel.getAbsolutePath());
    return SameDiff.fromFlatFile(pathToModel);
}
 
示例22
@Test @Ignore
public void writeBertUI() throws Exception {
    //Test used to generate graph for visualization to work out appropriate subgraph structure to replace
    File f = new File("C:/Temp/TF_Graphs/mrpc_output/frozen/bert_mrpc_frozen.pb");
    int minibatchSize = 4;

    Map<String, TFImportOverride> m = new HashMap<>();
    m.put("IteratorGetNext", (inputs, controlDepInputs, nodeDef, initWith, attributesForNode, graph) -> {
        //Return 3 placeholders called "IteratorGetNext:0", "IteratorGetNext:1", "IteratorGetNext:3" instead of the training iterator
        return Arrays.asList(
                initWith.placeHolder("IteratorGetNext", DataType.INT, minibatchSize, 128),
                initWith.placeHolder("IteratorGetNext:1", DataType.INT, minibatchSize, 128),
                initWith.placeHolder("IteratorGetNext:4", DataType.INT, minibatchSize, 128)
        );
    });

    //Skip the "IteratorV2" op - we don't want or need this
    TFOpImportFilter filter = (nodeDef, initWith, attributesForNode, graph) -> {
        return "IteratorV2".equals(nodeDef.getName());
    };

    SameDiff sd = TFGraphMapper.importGraph(f, m, filter);

    LogFileWriter w = new LogFileWriter(new File("C:/Temp/BERT_UI.bin"));
    long bytesWritten = w.writeGraphStructure(sd);
    long bytesWritten2 = w.writeFinishStaticMarker();
}
 
示例23
@Test
@Ignore
public void importGraph2() throws Exception {
    SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensorflow_inception_graph.pb").getInputStream());

    assertNotNull(graph);
}
 
示例24
@Test
@Ignore
public void importGraph3() throws Exception {
    SameDiff graph = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/max_log_reg.pb.txt").getInputStream());

    assertNotNull(graph);
}
 
示例25
@Test
@Ignore
public void testCrash_119_expand_dim() throws Exception {
    Nd4j.create(1);

    val tg = TFGraphMapper.importGraph(new ClassPathResource("tf_graphs/examples/expand_dim/frozen_model.pb").getInputStream());
    assertNotNull(tg);

    val input0 = Nd4j.create(new double[] {0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743}, new int[] {3, 4});

    tg.associateArrayWithVariable(input0, tg.getVariable("input_0"));

    tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/expand_dim.fb"));
}
 
示例26
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    if(attributesForNode.containsKey("output_type")) {
        outputType = TFGraphMapper.convertType(attributesForNode.get("output_type").getType());
    } else {
        outputType = DataType.LONG;
    }
}
 
示例27
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);

    if (nodeDef.containsAttr("use_locking")) {
        if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
            bArguments.add(true);
        } else {
            bArguments.add(false);
        }
    } else
        bArguments.add(false);
}
 
示例28
@Test
@Ignore
public void testProfConv() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.getInstance().importGraph(new File("/home/raver119/develop/workspace/models/profiling_conv.pb.txt"));
    assertNotNull(tg);

    tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/profiling_conv.fb"));
}
 
示例29
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);

    //Switch order: TF uses [logits, labels]; libnd4j expects [labels, logits]
    SameDiffOp op = initWith.getOps().get(this.getOwnName());
    List<String> list = op.getInputsToOp();
    List<String> newList = Arrays.asList(list.get(1), list.get(0));
    op.setInputsToOp(newList);
}
 
示例30
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
    if(attributesForNode.containsKey("Tidx")){
        dataType = TFGraphMapper.convertType(attributesForNode.get("Tidx").getType());
    }
    addDArgument(dataType);
}