R/evaluate_trainedmodel_performance.R
evaluate_trainedmodel_performance.Rd
Given trained models and a set of images, this function evaluates the performance of the models.
evaluate_trainedmodel_performance(
trained_models_dir,
image_data_dir,
output_dir = "data/",
positive.class = "Gibbons",
negative.class = "Noise"
)
Path to the directory containing trained models (.pt files).
Path to the directory containing image data for evaluation.
Path to the directory where the performance scores will be saved.
Label for the positive class.
Label for the negative class.
The .csv files containing summary of performance are written to output_dir.
Takes the directory of models trained 'train_CNN_binary' and test folder created using 'spectrogram_images'
{
# Set directory paths for trained models and test images
trained_models_dir <- system.file("extdata", "trainedresnetbinary", package = "gibbonNetR")
image_data_dir <- system.file("extdata", "binary", "test", package = "gibbonNetR")
# Evaluate the performance of the trained models using the test images
evaluate_trainedmodel_performance(
trained_models_dir = trained_models_dir,
image_data_dir = image_data_dir,
output_dir = file.path(tempdir(), "data/"),
positive.class = "Gibbons", # Label for positive class
negative.class = "Noise" # Label for negative class
)
# Find the location of saved evaluation files
CSVName <- list.files(file.path(tempdir(), "data"), recursive = TRUE, full.names = TRUE)
# Check the output of the first file
head(read.csv(CSVName[1]))
}
#> Warning: Some torch operators might not yet be implemented for the MPS device. A
#> temporary fix is to set the `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a
#> fall back for those operators:
#> ℹ Add `PYTORCH_ENABLE_MPS_FALLBACK=1` to your `.Renviron` file, for example use
#> `usethis::edit_r_environ()`.
#> ✖ Using `Sys.setenv()` doesn't work because the env var must be set before R
#> starts.
#> Sensitivity Specificity Pos.Pred.Value Neg.Pred.Value Precision Recall
#> 1 1 0 0.5 NA 0.5 1
#> 2 1 0 0.5 NA 0.5 1
#> 3 0 1 NA 0.5 NA 0
#> 4 0 1 NA 0.5 NA 0
#> 5 0 1 NA 0.5 NA 0
#> 6 0 1 NA 0.5 NA 0
#> F1 Prevalence Detection.Rate Detection.Prevalence Balanced.Accuracy
#> 1 0.6666667 0.5 0.5 1 0.5
#> 2 0.6666667 0.5 0.5 1 0.5
#> 3 NA 0.5 0.0 0 0.5
#> 4 NA 0.5 0.0 0 0.5
#> 5 NA 0.5 0.0 0 0.5
#> 6 NA 0.5 0.0 0 0.5
#> Training.Data N.epochs CNN.Architecture AUC Threshold Class
#> 1 test multi 3_resnet18_model 0.57 0.1 female.gibbon
#> 2 test multi 3_resnet18_model 0.57 0.2 female.gibbon
#> 3 test multi 3_resnet18_model 0.57 0.3 female.gibbon
#> 4 test multi 3_resnet18_model 0.57 0.4 female.gibbon
#> 5 test multi 3_resnet18_model 0.57 0.5 female.gibbon
#> 6 test multi 3_resnet18_model 0.57 0.6 female.gibbon
#> TestDataPath
#> 1 /Users/denaclink/Desktop/RStudioProjects/gibbonNetR/inst/extdata/multiclass/test
#> 2 /Users/denaclink/Desktop/RStudioProjects/gibbonNetR/inst/extdata/multiclass/test
#> 3 /Users/denaclink/Desktop/RStudioProjects/gibbonNetR/inst/extdata/multiclass/test
#> 4 /Users/denaclink/Desktop/RStudioProjects/gibbonNetR/inst/extdata/multiclass/test
#> 5 /Users/denaclink/Desktop/RStudioProjects/gibbonNetR/inst/extdata/multiclass/test
#> 6 /Users/denaclink/Desktop/RStudioProjects/gibbonNetR/inst/extdata/multiclass/test
#> Top1Accuracy
#> 1 0.125
#> 2 0.125
#> 3 0.125
#> 4 0.125
#> 5 0.125
#> 6 0.125