The challenges faced by supervised learning for eCommerce image classification is well documented. Due to the requirement for enormous amount of training data, eCommerce use cases are often constrained by the inhibitive cost and time of data labeling. In the following training flow, Data Collection is the step where we find ourselves spending up to 90% of our budget as well as time. In addition, wrong labels, as well as imbalances and biases in the labeled data, hamper the quality of supervised learning. Although features learned from supervised training in one domain can be transferred to different domains, supervised learning fits the model on whatever data it is trained on, and therefore does not have generalized image representation as its explicit or implicit objective.Workflow for Supervised Image Classification
Not only are the supervised models expensive to train, these purposely built models are also not reusable. In contrast, semi-supervised learning (SSL) allows us to use unlabeled data as well as labeled data and learn representation of image features without labels. Recent research has shown that SSL approach leads to better representation of image features, as well as dramatic reduction in the amount of data labeling for downstream tasks. The most recent semi-supervised models pretrained on ImageNet even exceed supervised pretrained models on multiple downstream tasks. Image classification is just one of many computer vision problems that can benefit from SSL. The large amount of unlabeled image data in eCommerce catalog is an untapped resource that exists in abundance. Leveraging semi-supervised image classification in order to increase efficiency and performance of image classification, we built a plug-and-play platform that unifies the pretraining of neural net encoders using unlabeled data and automated search of these SSL encoders for image classification tasks.
Our ultimate design goals are:
Pretrain models with unlabeled Walmart catalog images using SSL and SoTA neural architectures
Enable a class of downstream classification tasks with better performance utilizing a library of pretrained SSL neural nets
Make available embeddings from pretrained SSL neural nets as the new standard for similarity search for Walmart catalog images
Our approach is inspired by recent SSL research for image classification, the majority of which has been using ResNet as a go-to neural net. ResNet is a popular reference model in research publications, but it is rarely the best model for eCommerce use cases. In the research community, ResNet models are typically pretrained and evaluated using ImageNet data. As a result, we have serious concerns with overfitting of the semi-supervised methods on ResNet, as well as overfitting of their results on ImageNet data.
Another challenge is that it is not feasible to extend the myriad of open-source implementations for these SSL approaches to support neural nets other than ResNet. For the research to be published, the accompanying open-source implementation contains one-off techniques to squeeze sometimes minute improvement in performance. Some implementations are designed to use an enormous amount of computing resources as a minimal requirement to reach their reported performance. Collectively these factors render the use of their implementations impractical and make it impossible to translate their results into the realm of eCommerce applications.
There are two common ingredients that have contributed to the success of all SSL approaches. The first is pretraining on massive datasets. The second is the use of models with massive capacity of learning. To give you a sense of what we are talking about here, the neural nets are really deep and wide, with up to 1B parameters; and the large collection of unlabeled pretraining dataset can easily exceed 1B images. As a result, the computing resources required for pretraining is huge. For example, training SimCLR on a single GPU using the original TensorFlow implementation takes about 40 hours per epoch, for a training set of 1.2 million images. We have to be careful when we try to translate the published results into the real-world settings in the ecommerce domain.
In this blog, we will discuss the development of a unified framework, which consists of
- A core library that standardizes the multi-stage training of encoders
- A shared library of both SSL pretrained encoders and off-the-shelf supervised encoders
- Built-in support for image classification tasks via automated search for best encoders and extensible hyperparameter tuning
- Auxiliary techniques that complement the core library by design
First, any neural nets can be plugged into a set of state-of-the-art SSL programs for pretraining, where the choice of the neural net is independent of the choice of SSL approach. Whereas EfficienNet and DenseNet are examples of the neural nets supported, SimCLR, BYOL and SwAV are examples of the SSL approaches implemented.
After the neural nets are pretrained, they are extracted and made available in an SSL model library for downstream classification tasks to use. Once the encoders are pretrained and made available as part of a model library, they no longer have any dependency on the elements used in pretrain process. This means that they are literally first-class candidates capable for any transfer-learning tasks. The generic nature of these pretrained encoders enables automation of encoder search for optimal performance of downstream classification tasks.
Next, for any downstream classification task, instead of manually searching for the best pretrained encoder, we make use of techniques to automate the search, which are built on the search space defined by a model library, the specific classification performance desired, and the search strategy among a wide variety of hyperparameter tuning algorithms.
- First, the search space is not confined to just encoders from our SSL pretrained model library, but also includes off-the-shelf supervised models trained over ImageNet.
- Then, performance target is specified per use-case requirement, along with a resource budget for the search (e.g., in terms of GPU time).
- Finally, we are free to use any hyperparameter tuning algorithms for automated search. As a reference implementation, the framework has been integrated with Ray Tune, a cutting-edge distributed tuning library, and it could easily be extended to work with other tuning libraries like Weights & Biases.
The framework is a foundation for auxiliary techniques that will further improve the image classification performance. One example is knowledge distillation, where a classifier used as a teacher can help to train a classifier built on top of a light-weight encoder. The performance of the student classifier is on par with, or sometimes better than, that of the teacher with the benefit of lower inference cost.
Multi-Stage Image Classification
Let’s see how SSL pretrained encoders could be used for Walmart catalog image classification. First during the pretraining, we adapt ResNet based implementation to other neural nets. The pretraining of encoders is implemented in standard libraries, such that any neural nets could be plugged into the process to learn to represent image features from the unlabeled catalog images. Then during the training of downstream classifiers, we could pick an encoder from a stable set of pretrained SSL encoders, as well as ImageNet pretrained encoders, and plug it into the training process. The framework supports both by design and minimizes any impedance mismatch between the two. The goal is that with everything being equal except the starting encoder weights, we can fairly compare the performance of SSL pretrained encoder vs. ImageNet pretrained ones.
To ensure that the encoders are agnostic to the SSL networks used to pretrain them, we created common building blocks that consist of data preprocessing, data augmentation, extensible encoder and classifier, and standard training harness. For SimCLR, we built upon and adapted a PyTorch Lightning based implementation. As examples of necessary fixes and enhancements to the PyTorch Lightning code, we optimized the loss function implementation for SimCLR, and the Gaussian Blur transformation logic shared by SimCLR and BYOL.
In order to make the platform work with all SSL networks, we need to standardize how encoders are configured and embedded. In particular, the interaction between the encoders that represent the images in embedding space and the rest of the SSL learning network is isolated to a handful of well-defined methods. To illustrate the encapsulation of the encoders and the interaction of encoders with the image classifiers, let’s first look at the sequence diagram:
- Encoder initialized from ImageNet pretrained model
- Encoder pretrained with unlabeled Walmart catalog images using SSL
- Best performing encoder checkpointed
- Classifier initialized with SSL pretrained encoder
- Classifier trained and checkpointed
Next let’s take a look at the state transition of image classifiers:State of Classifiers
- Initialized with ImageNet pretrained models
- Initialized with SSL pretrained encoders
- Trained and checkpointed
- Restored from checkpoint
Finally, here is a closer look at the anatomy of the classifier:Anatomy of Classifier
- Encoder: output embedding
- Projector: output representation in hidden intermediate layer
- Predictor: output logits
Validation and Business Impact
How do we measure the success of the SSL approach? When the rubber hits the road, it is quite simple. It comes down to a class of downstream tasks that benefit from SSL in terms of both performance and data efficiency. If SSL truly lives up to its expectation, SSL encoders could present better embedding that captures the invariants off the catalog images and therefore require a smaller amount of adaptation for downstream tasks. If we have done the implementation right, not only are the encoders portable across the SSL choices, but they could also outperform the original one-off implementation.
The first comparison is regarding classification performance between Supervised and SSL classification methods. The SSL approach benefits a class of attribute classification tasks due to its better generalization of image representation, as well as its attention to individual image features (as compared to using labels as proxy of image similarity in supervised training). SSL pretrained encoders outperform by up to 8% in terms of top-1 accuracy.
The second benefit is regarding the computing resources required for pretraining. For example, compared to over 40 hours that the original SimCLR implementation from the research paper takes per epoch for 1.2 million training data, there is a 7 to 15x reduction in terms of time per epoch, using the same hardware. This is due to several factors. First, we standardize on the training harness to allow smaller batches to be accumulated before back propagation; second, we use smaller image resolution during pretraining, without compromising the performance of downstream classification. In a nutshell, the expense of computing resources is absorbed by way of one-time pretraining of SSL encoders and reusing the encoders for a series of image classification tasks.
Last but not the least, when only a fraction of labeled training data is available, the SSL approach tends to outperform at wider margins than the supervised. The following plot shows the negative change of performance for an example use case, when using 25% and 50% of labeled training data, and holding validation data the same. It clearly indicates that the SSL encoders are more data efficient than the ImageNet pretrained ones.
We discussed why pretrained SSL models make it possible to learn generalized representation of domain-specific images from large amounts of unlabeled data. We shared how pretrained SSL encoders are integrated into a multi-stage image classification framework and showed examples of a class of downstream tasks that benefit from SSL in terms of both better performance and higher data efficiency. We also discussed how to engineer the framework components to manage the complexity.
Walmart has invested in supervised learning over the years, training the best of class supervised models on labeled Walmart catalog images. We believe that the next wave of computer vision innovation comes from pretraining the best of class SSL models using unlabeled Walmart catalog images. The key ingredients to guarantee the full benefits of SSL pretrained models are standardization of library and reusability of encoder for a whole host of image classification tasks. Unified framework makes it possible to pretrain a library of SSL models, which are a must to
- Scale and generalize representation of a large amount of unlabeled, domain-specific data (Walmart Images)
- Enable a class of image classification tasks with better performance
Ultimately, the criteria for a good framework are whether it could facilitate these goals:
- Keep things simple
- Keep costs down
- Standardize and reuse
This work is an ongoing joint project between the Catalog and Search teams from Walmart Global Tech. Special thanks go to Alessandro Magnani for his constant support and Brian Seaman for his guidance.
Application of Semi-Supervised Neural Net to eCommerce Image Classification was originally published in Walmart Global Tech Blog on Medium, where people are continuing the conversation by highlighting and responding to this story.