AI for Plant Disease Detection
Use case of a classifier model for detecting downy mildew disease in Chinese cabbage field.
Aug 20, 2025 ยท Lukas Wiku
Intro
The dataset used in this article is derived from aerial imagery of Chinese cabbage fields. It consists of image patches, sizing 150 x 150 pixels each, which are categorized into three distinct classes: background, diseased, and healthy.
๐ฏ Goal
Develop a model capable of classifying the dataset into their designated class.
To begin with, let's do the usual Python import required for processing the images.
Dependencies and Setup
GPU Check
We recommend to use GPU to achieve faster training. However, the codes in this article will work just fine on a CPU. To check whether PyTorch
is correctly installed and has access to GPU, you can run the code below:
If you have a CUDA-compatible GPU, it will be printed as shown above.
Prepare the Dataset
The images are organized into three folders based on their class labels: background
, diseased
, and healthy
. They are my own drone-captured images of a Chinese cabbage field, taken in Korea as part of my academic research. The image collection process begin with an in-field survey, upon which disease symptoms were visually detected and confirmed. All images captured during this survey are done between 11am and 1pm so the lighting conditions are consistent.
This is my current directory structure:
Store the Image Directories on Memory
Firstly, we need to store the image paths in memory. A convenient way to do this is by using the pandas
library to create a DataFrame
. This DataFrame
will have two columns: image_directory
, which will hold the full path to each image file, and class_id
, which will represent the category (e.g., background, diseased, healthy) each image belongs to.
The class_id
numbers correspond to the following classes:
0
:background
1
:healthy
2
:diseased
Looking at the number of the data points for each class, it seems like we are dealing with the case of an imbalance dataset. A model developed on imbalanced dataset is likely to be biased, producing predictions that favor the class with the most data points.
There are a couple of ways to deal with this problem, such as truncation, upsampling, and data augmentation. While each approach have their pros and cons, let's start off with the most straight forward one, truncation. Here, we cut-off the excess data points from the background
and healthy
class to only 1500 data points.
Sampling
While this dataset is not perfectly balanced, this distribution accurately reflects the reality within the domain of Plant Pests and Diseases. In practice, the number of healthy samples are often always greater than the number of the diseased ones. Diseased samples typically constitute only 10-30% of the total population, which in practice makes the process of gathering diseased samples in itself a challenge.
Split to Train and Test Set
Divide it into train
and test
set with a 7:3
ratio.
Using PyTorch's Dataset & DataLoaders
With our DataFrame
of image directories ready, the next step is to load and process the images for model training. In PyTorch, this is efficiently handled using the Dataset
[1] and DataLoader
[2] classes.
We will create a custom dataset by defining a new class that inherits from torch.utils.data.Dataset
. This custom class will contain all the logic necessary to open the images and perform any required operations before they are fed into the model.
The core functionality lies in the __getitem__
method. Each time the dataset is accessed, the operations defined in this method are executed. This makes it a highly flexible way to manipulate our dataset, allowing us to perform nearly any kind of transformation or processing.
With our custom dataset class prepared, we can now create training
and testing
dataset instances and load them using DataLoader
.
Prepare the Model
For simplicity purposes, we're going to pick a ready-to-use architecture inception_v3
as our plant disease detection model. Using pretrained inception_v3
model means that this model has already been trained using ImageNet dataset and have its weights 'learned' useful image features.
However, to actually makes it applicable for our use case, the model needs to be tuned to 'adapt' to our dataset. This approach is commonly known as transfer learning approach.
To execute this approach, the last fully-connected (FC) layer has to be 'cut-out' and replaced with a new FC layer. This new FC layer is supposed to have the 3 outputs (nodes) in this case, for background
, diseased
and healthy
class.
For the training, we will:
- Use CrossEntropyLoss as the loss function.
- Optimize with Stochastic Gradient Descent (SGD).
While most modern approaches rely on the Adam optimizer, starting with SGD is a solid choice to understand the fundamentals of optimization.
Prepare Utility Functions
Before diving into the core task, let's create a few utility functions to help us in performing a last minute dataset inspection and later on to evaluate our model's prediction.
Prediction Function
This function puts the model in evaluation mode, disables gradient calculations, and for each batch of data, it performs a forward pass to get predictions. Then, the true and predicted labels are gathered and returned as two separate NumPy
arrays.
Confusion Matrix
This function plots the prediction into a confusion matrix, which is a simple yet effective way of showing prediction performance of each class. When you create a confusion matrix prior to the training phase, it serves as a baseline to compare against after the model has been trained.
Display images
We shall also create a helper function to display sample images. It does a quick last-minute check to ensure that the dataset preprocessing and model creation are functioning correctly.
Evaluation Function
Loop through the dataset loader and feed each batch to the model. It gives us model's prediction, which then calculates the number of correct predictions (overall accuracy).
Training
This section is the (almost) final step of the pipeline, all the preparation up to this pointโbuilding image patches, validating the dataset, and loading the modelโis to embrace the black box. We are now quite sure that we prepared the dataset (image patches) as intended, loaded the model and they all worked correctly.
The objective of this step is to fine-tune the pretrained inception_v3
model weights.
In PyTorch
, the training process is typically more explicit than in TensorFlow
. It involves iterating over the dataset, performing forward passes to generate predictions, resetting gradients, updating the model's weights, and repeating these steps for each batch and epoch.
The final testing accuracy is 0.8520, which is a solid result for an initial model.
An experienced deep learning practitioner might point out that the training setupโsuch as the learning_rate
, batch_size
, epochs
, and similar parametersโcould be optimized further. More advanced techniques such as early stopping and checkpoints could also be beneficial.
Accuracy & Loss Plots
Because we are using the weights of the final model, the above accuracy and loss plot suggest that the model may have slightly overfitted. While the training logic could certainly be improved, we shall keep this article concise and won't dive into those adjustments here.
There are a couple of good articles demonstrating better training strategies [3], [4] [5].
The Confusion Matrix
With the training done, let us now evaluate the model's performance using the confusion matrix. This will help us understand how well the model is performing across different classes.
Key analysis:
-
The background class having more distinguishable features compared to the other two, achieved 95% correct prediction, is the easiest to recognize.
-
The healthy class is predicted correctly 86% of the time, although 10% are incorrectly classified as diseased.
-
The diseased class however is only correctly predicted 51% of the time in test samples; nearly half (47%) are misclassified as healthy. This indicates that the model may not be sensitive enough to detect diseased leaf.
Display Predictions
Iterate over the train_loader
once, then display it using displayImgs
function to visualize the model's predictions.
Save the Model
In PyTorch
, you can save the weights with torch.save()
, and load them later with torch.load()
. When loading the weights, you need to ensure that the model architecture is defined in the same way as it was during training.
Application
Imagine that we are working a large aerial orthomosaic image of a farm field -- potentially several gigabytes in size -- it is rich in information but memory-expensive to process it in blob. To handle this, we need to process it in manageable size of chunks, for example, 2000 x 2000 pixels each.
Assuming we already have the necessary patches extracted, we shall demonstrate how to process these patches for inference, before aggregating these predictions to visualize a heatmap where diseased areas are highlighted.
Takeaways
We have successfully developed a working inception_v3
model with 0.85 overall accuracy to correctly classify background
, diseased
and healhy
of Chinese cabbage field images. However, it may need some improvements to make it more sensitive for detecting diseased leaves, as it was only able to correctly classify 51% of the diseased samples.
Nevertheless, we can use this initial experimentation as a foundation for further development and refinement. This might involve collecting more diverse training data, fine-tuning the model architecture, dataset augmentation, and experimenting with different hyperparameters to improve performance. In present day research, more advanced techniques that leverage transfer learning, self-supervised learning, and ensemble methods are being explored to enhance model robustness and accuracy.
Since large orthomosaic images cannot be processed all at once, the final application we develop must also incorporate a strategy of patch-based processing. This solution requires a method to divide them into smaller, analyzable patches. Our model could then process each patch independently, even defining the stages of infection (e.g. mild, medium, severe) for each patch in the context of the overall image.
The repertoire of techniques available for patch-based processing is vast, including methods for efficient data loading, augmentation strategies tailored to small image regions, and specialized neural network architectures designed to handle variable input sizes. When working with a consultant like Supertype, it is often the case of leveraging the existing body of knowledge and tools in the ecosystem to aid in the development of a plant disease detection system that is time-efficient and scalable.
Read More
AI for Rail Transportation
A high level guide on how AI is used in rail transportation industry meant for executives and decision makers.
Data Science for Digital Advertising
How I use data analytics to deliver value to my advertising clients; A case study on creative decay analysis, retention rate analysis, and exploratory data analysis (python code provided)
Unlocking Efficiency: The Importance of Predictive Maintenance
Explore how predictive maintenance powered by AI is transforming industries like mining and manufacturing by reducing costs and minimizing downtime.
Tags: agriculture, plant disease, cnn, pytorch