Given trained models and a set of images, this function evaluates the performance of the models.

evaluate_trainedmodel_performance_multi(
  trained_models_dir,
  image_data_dir,
  output_dir = "data/",
  class_names,
  noise.category = "noise",
  unfreeze = TRUE
)

Arguments

trained_models_dir

Path to the directory containing trained models (.pt files).

image_data_dir

Path to the directory containing image data for evaluation.

output_dir

Path to the directory where the performance scores will be saved.

class_names

Character vector specifying class names. User specified from training data folders.

noise.category

Category label for noise class. Default is 'noise'.

unfreeze

Logical indicating whether to unfreeze model parameters. User specified based on trained model.

Value

The .csv files containing summary of performance are written to output_dir.

Note

Takes the directory of models trained 'train_CNN_multi' and test folder created using 'spectrogram_images'.

Examples

{
  # Set directory paths for trained models and test images
  trained_models_dir <- system.file("extdata", "trainedresnetmulti", package = "gibbonNetR")
  image_data_dir <- system.file("extdata", "multiclass", "test", package = "gibbonNetR")
  class_names <- c("female.gibbon", "hornbill.helmeted", "hornbill.rhino", "long.argus", "noise")

  # Evaluate the performance of the trained models using the test images
  evaluate_trainedmodel_performance_multi(
    trained_models_dir = trained_models_dir,
    class_names = class_names,
    image_data_dir = image_data_dir,
    output_dir = file.path(tempdir(), "data/"),
    noise.category = "noise"
  )

  # Find the location of saved evaluation files
  CSVName <- list.files(file.path(tempdir(), "data"), recursive = TRUE, full.names = TRUE)

  # Check the output of the first file
  head(read.csv(CSVName[1]))
}
#> Evaluating performance of 3_resnet18_model N epochs= multi
#> Warning: Some torch operators might not yet be implemented for the MPS device. A
#> temporary fix is to set the `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a
#> fall back for those operators:
#>  Add `PYTORCH_ENABLE_MPS_FALLBACK=1` to your `.Renviron` file, for example use
#>   `usethis::edit_r_environ()`.
#>  Using `Sys.setenv()` doesn't work because the env var must be set before R
#>   starts.
#> Skipping hornbill.helmeted - cannot calculate performance
#> Skipping long.argus - cannot calculate performance
#>   Sensitivity Specificity Pos.Pred.Value Neg.Pred.Value Precision Recall
#> 1           1           0            0.5             NA       0.5      1
#> 2           1           0            0.5             NA       0.5      1
#> 3           0           1             NA            0.5        NA      0
#> 4           0           1             NA            0.5        NA      0
#> 5           0           1             NA            0.5        NA      0
#> 6           0           1             NA            0.5        NA      0
#>          F1 Prevalence Detection.Rate Detection.Prevalence Balanced.Accuracy
#> 1 0.6666667        0.5            0.5                    1               0.5
#> 2 0.6666667        0.5            0.5                    1               0.5
#> 3        NA        0.5            0.0                    0               0.5
#> 4        NA        0.5            0.0                    0               0.5
#> 5        NA        0.5            0.0                    0               0.5
#> 6        NA        0.5            0.0                    0               0.5
#>   Training.Data N.epochs CNN.Architecture  AUC Threshold         Class
#> 1          test    multi 3_resnet18_model 0.57       0.1 female.gibbon
#> 2          test    multi 3_resnet18_model 0.57       0.2 female.gibbon
#> 3          test    multi 3_resnet18_model 0.57       0.3 female.gibbon
#> 4          test    multi 3_resnet18_model 0.57       0.4 female.gibbon
#> 5          test    multi 3_resnet18_model 0.57       0.5 female.gibbon
#> 6          test    multi 3_resnet18_model 0.57       0.6 female.gibbon
#>                                                                       TestDataPath
#> 1 /Users/denaclink/Desktop/RStudioProjects/gibbonNetR/inst/extdata/multiclass/test
#> 2 /Users/denaclink/Desktop/RStudioProjects/gibbonNetR/inst/extdata/multiclass/test
#> 3 /Users/denaclink/Desktop/RStudioProjects/gibbonNetR/inst/extdata/multiclass/test
#> 4 /Users/denaclink/Desktop/RStudioProjects/gibbonNetR/inst/extdata/multiclass/test
#> 5 /Users/denaclink/Desktop/RStudioProjects/gibbonNetR/inst/extdata/multiclass/test
#> 6 /Users/denaclink/Desktop/RStudioProjects/gibbonNetR/inst/extdata/multiclass/test
#>   Top1Accuracy
#> 1        0.125
#> 2        0.125
#> 3        0.125
#> 4        0.125
#> 5        0.125
#> 6        0.125