Improving LlamaIndexโs Code Chunker by Cleaning Tree-Sitter CSTs
Kevin Lu - Aug 12th
Demo: https://huggingface.co/spaces/sweepai/chunker (opens in a new tab)
Notebook: https://github.com/sweepai/sweep/blob/main/notebooks/chunking.ipynb (opens in a new tab)
Last time (opens in a new tab), I detailed the design of our syntax tree-based chunking algorithm for code search at Sweep. It was also recently implemented in LlamaIndex as its default chunker for code (opens in a new tab)!
However, for brevity, I focussed on the core ideas behind the algorithm but left out some key implementation details that Iโll address in this post, including getting around denoising concrete syntax trees generated by community-maintained tree parsers and Modal rate limits.
Meet the Span
As a brief helper data structure, we first implemented the following dataclass for representing a slice of a string:
from __future__ import annotations
from dataclasses import dataclass
@dataclass
class Span:
start: int
end: int
def extract(self, s: str) -> str:
# Grab the corresponding substring of string s by bytes
return s[self.start: self.end]
def __add__(self, other: Span | int) -> Span:
# e.g. Span(1, 2) + Span(2, 4) = Span(1, 4) (concatenation)
# There are no safety checks: Span(a, b) + Span(c, d) = Span(a, d)
# and there are no requirements for b = c.
if isinstance(other, int):
return Span(self.start + other, self.end + other)
elif isinstance(other, Span):
return Span(self.start, other.end)
else:
raise NotImplementedError()
def __len__(self) -> int:
# i.e. Span(a, b) = b - a
return self.end - self.start
If youโre unfamiliar with dataclasses theyโre basically like structs in C/C++ and automatically define methods like __init__
, __repr__
and __str__
for you based on the class variables (start
and end
) in this case to minimize boilerplate.
Note that we donโt check for Span additions because sometimes consecutive spans have a small gap: e.g. Node from 1-5 followed by Node from 6-10, which weโll handle next.
Filling in the Gaps
Aside: For the earlier examples, Iโm going to use a max character count of 600 to show smaller breaks and switch it back to 1500 at the end for the final chunking.
The original implementation actually has bugs, yielding something like this:
defdownload_logs(repo_full_name: str, run_id: int, installation_id: int):
====================
headers = {
"Accept": "application/vnd.github+json",
"Authorization": f"Bearer {get_token(installation_id)}",
"X-GitHub-Api-Version": "2022-11-28"
}response = requests.get(f"https://api.github.com/repos/{repo_full_name}/actions/runs/{run_id}/logs",
headers=headers)logs_str = ""
As you can see thereโs missing whitespace between keywords and delimiters. Due to the noisiness of tree-sitter CST-parsers, as theyโre mostly community-maintained, we have cases where the end byte of one node does not correspond with the start byte of the next starting node, like so:
expression_statement:431-695
assignment:431-695
string:445-695
function_definition:697-1501
block:776-1501
expression_statement:776-952
assignment:776-952
dictionary:786-952
...
Notice that the โexpression_statementโ ends on byte 695 and โfunction_definitionโ starts on byte 697, skipping the byte corresponding to a new line. A partial solution is to instead of using node.start_byte
and node.end_byte
, we replace the end_byte
with the start_byte
of the next node, like so:
from tree_sitter import Node
from dataclasses import field
@dataclass
class MockNode:
start_byte: int = 0
end_byte: int = 0
children: list[MockNode] = field(default_factory=list)
def chunk_node(
node: Node,
text: str,
MAX_CHARS: int = 600
) -> list[str]:
chunks = []
current_chunk: list[str] = ""
node_children = node.children + [MockNode(node.end_byte, node.end_byte)]
for child, next_child in zip(node_children[:-1], node_children[1:]):
if child.end_byte - child.start_byte > MAX_CHARS:
chunks.append(current_chunk)
current_chunk = ""
chunks.extend(chunk_node(child, text, MAX_CHARS))
elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS:
chunks.append(current_chunk)
current_chunk = text[child.start_byte: next_child.start_byte]
else:
current_chunk += text[child.start_byte: next_child.start_byte]
chunks.append(current_chunk)
return chunks
We can clean up the implementation by using Spans, taking advantage of how concatenation works, as follows.
def chunk_node(
node: Node,
MAX_CHARS: int = 600,
) -> list[Span]:
new_chunks: list[Span] = []
current_chunk: Span = Span(node.start_byte, node.start_byte)
for child in node.children:
if child.end_byte - child.start_byte > MAX_CHARS:
new_chunks.append(current_chunk)
current_chunk = Span(child.end_byte, child.end_byte)
new_chunks.extend(chunk_node(child, MAX_CHARS))
elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS:
new_chunks.append(current_chunk)
current_chunk = Span(child.start_byte, child.end_byte)
else:
current_chunk += Span(child.start_byte, child.end_byte)
new_chunks.append(current_chunk)
return new_chunks
Which yields significantly better results but still has an issue, consecutive spans still has gaps.
Span(start=1, end=430)
Span(start=432, end=696)
Span(start=698, end=772)
Span(start=777, end=1122)
Span(start=1502, end=1502)
...
We can also eliminate the case where consecutive chunks are disconnected by adding a simple post-processing script.
def connect_chunks(chunks: list[Span]):
for prev, curr in zip(chunks[:-1], chunks[1:]):
prev.end = curr.start
return chunks
Coalescing
Occasionally, we have very small chunks that are right before or after another chunk, likely because of how some tree-sitter considers the header as its own node:
def download_logs(repo_full_name: str, run_id: int, installation_id: int):
====================
headers = {
"Accept": "application/vnd.github+json",
"Authorization": f"Bearer {get_token(installation_id)}",
"X-GitHub-Api-Version": "2022-11-28"
}
response = requests.get(f"https://api.github.com/repos/{repo_full_name}/actions/runs/{run_id}/logs",
headers=headers)
logs_str = ""
Regardless, we can fix this by fusing one-line chunks with the next chunk.
def coalesce_chunks(chunks: list[Span], source_code: str, coalesce: int = 50) -> list[Span]:
new_chunks = []
current_chunk = Span(0, 0)
for chunk in chunks:
current_chunk += chunk
if len(current_chunk) > coalesce and "\n" in current_chunk.extract(source_code):
new_chunks.append(current_chunk)
current_chunk = Span(chunk.end, chunk.end)
if len(current_chunk) > 0:
new_chunks.append(current_chunk)
return new_chunks
There is a small edge case to this though, for curly-bracket-based languages like Typescript and C++ we get something like the following
int main() {
cout << "Hello world";
return 0;
====================
}
The closing parentheses would end up being fused with the next block. Fortunately, this issue does not cause important semantic meaning to be lost and only creates extra noise that the language model filters out so this issue remains unaddressed. On the other hand, function headers being separated by the implementation is significantly worse.
Skipping Whitespace when Measuring Size
Another problem we realized was that at higher levels of indents, because each indent counts as a space, we sometimes get really small chunks which are not helpful to our language model. We have a simple solution where to count the number of lines, we skip the starting whitespace, like so:
def char_len(s: str) -> int: # old len function
return len(s)
def non_whitespace_len(s: str) -> int: # new len function
return len(re.sub("\s", "", s))
Fixing to Lines
Instead of binding chunks to character indices (in bytes, in fact), we can bind them to line numbers, since eventually we just store the line numbers in the Vector DB anyways. This eliminates errors with non-UTF-8 or ASCII encodings and ensures that we eliminate any chunks less than a line long.
This also means that we have to run another loop deleting any empty chunks. Overall we use a helper function like:
def get_line_number(index: int, source_code: str) -> int:
total_chars = 0
for line_number, line in enumerate(source_code.splitlines(keepends=True), start=1):
total_chars += len(line)
if total_chars > index:
return line_number - 1
return line_number
We can also add a line to our Span class to finally extract the right lines of code using this. Fortunately, Spanโs are unit agnostic so we can just determine when we extract it what units (bytes or lines) we want.
@dataclass
class Span:
start: int
end: int
def extract(self, s: str) -> str:
return s[self.start: self.end]
def extract_lines(self, s: str) -> str:
return "\n".join(s.splitlines()[self.start:self.end])
def __add__(self, other: Span | int) -> Span:
if isinstance(other, int):
return Span(self.start + other, self.end + other)
elif isinstance(other, Span):
return Span(self.start, other.end)
else:
raise NotImplementedError()
def __len__(self) -> int:
return self.end - self.start
Final New Algorithm
The final pipeline from the syntax tree to the chunks looks something as follows:
- Generate the chunks (the spans currently refer to line numbers)
- Coalescing small chunks with future larger chunks
- Translate byte indices to line numbers
- Delete empty chunks
And the final algorithm looks like the following:
from tree_sitter import Tree
def chunker(
tree: Tree,
source_code: bytes,
MAX_CHARS=512 * 3,
coalesce=50 # Any chunk less than 50 characters long gets coalesced with the next chunk
) -> list[Span]:
# 1. Recursively form chunks based on the last post (https://docs.sweep.dev/blogs/chunking-2m-files)
def chunk_node(node: Node) -> list[Span]:
chunks: list[Span] = []
current_chunk: Span = Span(node.start_byte, node.start_byte)
node_children = node.children
for child in node_children:
if child.end_byte - child.start_byte > MAX_CHARS:
chunks.append(current_chunk)
current_chunk = Span(child.end_byte, child.end_byte)
chunks.extend(chunk_node(child))
elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS:
chunks.append(current_chunk)
current_chunk = Span(child.start_byte, child.end_byte)
else:
current_chunk += Span(child.start_byte, child.end_byte)
chunks.append(current_chunk)
return chunks
chunks = chunk_node(tree.root_node)
# 2. Filling in the gaps
for prev, curr in zip(chunks[:-1], chunks[1:]):
prev.end = curr.start
curr.start = tree.root_node.end_byte
# 3. Combining small chunks with bigger ones
new_chunks = []
current_chunk = Span(0, 0)
for chunk in chunks:
current_chunk += chunk
if non_whitespace_len(current_chunk.extract(source_code)) > coalesce \
and "\n" in current_chunk.extract(source_code):
new_chunks.append(current_chunk)
current_chunk = Span(chunk.end, chunk.end)
if len(current_chunk) > 0:
new_chunks.append(current_chunk)
# 4. Changing line numbers
line_chunks = [Span(get_line_number(chunk.start, source_code),
get_line_number(chunk.end, source_code)) for chunk in new_chunks]
# 5. Eliminating empty chunks
line_chunks = [chunk for chunk in line_chunks if len(chunk) > 0]
return line_chunks
Yielding the final chunked file you see at https://gist.github.com/kevinlu1248/49a72a1978868775109c5627677dc512#file-on_check_suite-py-md (opens in a new tab) (uses max character count of 1500 instead of 600).
Bonus: Modal Rate Limits
On an unrelated note, we had an issue hitting the rate limits on Modal since we were spinning up tens of thousands of containers really suddenly to chunk a new user repository. whenever users update their codebase. We fixed this by adding the following:
- We batched them into groups of 30, so for a 30k file repo, we would spin up only a thousand containers each sequentially chunking 30 files each instead of 30k containers chunking 1 file each.
- We passively re-indexed only for very active users and paying users.