Python源码示例:tensorflow.keras.datasets.mnist.load_data()
示例1
def fit_VAE():
# for testing
from numpy_ml.neural_nets.models.vae import BernoulliVAE
np.random.seed(12345)
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# scale pixel intensities to [0, 1]
X_train = np.expand_dims(X_train.astype("float32") / 255.0, 3)
X_test = np.expand_dims(X_test.astype("float32") / 255.0, 3)
X_train = X_train[: 128 * 1] # 1 batch
BV = BernoulliVAE()
BV.fit(X_train, n_epochs=1, verbose=False)
示例2
def get_mnist_data():
# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], MNIST.img_rows, MNIST.img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], MNIST.img_rows, MNIST.img_cols, 1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, MNIST.num_classes)
y_test = keras.utils.to_categorical(y_test, MNIST.num_classes)
return x_train, y_train, x_test, y_test
示例3
def __init__(self, **kwargs):
super(MnistRet, self).__init__(name=None, queries_in_collection=True)
self.name = "RET_MNIST"
(x_train, y_train), (x_test, y_test) = load_data()
idx_train = np.where(y_train < 5)[0]
idx_test = np.where(y_test < 5)[0]
self.train_images = np.concatenate([x_train[idx_train], x_test[idx_test]], axis=0)
self.train_labels = np.concatenate([y_train[idx_train], y_test[idx_test]], axis=0)
idx_train = np.where(y_train >= 5)[0]
idx_test = np.where(y_test >= 5)[0]
self.test_images = np.concatenate([x_train[idx_train], x_test[idx_test]], axis=0)
self.test_labels = np.concatenate([y_train[idx_train], y_test[idx_test]], axis=0)
self.train_images = self.train_images[..., None]
self.test_images = self.test_images[..., None]
示例4
def get_mnist():
(x_train, y_train), (x_valid, y_valid) = mnist.load_data()
x_train = x_train.astype("float32") / 255
x_valid = x_valid.astype("float32") / 255
y_train = y_train.astype("int32")
y_valid = y_valid.astype("int32")
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(60000).batch(BATCHSIZE).take(N_TRAIN_EXAMPLES)
valid_ds = tf.data.Dataset.from_tensor_slices((x_valid, y_valid))
valid_ds = valid_ds.shuffle(10000).batch(BATCHSIZE).take(N_VALID_EXAMPLES)
return train_ds, valid_ds
# FYI: Objective functions can take additional arguments
# (https://optuna.readthedocs.io/en/stable/faq.html#objective-func-additional-args).
示例5
def main(unused_argv):
# Get the data.
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images.reshape(len(train_images), 28, 28, 1)
test_images = test_images.reshape(len(test_images), 28, 28, 1)
# Convert to Examples and write the result to TFRecords.
convert_to(train_images, train_labels, 'train')
convert_to(test_images, test_labels, 'test')
示例6
def load_data(k=8, noise_level=0.0):
"""
Loads the MNIST dataset and a K-NN graph to perform graph signal
classification, as described by [Defferrard et al. (2016)](https://arxiv.org/abs/1606.09375).
The K-NN graph is statically determined from a regular grid of pixels using
the 2d coordinates.
The node features of each graph are the MNIST digits vectorized and rescaled
to [0, 1].
Two nodes are connected if they are neighbours according to the K-NN graph.
Labels are the MNIST class associated to each sample.
:param k: int, number of neighbours for each node;
:param noise_level: fraction of edges to flip (from 0 to 1 and vice versa);
:return:
- X_train, y_train: training node features and labels;
- X_val, y_val: validation node features and labels;
- X_test, y_test: test node features and labels;
- A: adjacency matrix of the grid;
"""
A = _mnist_grid_graph(k)
A = _flip_random_edges(A, noise_level).astype(np.float32)
(X_train, y_train), (X_test, y_test) = m.load_data()
X_train, X_test = X_train / 255.0, X_test / 255.0
X_train = X_train.reshape(-1, MNIST_SIZE ** 2)
X_test = X_test.reshape(-1, MNIST_SIZE ** 2)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=10000)
return X_train, y_train, X_val, y_val, X_test, y_test, A
示例7
def main():
# input image dimensions
img_rows, img_cols = 28, 28
# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
if K.image_data_format() == "channels_first":
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
input_shape = (1, img_rows, img_cols)
else:
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
x_train /= 255
x_test /= 255
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")
# convert class vectors to binary class matrices
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)
pruning_params = {
"pruning_schedule":
pruning_schedule.ConstantSparsity(0.75, begin_step=2000, frequency=100)
}
if prune_whole_model:
model = build_model(input_shape)
model = prune.prune_low_magnitude(model, **pruning_params)
else:
model = build_layerwise_model(input_shape, **pruning_params)
train_and_save(model, x_train, y_train, x_test, y_test)
示例8
def load_mnist():
# the data, shuffled and split between train and test sets
from tensorflow.keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x = np.concatenate((x_train, x_test))
y = np.concatenate((y_train, y_test))
x = x.reshape([-1, 28, 28, 1]) / 255.0
print('MNIST samples', x.shape)
return x, y
示例9
def load_mnist_test():
# the data, shuffled and split between train and test sets
from tensorflow.keras.datasets import mnist
_, (x, y) = mnist.load_data()
x = x.reshape([-1, 28, 28, 1]) / 255.0
print('MNIST samples', x.shape)
return x, y
示例10
def load_fashion_mnist():
from tensorflow.keras.datasets import fashion_mnist # this requires keras>=2.0.9
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x = np.concatenate((x_train, x_test))
y = np.concatenate((y_train, y_test))
x = x.reshape([-1, 28, 28, 1]) / 255.0
print('Fashion MNIST samples', x.shape)
return x, y
示例11
def load_data(dataset):
x, y = load_data_conv(dataset)
return x.reshape([x.shape[0], -1]), y
示例12
def _dataset():
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train / 255
x_test = x_test / 255
axis = 1 if keras.backend.image_data_format() == 'channels_first' else -1
x_train = np.expand_dims(x_train, axis)
x_test = np.expand_dims(x_test, axis)
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
return x_train, y_train, x_test, y_test
示例13
def __init__(self, seq_len=5, square_count=3, square_size=5, noise_ratio=0.15, digits=range(10), max_angle=180):
(x_train, y_train),(x_test, y_test) = mnist.load_data()
mnist_train = [(img,label) for img, label in zip(x_train, y_train) if label in digits]
mnist_test = [(img, label) for img, label in zip(x_test, y_test) if label in digits]
train_images = []
test_images = []
train_rotations = []
test_rotations = []
train_labels = []
test_labels = []
for img, label in mnist_train:
train_img, train_rot = heal_image(img, seq_len, square_count, square_size, noise_ratio, max_angle)
train_images.append(train_img)
train_rotations.append(train_rot)
train_labels.append(label)
for img, label in mnist_test:
test_img, test_rot = heal_image(img, seq_len, square_count, square_size, noise_ratio, max_angle)
test_images.append(test_img)
test_rotations.append(test_rot)
test_labels.append(label)
self.train_images = np.array(train_images)
self.test_images = np.array(test_images)
self.train_rotations = np.array(train_rotations)
self.test_rotations = np.array(test_rotations)
self.train_labels = np.array(train_labels)
self.test_labels = np.array(test_labels)
示例14
def train_mnist(config):
# https://github.com/tensorflow/tensorflow/issues/32159
import tensorflow as tf
batch_size = 128
num_classes = 10
epochs = 12
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(config["hidden"], activation="relu"),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(num_classes, activation="softmax")
])
model.compile(
loss="sparse_categorical_crossentropy",
optimizer=tf.keras.optimizers.SGD(
lr=config["lr"], momentum=config["momentum"]),
metrics=["accuracy"])
model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=epochs,
verbose=0,
validation_data=(x_test, y_test),
callbacks=[TuneReporterCallback()])
示例15
def UseNetwork(weights_f, load_weights=False):
"""Use DenseModel.
Args:
weights_f: weight file location.
load_weights: load weights when it is True.
"""
model = QDenseModel(weights_f, load_weights)
batch_size = BATCH_SIZE
(x_train_, y_train_), (x_test_, y_test_) = mnist.load_data()
x_train_ = x_train_.reshape(60000, RESHAPED)
x_test_ = x_test_.reshape(10000, RESHAPED)
x_train_ = x_train_.astype("float32")
x_test_ = x_test_.astype("float32")
x_train_ /= 255
x_test_ /= 255
print(x_train_.shape[0], "train samples")
print(x_test_.shape[0], "test samples")
y_train_ = to_categorical(y_train_, NB_CLASSES)
y_test_ = to_categorical(y_test_, NB_CLASSES)
if not load_weights:
model.fit(
x_train_,
y_train_,
batch_size=batch_size,
epochs=NB_EPOCH,
verbose=VERBOSE,
validation_split=VALIDATION_SPLIT)
if weights_f:
model.save_weights(weights_f)
score = model.evaluate(x_test_, y_test_, verbose=VERBOSE)
print_qstats(model)
print("Test score:", score[0])
print("Test accuracy:", score[1])
示例16
def UseNetwork(weights_f, load_weights=False):
"""Use DenseModel.
Args:
weights_f: weight file location.
load_weights: load weights when it is True.
"""
model = QDenseModel(weights_f, load_weights)
batch_size = BATCH_SIZE
(x_train_, y_train_), (x_test_, y_test_) = mnist.load_data()
x_train_ = x_train_.reshape(60000, 28, 28, 1)
x_test_ = x_test_.reshape(10000, 28, 28, 1)
x_train_ = x_train_.astype("float32")
x_test_ = x_test_.astype("float32")
x_train_ /= 256.
x_test_ /= 256.
# x_train_ = 2*x_train_ - 1.0
# x_test_ = 2*x_test_ - 1.0
print(x_train_.shape[0], "train samples")
print(x_test_.shape[0], "test samples")
y_train_ = to_categorical(y_train_, NB_CLASSES)
y_test_ = to_categorical(y_test_, NB_CLASSES)
if not load_weights:
model.fit(
x_train_,
y_train_,
batch_size=batch_size,
epochs=NB_EPOCH,
verbose=VERBOSE,
validation_split=VALIDATION_SPLIT)
if weights_f:
model.save_weights(weights_f)
score = model.evaluate(x_test_, y_test_, verbose=False)
print("Test score:", score[0])
print("Test accuracy:", score[1])
return model, x_train_, x_test_
示例17
def UseNetwork(weights_f, load_weights=False):
"""Use DenseModel.
Args:
weights_f: weight file location.
load_weights: load weights when it is True.
"""
model = QDenseModel(weights_f, load_weights)
batch_size = BATCH_SIZE
(x_train_, y_train_), (x_test_, y_test_) = mnist.load_data()
x_train_ = x_train_.reshape(60000, 28*28)
x_test_ = x_test_.reshape(10000, 28*28)
x_train_ = x_train_.astype("float32")
x_test_ = x_test_.astype("float32")
x_train_ /= 256.
x_test_ /= 256.
# x_train_ = 2*x_train_ - 1.0
# x_test_ = 2*x_test_ - 1.0
print(x_train_.shape[0], "train samples")
print(x_test_.shape[0], "test samples")
y_train_ = to_categorical(y_train_, NB_CLASSES)
y_test_ = to_categorical(y_test_, NB_CLASSES)
if not load_weights:
model.fit(
x_train_,
y_train_,
batch_size=batch_size,
epochs=NB_EPOCH,
verbose=VERBOSE,
validation_split=VALIDATION_SPLIT)
if weights_f:
model.save_weights(weights_f)
score = model.evaluate(x_test_, y_test_, verbose=False)
print("Test score:", score[0])
print("Test accuracy:", score[1])
return model, x_train_
示例18
def UseNetwork(weights_f, load_weights=False):
"""Use DenseModel.
Args:
weights_f: weight file location.
load_weights: load weights when it is True.
"""
model = QConv2DModel(weights_f, load_weights)
batch_size = BATCH_SIZE
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 28, 28, 1)
x_test = x_test.reshape(10000, 28, 28, 1)
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
x_train /= 256.
x_test /= 256.
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")
y_train = to_categorical(y_train, NB_CLASSES)
y_test = to_categorical(y_test, NB_CLASSES)
if not load_weights:
model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=NB_EPOCH,
verbose=VERBOSE,
validation_split=VALIDATION_SPLIT)
if weights_f:
model.save_weights(weights_f)
score = model.evaluate(x_test, y_test, verbose=False)
print("Test score:", score[0])
print("Test accuracy:", score[1])
return model, x_train, x_test
示例19
def UseNetwork(weights_f, load_weights=False):
"""Use DenseModel.
Args:
weights_f: weight file location.
load_weights: load weights when it is True.
"""
model = QDenseModel(weights_f, load_weights)
batch_size = BATCH_SIZE
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 28*28)
x_test = x_test.reshape(10000, 28*28)
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
x_train /= 256.
x_test /= 256.
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")
y_train = to_categorical(y_train_, NB_CLASSES)
y_test = to_categorical(y_test_, NB_CLASSES)
if not load_weights:
model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=NB_EPOCH,
verbose=VERBOSE,
validation_split=VALIDATION_SPLIT)
if weights_f:
model.save_weights(weights_f)
score = model.evaluate(x_test, y_test, verbose=False)
print("Test score:", score[0])
print("Test accuracy:", score[1])
return model, x_train
示例20
def build_and_train_models():
# load MNIST dataset
(x_train, _), (_, _) = mnist.load_data()
# reshape data for CNN as (28, 28, 1) and normalize
image_size = x_train.shape[1]
x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
x_train = x_train.astype('float32') / 255
model_name = "dcgan_mnist"
# network parameters
# the latent or z vector is 100-dim
latent_size = 100
batch_size = 64
train_steps = 40000
lr = 2e-4
decay = 6e-8
input_shape = (image_size, image_size, 1)
# build discriminator model
inputs = Input(shape=input_shape, name='discriminator_input')
discriminator = build_discriminator(inputs)
# [1] or original paper uses Adam,
# but discriminator converges easily with RMSprop
optimizer = RMSprop(lr=lr, decay=decay)
discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
discriminator.summary()
# build generator model
input_shape = (latent_size, )
inputs = Input(shape=input_shape, name='z_input')
generator = build_generator(inputs, image_size)
generator.summary()
# build adversarial model
optimizer = RMSprop(lr=lr * 0.5, decay=decay * 0.5)
# freeze the weights of discriminator during adversarial training
discriminator.trainable = False
# adversarial = generator + discriminator
adversarial = Model(inputs,
discriminator(generator(inputs)),
name=model_name)
adversarial.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
adversarial.summary()
# train discriminator and adversarial networks
models = (generator, discriminator, adversarial)
params = (batch_size, latent_size, train_steps, model_name)
train(models, x_train, params)
示例21
def build_and_train_models():
# load MNIST dataset
(x_train, y_train), (_, _) = mnist.load_data()
# reshape data for CNN as (28, 28, 1) and normalize
image_size = x_train.shape[1]
x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
x_train = x_train.astype('float32') / 255
num_labels = np.amax(y_train) + 1
y_train = to_categorical(y_train)
model_name = "cgan_mnist"
# network parameters
# the latent or z vector is 100-dim
latent_size = 100
batch_size = 64
train_steps = 40000
lr = 2e-4
decay = 6e-8
input_shape = (image_size, image_size, 1)
label_shape = (num_labels, )
# build discriminator model
inputs = Input(shape=input_shape, name='discriminator_input')
labels = Input(shape=label_shape, name='class_labels')
discriminator = build_discriminator(inputs, labels, image_size)
# [1] or original paper uses Adam,
# but discriminator converges easily with RMSprop
optimizer = RMSprop(lr=lr, decay=decay)
discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
discriminator.summary()
# build generator model
input_shape = (latent_size, )
inputs = Input(shape=input_shape, name='z_input')
generator = build_generator(inputs, labels, image_size)
generator.summary()
# build adversarial model = generator + discriminator
optimizer = RMSprop(lr=lr*0.5, decay=decay*0.5)
# freeze the weights of discriminator during adversarial training
discriminator.trainable = False
outputs = discriminator([generator([inputs, labels]), labels])
adversarial = Model([inputs, labels],
outputs,
name=model_name)
adversarial.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
adversarial.summary()
# train discriminator and adversarial networks
models = (generator, discriminator, adversarial)
data = (x_train, y_train)
params = (batch_size, latent_size, train_steps, num_labels, model_name)
train(models, data, params)
示例22
def build_and_train_models():
"""Load the dataset, build LSGAN discriminator,
generator, and adversarial models.
Call the LSGAN train routine.
"""
# load MNIST dataset
(x_train, _), (_, _) = mnist.load_data()
# reshape data for CNN as (28, 28, 1) and normalize
image_size = x_train.shape[1]
x_train = np.reshape(x_train,
[-1, image_size, image_size, 1])
x_train = x_train.astype('float32') / 255
model_name = "lsgan_mnist"
# network parameters
# the latent or z vector is 100-dim
latent_size = 100
input_shape = (image_size, image_size, 1)
batch_size = 64
lr = 2e-4
decay = 6e-8
train_steps = 40000
# build discriminator model
inputs = Input(shape=input_shape, name='discriminator_input')
discriminator = gan.discriminator(inputs, activation=None)
# [1] uses Adam, but discriminator easily
# converges with RMSprop
optimizer = RMSprop(lr=lr, decay=decay)
# LSGAN uses MSE loss [2]
discriminator.compile(loss='mse',
optimizer=optimizer,
metrics=['accuracy'])
discriminator.summary()
# build generator model
input_shape = (latent_size, )
inputs = Input(shape=input_shape, name='z_input')
generator = gan.generator(inputs, image_size)
generator.summary()
# build adversarial model = generator + discriminator
optimizer = RMSprop(lr=lr*0.5, decay=decay*0.5)
# freeze the weights of discriminator
# during adversarial training
discriminator.trainable = False
adversarial = Model(inputs,
discriminator(generator(inputs)),
name=model_name)
# LSGAN uses MSE loss [2]
adversarial.compile(loss='mse',
optimizer=optimizer,
metrics=['accuracy'])
adversarial.summary()
# train discriminator and adversarial networks
models = (generator, discriminator, adversarial)
params = (batch_size, latent_size, train_steps, model_name)
gan.train(models, x_train, params)
示例23
def load_data():
# load mnist data
(source_data, _), (test_source_data, _) = mnist.load_data()
# pad with zeros 28x28 MNIST image to become 32x32
# svhn is 32x32
source_data = np.pad(source_data,
((0,0), (2,2), (2,2)),
'constant',
constant_values=0)
test_source_data = np.pad(test_source_data,
((0,0), (2,2), (2,2)),
'constant',
constant_values=0)
# input image dimensions
# we assume data format "channels_last"
rows = source_data.shape[1]
cols = source_data.shape[2]
channels = 1
# reshape images to row x col x channels
# for CNN output/validation
size = source_data.shape[0]
source_data = source_data.reshape(size,
rows,
cols,
channels)
size = test_source_data.shape[0]
test_source_data = test_source_data.reshape(size,
rows,
cols,
channels)
# load SVHN data
datadir = get_datadir()
get_file('train_32x32.mat',
origin='http://ufldl.stanford.edu/housenumbers/train_32x32.mat')
get_file('test_32x32.mat',
'http://ufldl.stanford.edu/housenumbers/test_32x32.mat')
path = os.path.join(datadir, 'train_32x32.mat')
target_data = loadmat(path)
path = os.path.join(datadir, 'test_32x32.mat')
test_target_data = loadmat(path)
# source data, target data, test_source data
data = (source_data, target_data, test_source_data, test_target_data)
filenames = ('mnist_test_source.png', 'svhn_test_target.png')
titles = ('MNIST test source images', 'SVHN test target images')
return other_utils.load_data(data, titles, filenames)
示例24
def setup(self, config):
# IMPORTANT: See the above note.
import tensorflow as tf
(x_train, y_train), (x_test, y_test) = load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# Add a channels dimension
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
self.train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
self.train_ds = self.train_ds.shuffle(10000).batch(
config.get("batch", 32))
self.test_ds = tf.data.Dataset.from_tensor_slices((x_test,
y_test)).batch(32)
self.model = MyModel(hiddens=config.get("hiddens", 128))
self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
self.optimizer = tf.keras.optimizers.Adam()
self.train_loss = tf.keras.metrics.Mean(name="train_loss")
self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name="train_accuracy")
self.test_loss = tf.keras.metrics.Mean(name="test_loss")
self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name="test_accuracy")
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = self.model(images)
loss = self.loss_object(labels, predictions)
gradients = tape.gradient(loss, self.model.trainable_variables)
self.optimizer.apply_gradients(
zip(gradients, self.model.trainable_variables))
self.train_loss(loss)
self.train_accuracy(labels, predictions)
@tf.function
def test_step(images, labels):
predictions = self.model(images)
t_loss = self.loss_object(labels, predictions)
self.test_loss(t_loss)
self.test_accuracy(labels, predictions)
self.tf_train_step = train_step
self.tf_test_step = test_step