Python源码示例:tensorflow.keras.datasets.mnist.load_data()

示例1
def fit_VAE():
    # for testing
    from numpy_ml.neural_nets.models.vae import BernoulliVAE

    np.random.seed(12345)

    (X_train, y_train), (X_test, y_test) = mnist.load_data()

    # scale pixel intensities to [0, 1]
    X_train = np.expand_dims(X_train.astype("float32") / 255.0, 3)
    X_test = np.expand_dims(X_test.astype("float32") / 255.0, 3)

    X_train = X_train[: 128 * 1]  # 1 batch

    BV = BernoulliVAE()
    BV.fit(X_train, n_epochs=1, verbose=False) 
示例2
def get_mnist_data():
        # the data, shuffled and split between train and test sets
        (x_train, y_train), (x_test, y_test) = mnist.load_data()

        x_train = x_train.reshape(x_train.shape[0], MNIST.img_rows, MNIST.img_cols, 1)
        x_test = x_test.reshape(x_test.shape[0], MNIST.img_rows, MNIST.img_cols, 1)

        x_train = x_train.astype('float32')
        x_test = x_test.astype('float32')
        x_train /= 255
        x_test /= 255
        print('x_train shape:', x_train.shape)
        print(x_train.shape[0], 'train samples')
        print(x_test.shape[0], 'test samples')

        # convert class vectors to binary class matrices
        y_train = keras.utils.to_categorical(y_train, MNIST.num_classes)
        y_test = keras.utils.to_categorical(y_test, MNIST.num_classes)
        return x_train, y_train, x_test, y_test 
示例3
def __init__(self, **kwargs):
        super(MnistRet, self).__init__(name=None, queries_in_collection=True)
        self.name = "RET_MNIST"
        (x_train, y_train), (x_test, y_test) = load_data()
        idx_train = np.where(y_train < 5)[0]
        idx_test = np.where(y_test < 5)[0]
        self.train_images = np.concatenate([x_train[idx_train], x_test[idx_test]], axis=0)
        self.train_labels = np.concatenate([y_train[idx_train], y_test[idx_test]], axis=0)

        idx_train = np.where(y_train >= 5)[0]
        idx_test = np.where(y_test >= 5)[0]
        self.test_images = np.concatenate([x_train[idx_train], x_test[idx_test]], axis=0)
        self.test_labels = np.concatenate([y_train[idx_train], y_test[idx_test]], axis=0)

        self.train_images = self.train_images[..., None]
        self.test_images = self.test_images[..., None] 
示例4
def get_mnist():
    (x_train, y_train), (x_valid, y_valid) = mnist.load_data()
    x_train = x_train.astype("float32") / 255
    x_valid = x_valid.astype("float32") / 255

    y_train = y_train.astype("int32")
    y_valid = y_valid.astype("int32")

    train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    train_ds = train_ds.shuffle(60000).batch(BATCHSIZE).take(N_TRAIN_EXAMPLES)

    valid_ds = tf.data.Dataset.from_tensor_slices((x_valid, y_valid))
    valid_ds = valid_ds.shuffle(10000).batch(BATCHSIZE).take(N_VALID_EXAMPLES)
    return train_ds, valid_ds


# FYI: Objective functions can take additional arguments
# (https://optuna.readthedocs.io/en/stable/faq.html#objective-func-additional-args). 
示例5
def main(unused_argv):
  # Get the data.
  (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
  train_images = train_images.reshape(len(train_images), 28, 28, 1)
  test_images = test_images.reshape(len(test_images), 28, 28, 1)

  # Convert to Examples and write the result to TFRecords.
  convert_to(train_images, train_labels, 'train')
  convert_to(test_images, test_labels, 'test') 
示例6
def load_data(k=8, noise_level=0.0):
    """
    Loads the MNIST dataset and a K-NN graph to perform graph signal
    classification, as described by [Defferrard et al. (2016)](https://arxiv.org/abs/1606.09375).
    The K-NN graph is statically determined from a regular grid of pixels using
    the 2d coordinates.

    The node features of each graph are the MNIST digits vectorized and rescaled
    to [0, 1].
    Two nodes are connected if they are neighbours according to the K-NN graph.
    Labels are the MNIST class associated to each sample.

    :param k: int, number of neighbours for each node;
    :param noise_level: fraction of edges to flip (from 0 to 1 and vice versa);

    :return:
        - X_train, y_train: training node features and labels;
        - X_val, y_val: validation node features and labels;
        - X_test, y_test: test node features and labels;
        - A: adjacency matrix of the grid;
    """
    A = _mnist_grid_graph(k)
    A = _flip_random_edges(A, noise_level).astype(np.float32)

    (X_train, y_train), (X_test, y_test) = m.load_data()
    X_train, X_test = X_train / 255.0, X_test / 255.0
    X_train = X_train.reshape(-1, MNIST_SIZE ** 2)
    X_test = X_test.reshape(-1, MNIST_SIZE ** 2)

    X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=10000)

    return X_train, y_train, X_val, y_val, X_test, y_test, A 
示例7
def main():
    # input image dimensions
    img_rows, img_cols = 28, 28

    # the data, shuffled and split between train and test sets
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    if K.image_data_format() == "channels_first":
      x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
      x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
      input_shape = (1, img_rows, img_cols)
    else:
      x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
      x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
      input_shape = (img_rows, img_cols, 1)

    x_train = x_train.astype("float32")
    x_test = x_test.astype("float32")
    x_train /= 255
    x_test /= 255
    print("x_train shape:", x_train.shape)
    print(x_train.shape[0], "train samples")
    print(x_test.shape[0], "test samples")

    # convert class vectors to binary class matrices
    y_train = to_categorical(y_train, num_classes)
    y_test = to_categorical(y_test, num_classes)

    pruning_params = {
        "pruning_schedule":
            pruning_schedule.ConstantSparsity(0.75, begin_step=2000, frequency=100)
    }
    
    if prune_whole_model:
        model = build_model(input_shape)
        model = prune.prune_low_magnitude(model, **pruning_params)
    else:
        model = build_layerwise_model(input_shape, **pruning_params)

    train_and_save(model, x_train, y_train, x_test, y_test) 
示例8
def load_mnist():
    # the data, shuffled and split between train and test sets
    from tensorflow.keras.datasets import mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x = np.concatenate((x_train, x_test))
    y = np.concatenate((y_train, y_test))
    x = x.reshape([-1, 28, 28, 1]) / 255.0
    print('MNIST samples', x.shape)
    return x, y 
示例9
def load_mnist_test():
    # the data, shuffled and split between train and test sets
    from tensorflow.keras.datasets import mnist
    _, (x, y) = mnist.load_data()
    x = x.reshape([-1, 28, 28, 1]) / 255.0
    print('MNIST samples', x.shape)
    return x, y 
示例10
def load_fashion_mnist():
    from tensorflow.keras.datasets import fashion_mnist  # this requires keras>=2.0.9
    (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
    x = np.concatenate((x_train, x_test))
    y = np.concatenate((y_train, y_test))
    x = x.reshape([-1, 28, 28, 1]) / 255.0
    print('Fashion MNIST samples', x.shape)
    return x, y 
示例11
def load_data(dataset):
    x, y = load_data_conv(dataset)
    return x.reshape([x.shape[0], -1]), y 
示例12
def _dataset():

    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    x_train = x_train / 255
    x_test = x_test / 255

    axis = 1 if keras.backend.image_data_format() == 'channels_first' else -1
    x_train = np.expand_dims(x_train, axis)
    x_test = np.expand_dims(x_test, axis)

    y_train = to_categorical(y_train, 10)
    y_test = to_categorical(y_test, 10)

    return x_train, y_train, x_test, y_test 
示例13
def __init__(self, seq_len=5, square_count=3, square_size=5, noise_ratio=0.15, digits=range(10), max_angle=180):
        (x_train, y_train),(x_test, y_test) = mnist.load_data()
        mnist_train = [(img,label) for img, label in zip(x_train, y_train) if label in digits]
        mnist_test = [(img, label) for img, label in zip(x_test, y_test) if label in digits]

        train_images = []
        test_images = []
        train_rotations = []
        test_rotations = []
        train_labels = []
        test_labels = []

        for img, label in mnist_train:
            train_img, train_rot = heal_image(img, seq_len, square_count, square_size, noise_ratio, max_angle)
            train_images.append(train_img)
            train_rotations.append(train_rot)
            train_labels.append(label)

        for img, label in mnist_test:
            test_img, test_rot = heal_image(img, seq_len, square_count, square_size, noise_ratio, max_angle)
            test_images.append(test_img)
            test_rotations.append(test_rot)
            test_labels.append(label)
        
        self.train_images = np.array(train_images)
        self.test_images = np.array(test_images)
        self.train_rotations = np.array(train_rotations)
        self.test_rotations = np.array(test_rotations)
        self.train_labels = np.array(train_labels)
        self.test_labels = np.array(test_labels) 
示例14
def train_mnist(config):
    # https://github.com/tensorflow/tensorflow/issues/32159
    import tensorflow as tf
    batch_size = 128
    num_classes = 10
    epochs = 12

    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)),
        tf.keras.layers.Dense(config["hidden"], activation="relu"),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(num_classes, activation="softmax")
    ])

    model.compile(
        loss="sparse_categorical_crossentropy",
        optimizer=tf.keras.optimizers.SGD(
            lr=config["lr"], momentum=config["momentum"]),
        metrics=["accuracy"])

    model.fit(
        x_train,
        y_train,
        batch_size=batch_size,
        epochs=epochs,
        verbose=0,
        validation_data=(x_test, y_test),
        callbacks=[TuneReporterCallback()]) 
示例15
def UseNetwork(weights_f, load_weights=False):
  """Use DenseModel.

  Args:
    weights_f: weight file location.
    load_weights: load weights when it is True.
  """
  model = QDenseModel(weights_f, load_weights)

  batch_size = BATCH_SIZE
  (x_train_, y_train_), (x_test_, y_test_) = mnist.load_data()

  x_train_ = x_train_.reshape(60000, RESHAPED)
  x_test_ = x_test_.reshape(10000, RESHAPED)
  x_train_ = x_train_.astype("float32")
  x_test_ = x_test_.astype("float32")

  x_train_ /= 255
  x_test_ /= 255

  print(x_train_.shape[0], "train samples")
  print(x_test_.shape[0], "test samples")

  y_train_ = to_categorical(y_train_, NB_CLASSES)
  y_test_ = to_categorical(y_test_, NB_CLASSES)

  if not load_weights:
    model.fit(
        x_train_,
        y_train_,
        batch_size=batch_size,
        epochs=NB_EPOCH,
        verbose=VERBOSE,
        validation_split=VALIDATION_SPLIT)

    if weights_f:
      model.save_weights(weights_f)

  score = model.evaluate(x_test_, y_test_, verbose=VERBOSE)
  print_qstats(model)
  print("Test score:", score[0])
  print("Test accuracy:", score[1]) 
示例16
def UseNetwork(weights_f, load_weights=False):
  """Use DenseModel.

  Args:
    weights_f: weight file location.
    load_weights: load weights when it is True.
  """
  model = QDenseModel(weights_f, load_weights)

  batch_size = BATCH_SIZE
  (x_train_, y_train_), (x_test_, y_test_) = mnist.load_data()

  x_train_ = x_train_.reshape(60000, 28, 28, 1)
  x_test_ = x_test_.reshape(10000, 28, 28, 1)
  x_train_ = x_train_.astype("float32")
  x_test_ = x_test_.astype("float32")

  x_train_ /= 256.
  x_test_ /= 256.

  # x_train_ = 2*x_train_ - 1.0
  # x_test_ = 2*x_test_ - 1.0

  print(x_train_.shape[0], "train samples")
  print(x_test_.shape[0], "test samples")

  y_train_ = to_categorical(y_train_, NB_CLASSES)
  y_test_ = to_categorical(y_test_, NB_CLASSES)

  if not load_weights:
    model.fit(
        x_train_,
        y_train_,
        batch_size=batch_size,
        epochs=NB_EPOCH,
        verbose=VERBOSE,
        validation_split=VALIDATION_SPLIT)

    if weights_f:
      model.save_weights(weights_f)

  score = model.evaluate(x_test_, y_test_, verbose=False)
  print("Test score:", score[0])
  print("Test accuracy:", score[1])

  return model, x_train_, x_test_ 
示例17
def UseNetwork(weights_f, load_weights=False):
  """Use DenseModel.

  Args:
    weights_f: weight file location.
    load_weights: load weights when it is True.
  """
  model = QDenseModel(weights_f, load_weights)

  batch_size = BATCH_SIZE
  (x_train_, y_train_), (x_test_, y_test_) = mnist.load_data()

  x_train_ = x_train_.reshape(60000, 28*28)
  x_test_ = x_test_.reshape(10000, 28*28)
  x_train_ = x_train_.astype("float32")
  x_test_ = x_test_.astype("float32")

  x_train_ /= 256.
  x_test_ /= 256.

  # x_train_ = 2*x_train_ - 1.0
  # x_test_ = 2*x_test_ - 1.0

  print(x_train_.shape[0], "train samples")
  print(x_test_.shape[0], "test samples")

  y_train_ = to_categorical(y_train_, NB_CLASSES)
  y_test_ = to_categorical(y_test_, NB_CLASSES)

  if not load_weights:
    model.fit(
        x_train_,
        y_train_,
        batch_size=batch_size,
        epochs=NB_EPOCH,
        verbose=VERBOSE,
        validation_split=VALIDATION_SPLIT)

    if weights_f:
      model.save_weights(weights_f)

  score = model.evaluate(x_test_, y_test_, verbose=False)
  print("Test score:", score[0])
  print("Test accuracy:", score[1])

  return model, x_train_ 
示例18
def UseNetwork(weights_f, load_weights=False):
  """Use DenseModel.

  Args:
    weights_f: weight file location.
    load_weights: load weights when it is True.
  """
  model = QConv2DModel(weights_f, load_weights)

  batch_size = BATCH_SIZE
  (x_train, y_train), (x_test, y_test) = mnist.load_data()

  x_train = x_train.reshape(60000, 28, 28, 1)
  x_test = x_test.reshape(10000, 28, 28, 1)
  x_train = x_train.astype("float32")
  x_test = x_test.astype("float32")

  x_train /= 256.
  x_test /= 256.

  print(x_train.shape[0], "train samples")
  print(x_test.shape[0], "test samples")

  y_train = to_categorical(y_train, NB_CLASSES)
  y_test = to_categorical(y_test, NB_CLASSES)

  if not load_weights:
    model.fit(
        x_train,
        y_train,
        batch_size=batch_size,
        epochs=NB_EPOCH,
        verbose=VERBOSE,
        validation_split=VALIDATION_SPLIT)

    if weights_f:
      model.save_weights(weights_f)

  score = model.evaluate(x_test, y_test, verbose=False)
  print("Test score:", score[0])
  print("Test accuracy:", score[1])

  return model, x_train, x_test 
示例19
def UseNetwork(weights_f, load_weights=False):
  """Use DenseModel.

  Args:
    weights_f: weight file location.
    load_weights: load weights when it is True.
  """
  model = QDenseModel(weights_f, load_weights)

  batch_size = BATCH_SIZE
  (x_train, y_train), (x_test, y_test) = mnist.load_data()

  x_train = x_train.reshape(60000, 28*28)
  x_test = x_test.reshape(10000, 28*28)
  x_train = x_train.astype("float32")
  x_test = x_test.astype("float32")

  x_train /= 256.
  x_test /= 256.

  print(x_train.shape[0], "train samples")
  print(x_test.shape[0], "test samples")

  y_train = to_categorical(y_train_, NB_CLASSES)
  y_test = to_categorical(y_test_, NB_CLASSES)

  if not load_weights:
    model.fit(
        x_train,
        y_train,
        batch_size=batch_size,
        epochs=NB_EPOCH,
        verbose=VERBOSE,
        validation_split=VALIDATION_SPLIT)

    if weights_f:
      model.save_weights(weights_f)

  score = model.evaluate(x_test, y_test, verbose=False)
  print("Test score:", score[0])
  print("Test accuracy:", score[1])

  return model, x_train 
示例20
def build_and_train_models():
    # load MNIST dataset
    (x_train, _), (_, _) = mnist.load_data()

    # reshape data for CNN as (28, 28, 1) and normalize
    image_size = x_train.shape[1]
    x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
    x_train = x_train.astype('float32') / 255

    model_name = "dcgan_mnist"
    # network parameters
    # the latent or z vector is 100-dim
    latent_size = 100
    batch_size = 64
    train_steps = 40000
    lr = 2e-4
    decay = 6e-8
    input_shape = (image_size, image_size, 1)

    # build discriminator model
    inputs = Input(shape=input_shape, name='discriminator_input')
    discriminator = build_discriminator(inputs)
    # [1] or original paper uses Adam, 
    # but discriminator converges easily with RMSprop
    optimizer = RMSprop(lr=lr, decay=decay)
    discriminator.compile(loss='binary_crossentropy',
                          optimizer=optimizer,
                          metrics=['accuracy'])
    discriminator.summary()

    # build generator model
    input_shape = (latent_size, )
    inputs = Input(shape=input_shape, name='z_input')
    generator = build_generator(inputs, image_size)
    generator.summary()

    # build adversarial model
    optimizer = RMSprop(lr=lr * 0.5, decay=decay * 0.5)
    # freeze the weights of discriminator during adversarial training
    discriminator.trainable = False
    # adversarial = generator + discriminator
    adversarial = Model(inputs, 
                        discriminator(generator(inputs)),
                        name=model_name)
    adversarial.compile(loss='binary_crossentropy',
                        optimizer=optimizer,
                        metrics=['accuracy'])
    adversarial.summary()

    # train discriminator and adversarial networks
    models = (generator, discriminator, adversarial)
    params = (batch_size, latent_size, train_steps, model_name)
    train(models, x_train, params) 
示例21
def build_and_train_models():
    # load MNIST dataset
    (x_train, y_train), (_, _) = mnist.load_data()

    # reshape data for CNN as (28, 28, 1) and normalize
    image_size = x_train.shape[1]
    x_train = np.reshape(x_train, [-1, image_size, image_size, 1])
    x_train = x_train.astype('float32') / 255

    num_labels = np.amax(y_train) + 1
    y_train = to_categorical(y_train)

    model_name = "cgan_mnist"
    # network parameters
    # the latent or z vector is 100-dim
    latent_size = 100
    batch_size = 64
    train_steps = 40000
    lr = 2e-4
    decay = 6e-8
    input_shape = (image_size, image_size, 1)
    label_shape = (num_labels, )

    # build discriminator model
    inputs = Input(shape=input_shape, name='discriminator_input')
    labels = Input(shape=label_shape, name='class_labels')

    discriminator = build_discriminator(inputs, labels, image_size)
    # [1] or original paper uses Adam, 
    # but discriminator converges easily with RMSprop
    optimizer = RMSprop(lr=lr, decay=decay)
    discriminator.compile(loss='binary_crossentropy',
                          optimizer=optimizer,
                          metrics=['accuracy'])
    discriminator.summary()

    # build generator model
    input_shape = (latent_size, )
    inputs = Input(shape=input_shape, name='z_input')
    generator = build_generator(inputs, labels, image_size)
    generator.summary()

    # build adversarial model = generator + discriminator
    optimizer = RMSprop(lr=lr*0.5, decay=decay*0.5)
    # freeze the weights of discriminator during adversarial training
    discriminator.trainable = False
    outputs = discriminator([generator([inputs, labels]), labels])
    adversarial = Model([inputs, labels],
                        outputs,
                        name=model_name)
    adversarial.compile(loss='binary_crossentropy',
                        optimizer=optimizer,
                        metrics=['accuracy'])
    adversarial.summary()

    # train discriminator and adversarial networks
    models = (generator, discriminator, adversarial)
    data = (x_train, y_train)
    params = (batch_size, latent_size, train_steps, num_labels, model_name)
    train(models, data, params) 
示例22
def build_and_train_models():
    """Load the dataset, build LSGAN discriminator,
    generator, and adversarial models.
    Call the LSGAN train routine.
    """
    # load MNIST dataset
    (x_train, _), (_, _) = mnist.load_data()

    # reshape data for CNN as (28, 28, 1) and normalize
    image_size = x_train.shape[1]
    x_train = np.reshape(x_train, 
                         [-1, image_size, image_size, 1])
    x_train = x_train.astype('float32') / 255

    model_name = "lsgan_mnist"
    # network parameters
    # the latent or z vector is 100-dim
    latent_size = 100
    input_shape = (image_size, image_size, 1)
    batch_size = 64
    lr = 2e-4
    decay = 6e-8
    train_steps = 40000

    # build discriminator model
    inputs = Input(shape=input_shape, name='discriminator_input')
    discriminator = gan.discriminator(inputs, activation=None)
    # [1] uses Adam, but discriminator easily 
    # converges with RMSprop
    optimizer = RMSprop(lr=lr, decay=decay)
    # LSGAN uses MSE loss [2]
    discriminator.compile(loss='mse',
                          optimizer=optimizer,
                          metrics=['accuracy'])
    discriminator.summary()

    # build generator model
    input_shape = (latent_size, )
    inputs = Input(shape=input_shape, name='z_input')
    generator = gan.generator(inputs, image_size)
    generator.summary()

    # build adversarial model = generator + discriminator
    optimizer = RMSprop(lr=lr*0.5, decay=decay*0.5)
    # freeze the weights of discriminator 
    # during adversarial training
    discriminator.trainable = False
    adversarial = Model(inputs,
                        discriminator(generator(inputs)),
                        name=model_name)
    # LSGAN uses MSE loss [2]
    adversarial.compile(loss='mse',
                        optimizer=optimizer,
                        metrics=['accuracy'])
    adversarial.summary()

    # train discriminator and adversarial networks
    models = (generator, discriminator, adversarial)
    params = (batch_size, latent_size, train_steps, model_name)
    gan.train(models, x_train, params) 
示例23
def load_data():
    # load mnist data
    (source_data, _), (test_source_data, _) = mnist.load_data()

    # pad with zeros 28x28 MNIST image to become 32x32
    # svhn is 32x32
    source_data = np.pad(source_data,
                         ((0,0), (2,2), (2,2)),
                         'constant',
                         constant_values=0)
    test_source_data = np.pad(test_source_data,
                              ((0,0), (2,2), (2,2)),
                              'constant',
                              constant_values=0)
    # input image dimensions
    # we assume data format "channels_last"
    rows = source_data.shape[1]
    cols = source_data.shape[2]
    channels = 1

    # reshape images to row x col x channels
    # for CNN output/validation
    size = source_data.shape[0]
    source_data = source_data.reshape(size,
                                      rows,
                                      cols,
                                      channels)
    size = test_source_data.shape[0]
    test_source_data = test_source_data.reshape(size,
                                                rows,
                                                cols,
                                                channels)

    # load SVHN data
    datadir = get_datadir()
    get_file('train_32x32.mat',
             origin='http://ufldl.stanford.edu/housenumbers/train_32x32.mat')
    get_file('test_32x32.mat',
             'http://ufldl.stanford.edu/housenumbers/test_32x32.mat')
    path = os.path.join(datadir, 'train_32x32.mat')
    target_data = loadmat(path)
    path = os.path.join(datadir, 'test_32x32.mat')
    test_target_data = loadmat(path)

    # source data, target data, test_source data
    data = (source_data, target_data, test_source_data, test_target_data)
    filenames = ('mnist_test_source.png', 'svhn_test_target.png')
    titles = ('MNIST test source images', 'SVHN test target images')
    
    return other_utils.load_data(data, titles, filenames) 
示例24
def setup(self, config):
        # IMPORTANT: See the above note.
        import tensorflow as tf
        (x_train, y_train), (x_test, y_test) = load_data()
        x_train, x_test = x_train / 255.0, x_test / 255.0

        # Add a channels dimension
        x_train = x_train[..., tf.newaxis]
        x_test = x_test[..., tf.newaxis]
        self.train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
        self.train_ds = self.train_ds.shuffle(10000).batch(
            config.get("batch", 32))

        self.test_ds = tf.data.Dataset.from_tensor_slices((x_test,
                                                           y_test)).batch(32)

        self.model = MyModel(hiddens=config.get("hiddens", 128))
        self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
        self.optimizer = tf.keras.optimizers.Adam()
        self.train_loss = tf.keras.metrics.Mean(name="train_loss")
        self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            name="train_accuracy")

        self.test_loss = tf.keras.metrics.Mean(name="test_loss")
        self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            name="test_accuracy")

        @tf.function
        def train_step(images, labels):
            with tf.GradientTape() as tape:
                predictions = self.model(images)
                loss = self.loss_object(labels, predictions)
            gradients = tape.gradient(loss, self.model.trainable_variables)
            self.optimizer.apply_gradients(
                zip(gradients, self.model.trainable_variables))

            self.train_loss(loss)
            self.train_accuracy(labels, predictions)

        @tf.function
        def test_step(images, labels):
            predictions = self.model(images)
            t_loss = self.loss_object(labels, predictions)

            self.test_loss(t_loss)
            self.test_accuracy(labels, predictions)

        self.tf_train_step = train_step
        self.tf_test_step = test_step