Python源码示例:ignite.contrib.handlers.ProgressBar()
示例1
def test_pbar(capsys):
n_epochs = 2
loader = [1, 2]
engine = Engine(update_fn)
pbar = ProgressBar()
pbar.attach(engine, ["a"])
engine.run(loader, max_epochs=n_epochs)
captured = capsys.readouterr()
err = captured.err.split("\r")
err = list(map(lambda x: x.strip(), err))
err = list(filter(None, err))
expected = "Epoch [2/2]: [1/2] 50%|█████ , a=1 [00:00<00:00]"
assert err[-1] == expected
示例2
def test_pbar_file(tmp_path):
n_epochs = 2
loader = [1, 2]
engine = Engine(update_fn)
file_path = tmp_path / "temp.txt"
file = open(str(file_path), "w+")
pbar = ProgressBar(file=file)
pbar.attach(engine, ["a"])
engine.run(loader, max_epochs=n_epochs)
file.close() # Force a flush of the buffer. file.flush() does not work.
file = open(str(file_path), "r")
lines = file.readlines()
expected = "Epoch [2/2]: [1/2] 50%|█████ , a=1 [00:00<00:00]\n"
assert lines[-2] == expected
示例3
def test_pbar_batch_indeces(capsys):
engine = Engine(lambda e, b: time.sleep(0.1))
@engine.on(Events.ITERATION_STARTED)
def print_iter(_):
print("iteration: ", engine.state.iteration)
ProgressBar(persist=True).attach(engine)
engine.run(list(range(4)), max_epochs=1)
captured = capsys.readouterr()
err = captured.err.split("\r")
err = list(map(lambda x: x.strip(), err))
err = list(filter(None, err))
printed_batch_indeces = set(map(lambda x: int(x.split("/")[0][-1]), err))
expected_batch_indeces = list(range(1, 5))
assert sorted(list(printed_batch_indeces)) == expected_batch_indeces
示例4
def test_pbar_no_metric_names(capsys):
n_epochs = 2
loader = [1, 2]
engine = Engine(update_fn)
pbar = ProgressBar()
pbar.attach(engine)
engine.run(loader, max_epochs=n_epochs)
captured = capsys.readouterr()
err = captured.err.split("\r")
err = list(map(lambda x: x.strip(), err))
err = list(filter(None, err))
actual = err[-1]
expected = "Epoch [2/2]: [1/2] 50%|█████ [00:00<00:00]"
assert actual == expected
示例5
def test_pbar_with_output(capsys):
n_epochs = 2
loader = [1, 2]
engine = Engine(update_fn)
pbar = ProgressBar()
pbar.attach(engine, output_transform=lambda x: {"a": x})
engine.run(loader, max_epochs=n_epochs)
captured = capsys.readouterr()
err = captured.err.split("\r")
err = list(map(lambda x: x.strip(), err))
err = list(filter(None, err))
expected = "Epoch [2/2]: [1/2] 50%|█████ , a=1 [00:00<00:00]"
assert err[-1] == expected
示例6
def test_pbar_with_scalar_output(capsys):
n_epochs = 2
loader = [1, 2]
engine = Engine(update_fn)
pbar = ProgressBar()
pbar.attach(engine, output_transform=lambda x: x)
engine.run(loader, max_epochs=n_epochs)
captured = capsys.readouterr()
err = captured.err.split("\r")
err = list(map(lambda x: x.strip(), err))
err = list(filter(None, err))
expected = "Epoch [2/2]: [1/2] 50%|█████ , output=1 [00:00<00:00]"
assert err[-1] == expected
示例7
def test_pbar_with_str_output(capsys):
n_epochs = 2
loader = [1, 2]
engine = Engine(update_fn)
pbar = ProgressBar()
pbar.attach(engine, output_transform=lambda x: "red")
engine.run(loader, max_epochs=n_epochs)
captured = capsys.readouterr()
err = captured.err.split("\r")
err = list(map(lambda x: x.strip(), err))
err = list(filter(None, err))
expected = "Epoch [2/2]: [1/2] 50%|█████ , output=red [00:00<00:00]"
assert err[-1] == expected
示例8
def test_pbar_output_tensor(capsys):
def _test(out_tensor, out_msg):
loader = [1, 2, 3, 4, 5]
def update_fn(engine, batch):
return out_tensor
engine = Engine(update_fn)
pbar = ProgressBar(desc="Output tensor")
pbar.attach(engine, output_transform=lambda x: x)
engine.run(loader, max_epochs=1)
captured = capsys.readouterr()
err = captured.err.split("\r")
err = list(map(lambda x: x.strip(), err))
err = list(filter(None, err))
expected = "Output tensor: [4/5] 80%|████████ , {} [00:00<00:00]".format(out_msg)
assert err[-1] == expected
_test(out_tensor=torch.tensor([5, 0]), out_msg="output_0=5, output_1=0")
_test(out_tensor=torch.tensor(123), out_msg="output=123")
_test(out_tensor=torch.tensor(1.234), out_msg="output=1.23")
示例9
def test_pbar_on_epochs(capsys):
n_epochs = 10
loader = [1, 2, 3, 4, 5]
engine = Engine(update_fn)
pbar = ProgressBar()
pbar.attach(engine, event_name=Events.EPOCH_STARTED, closing_event_name=Events.COMPLETED)
engine.run(loader, max_epochs=n_epochs)
captured = capsys.readouterr()
err = captured.err.split("\r")
err = list(map(lambda x: x.strip(), err))
err = list(filter(None, err))
actual = err[-1]
expected = "Epoch: [9/10] 90%|█████████ [00:00<00:00]"
assert actual == expected
示例10
def test_pbar_wrong_events_order():
engine = Engine(update_fn)
pbar = ProgressBar()
with pytest.raises(ValueError, match="should be called before closing event"):
pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.COMPLETED)
with pytest.raises(ValueError, match="should be called before closing event"):
pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.EPOCH_COMPLETED)
with pytest.raises(ValueError, match="should be called before closing event"):
pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.ITERATION_COMPLETED)
with pytest.raises(ValueError, match="should be called before closing event"):
pbar.attach(engine, event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.EPOCH_COMPLETED)
with pytest.raises(ValueError, match="should be called before closing event"):
pbar.attach(engine, event_name=Events.ITERATION_COMPLETED, closing_event_name=Events.ITERATION_STARTED)
with pytest.raises(ValueError, match="should not be a filtered event"):
pbar.attach(engine, event_name=Events.ITERATION_STARTED, closing_event_name=Events.EPOCH_COMPLETED(every=10))
示例11
def test_pbar_on_callable_events(capsys):
n_epochs = 1
loader = list(range(100))
engine = Engine(update_fn)
pbar = ProgressBar()
pbar.attach(engine, event_name=Events.ITERATION_STARTED(every=10), closing_event_name=Events.EPOCH_COMPLETED)
engine.run(loader, max_epochs=n_epochs)
captured = capsys.readouterr()
err = captured.err.split("\r")
err = list(map(lambda x: x.strip(), err))
err = list(filter(None, err))
actual = err[-1]
expected = "Epoch: [90/100] 90%|█████████ [00:00<00:00]"
assert actual == expected
示例12
def test_pbar_log_message(capsys):
pbar = ProgressBar()
pbar.log_message("test")
captured = capsys.readouterr()
out = captured.out.split("\r")
out = list(map(lambda x: x.strip(), out))
out = list(filter(None, out))
expected = "test"
assert out[-1] == expected
示例13
def test_pbar_log_message_file(tmp_path):
file_path = tmp_path / "temp.txt"
file = open(str(file_path), "w+")
pbar = ProgressBar(file=file)
pbar.log_message("test")
file.close() # Force a flush of the buffer. file.flush() does not work.
file = open(str(file_path), "r")
lines = file.readlines()
expected = "test\n"
assert lines[0] == expected
示例14
def test_pbar_with_metric(capsys):
n_iters = 2
data = list(range(n_iters))
loss_values = iter(range(n_iters))
def step(engine, batch):
loss_value = next(loss_values)
return loss_value
trainer = Engine(step)
RunningAverage(alpha=0.5, output_transform=lambda x: x).attach(trainer, "batchloss")
pbar = ProgressBar()
pbar.attach(trainer, metric_names=["batchloss",])
trainer.run(data=data, max_epochs=1)
captured = capsys.readouterr()
err = captured.err.split("\r")
err = list(map(lambda x: x.strip(), err))
err = list(filter(None, err))
actual = err[-1]
expected = "Epoch: [1/2] 50%|█████ , batchloss=0.5 [00:00<00:00]"
assert actual == expected
示例15
def test_pbar_with_all_metric(capsys):
n_iters = 2
data = list(range(n_iters))
loss_values = iter(range(n_iters))
another_loss_values = iter(range(1, n_iters + 1))
def step(engine, batch):
loss_value = next(loss_values)
another_loss_value = next(another_loss_values)
return loss_value, another_loss_value
trainer = Engine(step)
RunningAverage(alpha=0.5, output_transform=lambda x: x[0]).attach(trainer, "batchloss")
RunningAverage(alpha=0.5, output_transform=lambda x: x[1]).attach(trainer, "another batchloss")
pbar = ProgressBar()
pbar.attach(trainer, metric_names="all")
trainer.run(data=data, max_epochs=1)
captured = capsys.readouterr()
err = captured.err.split("\r")
err = list(map(lambda x: x.strip(), err))
err = list(filter(None, err))
actual = err[-1]
expected = "Epoch: [1/2] 50%|█████ , another batchloss=1.5, batchloss=0.5 [00:00<00:00]"
assert actual == expected
示例16
def test_pbar_with_tqdm_kwargs(capsys):
n_epochs = 10
loader = [1, 2, 3, 4, 5]
engine = Engine(update_fn)
pbar = ProgressBar(desc="My description: ")
pbar.attach(engine, output_transform=lambda x: x)
engine.run(loader, max_epochs=n_epochs)
captured = capsys.readouterr()
err = captured.err.split("\r")
err = list(map(lambda x: x.strip(), err))
err = list(filter(None, err))
expected = "My description: [10/10]: [4/5] 80%|████████ , output=1 [00:00<00:00]"
assert err[-1] == expected
示例17
def test_pbar_for_validation(capsys):
loader = [1, 2, 3, 4, 5]
engine = Engine(update_fn)
pbar = ProgressBar(desc="Validation")
pbar.attach(engine)
engine.run(loader, max_epochs=1)
captured = capsys.readouterr()
err = captured.err.split("\r")
err = list(map(lambda x: x.strip(), err))
err = list(filter(None, err))
expected = "Validation: [4/5] 80%|████████ [00:00<00:00]"
assert err[-1] == expected
示例18
def test_pbar_on_custom_events(capsys):
engine = Engine(update_fn)
pbar = ProgressBar()
with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"):
cpe = CustomPeriodicEvent(n_iterations=15)
with pytest.raises(ValueError, match=r"not in allowed events for this engine"):
pbar.attach(engine, event_name=cpe.Events.ITERATIONS_15_COMPLETED, closing_event_name=Events.EPOCH_COMPLETED)
示例19
def test_pbar_with_nan_input():
def update(engine, batch):
x = batch
return x.item()
def create_engine():
engine = Engine(update)
pbar = ProgressBar()
engine.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
pbar.attach(engine, event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED)
return engine
data = torch.from_numpy(np.array([np.nan] * 25))
engine = create_engine()
engine.run(data)
assert engine.should_terminate
assert engine.state.iteration == 1
assert engine.state.epoch == 1
data = torch.from_numpy(np.array([1] * 1000 + [np.nan] * 25))
engine = create_engine()
engine.run(data)
assert engine.should_terminate
assert engine.state.iteration == 1001
assert engine.state.epoch == 1
示例20
def test_tqdm_logger_iter_without_epoch_length(capsys):
size = 11
def finite_size_data_iter(size):
for i in range(size):
yield i
def train_step(trainer, batch):
pass
trainer = Engine(train_step)
@trainer.on(Events.ITERATION_COMPLETED(every=size))
def restart_iter():
trainer.state.dataloader = finite_size_data_iter(size)
pbar = ProgressBar(persist=True)
pbar.attach(trainer)
data_iter = finite_size_data_iter(size)
trainer.run(data_iter, max_epochs=5)
captured = capsys.readouterr()
err = captured.err.split("\r")
err = list(map(lambda x: x.strip(), err))
err = list(filter(None, err))
actual = err[-1]
expected = "Epoch [5/5]: [11/11] 100%|██████████ [00:00<00:00]"
assert actual == expected
示例21
def train():
writer = SummaryWriter()
net, optimiser, lr_scheduler, train_loader, val_loader = cifar10_experiment()
# Pre-training pruning using SKIP
keep_masks = SNIP(net, 0.05, train_loader, device) # TODO: shuffle?
apply_prune_mask(net, keep_masks)
trainer = create_supervised_trainer(net, optimiser, F.nll_loss, device)
evaluator = create_supervised_evaluator(net, {
'accuracy': Accuracy(),
'nll': Loss(F.nll_loss)
}, device)
pbar = ProgressBar()
pbar.attach(trainer)
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
lr_scheduler.step()
iter_in_epoch = (engine.state.iteration - 1) % len(train_loader) + 1
if engine.state.iteration % LOG_INTERVAL == 0:
# pbar.log_message("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
# "".format(engine.state.epoch, iter_in_epoch, len(train_loader), engine.state.output))
writer.add_scalar("training/loss", engine.state.output,
engine.state.iteration)
@trainer.on(Events.EPOCH_COMPLETED)
def log_epoch(engine):
evaluator.run(val_loader)
metrics = evaluator.state.metrics
avg_accuracy = metrics['accuracy']
avg_nll = metrics['nll']
# pbar.log_message("Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
# .format(engine.state.epoch, avg_accuracy, avg_nll))
writer.add_scalar("validation/loss", avg_nll, engine.state.iteration)
writer.add_scalar("validation/accuracy", avg_accuracy,
engine.state.iteration)
trainer.run(train_loader, EPOCHS)
# Let's look at the final weights
# for name, param in net.named_parameters():
# if name.endswith('weight'):
# writer.add_histogram(name, param)
writer.close()
示例22
def run(train_batch_size, val_batch_size, epochs, lr, momentum, display_gpu_info):
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
model = Net()
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
model.to(device) # Move model before creating optimizer
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device)
evaluator = create_supervised_evaluator(
model, metrics={"accuracy": Accuracy(), "nll": Loss(F.nll_loss)}, device=device
)
RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
if display_gpu_info:
from ignite.contrib.metrics import GpuInfo
GpuInfo().attach(trainer, name="gpu")
pbar = ProgressBar(persist=True)
pbar.attach(trainer, metric_names="all")
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
evaluator.run(train_loader)
metrics = evaluator.state.metrics
avg_accuracy = metrics["accuracy"]
avg_nll = metrics["nll"]
pbar.log_message(
"Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}".format(
engine.state.epoch, avg_accuracy, avg_nll
)
)
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
evaluator.run(val_loader)
metrics = evaluator.state.metrics
avg_accuracy = metrics["accuracy"]
avg_nll = metrics["nll"]
pbar.log_message(
"Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}".format(
engine.state.epoch, avg_accuracy, avg_nll
)
)
pbar.n = pbar.last_print_n = 0
trainer.run(train_loader, max_epochs=epochs)