This function facilitates training of convolutional neural network (CNN) models using various transfer learning architectures such as AlexNet, VGG16, VGG19, ResNet18, ResNet50, or ResNet152, on a given dataset. The trained model is saved along with metadata for further usage.
train_CNN_multi(
input.data.path,
test.data,
architecture,
unfreeze.param = TRUE,
batch_size = 32,
learning_rate,
save.model = FALSE,
class_weights = c(0.49, 0.49, 0.02),
epoch.iterations = 1,
early.stop = "yes",
output.base.path = tempdir(),
brightness = 0,
contrast = 0,
saturation = 0,
trainingfolder,
noise.category = "Noise"
)
Character. Path to the input data folder.
Character. Path to the test data folder.
Character. Specifies the CNN architecture to use ('alexnet', 'vgg16', 'vgg19', 'resnet18', 'resnet50', or 'resnet152').
Logical. Indicates whether all layers of the pretrained CNN should be unfrozen for retraining. Default is TRUE.
Numeric. Batch size for training the model. Default is 32.
Numeric. Learning rate for training the model.
Logical. Specifies whether to save the trained model for future use. Default is FALSE.
Numeric vector. Weights assigned to different classes for handling class imbalance. Default is c(0.49, 0.49, 0.02).
List of integers. Number of epochs for training the model. Default is 1.
Character. Indicates whether early stopping should be applied or not. Use "yes" to apply and "no" to skip. Default is 'yes'.
Character. Base path where the output files should be saved. Default is 'data/'.
Numeric. Brightness adjustment factor for color jitter. A value of 0 means no change. Higher values increase brightness. Default is 0.
Numeric. Contrast adjustment factor for color jitter. A value of 0 means no change. Higher values increase contrast. Default is 0.
Numeric. Saturation adjustment factor for color jitter. A value of 0 means no change. Higher values increase color saturation. Default is 0.
Character. A descriptive name for the training data, used for naming output files.
Character. Label for the noise category. Default is "Noise".
The function generates multiple output files, including:
Trained Models: If save.model = TRUE outputs saved model files (.pt) for specified architectures.
Training Logs: logs_model.csv containing logs of training sessions, including loss and accuracy metrics.
Metadata: model_metadata.csv contains metadata from training run
Model predictions: Saved for each architecture in output_TrainedModel_testdata.csv
Performance Evaluation: Saves .csv summarizing performance for each architecture in nested folder in output.base.path 'performance_tables_multi'.
Requires train, valid, and test folders created using created using 'spectrogram_images'
nn_module
and other torch functions.
{{ input.data.path <- system.file("extdata", "multiclass/", package = "gibbonNetR")
test.data <- system.file("extdata", "multiclass/test/", package = "gibbonNetR")
result <- train_CNN_multi(
input.data.path = input.data.path,
test.data = test.data,
architecture = "alexnet", # Choose architecture
unfreeze.param = TRUE,
class_weights = rep((1 / 5), 5),
batch_size = 6,
learning_rate = 0.001,
epoch.iterations = 1, # Or any other list of integer epochs
early.stop = "yes",
output.base.path = paste(tempdir(), "/", sep = ""),
trainingfolder = "test",
noise.category = "noise"
)
print(result) }}
#> Training alexnet
#> Detected classes: female.gibbon, hornbill.helmeted, hornbill.rhino, long.argus, noise
#> Warning: Some torch operators might not yet be implemented for the MPS device. A
#> temporary fix is to set the `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a
#> fall back for those operators:
#> ℹ Add `PYTORCH_ENABLE_MPS_FALLBACK=1` to your `.Renviron` file, for example use
#> `usethis::edit_r_environ()`.
#> ✖ Using `Sys.setenv()` doesn't work because the env var must be set before R
#> starts.
#> Epoch 1/1
#> Train metrics: Loss: 1.5896 - Acc: 0.1458
#> Valid metrics: Loss: 1.5147 - Acc: 0.5952
#> `geom_line()`: Each group consists of only one observation.
#> ℹ Do you need to adjust the group aesthetic?
#> Warning: Some torch operators might not yet be implemented for the MPS device. A
#> temporary fix is to set the `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a
#> fall back for those operators:
#> ℹ Add `PYTORCH_ENABLE_MPS_FALLBACK=1` to your `.Renviron` file, for example use
#> `usethis::edit_r_environ()`.
#> ✖ Using `Sys.setenv()` doesn't work because the env var must be set before R
#> starts.
#> Skipping hornbill.helmeted - cannot calculate performance
#> Skipping long.argus - cannot calculate performance
#> /var/folders/1s/x8xb37tj45j86tn_stc4v44w0000gn/T//RtmpScZf01/_test_multi_unfrozen_TRUE_/performance_tables_multi/
#> NULL