R/evaluate_trainedmodel_performance_multi.R
evaluate_trainedmodel_performance_multi.Rd
Given trained models and a set of images, this function evaluates the performance of the models.
evaluate_trainedmodel_performance_multi(
trained_models_dir,
image_data_dir,
output_dir = "data/",
class_names,
noise.category = "noise",
unfreeze = TRUE
)
Path to the directory containing trained models (.pt files).
Path to the directory containing image data for evaluation.
Path to the directory where the performance scores will be saved.
Character vector specifying class names. User specified from training data folders.
Category label for noise class. Default is 'noise'.
Logical indicating whether to unfreeze model parameters. User specified based on trained model.
The .csv files containing summary of performance are written to output_dir.
Takes the directory of models trained 'train_CNN_multi' and test folder created using 'spectrogram_images'.
{
# Set directory paths for trained models and test images
trained_models_dir <- system.file("extdata", "trainedresnetmulti", package = "gibbonNetR")
image_data_dir <- system.file("extdata", "multiclass", "test", package = "gibbonNetR")
class_names <- c("female.gibbon", "hornbill.helmeted", "hornbill.rhino", "long.argus", "noise")
# Evaluate the performance of the trained models using the test images
evaluate_trainedmodel_performance_multi(
trained_models_dir = trained_models_dir,
class_names = class_names,
image_data_dir = image_data_dir,
output_dir = file.path(tempdir(), "data/"),
noise.category = "noise"
)
# Find the location of saved evaluation files
CSVName <- list.files(file.path(tempdir(), "data"), recursive = TRUE, full.names = TRUE)
# Check the output of the first file
head(read.csv(CSVName[1]))
}
#> Evaluating performance of 3_resnet18_model N epochs= multi
#> Warning: Some torch operators might not yet be implemented for the MPS device. A
#> temporary fix is to set the `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a
#> fall back for those operators:
#> ℹ Add `PYTORCH_ENABLE_MPS_FALLBACK=1` to your `.Renviron` file, for example use
#> `usethis::edit_r_environ()`.
#> ✖ Using `Sys.setenv()` doesn't work because the env var must be set before R
#> starts.
#> Skipping hornbill.helmeted - cannot calculate performance
#> Skipping long.argus - cannot calculate performance
#> Sensitivity Specificity Pos.Pred.Value Neg.Pred.Value Precision Recall
#> 1 1 0 0.5 NA 0.5 1
#> 2 1 0 0.5 NA 0.5 1
#> 3 0 1 NA 0.5 NA 0
#> 4 0 1 NA 0.5 NA 0
#> 5 0 1 NA 0.5 NA 0
#> 6 0 1 NA 0.5 NA 0
#> F1 Prevalence Detection.Rate Detection.Prevalence Balanced.Accuracy
#> 1 0.6666667 0.5 0.5 1 0.5
#> 2 0.6666667 0.5 0.5 1 0.5
#> 3 NA 0.5 0.0 0 0.5
#> 4 NA 0.5 0.0 0 0.5
#> 5 NA 0.5 0.0 0 0.5
#> 6 NA 0.5 0.0 0 0.5
#> Training.Data N.epochs CNN.Architecture AUC Threshold Class
#> 1 test multi 3_resnet18_model 0.57 0.1 female.gibbon
#> 2 test multi 3_resnet18_model 0.57 0.2 female.gibbon
#> 3 test multi 3_resnet18_model 0.57 0.3 female.gibbon
#> 4 test multi 3_resnet18_model 0.57 0.4 female.gibbon
#> 5 test multi 3_resnet18_model 0.57 0.5 female.gibbon
#> 6 test multi 3_resnet18_model 0.57 0.6 female.gibbon
#> TestDataPath
#> 1 /Users/denaclink/Desktop/RStudioProjects/gibbonNetR/inst/extdata/multiclass/test
#> 2 /Users/denaclink/Desktop/RStudioProjects/gibbonNetR/inst/extdata/multiclass/test
#> 3 /Users/denaclink/Desktop/RStudioProjects/gibbonNetR/inst/extdata/multiclass/test
#> 4 /Users/denaclink/Desktop/RStudioProjects/gibbonNetR/inst/extdata/multiclass/test
#> 5 /Users/denaclink/Desktop/RStudioProjects/gibbonNetR/inst/extdata/multiclass/test
#> 6 /Users/denaclink/Desktop/RStudioProjects/gibbonNetR/inst/extdata/multiclass/test
#> Top1Accuracy
#> 1 0.125
#> 2 0.125
#> 3 0.125
#> 4 0.125
#> 5 0.125
#> 6 0.125