HALO: Hadamard-Assisted Lower-Precision Optimization for LLMs
HALO: Hadamard-Assisted Lower-Precision Optimization for LLMs
[TL;DR] The paper proposes a quantized fine-tuning framework that exploits low bit-width matrix multiplication hardware units by performing quantization online for the inputs, weights, outputs, and their gradients with Hadamard transforms.
Highlights
- Provide and experiment with different settings: HALO-0, HALO-1, HALO-2. Speedup: HALO-0 > HALO-1 > HALO-2, Accuracy: HALO-2 > HALO-1 > HALO-0
- Support and experiment with different data types: FP8, INT8, and FP6 for forward and backward passes, exploiting acceleration from low bit-width matrix multiplication hardware units
- Support Fully Sharded Data Parallel (FSDP) scheme to enable further savings by performing low-precision communication
- Support full fine-tuning (FFT) and parameter-efficient fine-tuning (PEFT) methods such as LoRA
- Demonstrate accuracy improvements on downstream tasks for LLAMA-family models in quantized fine-tuning
- Demonstrate practical speedups in both FFT and FSDP cases
- Will open-source the kernel implementation at https://github.com/IST-DASLab/HALO
Summary
- Observation 1: The quantization errors of the forward activations deviate the gradient direction (compare Figures (b) and (c)).

- Observation 2: The outliers of the gradient with respect to the layer output \(\mathbf{E}_{\mathbf{Y}}\) can only be eliminated by a left-hand Hadamard transformation (see the figure on the right).

- The problem statement: the large outliers in forward activations are difficult to represent in low bit-width data types.
- The solution: The paper addresses this issue by transforming the forward activations online into a smoother space using Hadamard matrices.
- The quantized fine-tuning framework: The proposed method quantizes the transformed inputs, weights, outputs, and their gradients with low bit-width data types (e.g., FP8, INT8, and FP6) in an online fashion to leverage acceleration from low bit-width matrix multiplication hardware units for forward and backward passes.

- Different strategies: The paper studies the placement of Hadamard rotations in both forward and backward passes. The left-hand Hadamard transformation (in the red rectangle) is applied to \(\mathbf{E}_{\mathbf{Y}}\).

Experiments
Compared with Baselines
- Left: Accuracy, Right: Relative speedup


HALO Levels
- Accuracy

- Relative speedup

Notations
The forward and backward passes
- Let matrices \(\mathbf{X} \in \mathbb{R}^{b \times m}\), \(\mathbf{W} \in \mathbb{R}^{n \times m}\), and \(\mathbf{Y} \in \mathbb{R}^{b \times n}\) be the inputs, weights, and outputs with batch size \(b\).
- \(\mathbf{E}_{\mathbf{X}}\), \(\mathbf{G}\), and \(\mathbf{E}_{\mathbf{Y}}\) are the gradients w.r.t. the inputs, weights, and outputs, respectively. The authors also refer \(\mathbf{E}_{\mathbf{X}}\) and \(\mathbf{E}_{\mathbf{Y}}\) as errors.
- The forward and backward calculations are defined as \(\begin{align} \mathbf{Y} &= \mathbf{X} \cdot \mathbf{W}^{\top} \quad &\textbf{(Forward)} \\ \mathbf{G} &= \mathbf{E}_{\mathbf{Y}}^{\top} \cdot \mathbf{X} \quad &\textbf{(Gradient)} \\ \mathbf{E}_{\mathbf{X}} &= \mathbf{E}_{\mathbf{Y}} \cdot \mathbf{W} \quad &\textbf{(Error)} \end{align}\)
Quantization
- Given half-precision (FP16) matrices \(\mathbf{A} \in \mathbb{R}^{m \times k}\) and \(\mathbf{B} \in \mathbb{R}^{k \times n}\)
- \(\mathbf{A}_Q\), \(\mathbf{B}_Q\) are the quantized copies of \(\mathbf{A}\) and \(\mathbf{B}\).
- Low precision matrix multiplication \(\mathbf{Y} = \mathbf{A_Q B_Q}\)
Hadamard Transforms
- No Hadamard Case: \(\mathbf{Y} = \mathbf{A_Q B_Q}\)
- Left Case: \(\require{mathtools}\prescript{\mathbf{H}}{}{\mathbf{Y}} = \mathbf{H_m (H_m^T A)_Q B_Q }\)
- Right Case: \(\mathbf{Y^H} = \mathbf{ A_Q (B H_n)_Q H_n^T}\)
- Middle Case: \(\stackrel{\mathbf{H}}{\mathbf{Y}} = \mathbf{(AH_k)_Q(H_k^TB)_Q}\)