Python源码示例:reid.utils.serialization.load_checkpoint()

示例1
def main(args):
    cudnn.benchmark = True
    # Redirect print to both console and log file
    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))

    train_loader = get_loader(args.train_path, args.height, args.width, relabel=True,
                                   batch_size=args.batch_size, mode='train', num_workers=args.workers, name_pattern = args.name_pattern)

    gallery_loader = get_loader(args.gallery_path, args.height, args.width, relabel=False,
                                   batch_size=args.batch_size, mode='test', num_workers=args.workers, name_pattern = args.name_pattern)

    query_loader = get_loader(args.query_path, args.height, args.width, relabel=False,
                                   batch_size=args.batch_size, mode='test', num_workers=args.workers, name_pattern = args.name_pattern)

    # Create model
    model = DenseNet(num_feature=args.num_feature, num_classes=args.true_class, num_iteration = args.num_iteration)

    # Load from checkpoint
    start_epoch = args.start_epoch
    if args.resume:
        checkpoint = load_checkpoint(args.resume)
        model.load_state_dict(checkpoint['state_dict'])

    model = nn.DataParallel(model).cuda()

    # Evaluator
    if args.evaluate:
        evaluator = Evaluator(model)
        print("Test:")
        evaluator.evaluate(query_loader, gallery_loader, query_loader.dataset.ret, gallery_loader.dataset.ret, args.output_feature)
        return

    # Start training
    model= train(args, model, train_loader, start_epoch)
    save_checkpoint({'state_dict': model.module.state_dict()}, fpath=osp.join(args.logs_dir, 'model.pth.tar'))

    evaluator = Evaluator(model)
    print("Test:")
    evaluator.evaluate(query_loader, gallery_loader, query_loader.dataset.ret, gallery_loader.dataset.ret, args.output_feature) 
示例2
def resume(self, ckpt_file, step):
        print("continued from step", step)
        model = models.create(self.model_name, dropout=self.dropout, num_classes=self.num_classes, mode=self.mode)
        self.model = nn.DataParallel(model).cuda()
        self.model.load_state_dict(load_checkpoint(ckpt_file))