Dataset Performance
This notebook evaluates the three SimpleCNN models for each task using three benchmark datasets: PlantVillage, PlantDoc, and DiaMOS. We present and analyze the performance metrics Accuracy and F1 Score for each model–dataset combination.
Setup¶
We import the necessary datasets for each task. Although they are sourced from the same directory DATASETS_DIR, the classification criterias have been specified for each task and dataset.
DATASETS_DIR = Directories.EXTERNAL_DATA_DIR.value / "huggingface"from lib.data import (
CombinedDiamosDataset,
CombinedPlantDocDataset,
CombinedPlantVillageDataset,
DiamosDiseaseDetection,
DiamosSymptomIdentification,
PlantDocDiseaseDetection,
PlantDocSymptomIdentification,
PlantVillageDiseaseDetection,
PlantVillageSymptomIdentification,
)We define a basic resizing preprocessing procedure, applied during dataset loading, to ensure architectural compatibility with varying image dimensions. All images are normalized to a fixed square size of 32×32 pixels. We then turn the images into tensor data structures, so PyTorch can easily interface with the data.
transform_pipeline = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.ToTensor(),
]
)Evaluation¶
We aim to evaluate the trained models on each dataset per task, to know how well MegaPlant can support models to generalize better. We restore the trained model weights from disk by loading them into an identical model architecture, then perform tests on each dataset with no data splits.
We evaluate model performance using the F1-score and accuracy, as these metrics are the most commonly reported in prior work on plant disease image classification, thereby ensuring direct comparability and reproducibility of our results. Accuracy quantifies the overall proportion of correctly classified samples and is defined as the ratio of true predictions to the total number of predictions. The F1-score provides a balanced assessment of a model’s predictive capability by combining precision and recall into a single harmonic mean, making it particularly suitable for datasets with class imbalance.
In these metrics, true positives (TP) refer to samples that are correctly classified as belonging to the positive class, true negatives (TN) denote samples correctly classified as belonging to the negative class, false positives (FP) represent negative samples that are incorrectly classified as positive, and false negatives (FN) indicate positive samples that are incorrectly classified as negative. These quantities form the basis for computing accuracy, precision, recall, and the F1-score. We will be using the scikit-learn Python framework by Pedregosa et al. (2011) to calculate the evaluation metrics for us.
Disease Detection¶
# Load model weights
disease_detection_model = SimpleCNN(channels=3, output_dim=1)
disease_detection_model.load_state_dict(torch.load(Directories.MODELS_DIR.value / "disease_detection_model.pth"))We specify the testing procedure in a function test for binary classification. This function evaluates a trained PyTorch model on a given dataset using the provided DataLoader. It computes the model’s performance by calculating both the F1 score and accuracy, and returns these two metrics as a tuple. The model argument represents the neural network to be evaluated, while the DataLoader supplies the evaluation data.
def test(model: torch.nn.Module, data_loader: DataLoader) -> Tuple:
"""
Evaluate the model on the given data loader and return F1 score and accuracy.
Parameters
----------
model : torch.nn.Module
The trained model to evaluate.
data_loader : DataLoader
DataLoader for the dataset to evaluate on.
Returns
-------
tuple
A tuple containing F1 score and accuracy.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
all_preds = []
all_labels = []
all_outputs = []
THRESHOLD = 0.5
with torch.no_grad():
for images, labels in tqdm(data_loader, desc=f"Evaluating Dataset {data_loader.dataset.__class__.__name__}"):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
preds = (outputs >= THRESHOLD).long()
all_outputs.extend(outputs.cpu().numpy())
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
f1 = f1_score(all_labels, all_preds)
accuracy = accuracy_score(all_labels, all_preds)
return f1, accuracy
PlantVillage¶
plantvillage = PlantVillageDiseaseDetection(data_path=DATASETS_DIR / "plantvillage", transforms=transform_pipeline)
plantvillage_loader = DataLoader(plantvillage, batch_size=32, shuffle=True)
f1, accuracy = test(disease_detection_model, plantvillage_loader)
print(f"F1 Score: {f1:.4f}")
print(f"Accuracy: {accuracy:.4f}")Evaluating Dataset PlantVillageDiseaseDetection: 100%|██████████| 1698/1698 [00:32<00:00, 52.47it/s]
F1 Score: 0.9901
Accuracy: 0.9856
PlantDoc¶
plantdoc = PlantDocDiseaseDetection(data_path=DATASETS_DIR / "plantdoc", transforms=transform_pipeline)
plantdoc_loader = DataLoader(plantdoc, batch_size=32, shuffle=True)
f1, accuracy = test(disease_detection_model, plantdoc_loader)
print(f"F1 Score: {f1:.4f}")
print(f"Accuracy: {accuracy:.4f}")Evaluating Dataset PlantDocDiseaseDetection: 100%|██████████| 92/92 [00:27<00:00, 3.29it/s]F1 Score: 0.8954
Accuracy: 0.8392
DiaMOS¶
diamos = DiamosDiseaseDetection(data_path=DATASETS_DIR / "diamos", transforms=transform_pipeline)
diamos_loader = DataLoader(diamos, batch_size=32, shuffle=True)
f1, accuracy = test(disease_detection_model, diamos_loader)
print(f"F1 Score: {f1:.4f}")
print(f"Accuracy: {accuracy:.4f}")Evaluating Dataset DiamosDiseaseDetection: 100%|██████████| 94/94 [03:48<00:00, 2.43s/it]F1 Score: 0.9918
Accuracy: 0.9837
Symptom Identification¶
symptom_identifier = SimpleCNN(channels=3, output_dim=12)
symptom_identifier.load_state_dict(torch.load(Directories.MODELS_DIR.value / "symptom_identification_model.pth"))<All keys matched successfully>We specify a new test procedure test_si for multi-class classification tasks. It takes in the same input as the test function but outputs a tuple of list of targets and list of predictions. So that downstream, we can both calculate the F1 and Accuracy score and generate a classification report, which generates F1 score for each predicted class.
def test_si(model: torch.nn.Module, data_loader: DataLoader) -> Tuple[List[int], List[int]]:
"""
Evaluate the symptom identification model on the given data loader and return predictions and labels.
Parameters
----------
model : torch.nn.Module
The trained symptom identification model to evaluate.
data_loader : DataLoader
DataLoader for the dataset to evaluate on.
Returns
-------
tuple
A tuple containing lists of true labels and predicted labels.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
all_preds = []
all_labels = []
all_outputs = []
with torch.no_grad():
for images, labels in tqdm(data_loader, desc=f"Evaluating Dataset {data_loader.dataset.__class__.__name__}"):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
preds = outputs.argmax(dim=1)
all_outputs.extend(outputs.cpu().numpy())
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
return all_labels, all_predsPlantDoc¶
plantdoc_si = PlantDocSymptomIdentification(data_path=DATASETS_DIR / "plantdoc", transforms=transform_pipeline)
plantdoc_si_loader = DataLoader(plantdoc_si, batch_size=32, shuffle=True)
all_labels, all_preds = test_si(symptom_identifier, plantdoc_si_loader)Evaluating Dataset PlantDocSymptomIdentification: 100%|██████████| 66/66 [00:16<00:00, 4.01it/s]
f1 = f1_score(all_labels, all_preds, average='weighted')
accuracy = accuracy_score(all_labels, all_preds)
print(f"F1 Score: {f1:.4f}")
print(f"Accuracy: {accuracy:.4f}")F1 Score: 0.7692
Accuracy: 0.7569
print(classification_report(all_labels, all_preds)) precision recall f1-score support
0 0.83 0.83 0.83 769
1 0.84 0.80 0.82 238
2 0.00 0.00 0.00 0
3 0.68 0.68 0.68 130
4 0.03 1.00 0.06 2
5 0.74 0.67 0.71 91
6 0.80 0.52 0.63 54
7 0.75 0.56 0.64 79
8 0.77 0.77 0.77 223
9 0.83 0.61 0.70 93
10 0.72 0.74 0.73 415
11 0.00 0.00 0.00 0
accuracy 0.76 2094
macro avg 0.58 0.60 0.55 2094
weighted avg 0.79 0.76 0.77 2094
/home/iragca/Documents/github/DS413-final-project/.venv/lib/python3.13/site-packages/sklearn/metrics/_classification.py:1731: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/home/iragca/Documents/github/DS413-final-project/.venv/lib/python3.13/site-packages/sklearn/metrics/_classification.py:1731: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/home/iragca/Documents/github/DS413-final-project/.venv/lib/python3.13/site-packages/sklearn/metrics/_classification.py:1731: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
PlantVillage¶
plantvillage_si = PlantVillageSymptomIdentification(data_path=DATASETS_DIR / "plantvillage", transforms=transform_pipeline)
plantvillage_si_loader = DataLoader(plantvillage_si, batch_size=32, shuffle=True)
all_labels, all_preds = test_si(symptom_identifier, plantvillage_si_loader)Evaluating Dataset PlantVillageSymptomIdentification: 100%|██████████| 1226/1226 [00:22<00:00, 54.17it/s]
f1 = f1_score(all_labels, all_preds, average='weighted')
accuracy = accuracy_score(all_labels, all_preds)
print(f"F1 Score: {f1:.4f}")
print(f"Accuracy: {accuracy:.4f}")F1 Score: 0.9302
Accuracy: 0.9305
print(classification_report(all_labels, all_preds)) precision recall f1-score support
0 0.96 0.78 0.86 6970
1 0.99 0.99 0.99 5507
2 0.99 0.98 0.98 5357
3 0.99 0.96 0.97 2887
4 0.97 0.85 0.91 1676
5 0.97 0.87 0.92 952
6 0.98 0.88 0.93 373
7 0.96 0.95 0.96 1801
8 0.97 0.97 0.97 1467
9 0.92 0.92 0.92 630
10 0.83 0.98 0.90 10492
11 0.98 0.96 0.97 1109
accuracy 0.93 39221
macro avg 0.96 0.92 0.94 39221
weighted avg 0.94 0.93 0.93 39221
DiaMOS¶
diamos_si = DiamosSymptomIdentification(data_path=DATASETS_DIR / "diamos", transforms=transform_pipeline)
diamos_si_loader = DataLoader(diamos_si, batch_size=32, shuffle=True)
all_labels, all_preds = test_si(symptom_identifier, diamos_si_loader)Evaluating Dataset DiamosSymptomIdentification: 100%|██████████| 93/93 [03:34<00:00, 2.31s/it]
f1 = f1_score(all_labels, all_preds, average='weighted')
accuracy = accuracy_score(all_labels, all_preds)
print(f"F1 Score: {f1:.4f}")
print(f"Accuracy: {accuracy:.4f}")F1 Score: 0.8919
Accuracy: 0.8853
print(classification_report(all_labels, all_preds)) precision recall f1-score support
0 0.00 0.00 0.00 0
1 0.00 0.00 0.00 0
2 1.00 0.69 0.81 54
3 0.00 0.00 0.00 0
4 0.90 0.95 0.93 2025
6 0.00 0.00 0.00 0
7 0.00 0.00 0.00 0
8 0.00 0.00 0.00 0
9 0.00 0.00 0.00 0
10 0.90 0.75 0.82 884
accuracy 0.89 2963
macro avg 0.28 0.24 0.26 2963
weighted avg 0.90 0.89 0.89 2963
/home/iragca/Documents/github/DS413-final-project/.venv/lib/python3.13/site-packages/sklearn/metrics/_classification.py:1731: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/home/iragca/Documents/github/DS413-final-project/.venv/lib/python3.13/site-packages/sklearn/metrics/_classification.py:1731: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/home/iragca/Documents/github/DS413-final-project/.venv/lib/python3.13/site-packages/sklearn/metrics/_classification.py:1731: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
Combined Identification and Detection¶
test_comb = test_si
combined_classifier = SimpleCNN(channels=3, output_dim=13)
combined_classifier.load_state_dict(torch.load(Directories.MODELS_DIR.value / f"combined_identification_model.pth"))<All keys matched successfully>PlantDoc¶
combined_plantdoc = CombinedPlantDocDataset(data_path=DATASETS_DIR / "plantdoc", transforms=transform_pipeline)
combined_plantdoc_loader = DataLoader(combined_plantdoc, batch_size=32, shuffle=True)
all_labels, all_preds = test_comb(combined_classifier, combined_plantdoc_loader)f1 = f1_score(all_labels, all_preds, average='weighted')
accuracy = accuracy_score(all_labels, all_preds)
print(f"F1 Score: {f1:.4f}")
print(f"Accuracy: {accuracy:.4f}")F1 Score: 0.7897
Accuracy: 0.7754
print(classification_report(all_labels, all_preds)) precision recall f1-score support
0 0.82 0.84 0.83 769
1 0.80 0.74 0.77 238
2 0.00 0.00 0.00 0
3 0.87 0.75 0.81 130
4 0.02 1.00 0.04 2
5 0.76 0.56 0.65 91
6 0.80 0.44 0.57 54
7 0.66 0.72 0.69 79
8 0.79 0.75 0.77 223
9 0.71 0.61 0.66 93
10 0.69 0.70 0.69 415
11 0.00 0.00 0.00 0
12 0.88 0.84 0.86 822
accuracy 0.78 2916
macro avg 0.60 0.61 0.56 2916
weighted avg 0.81 0.78 0.79 2916
/home/iragca/Documents/github/DS413-final-project/.venv/lib/python3.13/site-packages/sklearn/metrics/_classification.py:1731: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/home/iragca/Documents/github/DS413-final-project/.venv/lib/python3.13/site-packages/sklearn/metrics/_classification.py:1731: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/home/iragca/Documents/github/DS413-final-project/.venv/lib/python3.13/site-packages/sklearn/metrics/_classification.py:1731: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
labels_for_combined = []
preds_for_combined = []
for label, pred in zip(all_labels, all_preds):
if label == combined_plantdoc.CLASS_MAP['healthy']:
labels_for_combined.append(0)
if pred != combined_plantdoc.CLASS_MAP['healthy']:
preds_for_combined.append(1)
else:
preds_for_combined.append(0)
else:
labels_for_combined.append(1)
if pred != combined_plantdoc.CLASS_MAP['healthy']:
preds_for_combined.append(1)
else:
preds_for_combined.append(0)
accuracy_score(labels_for_combined, preds_for_combined)0.9242112482853223f1_score(labels_for_combined, preds_for_combined)0.947717057014431PlantVillage¶
combined_plantvillage = CombinedPlantVillageDataset(data_path=DATASETS_DIR / "plantvillage", transforms=transform_pipeline)
combined_plantvillage_loader = DataLoader(combined_plantvillage, batch_size=32, shuffle=True)
all_labels, all_preds = test_comb(combined_classifier, combined_plantvillage_loader)Evaluating Dataset CombinedPlantVillageDataset: 100%|██████████| 1698/1698 [00:40<00:00, 42.18it/s]
f1 = f1_score(all_labels, all_preds, average='weighted')
accuracy = accuracy_score(all_labels, all_preds)
print(f"F1 Score: {f1:.4f}")
print(f"Accuracy: {accuracy:.4f}")F1 Score: 0.9475
Accuracy: 0.9478
print(classification_report(all_labels, all_preds)) precision recall f1-score support
0 0.94 0.80 0.87 6970
1 0.99 0.99 0.99 5507
2 0.99 0.99 0.99 5357
3 0.99 0.97 0.98 2887
4 0.97 0.89 0.93 1676
5 0.96 0.93 0.94 952
6 0.88 0.95 0.91 373
7 0.94 0.94 0.94 1801
8 0.98 0.97 0.97 1467
9 0.88 0.91 0.89 630
10 0.86 0.95 0.90 10492
11 0.97 0.98 0.98 1109
12 0.98 0.98 0.98 15084
accuracy 0.95 54305
macro avg 0.95 0.94 0.94 54305
weighted avg 0.95 0.95 0.95 54305
labels_for_combined = []
preds_for_combined = []
for label, pred in zip(all_labels, all_preds):
if label == combined_plantdoc.CLASS_MAP['healthy']:
labels_for_combined.append(0)
if pred != combined_plantdoc.CLASS_MAP['healthy']:
preds_for_combined.append(1)
else:
preds_for_combined.append(0)
else:
labels_for_combined.append(1)
if pred != combined_plantdoc.CLASS_MAP['healthy']:
preds_for_combined.append(1)
else:
preds_for_combined.append(0)
accuracy_score(labels_for_combined, preds_for_combined)0.9905349415339287f1_score(labels_for_combined, preds_for_combined)0.9934437102987321DiaMOS¶
combined_diamos = CombinedDiamosDataset(data_path=DATASETS_DIR / "diamos", transforms=transform_pipeline)
combined_diamos_loader = DataLoader(combined_diamos, batch_size=32, shuffle=True)
all_labels, all_preds = test_comb(combined_classifier, combined_diamos_loader)Evaluating Dataset CombinedDiamosDataset: 100%|██████████| 94/94 [03:32<00:00, 2.26s/it]
f1 = f1_score(all_labels, all_preds, average='weighted')
accuracy = accuracy_score(all_labels, all_preds)
print(f"F1 Score: {f1:.4f}")
print(f"Accuracy: {accuracy:.4f}")F1 Score: 0.8705
Accuracy: 0.8666
print(classification_report(all_labels, all_preds)) precision recall f1-score support
0 0.00 0.00 0.00 0
1 0.00 0.00 0.00 0
2 1.00 0.72 0.84 54
3 0.00 0.00 0.00 0
4 0.89 0.94 0.91 2025
5 0.00 0.00 0.00 0
6 0.00 0.00 0.00 0
7 0.00 0.00 0.00 0
8 0.00 0.00 0.00 0
9 0.00 0.00 0.00 0
10 0.88 0.71 0.79 884
12 0.59 0.67 0.63 43
accuracy 0.87 3006
macro avg 0.28 0.25 0.26 3006
weighted avg 0.88 0.87 0.87 3006
/home/iragca/Documents/github/DS413-final-project/.venv/lib/python3.13/site-packages/sklearn/metrics/_classification.py:1731: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/home/iragca/Documents/github/DS413-final-project/.venv/lib/python3.13/site-packages/sklearn/metrics/_classification.py:1731: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
/home/iragca/Documents/github/DS413-final-project/.venv/lib/python3.13/site-packages/sklearn/metrics/_classification.py:1731: UndefinedMetricWarning: Recall is ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.
_warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
labels_for_combined = []
preds_for_combined = []
for label, pred in zip(all_labels, all_preds):
if label == combined_plantdoc.CLASS_MAP['healthy']:
labels_for_combined.append(0)
if pred != combined_plantdoc.CLASS_MAP['healthy']:
preds_for_combined.append(1)
else:
preds_for_combined.append(0)
else:
labels_for_combined.append(1)
if pred != combined_plantdoc.CLASS_MAP['healthy']:
preds_for_combined.append(1)
else:
preds_for_combined.append(0)
accuracy_score(labels_for_combined, preds_for_combined)0.9886892880904857f1_score(labels_for_combined, preds_for_combined)0.9942567567567567Results and Discussion¶
Disease Detection¶
| Dataset | F1 Score | Accuracy |
|---|---|---|
| DiaMOS | 0.9918 | 0.9837 |
| PlantVillage | 0.9901 | 0.9856 |
| PlantDoc | 0.8954 | 0.8392 |
We get small variance in performance between the 3 datasets as well very high scores.
Symptom Identification¶
| Dataset | F1 Score | Accuracy |
|---|---|---|
| DiaMOS | 0.8919 | 0.8853 |
| PlantVillage | 0.9302 | 0.9305 |
| PlantDoc | 0.7692 | 0.7569 |
Considering that symptom identification is harder task, we get more variance and less overall accuracy in all three datasets.
Combined Identification and Detection¶
| Dataset | F1 Score | Accuracy | Binary F1 Score | Binary Accuracy |
|---|---|---|---|---|
| DiaMOS | 0.8705 | 0.8666 | 0.9942 | 0.9886 |
| PlantVillage | 0.9475 | 0.9478 | 0.9934 | 0.9905 |
| PlantDoc | 0.7897 | 0.7754 | 0.9477 | 0.9242 |
Binary F1 Score and Accuracy are derived from the model’s outputs, where all symptom-present cases are mapped to the unhealthy/diseased class (1), and all healthy cases are mapped to the healthy class (0). This makes it easier to compare against disease detection models.
There is no noticeable difference when compared to symptom identification task, however when we compare it with the disease detection model, we get a noticeable increase in F1 and Accuracy score, particularly with PlantDoc.
Summary¶
We evaluated the three task-specific SimpleCNN models on each dataset and compared their performance across tasks. From this analysis, we observed that the combined-task model may outperform the disease-detection model, although it centralizes the decision-making responsibility, which can introduce its own trade-offs.
- Hughes, David. P., & Salathe, M. (2015). An open access repository of images on plant health to enable the development of mobile disease diagnostics. 10.48550/ARXIV.1511.08060
- Singh, D., Jain, N., Jain, P., Kayal, P., Kumawat, S., & Batra, N. (2020). PlantDoc: A Dataset for Visual Plant Disease Detection. Proceedings of the 7th ACM IKDD CoDS and 25th COMAD, 249–253. 10.1145/3371158.3371196
- Fenu, G., & Malloci, F. M. (2021). DiaMOS Plant: A Dataset for Diagnosis and Monitoring Plant Disease. Agronomy, 11(11), 2107. 10.3390/agronomy11112107
- Paszke, A., Gross, S., Massa, F., Lerer, A., Bradbury, J., Chanan, G., Killeen, T., Lin, Z., Gimelshein, N., Antiga, L., Desmaison, A., Köpf, A., Yang, E., DeVito, Z., Raison, M., Tejani, A., Chilamkurthy, S., Steiner, B., Fang, L., … Chintala, S. (2019). PyTorch: An Imperative Style, High-Performance Deep Learning Library. arXiv. 10.48550/ARXIV.1912.01703
- Pedregosa, F., Varoquaux, G., Gramfort, A., Michel, V., Thirion, B., Grisel, O., Blondel, M., Prettenhofer, P., Weiss, R., Dubourg, V., Vanderplas, J., Passos, A., Cournapeau, D., Brucher, M., Perrot, M., & Duchesnay, É. (2011). Scikit-learn: Machine Learning in Python. Journal of Machine Learning Research, 12, 2825–2830. https://jmlr.csail.mit.edu/papers/v12/pedregosa11a.html