๐Ÿ“š Blogs
๐Ÿ“ˆ Improving LlamaIndexโ€™s Code Chunker by Cleaning Tree-Sitter CSTs

Improving LlamaIndexโ€™s Code Chunker by Cleaning Tree-Sitter CSTs

Kevin Lu - Aug 12th


Demo: https://huggingface.co/spaces/sweepai/chunker (opens in a new tab)

Notebook: https://github.com/sweepai/sweep/blob/main/notebooks/chunking.ipynb (opens in a new tab)

Last time (opens in a new tab), I detailed the design of our syntax tree-based chunking algorithm for code search at Sweep. It was also recently implemented in LlamaIndex as its default chunker for code (opens in a new tab)!

However, for brevity, I focussed on the core ideas behind the algorithm but left out some key implementation details that Iโ€™ll address in this post, including getting around denoising concrete syntax trees generated by community-maintained tree parsers and Modal rate limits.

Meet the Span

As a brief helper data structure, we first implemented the following dataclass for representing a slice of a string:

from __future__ import annotations
from dataclasses import dataclass
 
@dataclass
class Span:
    start: int
    end: int
 
    def extract(self, s: str) -> str:
        # Grab the corresponding substring of string s by bytes
        return s[self.start: self.end]
 
    def __add__(self, other: Span | int) -> Span:
        # e.g. Span(1, 2) + Span(2, 4) = Span(1, 4) (concatenation)
        # There are no safety checks: Span(a, b) + Span(c, d) = Span(a, d)
        # and there are no requirements for b = c.
        if isinstance(other, int):
            return Span(self.start + other, self.end + other)
        elif isinstance(other, Span):
            return Span(self.start, other.end)
        else:
            raise NotImplementedError()
 
    def __len__(self) -> int:
        # i.e. Span(a, b) = b - a
        return self.end - self.start

If youโ€™re unfamiliar with dataclasses theyโ€™re basically like structs in C/C++ and automatically define methods like __init__, __repr__ and __str__ for you based on the class variables (start and end) in this case to minimize boilerplate.

Note that we donโ€™t check for Span additions because sometimes consecutive spans have a small gap: e.g. Node from 1-5 followed by Node from 6-10, which weโ€™ll handle next.

Filling in the Gaps

Aside: For the earlier examples, Iโ€™m going to use a max character count of 600 to show smaller breaks and switch it back to 1500 at the end for the final chunking.

The original implementation actually has bugs, yielding something like this:

defdownload_logs(repo_full_name: str, run_id: int, installation_id: int):
 
====================
 
headers = {
        "Accept": "application/vnd.github+json",
        "Authorization": f"Bearer {get_token(installation_id)}",
        "X-GitHub-Api-Version": "2022-11-28"
    }response = requests.get(f"https://api.github.com/repos/{repo_full_name}/actions/runs/{run_id}/logs",
                            headers=headers)logs_str = ""

As you can see thereโ€™s missing whitespace between keywords and delimiters. Due to the noisiness of tree-sitter CST-parsers, as theyโ€™re mostly community-maintained, we have cases where the end byte of one node does not correspond with the start byte of the next starting node, like so:

expression_statement:431-695
  assignment:431-695
    string:445-695
function_definition:697-1501
  block:776-1501
    expression_statement:776-952
      assignment:776-952
        dictionary:786-952
...

Notice that the โ€œexpression_statementโ€ ends on byte 695 and โ€œfunction_definitionโ€ starts on byte 697, skipping the byte corresponding to a new line. A partial solution is to instead of using node.start_byte and node.end_byte, we replace the end_byte with the start_byte of the next node, like so:

from tree_sitter import Node
from dataclasses import field
 
@dataclass
class MockNode:
    start_byte: int = 0
    end_byte: int = 0
    children: list[MockNode] = field(default_factory=list)
 
def chunk_node(
    node: Node,
    text: str,
    MAX_CHARS: int = 600
) -> list[str]:
    chunks = []
    current_chunk: list[str] = ""
    node_children = node.children + [MockNode(node.end_byte, node.end_byte)]
 
    for child, next_child in zip(node_children[:-1], node_children[1:]):
        if child.end_byte - child.start_byte > MAX_CHARS:
            chunks.append(current_chunk)
            current_chunk = ""
            chunks.extend(chunk_node(child, text, MAX_CHARS))
        elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS:
            chunks.append(current_chunk)
            current_chunk = text[child.start_byte: next_child.start_byte]
        else:
            current_chunk += text[child.start_byte: next_child.start_byte]
    chunks.append(current_chunk)
 
    return chunks

We can clean up the implementation by using Spans, taking advantage of how concatenation works, as follows.

def chunk_node(
    node: Node,
    MAX_CHARS: int = 600,
) -> list[Span]:
    new_chunks: list[Span] = []
    current_chunk: Span = Span(node.start_byte, node.start_byte)
    for child in node.children:
        if child.end_byte - child.start_byte > MAX_CHARS:
            new_chunks.append(current_chunk)
            current_chunk = Span(child.end_byte, child.end_byte)
            new_chunks.extend(chunk_node(child, MAX_CHARS))
        elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS:
            new_chunks.append(current_chunk)
            current_chunk = Span(child.start_byte, child.end_byte)
        else:
            current_chunk += Span(child.start_byte, child.end_byte)
    new_chunks.append(current_chunk)
    return new_chunks

Which yields significantly better results but still has an issue, consecutive spans still has gaps.

Span(start=1, end=430)
Span(start=432, end=696)
Span(start=698, end=772)
Span(start=777, end=1122)
Span(start=1502, end=1502)
...

We can also eliminate the case where consecutive chunks are disconnected by adding a simple post-processing script.

def connect_chunks(chunks: list[Span]):
    for prev, curr in zip(chunks[:-1], chunks[1:]):
        prev.end = curr.start
	return chunks

Coalescing

Occasionally, we have very small chunks that are right before or after another chunk, likely because of how some tree-sitter considers the header as its own node:

def download_logs(repo_full_name: str, run_id: int, installation_id: int):
 
====================
 
	headers = {
        "Accept": "application/vnd.github+json",
        "Authorization": f"Bearer {get_token(installation_id)}",
        "X-GitHub-Api-Version": "2022-11-28"
    }
    response = requests.get(f"https://api.github.com/repos/{repo_full_name}/actions/runs/{run_id}/logs",
                            headers=headers)
 
    logs_str = ""

Regardless, we can fix this by fusing one-line chunks with the next chunk.

def coalesce_chunks(chunks: list[Span], source_code: str, coalesce: int = 50) -> list[Span]:
    new_chunks = []
    current_chunk = Span(0, 0)
    for chunk in chunks:
        current_chunk += chunk
        if len(current_chunk) > coalesce and "\n" in current_chunk.extract(source_code):
            new_chunks.append(current_chunk)
            current_chunk = Span(chunk.end, chunk.end)
    if len(current_chunk) > 0:
        new_chunks.append(current_chunk)
    return new_chunks

There is a small edge case to this though, for curly-bracket-based languages like Typescript and C++ we get something like the following

int main() {
	cout << "Hello world";
	return 0;
 
====================
 
}

The closing parentheses would end up being fused with the next block. Fortunately, this issue does not cause important semantic meaning to be lost and only creates extra noise that the language model filters out so this issue remains unaddressed. On the other hand, function headers being separated by the implementation is significantly worse.

Skipping Whitespace when Measuring Size

Another problem we realized was that at higher levels of indents, because each indent counts as a space, we sometimes get really small chunks which are not helpful to our language model. We have a simple solution where to count the number of lines, we skip the starting whitespace, like so:

def char_len(s: str) -> int: # old len function
	return len(s)
 
def non_whitespace_len(s: str) -> int: # new len function
    return len(re.sub("\s", "", s))

Fixing to Lines

Instead of binding chunks to character indices (in bytes, in fact), we can bind them to line numbers, since eventually we just store the line numbers in the Vector DB anyways. This eliminates errors with non-UTF-8 or ASCII encodings and ensures that we eliminate any chunks less than a line long.

This also means that we have to run another loop deleting any empty chunks. Overall we use a helper function like:

def get_line_number(index: int, source_code: str) -> int:
    total_chars = 0
    for line_number, line in enumerate(source_code.splitlines(keepends=True), start=1):
        total_chars += len(line)
        if total_chars > index:
            return line_number - 1
    return line_number

We can also add a line to our Span class to finally extract the right lines of code using this. Fortunately, Spanโ€™s are unit agnostic so we can just determine when we extract it what units (bytes or lines) we want.

@dataclass
class Span:
    start: int
    end: int
 
    def extract(self, s: str) -> str:
        return s[self.start: self.end]
 
    def extract_lines(self, s: str) -> str:
        return "\n".join(s.splitlines()[self.start:self.end])
 
    def __add__(self, other: Span | int) -> Span:
        if isinstance(other, int):
            return Span(self.start + other, self.end + other)
        elif isinstance(other, Span):
            return Span(self.start, other.end)
        else:
            raise NotImplementedError()
 
    def __len__(self) -> int:
        return self.end - self.start

Final New Algorithm

The final pipeline from the syntax tree to the chunks looks something as follows:

  1. Generate the chunks (the spans currently refer to line numbers)
  2. Coalescing small chunks with future larger chunks
  3. Translate byte indices to line numbers
  4. Delete empty chunks

And the final algorithm looks like the following:

from tree_sitter import Tree
 
def chunker(
	tree: Tree,
	source_code: bytes,
	MAX_CHARS=512 * 3,
	coalesce=50 # Any chunk less than 50 characters long gets coalesced with the next chunk
) -> list[Span]:
 
    # 1. Recursively form chunks based on the last post (https://docs.sweep.dev/blogs/chunking-2m-files)
    def chunk_node(node: Node) -> list[Span]:
        chunks: list[Span] = []
        current_chunk: Span = Span(node.start_byte, node.start_byte)
        node_children = node.children
        for child in node_children:
            if child.end_byte - child.start_byte > MAX_CHARS:
                chunks.append(current_chunk)
                current_chunk = Span(child.end_byte, child.end_byte)
                chunks.extend(chunk_node(child))
            elif child.end_byte - child.start_byte + len(current_chunk) > MAX_CHARS:
                chunks.append(current_chunk)
                current_chunk = Span(child.start_byte, child.end_byte)
            else:
                current_chunk += Span(child.start_byte, child.end_byte)
        chunks.append(current_chunk)
        return chunks
    chunks = chunk_node(tree.root_node)
 
    # 2. Filling in the gaps
    for prev, curr in zip(chunks[:-1], chunks[1:]):
        prev.end = curr.start
    curr.start = tree.root_node.end_byte
 
    # 3. Combining small chunks with bigger ones
    new_chunks = []
    current_chunk = Span(0, 0)
    for chunk in chunks:
        current_chunk += chunk
        if non_whitespace_len(current_chunk.extract(source_code)) > coalesce \
            and "\n" in current_chunk.extract(source_code):
            new_chunks.append(current_chunk)
            current_chunk = Span(chunk.end, chunk.end)
    if len(current_chunk) > 0:
        new_chunks.append(current_chunk)
 
    # 4. Changing line numbers
    line_chunks = [Span(get_line_number(chunk.start, source_code),
                    get_line_number(chunk.end, source_code)) for chunk in new_chunks]
 
    # 5. Eliminating empty chunks
    line_chunks = [chunk for chunk in line_chunks if len(chunk) > 0]
 
    return line_chunks

Yielding the final chunked file you see at https://gist.github.com/kevinlu1248/49a72a1978868775109c5627677dc512#file-on_check_suite-py-md (opens in a new tab) (uses max character count of 1500 instead of 600).

Bonus: Modal Rate Limits

On an unrelated note, we had an issue hitting the rate limits on Modal since we were spinning up tens of thousands of containers really suddenly to chunk a new user repository. whenever users update their codebase. We fixed this by adding the following:

  1. We batched them into groups of 30, so for a 30k file repo, we would spin up only a thousand containers each sequentially chunking 30 files each instead of 30k containers chunking 1 file each.
  2. We passively re-indexed only for very active users and paying users.