Structured Gradient Descent
A tale of a scientist
Shirley wanted to become a scientist. She wanted to find about jumping behavioral patterns of tree crickets, which are insects of the order Orthoptera. Specifically, she aimed in testing the hypothesis that tree crickets understood human-issued commands that ordered them “to jump”. She caught a bunch of crickets and started testing her idea by telling one group of crickets to jump, and leaving a control group of crickets telling them nothing. Most crickets in either group jumped, merrily and highly. To her, these results weakly suggest that crickets understood the order to jump and did so accordingly. Then for the same group of crickets in both conditions, she mercilessly cut their legs off. Again, she issued the command to jump for one group, and left the control without saying anything. None of the crickets jumped. What did Shirley the scientist conclude? “Crickets stop understanding human-language after having their legs chopped off”
It’s not her fault: The parallel with deep neural network’s learning
Perhaps you’ve heard of the above story in a course introducing the scientific method.
This tale is interesting given the parallels there exists with current deep neural networks. These are analogous to Shirley, their parameters a metaphor to what she learns, and her crickets resemble the training data. No one has any control with regard to what she learns given the training data. When she concluded that “crickets can’t interpret English after having their legs cut-off”, that is not a “wrong” interpretation, per se. Rather, reaching this conclusion is a simple consequence of how she was designed and the coincidences of the training data (i.e. crickets did stop jumping after their legs were chopped off despite Shirley’s instructions to jump).
With respect to deep neural networks, you can train it to classify images in CIFAR-10 or FashionMNIST, or more generally, train to generate the images. As per what sort of embeddings it learns, how it learns, among other aspects of learning, are out of our control. Exemplified in Stammer et al., (2021), in a ColorMNIST dataset where a ResNet learns to classify digits, if by chance most digits nines appeared with a purple color, then this ResNet would classify purple color digits as 9 in the test set, even though the digits may be of different shape. In such a case, the model has learned that the concept of color defines the identity of a digit, since we have no control over what artifacts does the ResNet pick.
The main powerhouse driving learning is stochastic gradient descent, and it is partly due to this stochasticity that deep neural networks are black-boxes.
The stochasticity of gradient descent is mainly derived from the order and the semantics of the incoming training data. We thus have no control over what the model learns. The AI research community has proposed several methods, such as trying to learn disentangled representations of data, which can be obtained by having a more intricate labelling of the dataset. Consider the CLEVR Diagnostic Dataset, in which each object appearing in each training image is exquisitely labeled with its (x, y, z) coordinates, shape, color, size, material and positional relations among others, along with questions, answers and functional programs to answer each question. A variant of such dataset with real-life scenes is GQA.
Structured gradient descent
Given these precedents, I wanted to depict my desire to see in the future a kind of “structured” gradient descent (which in English coincides with the same acronym as stochastic gradient descent):
Structured gradient descent, as its name suggest, would allow the training of a model where we have control over what kind of representation it obtains from the supplied data. Analogous to the tale above, it’d akin to explicitly telling Shirley that crickets can’t understand human language, and that her line of reasoning has to start from the motility of legs of tree crickets in giving them the ability to jump. Afterwards, the jumping behavior of crickets is up to her to explain.
The unanswered questions that follow would include:
- How would structured gradient descent materialize in a learning algorithm?
- Will a grammar, prior knowledge in the form of logic or external ontologies play a pivotal role in the design of such an algorithm?
- Will structured gradient descent allow a model to learn a disentangled representation, along with benefits such as a lower demand for computational power, dataset size and model size?
- What would be the demands for structured gradient descent to run? More intricate labeled data?
- Most importantly, would more intricately labeled data compensate for low data?
Enjoy Reading This Article?
Here are some more articles you might like to read next: