Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(repo_map): add elixir support #894

Merged
merged 1 commit into from
Nov 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
feat(repo_map): add elixir support
  • Loading branch information
radwo committed Nov 23, 2024
commit c33845c57b6fdfe9730351113dc2613af1043432
11 changes: 11 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/avante-repo-map/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ tree-sitter-lua = "0.2"
tree-sitter-ruby = "0.23"
tree-sitter-zig = "1.0.2"
tree-sitter-scala = "0.23"
tree-sitter-elixir = "0.3.1"

[lints]
workspace = true
Expand Down
21 changes: 21 additions & 0 deletions crates/avante-repo-map/queries/tree-sitter-elixir-defs.scm
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
; * modules and protocols
(call
target: (identifier) @ignore
(arguments (alias) @class)
(#match? @ignore "^(defmodule|defprotocol)$"))

; * functions
(call
target: (identifier) @ignore
(arguments
[
; zero-arity functions with no parentheses
(identifier) @method
; regular function clause
(call target: (identifier) @method)
; function clause with a guard clause
(binary_operator
left: (call target: (identifier) @method)
operator: "when")
])
(#match? @ignore "^(def|defdelegate|defguard|defn)$"))
93 changes: 85 additions & 8 deletions crates/avante-repo-map/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub struct Func {

#[derive(Debug, Clone)]
pub struct Class {
pub type_name: String,
pub name: String,
pub methods: Vec<Func>,
pub properties: Vec<Variable>,
Expand Down Expand Up @@ -61,6 +62,7 @@ fn get_ts_language(language: &str) -> Option<LanguageFn> {
"ruby" => Some(tree_sitter_ruby::LANGUAGE),
"zig" => Some(tree_sitter_zig::LANGUAGE),
"scala" => Some(tree_sitter_scala::LANGUAGE),
"elixir" => Some(tree_sitter_elixir::LANGUAGE),
_ => None,
}
}
Expand All @@ -76,6 +78,7 @@ const ZIG_QUERY: &str = include_str!("../queries/tree-sitter-zig-defs.scm");
const TYPESCRIPT_QUERY: &str = include_str!("../queries/tree-sitter-typescript-defs.scm");
const RUBY_QUERY: &str = include_str!("../queries/tree-sitter-ruby-defs.scm");
const SCALA_QUERY: &str = include_str!("../queries/tree-sitter-scala-defs.scm");
const ELIXIR_QUERY: &str = include_str!("../queries/tree-sitter-elixir-defs.scm");

fn get_definitions_query(language: &str) -> Result<Query, String> {
let ts_language = get_ts_language(language);
Expand All @@ -95,6 +98,7 @@ fn get_definitions_query(language: &str) -> Result<Query, String> {
"typescript" => TYPESCRIPT_QUERY,
"ruby" => RUBY_QUERY,
"scala" => SCALA_QUERY,
"elixir" => ELIXIR_QUERY,
_ => return Err(format!("Unsupported language: {language}")),
};
let query = Query::new(&ts_language.into(), contents)
Expand Down Expand Up @@ -185,6 +189,23 @@ fn zig_find_type_in_parent<'a>(node: &'a Node, source: &'a [u8]) -> Option<Strin
None
}

fn ex_find_parent_module_declaration_name<'a>(node: &'a Node, source: &'a [u8],) -> Option<String> {
let mut parent = node.parent();
while let Some(parent_node) = parent {
if parent_node.kind() == "call" {
let text = get_node_text(&parent_node, source);
if text.starts_with("defmodule ") {
let arguments_node = find_child_by_type(&parent_node, "arguments");
if let Some(arguments_node) = arguments_node {
return Some(get_node_text(&arguments_node, source));
}
}
}
parent = parent_node.parent();
}
None
}

fn get_node_text<'a>(node: &'a Node, source: &'a [u8]) -> String {
node.utf8_text(source).unwrap_or_default().to_string()
}
Expand Down Expand Up @@ -235,9 +256,14 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
let mut enum_def_map: BTreeMap<String, RefCell<Enum>> = BTreeMap::new();
let mut union_def_map: BTreeMap<String, RefCell<Union>> = BTreeMap::new();

let ensure_class_def = |name: &str, class_def_map: &mut BTreeMap<String, RefCell<Class>>| {
let ensure_class_def = |language: &str, name: &str, class_def_map: &mut BTreeMap<String, RefCell<Class>>| {
let mut type_name = "class";
if language == "elixir" {
type_name = "module"
}
class_def_map.entry(name.to_string()).or_insert_with(|| {
RefCell::new(Class {
type_name: type_name.to_string(),
name: name.to_string(),
methods: vec![],
properties: vec![],
Expand Down Expand Up @@ -337,7 +363,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
if language == "go" && !is_first_letter_uppercase(&name) {
continue;
}
ensure_class_def(&name, &mut class_def_map);
ensure_class_def(&language, &name, &mut class_def_map);
let visibility_modifier_node =
find_child_by_type(&node, "visibility_modifier");
let visibility_modifier = visibility_modifier_node
Expand Down Expand Up @@ -449,12 +475,18 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
.child_by_field_name("parameters")
.or_else(|| find_descendant_by_type(&node, "parameter_list"));

let function_node = find_ancestor_by_type(&node, "function_declaration");
let zig_function_node = find_ancestor_by_type(&node, "function_declaration");
if language == "zig" {
params_node = function_node
params_node = zig_function_node
.as_ref()
.and_then(|n| find_child_by_type(n, "parameters"));
}
let ex_function_node = find_ancestor_by_type(&node, "call");
if language == "elixir" {
params_node = ex_function_node
.as_ref()
.and_then(|n| find_child_by_type(n, "arguments"));
}

let params = params_node
.map(|n| n.utf8_text(source.as_bytes()).unwrap())
Expand All @@ -480,6 +512,9 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
return_type_node = node.child_by_field_name("result");
}
let mut return_type = "void".to_string();
if language == "elixir" {
return_type = "".to_string();
}
if return_type_node.is_some() {
return_type = get_node_type(&return_type_node.unwrap(), source.as_bytes());
if return_type.is_empty() {
Expand All @@ -496,6 +531,9 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
let class_name = if language == "zig" {
zig_find_parent_variable_declaration_name(&node, source.as_bytes())
.unwrap_or_default()
} else if language == "elixir" {
ex_find_parent_module_declaration_name(&node, source.as_bytes())
.unwrap_or_default()
} else if language == "cpp" {
find_ancestor_by_type(&node, "class_specifier")
.or_else(|| find_ancestor_by_type(&node, "struct_specifier"))
Expand Down Expand Up @@ -524,7 +562,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
continue;
}

ensure_class_def(&class_name, &mut class_def_map);
ensure_class_def(&language, &class_name, &mut class_def_map);
let class_def = class_def_map.get_mut(&class_name).unwrap();

let accessibility_modifier_node =
Expand Down Expand Up @@ -569,7 +607,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
if class_name.is_empty() {
continue;
}
ensure_class_def(&class_name, &mut class_def_map);
ensure_class_def(&language, &class_name, &mut class_def_map);
let class_def = class_def_map.get_mut(&class_name).unwrap();
let variable = Variable {
name: left.to_string(),
Expand Down Expand Up @@ -623,7 +661,7 @@ fn extract_definitions(language: &str, source: &str) -> Result<Vec<Definition>,
if !name.is_empty() && language == "go" && !is_first_letter_uppercase(&name) {
continue;
}
ensure_class_def(&class_name, &mut class_def_map);
ensure_class_def(&language, &class_name, &mut class_def_map);
let class_def = class_def_map.get_mut(&class_name).unwrap();
let variable = Variable {
name: name.to_string(),
Expand Down Expand Up @@ -888,7 +926,7 @@ fn stringify_union_item(item: &Variable) -> String {
}

fn stringify_class(class: &Class) -> String {
let mut res = format!("class {}{{", class.name);
let mut res = format!("{} {}{{", class.type_name, class.name);
for method in &class.methods {
let method_str = stringify_function(method);
res = format!("{res}{method_str}");
Expand Down Expand Up @@ -1428,6 +1466,45 @@ mod tests {
assert_eq!(stringified, expected);
}

#[test]
fn test_elixir() {
let source = r#"
defmodule TestModule do
@moduledoc """
This is a test module
"""

@test_const "test"
@other_const 123

def test_func(a, b) do
a + b
end

defp private_func(x) do
x * 2
end

defmacro test_macro(expr) do
quote do
unquote(expr)
end
end
end

defmodule AnotherModule do
def another_func() do
:ok
end
end
"#;
let definitions = extract_definitions("elixir", source).unwrap();
let stringified = stringify_definitions(&definitions);
println!("{stringified}");
let expected = "module AnotherModule{func another_func();};module TestModule{func test_func(a, b);};";
assert_eq!(stringified, expected);
}

#[test]
fn test_unsupported_language() {
let source = "print('Hello, world!')";
Expand Down
14 changes: 6 additions & 8 deletions lua/avante/repo_map.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ local Popup = require("nui.popup")
local Utils = require("avante.utils")
local event = require("nui.utils.autocmd").event
local Config = require("avante.config")
local fn = vim.fn

local filetype_map = {
["javascriptreact"] = "javascript",
Expand Down Expand Up @@ -34,14 +33,13 @@ function RepoMap.get_ts_lang(filepath)
end

function RepoMap.get_filetype(filepath)
local filetype = vim.filetype.match({ filename = filepath })
-- TypeScript files are sometimes not detected correctly
-- Some files are sometimes not detected correctly when buffer is not included
-- https://github.com/neovim/neovim/issues/27265
if not filetype then
local ext = fn.fnamemodify(filepath, ":e")
if ext == "tsx" then filetype = "typescriptreact" end
if ext == "ts" then filetype = "typescript" end
end

local buf = vim.api.nvim_create_buf(false, true)
local filetype = vim.filetype.match({ filename = filepath, buf = buf })
vim.api.nvim_buf_delete(buf, { force = true })

return filetype
end

Expand Down