![]() |
Welcome to ShortScience.org! |
![]() ![]() ![]() |
[link]
In 2018, a group including many of the authors of this updated paper argued for a theory of deep neural network optimization that they called the “Lottery Ticket Hypothesis”. It framed itself as a solution to what was otherwise a confusing seeming-contradiction: that you could prune or compress trained networks to contain a small percentage of their trained weights without loss of performance, but also that if you tried to train a comparably small network (comparable to the post-training pruned network) from initialization, you wouldn’t be able to achieve similar performance. They showed that, at least for some set of tasks, you could in fact train a smaller network to equivalent performance, but only if you kept the same connectivity patterns as in the pruned network, and if you re-initialized those neurons to the same values they were initialized at during the initial training. These lucky initialization patterns are the lottery tickets being referred to in the eponymous hypothesis: small subnetworks of well-initialized weights that are set up to be able to train well. This paper assesses whether and under what conditions the LTH holds on larger problems, and does a bit of a meta-analysis over different alternate theories in this space. One such alternate theory, from Liu et al, proposes that, in fact, there is no value in re-initializing to the specific initial values, and that you can actually get away with random initialization if you keep the connectivity patterns of the pruned weights. The “At Scale” paper compares the two methods over a wide range of pruning percentages, and convincingly shows that while random initialization with the same connectivity can perform well up to 80% of the weights being removed, after 80%, the performance of the random initialization drops, whereas the performance of the “winning ticket” approach remains comparable with full network training up to 99% of the weights being pruned. This seems to provide support for the theory that there is value in re-initializing the weights to how they were, especially when you prune to very small subnetworks. https://i.imgur.com/9O2aAIT.png The core of the current paper focuses on a difficulty in the original LTH paper: that the procedure of iterative pruning (train, then prune some weights, then train again) wasn’t able to reliably find “winning tickets” for deep networks of the type needed to solve ImageNet or CIFAR. To be precise, re-initializing pruned networks to their original values did no better than initializing them randomly in these networks. In order to actually get these winning tickets to perform well, the original authors had to do a somewhat arcane process of of starting the learning rate very small and scaling it up, called “warmup”. Neither paper gave a clear intuition as to why this would be the case, but the updated paper found that they could avoid the need for this approach if, instead of re-initializing weights to their original value, they set them to the values they were at after some small number of iterations into training. They justify this by showing that performance under this new initialization is related to something they call “stability to pruning,” which measures how close the learned weights after re-initialization are to the original learned weights in the full model training. And, while the weights of deeper networks are unstable (by this metric) when first initialized, they become stable fairly early on. I was a little confused by this framing, since it seemed fairly tautological to me, since you’re using “how stably close are the weights to the original weights” as a way to explain “when can you recover performance comparable to original performance.” This was framed as being a mechanistic explanation of why you can see a lottery ticket phenomenon to some extent, but only if you do a “late reset” to several iterations after initialization, but it didn’t feel quite mechanistically satisfying enough to me. That said, I think this is overall an intriguing idea, and I’d love to see more papers discuss it. In particular, I’d love to read more qualitative analysis about whether there are any notable patterns shared by “winning tickets”. ![]() |
[link]
The Transformer, a somewhat confusingly-named model structure that uses attention mechanisms to aggregate information for understanding or generating data, has been having a real moment in the last year or so, with GPT-2 being only the most well-publicized tip of that iceberg. It has lots of advantages: the obvious attractions of strong performance, as well as the ability to train in parallel across parts of a sequence, which RNNs can’t do because of the need to build up and maintain state. However, a problematic fact about the Transformer approach is how it scales to large sequences of input data. Because attention is based on performing pairwise queries between each point in the data sequence and each other point, to allow for aggregation of information from places throughout the sequence, it scales as O(N^2), because every new element in the sequence needs to be queried by N other ones. This makes it resource-intensive to run transformer models on large architectures. The Sparse Transformer design proposed in this OpenAI paper tries to cut down on this resource cost by loosening the requirement that, in every attention operation, each element directly pulls information from every other element. In this new system, each point doesn’t get information about each other point in a single operation, but, having two operations such limited operations being chained in a row provides that global visibility. This is done in one of two ways. (1) The first, called the “strided” version, performs two operations in a row, one masked attention that only looks at the last k timesteps (for example, the last 7), and then a second masked attention that only looks at every kth timestep. So, at the end of the second operation, each point has pulled information from points at checkpoints 7, 14, 21 steps ago, and each of these has pulled information from the window between it and its preceding checkpoint, giving visibility into a full global receptive frame in the course of two operations (2) The second, called the “fixed” version, uses a similar sort of logic, but instead of having the “window accumulation points” be defined in reference to the point doing the querying, you instead have fixed accumulation points responsible for gathering information from the windows around them. So, using the example given in the paper, if you imagine a window of size 128, and an “accumulation points per window” of 8, then the points in indices 120-128 (say) would have visibility into points 0-128. That represents the first operation, and in the second one, all other points in the sequence pull in information by querying the designated accumulation points for all the windows that aren’t masked for it. The paper argues that, between these two systems, the Strided system should work best when the data has some inherent periodicity, but I don’t know that I particularly follow that intuition. I have some sense that the important distinction here is that in the strided case you have many points of accumulation, each with not much context, whereas in the fixed case you have very few accumulation points each with a larger window, but I don’t know what performance differences exactly I’d expect these mechanical differences to predict. This whole project of reducing access to global information seems initially a little counterintuitive, since the whole point of a transformer design, in some sense, was its ability to gain global context in a single layer, as opposed to a convnet needing multiple layers to build receptive field, or a RNN needing to maintain state throughout the sequence. However, I think this paper makes the most sense as a way of interpolating the space between something like a CNN and a full attention design, for the sake of efficiency. With a CNN, you have a fixed kernel, and so as your sequence gets longer, you need to add more and more layers in order for any given point to be able to incorporate into its representation information from the complete other side of the sequence. With a RNN, as your sequence gets longers, you pay the cost of needing to backpropogate state farther. So, by contrast, even though the Sparse Transformer seems to be giving up its signal advantage, it’s instead just trading one constant number of steps to achieve global visibility (1), for another (2, in this paper, but conceptually could be more), but still in a way that’s constant with respect to the length of the data. And, in exchange for this trade, they get very sparse, very masked operations, where many of the multiplications involved in these big query calculations can be ignored, making for faster computation speeds. On the datasets tried, the Sparse Transformer increased speed, and in fact in I think all cases increased performance - not by much, the performance gain by itself isn’t that dramatic, but in the context of expecting if anything worse performance as a result of limiting model structure, it’s encouraging and interesting that it either stays about the same or possible improves. ![]() |
[link]
This paper focuses on taking advances from the (mostly heuristic, practical) world of data augmentation for supervised learning, and applying those to the unsupervised setting, as a way of inducing better performance in a semi-supervised environment (with many unlabeled points, and few labeled ones) Data augmentation has been a mostly behind-the-scenes implementation detail in modern deep learning: minor modifications like shifting a dataset by a few pixels, rotating it slightly, or flipping it horizontally, to generate additional pseudoexamples for the model to train on. The core idea motivating such approaches is that the tactics of data augmentation are modifications that should not change the class label in a platonic, ground truth sense, which allows us to use them as training examples of the same class label as the original image from which the perturbations were made. In addition to just giving then network generically more data, this approach communicates to the network the specific kinds of modifications that can be made to an image and have it still be of the same class. The Unsupervised Data Augmentation (UDA) tactic from this paper notices two things: (1) Within the sphere of supervised learning, there has been dataset-specific innovation in generating augmented data that will be particularly helpful to a given dataset. Inlanguage modeling, an example of this is having a sentence go into another language and back again through two well-trained translation networks, and use the resulting sentence as another example of the same class. For ImageNet, there’s an approach called AutoAugment that uses reinforcement learning on a validation set to learn a policy of image operations (rotate, shear, change color) in order to increase validation accuracy. [I am a little confused about this as an approach, since I worry about essentially overfitting to the validation set. That said, I don’t have time to delve specifically into the AutoAugment paper today, so I’ll just leave this here as a caveat] (2) Within semi-supervised learning, there’s a growing tendency to use consistency loss as a way of making use of unlabeled data. The basic idea of consistency loss is that, even if you don’t know the class of a given datapoint, if you modify it in some small way, you can be confident that the model’s prediction should be consistent between the point and its perturbation, even if you don’t have knowledge of what the actual ground truth is. Often, systems like this have been designed using simple Gaussian noise on top of original unlabeled images. The key proposal of this paper is to substitute this more simplified perturbation procedure with the augmentation approaches being iterated on in supervised learning, since the goal of both - to modify inputs so as to capture ways in which they might differ but still be notionally of the same class - is nearly identical On top of this core idea, the UDA paper also proposes an additional clever training tactic: in cases where you have many unlabeled examples and few labeled ones, you may need a large model to capture the information from the unlabeled examples, but this may result in overfitting on the labeled ones. To avoid this, they use an approach called “Training Signal Annealing,” where at each point in training they remove from the loss calculation any examples that the model is particularly confident about: where the prediction of the true class is above some threshold eta. As training goes on, the network is gradually allowed to see more of the training signals. In this kind of a framework, the model can’t overfit as easily, because once it starts getting the right answer on supervised examples, they drop out of the loss calculation. In terms of empirical results, the authors find that in UDA they’re able to improve on many semi-supervised benchmarks with quite small numbers of labeled examples. At one point, they use as a baseline a BERT model that was fine-tuned in an unsupervised way prior to its semi-supervised training, and show that their augmentation method can add value even on top of the value that the unsupervised pre-training adds. ![]() |
[link]
The Magenta group at Google is a consistent source of really interesting problems for machine learning to solve, in the vein of creative generation of art and music, as well as mathematically creative ways to solve those problem. In this paper, they tackle a new problem with some interesting model-structural implications: generating Bach chorales composed of polyphonic multi-instrument arrangements. On one layer, this is similar to music generation problems that have been studied before, in that generating a musically coherent sequence requires learning both local and larger-scale structure between time steps in the music sequence. However, an additional element here is that there’s dependence of multiple instruments’ notes on one another at a given time step, so, in addition to generating time steps conditional on one another, you ideally want to learn how to model certain notes in a given harmony conditional on the other notes already present there. Understanding the specifics of the approach was one of those scenarios where the mathematical arguments were somewhat opaque, but the actual mechanical description of the model gave a lot of clarity. I find this frequently the case with machine learning, where there’s this strange set of dual incentives between the engineering impulse towards designing effective system, and the academic need to connect the approach to a more theoretical mathematical foundation. The approach taken here has a lot in common with the autoregressive model structures used in PixelCNN or WaveNet. These are all based, theoretically speaking, on the autoregressive property of joint probability distributions, that they can sampled from by sampling first from the prior over the first variable (or pixel, or wave value), and then the second conditional on the first, then the third conditional on the first two, and so on. In practice, autoregressive models don’t necessary condition on the *entire* previous rest of the input in generating a conditional distribution for a new point (for example, because they use a convolutional structure that doesn’t have a receptive field big enough to reach back through the entire previous sequence), but they are based on that idea. A unique aspect of this model is that, instead of defining one specific conditional dependence relationship (where pixel J is conditioned on wave values J-5 through J, or some such), they argue that they instead learn conditional relationships over any possible autoregressive ordering of both time steps and instrument IDs. This is a bit of a strange idea, that, like I mentioned, is simplified by going through the mechanics. The model works in a way strikingly similar to recent large scale language modeling: by, for each sample, masking some random subset of the tokens, and asking the model to predict the masked values given the unmasked ones. In this case, an interesting nuance is that the values to be masked are randomly sampled across both time step and instrument, such that in some cases you’ll have a prior time step but no other instruments at your time step, or other instruments at your time step but no prior time steps to work from, and so on. The model needs to flexibly use various kinds of local context to predict the notes that are masked. (As an aside, in addition to the actual values, the network is given the actual 0/1 mask, so it can better distinguish between “0, no information” and “0 because in the actual data sample there wasn’t a pitch here”.) The model refers to these unmasked points as “context points”. An interesting capacity that this gives the model, and which the authors use as their sampling technique, is to create songs that are hybrids of existing chorales by randomly keeping some chunks and dropping out others, and using the model to interpolate through the missing bits. ![]() |
[link]
Attention mechanisms are a common subcomponent within language models, initially as a part of recurrent models, and more recently as their own form of aggregating information over sequences, independent from the recurrence structure. Attention works by taking as input some sequence of inputs, in the most typical case embedded representations of words in a sentence, and learning a distribution of weights over those representations, which allows the network to aggregate the representations, typically by taking a weighted sum. One effect of using an attention mechanism is that, for each instance being predicted, the network produces this weight distribution over inputs, which intuitively feels like it’s the network demonstrating which input words were most important in constructing its decision. As a result, uses of attention have often been accompanied by examples that show attention distributions for examples, implicitly using them as a form of interpretability or model explanation. This paper has the goal of understanding whether attention distributions can be seen as a valid form of feature importance, and takes the position that they shouldn’t be. At a high level, I think the paper makes some valid criticisms, but ultimately I didn’t find the evidence it presented quite as strong as I would have liked. The paper performs two primary analyses of the attention distributions produced by a trained LSTM model: (1) It calculates the level of correlation between the importance that would be implied by attention weights and the importance as calculated using more canonical gradient-based methods (generally things in the shape of “which words contributed the most towards the prediction being what it was). It finds correlation values that range across random seeds, but are generally centered around 0.5. The paper frames this as a negative result, implying that, in the case where attention was a valid form of importance, the correlation with existing metrics would be higher. I definitely follow the intuition that you would expect there be a significant and positive correlation between methods in this class, but it’s unclear to me what a priori reasoning chooses to draw the threshold on “significant” in a way that makes 0.5 fall below it. It just feels like one of those cases where I could imagine someone showing the same plots and coming to a different interpretation, and it’s not clear to me what criteria support one threshold of magnitude vs another (2) It measures how much it can permute the weights of an attention distribution, and have the prediction made by the network not change in a meaningful way. It does this both by random tests, and also by measuring the maximum adversarial perturbation: the farthest-away distribution (in terms of Jenson-Shannon distance) that still produces a prediction within some epsilon of the original prediction. There are a few concerns I have about this as an analysis. First off, it makes a bit of an assumption that attention can only be a valid form of explanation if it’s a causal mechanism within the model. You could imagine that attention distributions still give you information about the internal state of the model, even if they are just reporting that state rather than directly influencing it. Secondly, it seems possible to me that you could get a relatively high Jenson-Shannon distance from an initial distribution just by permuting the indexes of the low-value weights, and shifting distributional weight between them in a way that doesn’t fundamentally change what the network is primarily attending to. Even if this is not the case in this paper, I’d love to see an example or some kind of quantitative measure showing that the J-S Shannon distances they demonstrate require a substantive change in weight priorities, rather than a trivial one. Another general critique is that the experiments in this paper only focused on attention within a LSTM structure, where the embedding associated with each word isn’t really strictly an embedding of that specific word, but also contains a lot of information about things before and after, because of the nature of a BiLSTM. So, there is some specificity in the embedding corresponding to just that word, but a lot less than in a pure attention model, like some being used in NLP these days, where you’re learning an attention distribution over the raw, non-LSTM-ed representations. In this case, it makes sense that attention would be blurry, and not map exactly to our notions of which words are more important, since the word level representations are themselves already aggregations. I think it’s totally fair to only focus on the LSTM case, but would prefer the paper scoped its claims in better accordance with its empirical results. I feel a bit bad: overall, I really approve of papers like this being done to put a critical empirical frame on ML’s tendency to get conceptually ahead of itself. And, I do think that the evidentiary standard for “prove that X metric isn’t a form of interpretability” shouldn’t be that high, becuase on priors, I would expect most things not to be. I think that they may well be right in their assessment, I would just like a more surefooteded set of analyses and interpretation behind it. ![]() |