Token Weighting
Overview
Token Weighting is a technique for adjusting the contribution of each sample to the loss function across tasks with varying output lengths in the training data. During Molmo2 training, data with diverse output lengths coexist, ranging from single-token multiple-choice questions to video captions exceeding 4,000 tokens.
Problem Setting
Loss Domination by Long Outputs
When training data has large disparities in output length, the following problems arise:
- Token count imbalance: Tasks with long outputs (e.g., video captioning) dominate the majority of loss tokens during training, even when their sampling frequency is low
- Degradation of short-output tasks: Performance on multiple-choice questions and tasks requiring short answers deteriorates
- Task balance collapse: The model becomes overly optimized for tasks that generate long outputs, sacrificing performance on other tasks
When training on a single video captioning sample (4,000 tokens) and 100 multiple-choice samples (1 token each) in the same batch, the video caption accounts for approximately 97.5% of the total loss (4,000 / (4,000 + 100) = 0.975).
Molmo2’s Weighting Strategy
In Molmo2, the weight applied to each sample’s loss is adjusted according to the task type and output length.
Weighting Scheme
| Task Type | Weight | Reason |
|---|---|---|
| Video captioning | 0.1 | Produces very long, dense outputs (4,000+ tokens) |
| Pointing | 0.2 | Produces long, dense outputs due to coordinate enumeration |
| Other tasks | \(\frac{4}{\sqrt{n}}\) | Adaptive weighting based on output length \(n\) |
Formal Definition
The loss weight \(w_i\) for sample \(i\) is defined as follows:
\[ w_i = \begin{cases} 0.1 & \text{if task is video captioning} \\ 0.2 & \text{if task is pointing} \\ \frac{4}{\sqrt{n_i}} & \text{otherwise} \end{cases} \]
where \(n_i\) is the number of answer tokens for sample \(i\).
Effect of Adaptive Weighting
The \(\frac{4}{\sqrt{n}}\) weighting for other tasks has the following properties:
- Short outputs: \(w = 4.0\) when \(n = 1\), \(w = 1.0\) when \(n = 16\)
- Medium outputs: \(w = 0.4\) when \(n = 100\)
- Long outputs: \(w = 0.2\) when \(n = 400\)
This square-root decay suppresses the influence of samples with long outputs while avoiding ignoring them entirely.
The square-root weighting \(\frac{4}{\sqrt{n}}\) has the property that the weight halves whenever the output length quadruples. This enables balanced learning between long and short outputs.
For example:
- 1-token output: weight 4.0
- 4-token output: weight 2.0
- 16-token output: weight 1.0
- 64-token output: weight 0.5
Effects
The introduction of Token Weighting yields the following effects:
- Performance maintenance across diverse tasks: Good performance is achieved on both tasks requiring short answers (e.g., multiple-choice questions) and tasks generating long outputs (e.g., video captioning)
- Improved training stability: Prevents the loss from being dominated by specific tasks, enabling more stable training
- Efficient data utilization: Diverse tasks with different output lengths can be effectively learned by a single model
Implementation Considerations
Token Weighting is applied during loss computation:
\[ \mathcal{L} = \frac{1}{B} \sum_{i=1}^{B} w_i \cdot \mathcal{L}_i \]
where:
- \(B\) is the batch size
- \(\mathcal{L}_i\) is the unadjusted loss of sample \(i\)
- \(w_i\) is the weight of sample \(i\)
Since this weighting allows the loss of each sample to be computed independently, the implementation is straightforward and can be easily integrated into existing training pipelines.