Logit Masking Guarantees Valid Outputs
LLMs generate tokens autoregressively, producing a logit vector over 32,000-100,000 vocabulary tokens at each step, converted to probabilities via softmax. Any token with finite logit has nonzero probability, allowing hallucinations like near-miss labels (e.g., "Techology" instead of "Technology"). Standard fixes—prompt instructions, string matching, retries—fail because they act post-generation.
Constrained decoding intervenes pre-sampling: set logits of invalid tokens to -∞, yielding exactly zero softmax probability. Remaining valid logits renormalize to sum to 1. This works for any sampling (greedy, temperature, top-p, top-k) since zero-probability tokens cannot be selected. In code: logits[~valid_token_mask] = float('-inf').
Validity depends on prior tokens. A trie (prefix tree) encodes all taxonomy labels as token paths. Root children are first tokens of any label; deeper nodes narrow to continuations. After prefix " Tech" (token ID 8987), only "nology" (ID 1366) is valid. At end nodes, only EOS is valid, terminating the label.
Tokenization nuance: BPE splits depend on context. Tokenize labels as continuations with leading space (" " + label, add_special_tokens=False), e.g., Qwen2.5 tokenizes " Sports" to 22470, not "Sports" to 51660. Verify round-trip: tokenizer.decode(token_ids) == " " + label. Tiktoken (GPT-4 family) bakes whitespace into boundaries without ▁.
Trie and Logits Processor Implementation
Build trie from labels:
class TrieNode:
def __init__(self):
self.children = {} # token_id → TrieNode
self.is_end = False
class ConstrainedTrie:
def __init__(self):
self.root = TrieNode()
def insert(self, token_ids):
node = self.root
for tid in token_ids:
if tid not in node.children:
node.children[tid] = TrieNode()
node = node.children[tid]
node.is_end = True
def get_valid_next_tokens(self, prefix):
node = self.root
for tid in prefix:
if tid not in node.children:
return set()
node = node.children[tid]
return set(node.children.keys())
def is_complete(self, prefix):
node = self.root
for tid in prefix:
if tid not in node.children:
return False
node = node.children[tid]
return node.is_end
Insert: token_ids = tokenizer.encode(" " + label, add_special_tokens=False); trie.insert(token_ids). Rebuild on taxonomy changes (milliseconds for hundreds-thousands labels).
HuggingFace LogitsProcessor:
class TrieLogitsProcessor(LogitsProcessor):
def __init__(self, trie, prompt_length, eos_token_id):
self.trie = trie
self.prompt_length = prompt_length
self.eos = eos_token_id
def __call__(self, input_ids, scores):
generated = input_ids[0, self.prompt_length:].tolist()
valid = self.trie.get_valid_next_tokens(generated)
if self.trie.is_complete(generated):
valid.add(self.eos)
masked = torch.full_like(scores, float('-inf'))
for tid in valid:
masked[0, tid] = scores[0, tid]
return masked
Generate: model.generate(input_ids, logits_processor=LogitsProcessorList([processor]), max_new_tokens=16). Output decodes to exact label.
Multi-Label, Hierarchies, and Broader Applications
For multi-label: After end node, allow EOS or separator (e.g., |,|). Parse generated tokens into seen labels and current prefix. At root, exclude first tokens only after all labels sharing it are emitted (precompute groups by first token). Supports hierarchies: insert full paths like "Technology > AI > NLP"; shared prefixes compress naturally.
Edge cases: Low confidence concentrates mass on valid tokens (fix: fine-tune); long labels create narrow paths (fine-tune improves); rebuild trie on changes.
Proof of correctness: (1) Forward invariant—emitted tokens always extend valid prefixes; (2) Termination invariant—EOS only at end nodes. Verify by enumerating trie paths against labels. Independent of model, temperature, etc.
Limitations: Needs logit access (open models like Qwen2.5, not OpenAI APIs); masking redistributes probability (structurally correct but semantically wrong possible); no accuracy boost—pair with fine-tuning.
Generalizes to JSON (trie encodes schema), SQL (grammar FSM), agents (tool names). Enforces structure without prompt/model changes.