Skip to article content

MegaPlant

An integrated image classification dataset of laboratory and field images

Data Exploration

This notebook introduces and interfaces the data and does light exploration of its characteristics. Right after downloading the dataset, we can import either of the 3 datsets:

  • MegaPlantDataset contains both healthy and unhealthy leaf images and is designed for binary classification.

  • UnhealthyMegaPlantdataset contains only the unhealthy leaf images but has 12 classes of plant disease symptoms, designed for multi-class classification for symptom-identification.

  • CombinedMegaPlantDataset contains both healthy class and unhealthy subclass images, designed for both plant disease detection and symptom-identification.

from lib.data import MegaPlantDataset, UnhealthyMegaPlantDataset, CombinedMegaPlantDataset

MEGAPLANT_DIR = Directories.EXTERNAL_DATA_DIR.value / "huggingface" / "leaves"

megaplant = MegaPlantDataset(
    data_path=MEGAPLANT_DIR, transforms=None
    )

unhealthy_megaplant = UnhealthyMegaPlantDataset(
    data_path=MEGAPLANT_DIR, transforms=None
    )

combined_megaplant = CombinedMegaPlantDataset(
    data_path=MEGAPLANT_DIR, transforms=None
    )

As we have mentioned in our paper, we compiled the dataset according to some criteria regarding the folder names whether they have a corresponding name in the 12 unhealthy subclasses.

We counted the retrieved images by using this command inside a Bash terminal.

find path/to/dir -type f | wc -l

DatasetReported sizeRetrieved images
DiaMOS3,9013,006
PlantVillage54,30554,306
PlantDoc2,9222,598

Due to duplicates, corrupted images, we receive different counts of images rather than the ones reported in the original authors’ papers.

Sample Images

We can index and iterate over the dataset, and each sample is a tuple of (image, label) dataset. If transforms=None then image will be a pathlib.Path object. Otherwise, it will be whatever torchvision.transforms output will be. label will always be the image’s label.

image, label = megaplant[0]

print("Label:", label)
Image.open(image)
Label: 0
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=256x256>
megaplant.STATUS_MAP
{'healthy': 0, 'unhealthy': 1}

This image from the MegaPlantDataset is of label 0 indicating that it is health, and 1 if otherwise.

image, label = unhealthy_megaplant[-2007]

print("Label:", label)
Image.open(image)
Label: 10
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=256x256>
unhealthy_megaplant.SYMPTOM_MAP
{'blight': 0, 'yellowing': 1, 'malformation': 2, 'powdery_mildew': 3, 'feeding': 4, 'mold': 5, 'mosaic': 6, 'rot': 7, 'rust': 8, 'scab': 9, 'spot': 10, 'scorch': 11}

This image from the UnhealthyMegaPlantDataset is of class 10, meaning it has a symptom of ‘spots’.

image, label = combined_megaplant[0]

print("Label:", label)
Image.open(image)
Label: 0
<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=256x256>
combined_megaplant.CLASS_MAP
{'blight': 0, 'yellowing': 1, 'malformation': 2, 'powdery_mildew': 3, 'feeding': 4, 'mold': 5, 'mosaic': 6, 'rot': 7, 'rust': 8, 'scab': 9, 'spot': 10, 'scorch': 11, 'healthy': 12}

This image from the CombinedMegaPlantDataset is of class 0, meaning it has a symptom of ‘blight’, indicating that it has a disease. If otherwise it had a class of 12, it would have no disease or symptom.

Analysis

The analysis we will perform are some simple descriptive statistics about the dataset itself, as well as statistics about image metadata such as resolution or image channels.

Sample Sizes

print("MegaPlantDataset size:", len(megaplant))
print("UnhealthyMegaPlantDataset size:", len(unhealthy_megaplant))
print("CombinedMegaPlantDataset size:", len(combined_megaplant))
MegaPlantDataset size: 60206
UnhealthyMegaPlantDataset size: 44270
CombinedMegaPlantDataset size: 60206

We have 60,000 images of leaves both healthy, and unhealthy.

Source
from collections import defaultdict

# Make a frequency count of statuses and symptoms
status_counts = defaultdict(int)
class_counts = defaultdict(int)
symptom_int_map = {v: k for k, v in unhealthy_megaplant.SYMPTOM_MAP.items()}

for _, label in megaplant:
    status_counts[label] += 1

for _, label in combined_megaplant:
    class_counts[label] += 1

# Plot status counts
sns.barplot(
    x=list(status_counts.keys()),
    y=list(status_counts.values()),
)
plt.xlabel("Plant Status")
plt.ylabel("Count")
plt.title("Plant Status Counts")
plt.xticks(ticks=[0, 1], labels=["Healthy", "Unhealthy"])

print("Unhealthy portion: ", status_counts[1] / sum(status_counts.values()))
print("Healthy portion: ", status_counts[0] / sum(status_counts.values()))
Unhealthy portion:  0.7353087732119722
Healthy portion:  0.26469122678802776
<Figure size 400x250 with 1 Axes>

A 3:1 ratio between the unhealthy and healthy class implies a significant class imbalance.

Source
sorted_class_counts = dict(
    sorted(class_counts.items(), key=lambda item: item[1], reverse=True)
)
class_int_map = {v: k for k, v in combined_megaplant.CLASS_MAP.items()}
sns.barplot(
    y=[class_int_map[k] for k in sorted_class_counts.keys()],
    x=list(sorted_class_counts.values()),
    orient="h"
)
plt.xlabel("Count")
plt.ylabel("Class")
plt.title("Class Counts")
<Figure size 400x250 with 1 Axes>

A Pareto like pattern emerges in the class counts, with spot and blight symptoms composing around 50% of the unhealthy leaf images.

Considering that classes are imbalanced both for all 3 datasets, it is imperative that we use performance metrics like Precision, Recall and F1 Score to capture in-class prediction accuracy to mitigate misleading conclusion due to class majority bias rather than solely relying on overall accuracy.

Image Metadata

Source
sizes_freq = defaultdict(int)
channel_freq = defaultdict(int)

for image, label in unhealthy_megaplant:
    image_size = Image.open(image).size
    sizes_freq[image_size] += 1

    channel_freq[len(Image.open(image).getbands())] += 1

sizes_freq = dict(sorted(sizes_freq.items(), key=lambda item: item[1], reverse=True))
channel_freq = dict(sorted(channel_freq.items(), key=lambda item: item[0]))
Source
TOP_K = 10
sns.barplot(
    y=[f"{size[0]}x{size[1]}" for size in list(sizes_freq.keys())[:TOP_K]],
    x=list(sizes_freq.values())[:TOP_K],
    orient="h"
)
plt.xlabel("Count")
plt.ylabel("Image Size")
plt.title("Image Size Distribution")
plt.show()
<Figure size 400x250 with 1 Axes>

We can see that most of our images are sized 256x256 and will need preprocessing before passing it through a convolutional network.

Source
print("Channel counts:", channel_freq)
sns.barplot(
    x=list(channel_freq.keys()),
    y=list(channel_freq.values()),
)
plt.xlabel("Number of Channels")
plt.ylabel("Count")
plt.title("Channel Count Distribution")
Channel counts: {1: 1, 3: 44260, 4: 9}
<Figure size 400x250 with 1 Axes>

Most of our images have 3 channels (RGB or colored).

Summary

We’ve shown how to import the datasets, and use their interface including iterating and indexing, we also explored its characteristics, leaving with the information that our data’s class imbalance is significant and is imperative that we use class imbalance agnostic performance metrics like Precision, Recall and F1 Score. We also discovered that we will need to do preprocessing before modeling due to different image resolutions and channels.