Training a full model#

This tutorial illustrates how to train a machine learning model for miniML. The model is a convolutional neural network (CNN) and long-short term memory (LSTM) network. The notebook illlustrates how to train a full model using training data of cerebellar GC miniature excitatory postsynaptic currents as described in the miniML manuscript.

For details please refer to the manuscript in eLife and to the GitHub repository.

eLife: https://doi.org/10.7554/eLife.98485.1

GitHub: delvendahl/miniML

Preparation#

Labeled training data are available in the Zenodo repository. Details on the dataset are described in the miniML paper. We begin by downloading the training dataset from Zenodo:

import requests

file_url = 'https://zenodo.org/records/14507343/files/1_GC_mepsc_train.h5'

response = requests.get(file_url)

if response.status_code == 200:
    with open('../_data/GC_mEPSC_training_data.h5', 'wb') as file:
        file.write(response.content)
    print('Downloaded GC_mEPSC_training_data.h5')
Downloaded GC_mEPSC_training_data.h5

Now we set up our Python environment:

import tensorflow as tf
from tensorflow.keras.layers import (Input, BatchNormalization, AveragePooling1D, Conv1D,
                                     Bidirectional, LSTM, Dense, Dropout, LeakyReLU)
from tensorflow.keras.optimizers.legacy import Adam
from sklearn.preprocessing import minmax_scale
from sklearn.model_selection import train_test_split
from sklearn.metrics import auc, confusion_matrix, ConfusionMatrixDisplay, roc_curve, matthews_corrcoef
import time
import h5py
import numpy as np
import matplotlib.pyplot as plt

print('Using TF version', tf.__version__)
print('GPUs available:', len(tf.config.list_physical_devices('GPU')))
Using TF version 2.15.0
GPUs available: 1

Set hyperparameters for training#

The following parameters are used for training the model:

settings = {
    'training_size': 0.8,
    'testing_size': None,
    'learn_rate': 2e-5,
    'epsilon': 1e-8,
    'patience': 8,
    'epochs': 100,
    'batch_size': 128,
    'dropout': 0.2,
    'training_data': '../_data/GC_mEPSC_training_data.h5'
}

Load training data#

Now we can load the training data. The data is stored in an HDF5 file containing about 30,000 data segments with 600 time points each. Each data segment has a label of 0 (negative) or 1 (positive).

We min-max scale the data and split the dataset into training and test sets. The training set is used to train the model, while the test set is used to evaluate the model. Both sets are tf.data.Dataset objects.

with h5py.File(f'{settings["training_data"]}', 'r') as f:
    x = f['events'][:]
    y = f['scores'][:]

print(f'loaded events with shape {x.shape}')
print(f'loaded scores with shape {y.shape}') 
print(f'ratio of pos/neg scores: {y.sum()/(y.shape[0]-y.sum()):.2f}')

scaled_data = minmax_scale(x, feature_range=(0,1), axis=1)
scaled_data = np.expand_dims(scaled_data, axis=2)
y = np.expand_dims(y, axis=1)
x_train, x_test, y_train, y_test = train_test_split(scaled_data, y, train_size=settings['training_size'], random_state=1234)

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
loaded events with shape (30140, 600)
loaded scores with shape (30140,)
ratio of pos/neg scores: 1.03
2025-03-21 17:33:27.681055: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Pro
2025-03-21 17:33:27.681094: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2025-03-21 17:33:27.681099: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2025-03-21 17:33:27.681151: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-03-21 17:33:27.681176: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)

Build the model#

Here, we define our model architecture. The model is a convolutional neural network (CNN) and long-short term memory (LSTM) network.

def build_model(x_train, dropout_rate):
    model = tf.keras.models.Sequential()
    model.add(Input(shape=(x_train.shape[1:])))
              
    model.add(Conv1D(filters=32, kernel_size=9, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(AveragePooling1D(pool_size=3, strides=3))
    
    model.add(Conv1D(filters=48, kernel_size=7, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(AveragePooling1D(pool_size=2, strides=2))
    
    model.add(Conv1D(filters=64, kernel_size=5, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    model.add(AveragePooling1D(pool_size=2, strides=2))
    
    model.add(Conv1D(filters=80, kernel_size=3, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    
    model.add(Bidirectional(LSTM(96, dropout=dropout_rate), merge_mode='sum'))
    model.add(Dense(128, activation=LeakyReLU()))
    model.add(Dropout(dropout_rate))
    model.add(Dense(1, activation='sigmoid'))
    
    return model


model = build_model(x_train, settings['dropout'])
model.compile(optimizer=Adam(learning_rate=settings['learn_rate'], epsilon=settings['epsilon'], amsgrad=True),
              loss=tf.keras.losses.BinaryCrossentropy(), 
              metrics=['Accuracy'])

model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 conv1d (Conv1D)             (None, 600, 32)           320       
                                                                 
 batch_normalization (Batch  (None, 600, 32)           128       
 Normalization)                                                  
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 600, 32)           0         
                                                                 
 average_pooling1d (Average  (None, 200, 32)           0         
 Pooling1D)                                                      
                                                                 
 conv1d_1 (Conv1D)           (None, 200, 48)           10800     
                                                                 
 batch_normalization_1 (Bat  (None, 200, 48)           192       
 chNormalization)                                                
                                                                 
 leaky_re_lu_1 (LeakyReLU)   (None, 200, 48)           0         
                                                                 
 average_pooling1d_1 (Avera  (None, 100, 48)           0         
 gePooling1D)                                                    
                                                                 
 conv1d_2 (Conv1D)           (None, 100, 64)           15424     
                                                                 
 batch_normalization_2 (Bat  (None, 100, 64)           256       
 chNormalization)                                                
                                                                 
 leaky_re_lu_2 (LeakyReLU)   (None, 100, 64)           0         
                                                                 
 average_pooling1d_2 (Avera  (None, 50, 64)            0         
 gePooling1D)                                                    
                                                                 
 conv1d_3 (Conv1D)           (None, 50, 80)            15440     
                                                                 
 batch_normalization_3 (Bat  (None, 50, 80)            320       
 chNormalization)                                                
                                                                 
 leaky_re_lu_3 (LeakyReLU)   (None, 50, 80)            0         
                                                                 
 bidirectional (Bidirection  (None, 96)                135936    
 al)                                                             
                                                                 
 dense (Dense)               (None, 128)               12416     
                                                                 
 dropout (Dropout)           (None, 128)               0         
                                                                 
 dense_1 (Dense)             (None, 1)                 129       
                                                                 
=================================================================
Total params: 191361 (747.50 KB)
Trainable params: 190913 (745.75 KB)
Non-trainable params: 448 (1.75 KB)
_________________________________________________________________

Train the model#

Now we can train our model using the training data. Note that running this cell can take a few minutes on a GPU. When using the CPU, the training time can be much longer.

We also provide an executable notebook for model training on Kaggle, which is available here.

early_stopping_callback = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', 
    patience=settings['patience'],
    restore_best_weights=True)

start = time.time()
SHUFFLE_BUFFER_SIZE = 100

train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(settings['batch_size'], num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(settings['batch_size'], num_parallel_calls=tf.data.AUTOTUNE)

train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
test_dataset = test_dataset.prefetch(tf.data.AUTOTUNE)

history = model.fit(train_dataset,
                    verbose=2,
                    epochs=settings['epochs'],
                    validation_data=test_dataset,
                    shuffle=True,
                    callbacks=[early_stopping_callback])

print('')
print('----')
print(f'train shape: {x_train.shape}')
print(f'score on val: {model.evaluate(test_dataset, verbose=0)[1]}')
print(f'score on train: {model.evaluate(train_dataset, verbose=0)[1]}')
print('----')
print(f'training time (s): {time.time()-start:.4f}')
Epoch 1/100
2025-03-21 17:33:29.182640: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.
189/189 - 12s - loss: 0.4564 - Accuracy: 0.8126 - val_loss: 0.5859 - val_Accuracy: 0.8200 - 12s/epoch - 62ms/step
Epoch 2/100
189/189 - 9s - loss: 0.3674 - Accuracy: 0.8407 - val_loss: 0.4523 - val_Accuracy: 0.8069 - 9s/epoch - 46ms/step
Epoch 3/100
189/189 - 9s - loss: 0.2600 - Accuracy: 0.8943 - val_loss: 0.3641 - val_Accuracy: 0.8135 - 9s/epoch - 47ms/step
Epoch 4/100
189/189 - 9s - loss: 0.1402 - Accuracy: 0.9509 - val_loss: 0.1687 - val_Accuracy: 0.9335 - 9s/epoch - 46ms/step
Epoch 5/100
189/189 - 9s - loss: 0.1153 - Accuracy: 0.9573 - val_loss: 0.1067 - val_Accuracy: 0.9600 - 9s/epoch - 47ms/step
Epoch 6/100
189/189 - 9s - loss: 0.1059 - Accuracy: 0.9591 - val_loss: 0.0901 - val_Accuracy: 0.9668 - 9s/epoch - 46ms/step
Epoch 7/100
189/189 - 9s - loss: 0.0987 - Accuracy: 0.9616 - val_loss: 0.0857 - val_Accuracy: 0.9655 - 9s/epoch - 45ms/step
Epoch 8/100
189/189 - 8s - loss: 0.0947 - Accuracy: 0.9628 - val_loss: 0.0809 - val_Accuracy: 0.9691 - 8s/epoch - 43ms/step
Epoch 9/100
189/189 - 8s - loss: 0.0916 - Accuracy: 0.9640 - val_loss: 0.0770 - val_Accuracy: 0.9700 - 8s/epoch - 43ms/step
Epoch 10/100
189/189 - 8s - loss: 0.0861 - Accuracy: 0.9662 - val_loss: 0.0772 - val_Accuracy: 0.9690 - 8s/epoch - 44ms/step
Epoch 11/100
189/189 - 9s - loss: 0.0860 - Accuracy: 0.9661 - val_loss: 0.0752 - val_Accuracy: 0.9710 - 9s/epoch - 45ms/step
Epoch 12/100
189/189 - 9s - loss: 0.0831 - Accuracy: 0.9675 - val_loss: 0.0721 - val_Accuracy: 0.9715 - 9s/epoch - 46ms/step
Epoch 13/100
189/189 - 8s - loss: 0.0828 - Accuracy: 0.9660 - val_loss: 0.0730 - val_Accuracy: 0.9726 - 8s/epoch - 42ms/step
Epoch 14/100
189/189 - 8s - loss: 0.0808 - Accuracy: 0.9676 - val_loss: 0.0693 - val_Accuracy: 0.9728 - 8s/epoch - 42ms/step
Epoch 15/100
189/189 - 8s - loss: 0.0768 - Accuracy: 0.9692 - val_loss: 0.0695 - val_Accuracy: 0.9723 - 8s/epoch - 42ms/step
Epoch 16/100
189/189 - 9s - loss: 0.0771 - Accuracy: 0.9690 - val_loss: 0.0677 - val_Accuracy: 0.9721 - 9s/epoch - 47ms/step
Epoch 17/100
189/189 - 9s - loss: 0.0750 - Accuracy: 0.9695 - val_loss: 0.0671 - val_Accuracy: 0.9736 - 9s/epoch - 45ms/step
Epoch 18/100
189/189 - 9s - loss: 0.0745 - Accuracy: 0.9684 - val_loss: 0.0668 - val_Accuracy: 0.9730 - 9s/epoch - 45ms/step
Epoch 19/100
189/189 - 9s - loss: 0.0726 - Accuracy: 0.9697 - val_loss: 0.0656 - val_Accuracy: 0.9733 - 9s/epoch - 46ms/step
Epoch 20/100
189/189 - 8s - loss: 0.0725 - Accuracy: 0.9710 - val_loss: 0.0649 - val_Accuracy: 0.9736 - 8s/epoch - 43ms/step
Epoch 21/100
189/189 - 8s - loss: 0.0703 - Accuracy: 0.9710 - val_loss: 0.0651 - val_Accuracy: 0.9746 - 8s/epoch - 42ms/step
Epoch 22/100
189/189 - 9s - loss: 0.0709 - Accuracy: 0.9704 - val_loss: 0.0643 - val_Accuracy: 0.9741 - 9s/epoch - 45ms/step
Epoch 23/100
189/189 - 8s - loss: 0.0691 - Accuracy: 0.9719 - val_loss: 0.0661 - val_Accuracy: 0.9726 - 8s/epoch - 44ms/step
Epoch 24/100
189/189 - 8s - loss: 0.0692 - Accuracy: 0.9723 - val_loss: 0.0644 - val_Accuracy: 0.9731 - 8s/epoch - 45ms/step
Epoch 25/100
189/189 - 9s - loss: 0.0676 - Accuracy: 0.9720 - val_loss: 0.0625 - val_Accuracy: 0.9745 - 9s/epoch - 46ms/step
Epoch 26/100
189/189 - 9s - loss: 0.0653 - Accuracy: 0.9732 - val_loss: 0.0634 - val_Accuracy: 0.9731 - 9s/epoch - 47ms/step
Epoch 27/100
189/189 - 9s - loss: 0.0660 - Accuracy: 0.9724 - val_loss: 0.0638 - val_Accuracy: 0.9740 - 9s/epoch - 46ms/step
Epoch 28/100
189/189 - 9s - loss: 0.0668 - Accuracy: 0.9723 - val_loss: 0.0629 - val_Accuracy: 0.9740 - 9s/epoch - 47ms/step
Epoch 29/100
189/189 - 8s - loss: 0.0657 - Accuracy: 0.9728 - val_loss: 0.0653 - val_Accuracy: 0.9731 - 8s/epoch - 42ms/step
Epoch 30/100
189/189 - 9s - loss: 0.0643 - Accuracy: 0.9731 - val_loss: 0.0623 - val_Accuracy: 0.9731 - 9s/epoch - 45ms/step
Epoch 31/100
189/189 - 9s - loss: 0.0626 - Accuracy: 0.9735 - val_loss: 0.0648 - val_Accuracy: 0.9730 - 9s/epoch - 45ms/step
Epoch 32/100
189/189 - 9s - loss: 0.0633 - Accuracy: 0.9739 - val_loss: 0.0615 - val_Accuracy: 0.9736 - 9s/epoch - 46ms/step
Epoch 33/100
189/189 - 8s - loss: 0.0624 - Accuracy: 0.9739 - val_loss: 0.0632 - val_Accuracy: 0.9741 - 8s/epoch - 43ms/step
Epoch 34/100
189/189 - 8s - loss: 0.0618 - Accuracy: 0.9738 - val_loss: 0.0614 - val_Accuracy: 0.9753 - 8s/epoch - 42ms/step
Epoch 35/100
189/189 - 8s - loss: 0.0597 - Accuracy: 0.9752 - val_loss: 0.0619 - val_Accuracy: 0.9741 - 8s/epoch - 44ms/step
Epoch 36/100
189/189 - 9s - loss: 0.0611 - Accuracy: 0.9742 - val_loss: 0.0617 - val_Accuracy: 0.9740 - 9s/epoch - 47ms/step
Epoch 37/100
189/189 - 9s - loss: 0.0613 - Accuracy: 0.9747 - val_loss: 0.0598 - val_Accuracy: 0.9751 - 9s/epoch - 46ms/step
Epoch 38/100
189/189 - 9s - loss: 0.0613 - Accuracy: 0.9741 - val_loss: 0.0622 - val_Accuracy: 0.9753 - 9s/epoch - 46ms/step
Epoch 39/100
189/189 - 8s - loss: 0.0590 - Accuracy: 0.9754 - val_loss: 0.0603 - val_Accuracy: 0.9751 - 8s/epoch - 42ms/step
Epoch 40/100
189/189 - 9s - loss: 0.0602 - Accuracy: 0.9752 - val_loss: 0.0594 - val_Accuracy: 0.9754 - 9s/epoch - 46ms/step
Epoch 41/100
189/189 - 8s - loss: 0.0600 - Accuracy: 0.9747 - val_loss: 0.0623 - val_Accuracy: 0.9741 - 8s/epoch - 44ms/step
Epoch 42/100
189/189 - 8s - loss: 0.0607 - Accuracy: 0.9748 - val_loss: 0.0607 - val_Accuracy: 0.9758 - 8s/epoch - 45ms/step
Epoch 43/100
189/189 - 8s - loss: 0.0580 - Accuracy: 0.9756 - val_loss: 0.0597 - val_Accuracy: 0.9761 - 8s/epoch - 45ms/step
Epoch 44/100
189/189 - 8s - loss: 0.0582 - Accuracy: 0.9759 - val_loss: 0.0598 - val_Accuracy: 0.9758 - 8s/epoch - 43ms/step
Epoch 45/100
189/189 - 9s - loss: 0.0576 - Accuracy: 0.9760 - val_loss: 0.0601 - val_Accuracy: 0.9741 - 9s/epoch - 46ms/step
Epoch 46/100
189/189 - 8s - loss: 0.0565 - Accuracy: 0.9767 - val_loss: 0.0606 - val_Accuracy: 0.9743 - 8s/epoch - 45ms/step
Epoch 47/100
189/189 - 8s - loss: 0.0569 - Accuracy: 0.9765 - val_loss: 0.0598 - val_Accuracy: 0.9754 - 8s/epoch - 45ms/step
Epoch 48/100
189/189 - 8s - loss: 0.0553 - Accuracy: 0.9771 - val_loss: 0.0587 - val_Accuracy: 0.9748 - 8s/epoch - 43ms/step
Epoch 49/100
189/189 - 9s - loss: 0.0571 - Accuracy: 0.9763 - val_loss: 0.0626 - val_Accuracy: 0.9736 - 9s/epoch - 47ms/step
Epoch 50/100
189/189 - 9s - loss: 0.0571 - Accuracy: 0.9760 - val_loss: 0.0600 - val_Accuracy: 0.9748 - 9s/epoch - 46ms/step
Epoch 51/100
189/189 - 9s - loss: 0.0571 - Accuracy: 0.9764 - val_loss: 0.0584 - val_Accuracy: 0.9758 - 9s/epoch - 46ms/step
Epoch 52/100
189/189 - 9s - loss: 0.0555 - Accuracy: 0.9765 - val_loss: 0.0582 - val_Accuracy: 0.9764 - 9s/epoch - 46ms/step
Epoch 53/100
189/189 - 8s - loss: 0.0549 - Accuracy: 0.9776 - val_loss: 0.0588 - val_Accuracy: 0.9764 - 8s/epoch - 41ms/step
Epoch 54/100
189/189 - 8s - loss: 0.0557 - Accuracy: 0.9767 - val_loss: 0.0576 - val_Accuracy: 0.9763 - 8s/epoch - 44ms/step
Epoch 55/100
189/189 - 8s - loss: 0.0533 - Accuracy: 0.9779 - val_loss: 0.0579 - val_Accuracy: 0.9763 - 8s/epoch - 43ms/step
Epoch 56/100
189/189 - 8s - loss: 0.0540 - Accuracy: 0.9767 - val_loss: 0.0573 - val_Accuracy: 0.9769 - 8s/epoch - 45ms/step
Epoch 57/100
189/189 - 9s - loss: 0.0542 - Accuracy: 0.9779 - val_loss: 0.0607 - val_Accuracy: 0.9738 - 9s/epoch - 47ms/step
Epoch 58/100
189/189 - 8s - loss: 0.0547 - Accuracy: 0.9771 - val_loss: 0.0580 - val_Accuracy: 0.9761 - 8s/epoch - 44ms/step
Epoch 59/100
189/189 - 9s - loss: 0.0543 - Accuracy: 0.9764 - val_loss: 0.0574 - val_Accuracy: 0.9769 - 9s/epoch - 45ms/step
Epoch 60/100
189/189 - 8s - loss: 0.0544 - Accuracy: 0.9771 - val_loss: 0.0569 - val_Accuracy: 0.9756 - 8s/epoch - 44ms/step
Epoch 61/100
189/189 - 9s - loss: 0.0520 - Accuracy: 0.9776 - val_loss: 0.0595 - val_Accuracy: 0.9758 - 9s/epoch - 46ms/step
Epoch 62/100
189/189 - 8s - loss: 0.0515 - Accuracy: 0.9788 - val_loss: 0.0600 - val_Accuracy: 0.9761 - 8s/epoch - 44ms/step
Epoch 63/100
189/189 - 9s - loss: 0.0530 - Accuracy: 0.9780 - val_loss: 0.0563 - val_Accuracy: 0.9768 - 9s/epoch - 47ms/step
Epoch 64/100
189/189 - 9s - loss: 0.0527 - Accuracy: 0.9776 - val_loss: 0.0582 - val_Accuracy: 0.9746 - 9s/epoch - 47ms/step
Epoch 65/100
189/189 - 9s - loss: 0.0504 - Accuracy: 0.9789 - val_loss: 0.0561 - val_Accuracy: 0.9763 - 9s/epoch - 46ms/step
Epoch 66/100
189/189 - 9s - loss: 0.0514 - Accuracy: 0.9780 - val_loss: 0.0560 - val_Accuracy: 0.9758 - 9s/epoch - 47ms/step
Epoch 67/100
189/189 - 9s - loss: 0.0515 - Accuracy: 0.9784 - val_loss: 0.0615 - val_Accuracy: 0.9745 - 9s/epoch - 46ms/step
Epoch 68/100
189/189 - 9s - loss: 0.0520 - Accuracy: 0.9784 - val_loss: 0.0571 - val_Accuracy: 0.9764 - 9s/epoch - 45ms/step
Epoch 69/100
189/189 - 9s - loss: 0.0516 - Accuracy: 0.9788 - val_loss: 0.0567 - val_Accuracy: 0.9761 - 9s/epoch - 46ms/step
Epoch 70/100
189/189 - 9s - loss: 0.0491 - Accuracy: 0.9795 - val_loss: 0.0584 - val_Accuracy: 0.9768 - 9s/epoch - 47ms/step
Epoch 71/100
189/189 - 9s - loss: 0.0501 - Accuracy: 0.9786 - val_loss: 0.0595 - val_Accuracy: 0.9753 - 9s/epoch - 45ms/step
Epoch 72/100
189/189 - 8s - loss: 0.0511 - Accuracy: 0.9784 - val_loss: 0.0560 - val_Accuracy: 0.9758 - 8s/epoch - 43ms/step
Epoch 73/100
189/189 - 9s - loss: 0.0518 - Accuracy: 0.9779 - val_loss: 0.0634 - val_Accuracy: 0.9741 - 9s/epoch - 46ms/step
Epoch 74/100
189/189 - 8s - loss: 0.0502 - Accuracy: 0.9779 - val_loss: 0.0546 - val_Accuracy: 0.9776 - 8s/epoch - 43ms/step
Epoch 75/100
189/189 - 8s - loss: 0.0489 - Accuracy: 0.9794 - val_loss: 0.0589 - val_Accuracy: 0.9748 - 8s/epoch - 43ms/step
Epoch 76/100
189/189 - 9s - loss: 0.0492 - Accuracy: 0.9792 - val_loss: 0.0555 - val_Accuracy: 0.9764 - 9s/epoch - 46ms/step
Epoch 77/100
189/189 - 8s - loss: 0.0501 - Accuracy: 0.9801 - val_loss: 0.0589 - val_Accuracy: 0.9750 - 8s/epoch - 45ms/step
Epoch 78/100
189/189 - 8s - loss: 0.0485 - Accuracy: 0.9791 - val_loss: 0.0582 - val_Accuracy: 0.9746 - 8s/epoch - 43ms/step
Epoch 79/100
189/189 - 8s - loss: 0.0496 - Accuracy: 0.9791 - val_loss: 0.0541 - val_Accuracy: 0.9766 - 8s/epoch - 44ms/step
Epoch 80/100
189/189 - 8s - loss: 0.0495 - Accuracy: 0.9793 - val_loss: 0.0541 - val_Accuracy: 0.9776 - 8s/epoch - 43ms/step
Epoch 81/100
189/189 - 9s - loss: 0.0466 - Accuracy: 0.9811 - val_loss: 0.0616 - val_Accuracy: 0.9733 - 9s/epoch - 46ms/step
Epoch 82/100
189/189 - 8s - loss: 0.0483 - Accuracy: 0.9809 - val_loss: 0.0571 - val_Accuracy: 0.9753 - 8s/epoch - 45ms/step
Epoch 83/100
189/189 - 8s - loss: 0.0471 - Accuracy: 0.9803 - val_loss: 0.0600 - val_Accuracy: 0.9740 - 8s/epoch - 43ms/step
Epoch 84/100
189/189 - 8s - loss: 0.0475 - Accuracy: 0.9802 - val_loss: 0.0561 - val_Accuracy: 0.9756 - 8s/epoch - 43ms/step
Epoch 85/100
189/189 - 8s - loss: 0.0471 - Accuracy: 0.9803 - val_loss: 0.0560 - val_Accuracy: 0.9761 - 8s/epoch - 44ms/step
Epoch 86/100
189/189 - 9s - loss: 0.0471 - Accuracy: 0.9806 - val_loss: 0.0576 - val_Accuracy: 0.9756 - 9s/epoch - 47ms/step
Epoch 87/100
189/189 - 8s - loss: 0.0470 - Accuracy: 0.9803 - val_loss: 0.0603 - val_Accuracy: 0.9753 - 8s/epoch - 44ms/step
Epoch 88/100
189/189 - 8s - loss: 0.0463 - Accuracy: 0.9807 - val_loss: 0.0572 - val_Accuracy: 0.9763 - 8s/epoch - 44ms/step
----
train shape: (24112, 600, 1)
score on val: 0.9776045083999634
score on train: 0.9839913845062256
----
training time (s): 751.2502

Plot training and evaluate the model#

We can plot the training and validation accuracy and loss to visualize the training progress.

acc = history.history['Accuracy']
val_acc = history.history['val_Accuracy']
epochs = range(1, len(acc) + 1)

fig, axs = plt.subplots(1, 2, figsize=(8, 4))
axs[0].plot(epochs, acc, '#12B5CB', label='Training')
axs[0].plot(epochs, val_acc, '#425066', label='Validation')
axs[0].set_title('Training and validation accuracy\n')
axs[0].set_xlabel('Epochs')
axs[0].set_ylabel('Accuracy')
axs[0].legend()

loss = history.history['loss']
val_loss = history.history['val_loss']

axs[1].plot(epochs, loss, '#12B5CB', label='Training')
axs[1].plot(epochs, val_loss, '#425066', label='Validation')
axs[1].set_title('Training and validation loss\n')
axs[1].set_xlabel('Epochs')
axs[1].set_ylabel('Loss')
axs[1].legend()
plt.tight_layout()
plt.show()

best_epoch = val_acc.index(max(val_acc)) + 1
print(f'Best epoch: {best_epoch} (accuracy={max(val_acc):.4f})')
../_images/3ece5bcb963e7f0d4b831593489fe2b8f324ead804a450e03f97a8efe50952f0.png
Best epoch: 74 (accuracy=0.9776)

We can plot the ROC curve and the confusion matrix to further evaluate our trained model.

y_preds = model.predict(x_test, verbose=0).ravel()
fpr, tpr, thresholds = roc_curve(y_test, y_preds)

fig, axs = plt.subplots(1, 2, figsize=(8, 4))
axs[0].plot([0, 1], [0, 1], 'y--')
axs[0].plot(fpr, tpr, marker='.')
axs[0].set_xlabel('False positive rate')
axs[0].set_ylabel('True positive rate')
axs[0].set_title('ROC curve\n')

optimal_threshold = thresholds[np.argmax(tpr - fpr)]
y_pred2 = (model.predict(x_test, verbose=0) >= optimal_threshold).astype(int)
cm = confusion_matrix(y_test, y_pred2)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Blues', ax=axs[1])
axs[1].set_title('Confusion matrix\n')
plt.tight_layout()
plt.show()

print(f'Area under curve, AUC = {auc(fpr, tpr)}')
print(f'Optimal threshold value = {optimal_threshold}')
print(f'Matthews correlation coefficient = {matthews_corrcoef(y_test, y_preds>0.5)}')
../_images/ee735e8aea3cf1a53f6d18991c31d1f75d07ba89d727295346662f688471f051.png
Area under curve, AUC = 0.9982805833631123
Optimal threshold value = 0.45347926020622253
Matthews correlation coefficient = 0.9551876774883516

Save trained model and training settings#

Finally, we save the trained model to an .h5 file and the training settings to a text file.

model.save('../results/lstm_full_training.h5')

with open('../results/full_training_settings.txt', 'w') as text_file:
    text_file.write('\n'.join(f'{i}: {settings[i]}' for i in settings))
/opt/miniconda3/envs/miniml/lib/python3.9/site-packages/keras/src/engine/training.py:3103: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')`.
  saving_api.save_model(