Skip to content

Commit

Permalink
feat(repo_map): add elixir support (#894)
Browse files Browse the repository at this point in the history
  • Loading branch information
radwo authored Nov 24, 2024
1 parent e60ccd2 commit 890fd92
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 16 deletions.
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/avante-repo-map/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ tree-sitter-lua = "0.2"
tree-sitter-ruby = "0.23"
tree-sitter-zig = "1.0.2"
tree-sitter-scala = "0.23"
tree-sitter-elixir = "0.3.1"

[lints]
workspace = true
Expand Down
21 changes: 21 additions & 0 deletions crates/avante-repo-map/queries/tree-sitter-elixir-defs.scm
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
; * modules and protocols
(call
target: (identifier) @ignore
(arguments (alias) @class)
(#match? @ignore "^(defmodule|defprotocol)$"))

; * functions
(call
target: (identifier) @ignore
(arguments
[
; zero-arity functions with no parentheses
(identifier) @method
; regular function clause
(call target: (identifier) @method)
; function clause with a guard clause
(binary_operator
left: (call target: (identifier) @method)
operator: "when")
])
(#match? @ignore "^(def|defdelegate|defguard|defn)$"))
93 changes: 85 additions & 8 deletions crates/avante-repo-map/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub struct Func {

#[derive(Debug, Clone)]
pub struct Class {
pub type_name: String,
pub name: String,
pub methods: Vec<Func>,
pub properties: Vec<Variable>,
Expand Down Expand Up @@ -61,6 +62,7 @@ fn get_ts_language(language: &str) -> Option<LanguageFn> {
"ruby" => Some(tree_sitter_ruby::LANGUAGE),
"zig" => Some(tree_sitter_zig::LANGUAGE),
"scala" => Some(tree_sitter_scala::LANGUAGE),
"elixir" => Some(tree_sitter_elixir::LANGUAGE),
_ => None,
}
}
Expand All @@ -76,6 +78,7 @@ const ZIG_QUERY: &str = include_str!("../queries/tree-sitter-zig-defs.scm");
const TYPESCRIPT_QUERY: &str = include_str!("../queries/tree-sitter-typescript-defs.scm");
const RUBY_QUERY: &str = include_str!("../queries/tree-sitter-ruby-defs.scm");
const SCALA_QUERY: &str = include_str!("../queries/tree-sitter-scala-defs.scm");
const ELIXIR_QUERY: &str = include_str!("../queries/tree-sitter-elixir-defs.scm");

fn get_definitions_query(language: &str) -> Result<Query, String> {
let ts_language = get_ts_language(language);
Expand All @@ -95,6 +98,7 @@ fn get_definitions_query(language: &str) -> Result<Query, String> {
"typescript" => TYPESCRIPT_QUERY,
"ruby" => RUBY_QUERY,
"scala" => SCALA_QUERY,
"elixir" => ELIXIR_QUERY,
_ => return Err(format!("Unsupported language: {language}")),
};
let query = Query::new(&ts_language.into(), contents)
Expand Down Expand Up @@ -185,6 +189,23 @@ fn zig_find_type_in_parent<'a>(node: &'a Node, source: &'a [u8]) -> Option<Strin
None
}

fn ex_find_parent_module_declaration_name<'a>(node: &'a Node, source: &'a [u8],) -> Option<String> {
let mut parent = node.parent();
while let Some(parent_node) = parent {
if parent_node.kind() == "call" {
let text = get_node_text(&parent_node, source);
if text.starts_with("defmodule ") {
let arguments_node = find_child_by_type(&parent_node, "arguments");
if let Some(arguments_node) = arguments_node {
return Some(get_node_text(&arguments_node, source));
}
}
}
parent = parent_node.parent();
}
None
}

fn get_node_text<'a>(node: &'a Node, source: &'a [u8]) -> String {
node.utf8_text(source).unwrap_or_default().to_string()
}
Expand Down Expand Up @@ -235,9 +256,14 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
let mut enum_def_map: BTreeMap<String, RefCell<Enum>> = BTreeMap::new();
let mut union_def_map: BTreeMap<String, RefCell<Union>> = BTreeMap::new();

let ensure_class_def = |name: &str, class_def_map: &mut BTreeMap<String, RefCell<Class>>| {
let ensure_class_def = |language: &str, name: &str, class_def_map: &mut BTreeMap<String, RefCell<Class>>| {
let mut type_name = "class";
if language == "elixir" {
type_name = "module"
}
class_def_map.entry(name.to_string()).or_insert_with(|| {
RefCell::new(Class {
type_name: type_name.to_string(),
name: name.to_string(),
methods: vec![],
properties: vec![],
Expand Down Expand Up @@ -337,7 +363,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
if language == "go" && !is_first_letter_uppercase(&name) {
continue;
}
ensure_class_def(&name, &mut class_def_map);
ensure_class_def(&language, &name, &mut class_def_map);
let visibility_modifier_node =
find_child_by_type(&node, "visibility_modifier");
let visibility_modifier = visibility_modifier_node
Expand Down Expand Up @@ -449,12 +475,18 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
.child_by_field_name("parameters")
.or_else(|| find_descendant_by_type(&node, "parameter_list"));

let function_node = find_ancestor_by_type(&node, "function_declaration");
let zig_function_node = find_ancestor_by_type(&node, "function_declaration");
if language == "zig" {
params_node = function_node
params_node = zig_function_node
.as_ref()
.and_then(|n| find_child_by_type(n, "parameters"));
}
let ex_function_node = find_ancestor_by_type(&node, "call");
if language == "elixir" {
params_node = ex_function_node
.as_ref()
.and_then(|n| find_child_by_type(n, "arguments"));
}

let params = params_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
Expand All @@ -480,6 +512,9 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
return_type_node = node.child_by_field_name("result");
}
let mut return_type = "void".to_string();
if language == "elixir" {
return_type = "".to_string();
}
if return_type_node.is_some() {
return_type = get_node_type(&return_type_node.unwrap(), source.as_bytes());
if return_type.is_empty() {
Expand All @@ -496,6 +531,9 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
let class_name = if language == "zig" {
zig_find_parent_variable_declaration_name(&node, source.as_bytes())
.unwrap_or_default()
} else if language == "elixir" {
ex_find_parent_module_declaration_name(&node, source.as_bytes())
.unwrap_or_default()
} else if language == "cpp" {
find_ancestor_by_type(&node, "class_specifier")
.or_else(|| find_ancestor_by_type(&node, "struct_specifier"))
Expand Down Expand Up @@ -524,7 +562,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
continue;
}

ensure_class_def(&class_name, &mut class_def_map);
ensure_class_def(&language, &class_name, &mut class_def_map);
let class_def = class_def_map.get_mut(&class_name).unwrap();

let accessibility_modifier_node =
Expand Down Expand Up @@ -569,7 +607,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
if class_name.is_empty() {
continue;
}
ensure_class_def(&class_name, &mut class_def_map);
ensure_class_def(&language, &class_name, &mut class_def_map);
let class_def = class_def_map.get_mut(&class_name).unwrap();
let variable = Variable {
name: left.to_string(),
Expand Down Expand Up @@ -623,7 +661,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(&name) {
continue;
}
ensure_class_def(&class_name, &mut class_def_map);
ensure_class_def(&language, &class_name, &mut class_def_map);
let class_def = class_def_map.get_mut(&class_name).unwrap();
let variable = Variable {
name: name.to_string(),
Expand Down Expand Up @@ -888,7 +926,7 @@ fn stringify_union_item(item: &Variable) -> String {
}

fn stringify_class(class: &Class) -> String {
let mut res = format!("class {}{{", class.name);
let mut res = format!("{} {}{{", class.type_name, class.name);
for method in &class.methods {
let method_str = stringify_function(method);
res = format!("{res}{method_str}");
Expand Down Expand Up @@ -1428,6 +1466,45 @@ mod tests {
assert_eq!(stringified, expected);
}

#[test]
fn test_elixir() {
let source = r#"
defmodule TestModule do
@moduledoc """
This is a test module
"""
@test_const "test"
@other_const 123
def test_func(a, b) do
a + b
end
defp private_func(x) do
x * 2
end
defmacro test_macro(expr) do
quote do
unquote(expr)
end
end
end
defmodule AnotherModule do
def another_func() do
:ok
end
end
"#;
let definitions = extract_definitions("elixir", source).unwrap();
let stringified = stringify_definitions(&definitions);
println!("{stringified}");
let expected = "module AnotherModule{func another_func();};module TestModule{func test_func(a, b);};";
assert_eq!(stringified, expected);
}

#[test]
fn test_unsupported_language() {
let source = "print('Hello, world!')";
Expand Down
14 changes: 6 additions & 8 deletions lua/avante/repo_map.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ local Popup = require("nui.popup")
local Utils = require("avante.utils")
local event = require("nui.utils.autocmd").event
local Config = require("avante.config")
local fn = vim.fn

local filetype_map = {
["javascriptreact"] = "javascript",
Expand Down Expand Up @@ -34,14 +33,13 @@ function RepoMap.get_ts_lang(filepath)
end

function RepoMap.get_filetype(filepath)
local filetype = vim.filetype.match({ filename = filepath })
-- TypeScript files are sometimes not detected correctly
-- Some files are sometimes not detected correctly when buffer is not included
-- https://github.com/neovim/neovim/issues/27265
if not filetype then
local ext = fn.fnamemodify(filepath, ":e")
if ext == "tsx" then filetype = "typescriptreact" end
if ext == "ts" then filetype = "typescript" end
end

local buf = vim.api.nvim_create_buf(false, true)
local filetype = vim.filetype.match({ filename = filepath, buf = buf })
vim.api.nvim_buf_delete(buf, { force = true })

return filetype
end

Expand Down

0 comments on commit 890fd92

Please sign in to comment.