HALO: Hadamard-Assisted Lower-Precision Optimization for LLMs
HALO: Hadamard-Assisted Lower-Precision Optimization for LLMs
[TL;DR] The paper proposes a quantized fine-tuning framework that exploits low bit-width matrix multiplication hardware units by performing quantization online for the inputs, weights, outputs, and their gradients with Hadamard transforms.
Highlights
- Provide and experiment with different settings: HALO-0, HALO-1, HALO-2. Speedup: HALO-0 > HALO-1 > HALO-2, Accuracy: HALO-2 > HALO-1 > HALO-0
- Support and experiment with different data types: FP8, INT8, and FP6 for forward and backward passes, exploiting acceleration from low bit-width matrix multiplication hardware units
- Support Fully Sharded Data Parallel (FSDP) scheme to enable further savings by performing low-precision communication
- Support full fine-tuning (FFT) and parameter-efficient fine-tuning (PEFT) methods such as LoRA
- Demonstrate accuracy improvements on downstream tasks for LLAMA-family models in quantized fine-tuning
- Demonstrate practical speedups in both FFT and FSDP cases
- Will open-source the kernel implementation at https://github.com/IST-DASLab/HALO
Summary
- Observation 1: The quantization errors of the forward activations deviate the gradient direction (compare Figures (b) and (c)).
data:image/s3,"s3://crabby-images/96016/960169a9e6d804d882b3a7bb14738306254a63bb" alt="example image"
- Observation 2: The outliers of the gradient with respect to the layer output \(\mathbf{E}_{\mathbf{Y}}\) can only be eliminated by a left-hand Hadamard transformation (see the figure on the right).
data:image/s3,"s3://crabby-images/ba9a2/ba9a212b4e60597b1bf3651b9db65bbafe617f16" alt="example image"
- The problem statement: the large outliers in forward activations are difficult to represent in low bit-width data types.
- The solution: The paper addresses this issue by transforming the forward activations online into a smoother space using Hadamard matrices.
- The quantized fine-tuning framework: The proposed method quantizes the transformed inputs, weights, outputs, and their gradients with low bit-width data types (e.g., FP8, INT8, and FP6) in an online fashion to leverage acceleration from low bit-width matrix multiplication hardware units for forward and backward passes.
data:image/s3,"s3://crabby-images/25b94/25b94b1b146bf16ae5ca9908697696da922ed19b" alt="example image"
- Different strategies: The paper studies the placement of Hadamard rotations in both forward and backward passes. The left-hand Hadamard transformation (in the red rectangle) is applied to \(\mathbf{E}_{\mathbf{Y}}\).
data:image/s3,"s3://crabby-images/08307/0830796792ca542baff9d4348b15d22e261ced78" alt="example image"
Experiments
Compared with Baselines
- Left: Accuracy, Right: Relative speedup
data:image/s3,"s3://crabby-images/c8572/c8572c680e4f2dc8fea01fbf0da8b9f9e105ceaa" alt=""
data:image/s3,"s3://crabby-images/4868d/4868de8de296995edc9cdeb4ead7b6632612fbea" alt=""
HALO Levels
- Accuracy
data:image/s3,"s3://crabby-images/6a32d/6a32d4648b22104ff90ef18a813b4e71d7e59640" alt=""
- Relative speedup
data:image/s3,"s3://crabby-images/3280a/3280aaa3ffbea665481d72172c12e3f275512774" alt=""
Notations
The forward and backward passes
- Let matrices \(\mathbf{X} \in \mathbb{R}^{b \times m}\), \(\mathbf{W} \in \mathbb{R}^{n \times m}\), and \(\mathbf{Y} \in \mathbb{R}^{b \times n}\) be the inputs, weights, and outputs with batch size \(b\).
- \(\mathbf{E}_{\mathbf{X}}\), \(\mathbf{G}\), and \(\mathbf{E}_{\mathbf{Y}}\) are the gradients w.r.t. the inputs, weights, and outputs, respectively. The authors also refer \(\mathbf{E}_{\mathbf{X}}\) and \(\mathbf{E}_{\mathbf{Y}}\) as errors.
- The forward and backward calculations are defined as \(\begin{align} \mathbf{Y} &= \mathbf{X} \cdot \mathbf{W}^{\top} \quad &\textbf{(Forward)} \\ \mathbf{G} &= \mathbf{E}_{\mathbf{Y}}^{\top} \cdot \mathbf{X} \quad &\textbf{(Gradient)} \\ \mathbf{E}_{\mathbf{X}} &= \mathbf{E}_{\mathbf{Y}} \cdot \mathbf{W} \quad &\textbf{(Error)} \end{align}\)
Quantization
- Given half-precision (FP16) matrices \(\mathbf{A} \in \mathbb{R}^{m \times k}\) and \(\mathbf{B} \in \mathbb{R}^{k \times n}\)
- \(\mathbf{A}_Q\), \(\mathbf{B}_Q\) are the quantized copies of \(\mathbf{A}\) and \(\mathbf{B}\).
- Low precision matrix multiplication \(\mathbf{Y} = \mathbf{A_Q B_Q}\)
Hadamard Transforms
- No Hadamard Case: \(\mathbf{Y} = \mathbf{A_Q B_Q}\)
- Left Case: \(\require{mathtools}\prescript{\mathbf{H}}{}{\mathbf{Y}} = \mathbf{H_m (H_m^T A)_Q B_Q }\)
- Right Case: \(\mathbf{Y^H} = \mathbf{ A_Q (B H_n)_Q H_n^T}\)
- Middle Case: \(\stackrel{\mathbf{H}}{\mathbf{Y}} = \mathbf{(AH_k)_Q(H_k^TB)_Q}\)