Cut Your Losses in Large-Vocabulary Language Models
Сегодня разберём статью, в которой описывается эффективный метод фьюза LM-головы и кросс-энтропии.
Авторы формулируют проблему чрезмерного потребления памяти на слое кросс-энтропии при обучении LLM с крупными словарями: материализация логитов размера |V|×N доминирует и может занимать до ~90% памяти, что ограничивает батч и масштаб обучения.
Инженеры предлагают метод Cut Cross-Entropy (CCE), который предполагает вычисление лосса без сохранения всех логитов в глобальной памяти. Нужно брать только логит правильного токена и выполнять log-sum-exp «на лету» в SRAM; на примере Gemma-2 на 2 миллиарда параметров память на вычисление лосса сокращается примерно с 24 ГБ до 1 МБ, а общий след classifier-head при обучении — с 28 ГБ до 1 ГБ, без потерь по скорости или сходимости.
Лосс для всех токенов в последовательности считается по формуле ℓ = (CᵀE)_x − log∑_j exp(CⱼᵀE). Первая часть реализована как матричное умножение в едином CUDA/Triton-ядре с загрузкой нужного столбца классификатора и эмбеддинга в SRAM и немедленным скалярным произведением.
Вторая — как блочно-параллельный linear-log-sum-exp, комбинирующий матричное умножение и редукцию с потокобезопасным log-add-exp, также без промежуточных логитов в DRAM. В обратном проходе CᵀE перевычисляется в общей памяти. Градиенты считаются с учётом разреженности softmax: элементы ниже порога ε=2⁻¹² (bf16) отбрасываются, а словарь переупорядочивается по среднему логиту для уплотнения полезных блоков. Это даёт до ускорение примерно в 3,5 раза на бэкворде при том, что фактически ненулевых значений <0,02%.
CCE чуть быстрее torch.compile на форварде и сопоставим по суммарному времени, обеспечивая на порядок меньший след памяти. Дополнительно показывают, что CCE увеличивает достижимый размер батча на 16 GPU в 1,5–10 раз в зависимости от модели, а кривые обучения при файнтюнинге совпадают с torch.compile. Для претрейнинга точность выравнивается вариантом CCE-Kahan-FullC, ценой временных буферов и большего времени на бэкворде.
Душный NLP
Сегодня разберём статью, в которой описывается эффективный метод фьюза LM-головы и кросс-энтропии.
Авторы формулируют проблему чрезмерного потребления памяти на слое кросс-энтропии при обучении LLM с крупными словарями: материализация логитов размера |V|×N доминирует и может занимать до ~90% памяти, что ограничивает батч и масштаб обучения.
Инженеры предлагают метод Cut Cross-Entropy (CCE), который предполагает вычисление лосса без сохранения всех логитов в глобальной памяти. Нужно брать только логит правильного токена и выполнять log-sum-exp «на лету» в SRAM; на примере Gemma-2 на 2 миллиарда параметров память на вычисление лосса сокращается примерно с 24 ГБ до 1 МБ, а общий след classifier-head при обучении — с 28 ГБ до 1 ГБ, без потерь по скорости или сходимости.
Лосс для всех токенов в последовательности считается по формуле ℓ = (CᵀE)_x − log∑_j exp(CⱼᵀE). Первая часть реализована как матричное умножение в едином CUDA/Triton-ядре с загрузкой нужного столбца классификатора и эмбеддинга в SRAM и немедленным скалярным произведением.
Вторая — как блочно-параллельный linear-log-sum-exp, комбинирующий матричное умножение и редукцию с потокобезопасным log-add-exp, также без промежуточных логитов в DRAM. В обратном проходе CᵀE перевычисляется в общей памяти. Градиенты считаются с учётом разреженности softmax: элементы ниже порога ε=2⁻¹² (bf16) отбрасываются, а словарь переупорядочивается по среднему логиту для уплотнения полезных блоков. Это даёт до ускорение примерно в 3,5 раза на бэкворде при том, что фактически ненулевых значений <0,02%.
CCE чуть быстрее torch.compile на форварде и сопоставим по суммарному времени, обеспечивая на порядок меньший след памяти. Дополнительно показывают, что CCE увеличивает достижимый размер батча на 16 GPU в 1,5–10 раз в зависимости от модели, а кривые обучения при файнтюнинге совпадают с torch.compile. Для претрейнинга точность выравнивается вариантом CCE-Kahan-FullC, ценой временных буферов и большего времени на бэкворде.
Душный NLP
👍21❤7🔥2