HALO: Hadamard-Assisted Lower-Precision Optimization for LLMs

[TL;DR] The paper proposes a quantized fine-tuning framework that exploits low bit-width matrix multiplication hardware units by performing quantization online for the inputs, weights, outputs, and their gradients with Hadamard transforms.

Highlights

  • Provide and experiment with different settings: HALO-0, HALO-1, HALO-2. Speedup: HALO-0 > HALO-1 > HALO-2, Accuracy: HALO-2 > HALO-1 > HALO-0
  • Support and experiment with different data types: FP8, INT8, and FP6 for forward and backward passes, exploiting acceleration from low bit-width matrix multiplication hardware units
  • Support Fully Sharded Data Parallel (FSDP) scheme to enable further savings by performing low-precision communication
  • Support full fine-tuning (FFT) and parameter-efficient fine-tuning (PEFT) methods such as LoRA
  • Demonstrate accuracy improvements on downstream tasks for LLAMA-family models in quantized fine-tuning
  • Demonstrate practical speedups in both FFT and FSDP cases
  • Will open-source the kernel implementation at https://github.com/IST-DASLab/HALO

Summary

  • Observation 1: The quantization errors of the forward activations deviate the gradient direction (compare Figures (b) and (c)).


  • Observation 2: The outliers of the gradient with respect to the layer output EY can only be eliminated by a left-hand Hadamard transformation (see the figure on the right).


  • The problem statement: the large outliers in forward activations are difficult to represent in low bit-width data types.
  • The solution: The paper addresses this issue by transforming the forward activations online into a smoother space using Hadamard matrices.
  • The quantized fine-tuning framework: The proposed method quantizes the transformed inputs, weights, outputs, and their gradients with low bit-width data types (e.g., FP8, INT8, and FP6) in an online fashion to leverage acceleration from low bit-width matrix multiplication hardware units for forward and backward passes.


  • Different strategies: The paper studies the placement of Hadamard rotations in both forward and backward passes. The left-hand Hadamard transformation (in the red rectangle) is applied to EY.


Experiments

Compared with Baselines

  • Left: Accuracy, Right: Relative speedup

HALO Levels

  • Accuracy
  • Relative speedup


Notations

The forward and backward passes

  • Let matrices XRb×m, WRn×m, and YRb×n be the inputs, weights, and outputs with batch size b.
  • EX, G, and EY are the gradients w.r.t. the inputs, weights, and outputs, respectively. The authors also refer EX and EY as errors.
  • The forward and backward calculations are defined as (1)Y=XW(Forward)(2)G=EYX(Gradient)(3)EX=EYW(Error)

Quantization

  • Given half-precision (FP16) matrices ARm×k and BRk×n
  • AQ, BQ are the quantized copies of A and B.
  • Low precision matrix multiplication Y=AQBQ

Hadamard Transforms

  • No Hadamard Case: Y=AQBQ
  • Left Case: YH=Hm(HmTA)QBQ
  • Right Case: YH=AQ(BHn)QHnT
  • Middle Case: YH=(AHk)Q(HkTB)Q