Grouped Loss may disfavor discontinuous capabilities
Jul 9, 2022 08:21 · 1144 words · 6 minute read
(Cross-posted from the AI Alignment Forum)
Thanks to Evan Hubinger and Beth Barnes for comments on these ideas.
Language models exhibit clear scaling laws, where the loss is a power-law in model size. This offers a lot of predictive power, and seems like a useful thing to know. By contrast, individual capabilities often exhibit sharp discontinuities in performance as a function of model size and training time.
It would be great if individual capabilities just gradually improved like the broader loss. Then we wouldn’t have to worry quite so much about surprising new capabilities emerging suddenly during training.
Is there a way to change the loss function so that it incentivizes more gradually capability improvements?
Grouped Loss
Imagine grouping training examples by the kind of capability they exhibit. For instance arithmetic problems go in one group, “parse json” could go in another, and so on. With these groups, we could define a new loss function
$$ L = \sum_{g\in \mathrm{groups}} \ell_{g}^2 $$
where $\ell$ is the loss function we originally used (e.g. cross-entropy loss) and $\ell_g$ means to compute $\ell$ just on examples from group $g$, e.g.
$$ \ell_g = \frac{1}{\mathrm{len(}g\mathrm{)}}\sum_{\mathrm{example}\in g} \ell(\mathrm{example}) $$
which may be estimated by using random examples drawn from $g$.
Because we have squared the group losses, the overall loss is dominated by the worst group. As a result, the model is incentivized to develop capabilities in each group at comparable rates, and so has little incentive to e.g. finely hone its poetry skills while being unable to multiply numbers.
Challenge: Defining Groups
It’s possible that using grouped loss results in smooth development of capabilities that aren’t represented in the groups. For instance, it seems plausible that if “adding arabic numerals” and “translating words into arabic numerals” are two groups but “adding numbers written as words” is not, performance on the latter could nonetheless develop smoothly as the model gets better at the others. It would certainly be weird if performance ”adding numbers written as words” advanced as a sudden leap in this case.
This points to a general problem though, which is that if we have to define the groups manually we have to foresee the capabilities we’re worried about. That seems bad.
Gradient Cluster Grouping
If we could automatically group examples we wouldn’t need to do it manually. How could we do this?
I think the key feature of a group is that when the model updates, the loss of most examples in a group changes in a similar way. When that happens, it seems intuitive to say that there’s a discrete capability somewhere in the model and that those examples all depend on it.
This suggests looking for examples where the loss has similar gradients, because these probably make use of similar machinery in the model.
Concretely, I’m imagining the following procedure:
- Draw $N$ examples from the training set.
- Evaluate the gradient of $\ell$ for each example.
- Group the examples by clustering their gradients, evaluate the grouped loss, and perform gradient descent on that.
As a technical note: In this approach, the grouped loss is a moving target. As the model learns and capabilities form the groups shift. This means that SGD is no longer minimizing a constant loss. I don’t think that’s a problem, in part because all versions of the loss agree when the model has reached zero-loss, so the different iterations of the loss function all point towards better capabilities.
Challenge: How many groups?
I don’t know of a principled way to pick the number of groups to cluster examples into, and that seems like a problem. Guessing too many groups loses the advantage of grouping because each group reflects an extremely narrow task. Guessing too few groups also loses the advantage of grouping, because then the capabilities that show gradual improvements will be very broad ones, and narrow capabilities will still show discontinuous improvements.
SVD-Grouped Loss
(Note that I don’t think this specific loss is necessarily the best idea, but I think it illustrates the kind of approach that might solve the challenge of identifying appropriate groups.)
An improvement over clustering by gradients is to use the singular value decomposition (SVD), which provides a more continuous way to talk about the similarity between gradients.
The idea here is that the SVD of the gradients of different examples will identify the most important directions in loss-space, which I (weakly) suspect correspond to directions that improve distinct capabilities.
Construction
We begin as before by drawing $N$ examples from the training set and evaluating the gradient of $\ell$ for each example. Each gradient has length equal to the number of parameters $M$ in the model. Combining the gradients, we form the $N\times M$ matrix $G$.
We next compute the SVD of $G$. This produces singular values $\sigma_i$ and pairs of singular vectors $(v_i,w_i)$, where $v_i$ has length $N$ and $w_i$ has length $M$. Importantly, $w_i$ lives in the same space as the loss gradients, and the set of ${w_i}$ spans the space of the gradients. As such, we can write each gradient in terms of $w_i$ as:
$$ \nabla \ell = \sum_i \alpha_i w_i $$
We can then define the loss
$$ L=\sum_i\left( \sigma_i\sum_{\rm example} \ell(\mathrm{example}) \hat{\alpha}_i(\mathrm{example})\right)^2 $$
where $\hat{\alpha}_i = \alpha_i / \sqrt{\sum_i \alpha_i^2}$ is just a component of the normalized gradient. For purposes of evaluating gradients of the loss we treat the $\sigma_i$’s and $\hat{\alpha}_i$’s as constant. This should not be a problem because, regardless of the values of $\sigma_i,\hat{\alpha}_i$, values all versions of the loss agree when the model has reached zero-loss. So as in the Grouped Loss case, different iterations of the loss function all point towards better capabilities.
Interpretation
In this loss function groups correspond to singular vectors, and are weighted by their singular values. Examples are attributed continuously to groups (e.g. each example belongs to multiple groups to varying degrees) in accordance with how much their gradients correspond to the groups’ singular vectors.
My intuition here is that singular vectors with large singular values correspond intuitively to individual capabilities, because they are directions in gradient-space that improve many examples (the more examples improve the higher the singular value).
Summary
I would like to see capabilities arise more gradually during training, rather than sudden grokking. That could make it easier to notice dangerous capabilities developing.
I think grouped loss functions are one way to do this, and they work (if they work) by making SGD care most about the model’s weakest capability at all times.
Assuming grouped losses are feasible to implement and indeed behave this way, they would also provide a weak guarantee that the model’s performance on one task is representative of its performance on other tasks (so long as both tasks appeared during training). This seems like a really useful (if unintended) property, because it means that we can understand a model’s capabilities with much sparser testing.