From bac46cee83ce532d597e6ee7da844d448f8c24cc Mon Sep 17 00:00:00 2001 From: Boy Maas Date: Sun, 29 Sep 2024 19:27:10 +0200 Subject: [PATCH] feat(repo-map): zig support (#663) * feature: zig support for repo map * Update crates/avante-repo-map/Cargo.toml Co-authored-by: yetone * fix: update lint error Signed-off-by: Aaron Pham --------- Signed-off-by: Aaron Pham Co-authored-by: Aaron Pham Co-authored-by: yetone Co-authored-by: Aaron Pham --- Cargo.lock | 11 + crates/avante-repo-map/Cargo.toml | 1 + .../queries/tree-sitter-zig-defs.scm | 23 ++ crates/avante-repo-map/src/lib.rs | 260 +++++++++++++++++- 4 files changed, 288 insertions(+), 7 deletions(-) create mode 100644 crates/avante-repo-map/queries/tree-sitter-zig-defs.scm diff --git a/Cargo.lock b/Cargo.lock index 9c2922a72..571e3d908 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -48,6 +48,7 @@ dependencies = [ "tree-sitter-ruby", "tree-sitter-rust", "tree-sitter-typescript", + "tree-sitter-zig", ] [[package]] @@ -1404,6 +1405,16 @@ dependencies = [ "tree-sitter-language", ] +[[package]] +name = "tree-sitter-zig" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2994e37b8ef1f715b931a5ff084a1b1713b1bc56e7aaebd148cc3efe0bf29ad9" +dependencies = [ + "cc", + "tree-sitter-language", +] + [[package]] name = "typeid" version = "1.0.2" diff --git a/crates/avante-repo-map/Cargo.toml b/crates/avante-repo-map/Cargo.toml index bcf40d7c3..2e2e2983c 100644 --- a/crates/avante-repo-map/Cargo.toml +++ b/crates/avante-repo-map/Cargo.toml @@ -26,6 +26,7 @@ tree-sitter-c = "0.23" tree-sitter-cpp = "0.23" tree-sitter-lua = "0.2" tree-sitter-ruby = "0.23" +tree-sitter-zig = "1.0.2" [lints] workspace = true diff --git a/crates/avante-repo-map/queries/tree-sitter-zig-defs.scm b/crates/avante-repo-map/queries/tree-sitter-zig-defs.scm new file mode 100644 index 000000000..f2b1215ca --- /dev/null +++ b/crates/avante-repo-map/queries/tree-sitter-zig-defs.scm @@ -0,0 +1,23 @@ + ;; Capture functions, structs, methods, variable definitions, and unions in Zig +(variable_declaration (identifier) + (struct_declaration + (container_field) @class_variable)) + +(variable_declaration (identifier) + (struct_declaration + (function_declaration + name: (identifier) @method))) + +(variable_declaration (identifier) + (enum_declaration + (container_field + type: (identifier) @enum_item))) + +(variable_declaration (identifier) + (union_declaration + (container_field + name: (identifier) @union_item))) + +(source_file (function_declaration) @function) + +(source_file (variable_declaration (identifier) @variable)) diff --git a/crates/avante-repo-map/src/lib.rs b/crates/avante-repo-map/src/lib.rs index 9e93956e4..5ad8973ec 100644 --- a/crates/avante-repo-map/src/lib.rs +++ b/crates/avante-repo-map/src/lib.rs @@ -26,6 +26,12 @@ pub struct Enum { pub items: Vec, } +#[derive(Debug, Clone)] +pub struct Union { + pub name: String, + pub items: Vec, +} + #[derive(Debug, Clone)] pub struct Variable { pub name: String, @@ -38,6 +44,7 @@ pub enum Definition { Class(Class), Enum(Enum), Variable(Variable), + Union(Union), } fn get_ts_language(language: &str) -> Option { @@ -51,6 +58,7 @@ fn get_ts_language(language: &str) -> Option { "cpp" => Some(tree_sitter_cpp::LANGUAGE), "lua" => Some(tree_sitter_lua::LANGUAGE), "ruby" => Some(tree_sitter_ruby::LANGUAGE), + "zig" => Some(tree_sitter_zig::LANGUAGE), _ => None, } } @@ -62,6 +70,7 @@ const JAVASCRIPT_QUERY: &str = include_str!("../queries/tree-sitter-javascript-d const LUA_QUERY: &str = include_str!("../queries/tree-sitter-lua-defs.scm"); const PYTHON_QUERY: &str = include_str!("../queries/tree-sitter-python-defs.scm"); const RUST_QUERY: &str = include_str!("../queries/tree-sitter-rust-defs.scm"); +const ZIG_QUERY: &str = include_str!("../queries/tree-sitter-zig-defs.scm"); const TYPESCRIPT_QUERY: &str = include_str!("../queries/tree-sitter-typescript-defs.scm"); const RUBY_QUERY: &str = include_str!("../queries/tree-sitter-ruby-defs.scm"); @@ -79,6 +88,7 @@ fn get_definitions_query(language: &str) -> Result { "lua" => LUA_QUERY, "python" => PYTHON_QUERY, "rust" => RUST_QUERY, + "zig" => ZIG_QUERY, "typescript" => TYPESCRIPT_QUERY, "ruby" => RUBY_QUERY, _ => return Err(format!("Unsupported language: {language}")), @@ -128,6 +138,49 @@ fn find_child_by_type<'a>(node: &'a Node, child_type: &str) -> Option> .find(|child| child.kind() == child_type) } +// Zig-specific function to find the parent variable declaration +fn zig_find_parent_variable_declaration_name<'a>( + node: &'a Node, + source: &'a [u8], +) -> Option { + let vardec = find_ancestor_by_type(node, "variable_declaration"); + if let Some(vardec) = vardec { + // Find the identifier child node, which represents the class name + let identifier_node = find_child_by_type(&vardec, "identifier"); + if let Some(identifier_node) = identifier_node { + return Some(get_node_text(&identifier_node, source)); + } + } + None +} + +fn zig_is_declaration_public<'a>(node: &'a Node, declaration_type: &str, source: &'a [u8]) -> bool { + let declaration = find_ancestor_by_type(node, declaration_type); + if let Some(declaration) = declaration { + let declaration_text = get_node_text(&declaration, source); + return declaration_text.starts_with("pub"); + } + false +} + +fn zig_is_variable_declaration_public<'a>(node: &'a Node, source: &'a [u8]) -> bool { + zig_is_declaration_public(node, "variable_declaration", source) +} + +fn zig_is_function_declaration_public<'a>(node: &'a Node, source: &'a [u8]) -> bool { + zig_is_declaration_public(node, "function_declaration", source) +} + +fn zig_find_type_in_parent<'a>(node: &'a Node, source: &'a [u8]) -> Option { + // First go to the parent and then get the child_by_field_name "type" + if let Some(parent) = node.parent() { + if let Some(type_node) = parent.child_by_field_name("type") { + return Some(get_node_text(&type_node, source)); + } + } + None +} + fn get_node_text<'a>(node: &'a Node, source: &'a [u8]) -> String { node.utf8_text(source).unwrap_or_default().to_string() } @@ -176,6 +229,7 @@ fn extract_definitions(language: &str, source: &str) -> Result, let mut class_def_map: HashMap> = HashMap::new(); let mut enum_def_map: HashMap> = HashMap::new(); + let mut union_def_map: HashMap> = HashMap::new(); let ensure_class_def = |name: &str, class_def_map: &mut HashMap>| { class_def_map.entry(name.to_string()).or_insert_with(|| { @@ -197,14 +251,26 @@ fn extract_definitions(language: &str, source: &str) -> Result, }); }; + let ensure_union_def = |name: &str, union_def_map: &mut HashMap>| { + union_def_map.entry(name.to_string()).or_insert_with(|| { + RefCell::new(Union { + name: name.to_string(), + items: vec![], + }) + }); + }; + for (m, _) in captures { for capture in m.captures { let capture_name = &query.capture_names()[capture.index as usize]; let node = capture.node; + let node_text = node.utf8_text(source.as_bytes()).unwrap(); + let name_node = node.child_by_field_name("name"); let name = name_node .map(|n| n.utf8_text(source.as_bytes()).unwrap()) - .unwrap_or(""); + .unwrap_or(node_text); + match *capture_name { "class" => { if !name.is_empty() { @@ -235,7 +301,17 @@ fn extract_definitions(language: &str, source: &str) -> Result, if language == "rust" && !visibility_modifier.contains("pub") { continue; } - let enum_name = get_closest_ancestor_name(&node, source); + if language == "zig" + && !zig_is_variable_declaration_public(&node, source.as_bytes()) + { + continue; + } + let mut enum_name = get_closest_ancestor_name(&node, source); + if language == "zig" { + enum_name = + zig_find_parent_variable_declaration_name(&node, source.as_bytes()) + .unwrap_or_default(); + } if !enum_name.is_empty() && language == "go" && !is_first_letter_uppercase(&enum_name) @@ -254,6 +330,28 @@ fn extract_definitions(language: &str, source: &str) -> Result, }; enum_def.borrow_mut().items.push(variable); } + "union_item" => { + if language != "zig" { + continue; + } + if !zig_is_variable_declaration_public(&node, source.as_bytes()) { + continue; + } + let union_name = + zig_find_parent_variable_declaration_name(&node, source.as_bytes()) + .unwrap_or_default(); + ensure_union_def(&union_name, &mut union_def_map); + let union_def = union_def_map.get_mut(&union_name).unwrap(); + let union_type_node = find_descendant_by_type(&node, "type_identifier"); + let union_type = union_type_node + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or(""); + let variable = Variable { + name: name.to_string(), + value_type: union_type.to_string(), + }; + union_def.borrow_mut().items.push(variable); + } "method" => { let visibility_modifier_node = find_descendant_by_type(&node, "visibility_modifier"); @@ -263,10 +361,25 @@ fn extract_definitions(language: &str, source: &str) -> Result, if language == "rust" && !visibility_modifier.contains("pub") { continue; } + if language == "zig" + && !(zig_is_function_declaration_public(&node, source.as_bytes()) + && zig_is_variable_declaration_public(&node, source.as_bytes())) + { + continue; + } + if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) { continue; } - let params_node = node.child_by_field_name("parameters"); + let mut params_node = node.child_by_field_name("parameters"); + + let function_node = find_ancestor_by_type(&node, "function_declaration"); + if language == "zig" { + params_node = function_node + .as_ref() + .and_then(|n| find_child_by_type(n, "parameters")); + } + let params = params_node .map(|n| n.utf8_text(source.as_bytes()).unwrap()) .unwrap_or("()"); @@ -288,7 +401,10 @@ fn extract_definitions(language: &str, source: &str) -> Result, let impl_item_node = find_ancestor_by_type(&node, "impl_item"); let receiver_node = node.child_by_field_name("receiver"); - let class_name = if let Some(impl_item) = impl_item_node { + let class_name = if language == "zig" { + zig_find_parent_variable_declaration_name(&node, source.as_bytes()) + .unwrap_or_default() + } else if let Some(impl_item) = impl_item_node { let impl_type_node = impl_item.child_by_field_name("type"); impl_type_node .map(|n| n.utf8_text(source.as_bytes()).unwrap()) @@ -371,8 +487,22 @@ fn extract_definitions(language: &str, source: &str) -> Result, if language == "rust" && !visibility_modifier.contains("pub") { continue; } + let value_type = get_node_type(&node, source.as_bytes()); - let class_name = get_closest_ancestor_name(&node, source); + + if language == "zig" { + // when top level class is not public, skip + if !zig_is_variable_declaration_public(&node, source.as_bytes()) { + continue; + } + } + + let mut class_name = get_closest_ancestor_name(&node, source); + if language == "zig" { + class_name = + zig_find_parent_variable_declaration_name(&node, source.as_bytes()) + .unwrap_or_default(); + } if !class_name.is_empty() && language == "go" && !is_first_letter_uppercase(&class_name) @@ -399,9 +529,19 @@ fn extract_definitions(language: &str, source: &str) -> Result, let visibility_modifier = visibility_modifier_node .map(|n| n.utf8_text(source.as_bytes()).unwrap()) .unwrap_or(""); + if language == "rust" && !visibility_modifier.contains("pub") { continue; } + + if language == "zig" { + let variable_declaration_text = + node.utf8_text(source.as_bytes()).unwrap_or(""); + if !variable_declaration_text.contains("pub") { + continue; + } + } + if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) { continue; } @@ -493,9 +633,17 @@ fn extract_definitions(language: &str, source: &str) -> Result, let visibility_modifier = visibility_modifier_node .map(|n| n.utf8_text(source.as_bytes()).unwrap()) .unwrap_or(""); + if language == "rust" && !visibility_modifier.contains("pub") { continue; } + + if language == "zig" + && !zig_is_variable_declaration_public(&node, source.as_bytes()) + { + continue; + } + let impl_item_node = find_ancestor_by_type(&node, "impl_item") .or_else(|| find_ancestor_by_type(&node, "class_declaration")) .or_else(|| find_ancestor_by_type(&node, "class_definition")); @@ -532,7 +680,15 @@ fn extract_definitions(language: &str, source: &str) -> Result, continue; } } - let value_type = get_node_type(&node, source.as_bytes()); + + let mut value_type = get_node_type(&node, source.as_bytes()); + if language == "zig" { + if let Some(zig_type) = zig_find_type_in_parent(&node, source.as_bytes()) { + value_type = zig_type; + } else { + continue; + }; + } if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) { continue; } @@ -563,6 +719,9 @@ fn extract_definitions(language: &str, source: &str) -> Result, for (_, def) in enum_def_map { definitions.push(Definition::Enum(def.into_inner())); } + for (_, def) in union_def_map { + definitions.push(Definition::Union(def.into_inner())); + } Ok(definitions) } @@ -599,6 +758,14 @@ fn stringify_enum_item(item: &Variable) -> String { format!("{res};") } +fn stringify_union_item(item: &Variable) -> String { + let mut res = item.name.clone(); + if !item.value_type.is_empty() { + res = format!("{res}:{}", item.value_type); + } + format!("{res};") +} + fn stringify_class(class: &Class) -> String { let mut res = format!("class {}{{", class.name); for method in &class.methods { @@ -620,6 +787,14 @@ fn stringify_enum(enum_def: &Enum) -> String { } format!("{res}}};") } +fn stringify_union(union_def: &Union) -> String { + let mut res = format!("union {}{{", union_def.name); + for item in &union_def.items { + let item_str = stringify_union_item(item); + res = format!("{res}{item_str}"); + } + format!("{res}}};") +} fn stringify_definitions(definitions: &Vec) -> String { let mut res = String::new(); @@ -627,6 +802,7 @@ fn stringify_definitions(definitions: &Vec) -> String { match definition { Definition::Class(class) => res = format!("{res}{}", stringify_class(class)), Definition::Enum(enum_def) => res = format!("{res}{}", stringify_enum(enum_def)), + Definition::Union(union_def) => res = format!("{res}{}", stringify_union(union_def)), Definition::Func(func) => res = format!("{res}{}", stringify_function(func)), Definition::Variable(variable) => { let variable_str = stringify_variable(variable); @@ -660,7 +836,6 @@ fn avante_repo_map(lua: &Lua) -> LuaResult { mod tests { use super::*; - #[test] fn test_rust() { let source = r#" // This is a test comment @@ -718,6 +893,77 @@ mod tests { assert_eq!(stringified, expected); } + #[test] + fn test_zig() { + let source = r#" + // This is a test comment + pub const TEST_CONST: u32 = 1; + pub var TEST_VAR: u32 = 2; + const INNER_TEST_CONST: u32 = 3; + var INNER_TEST_VAR: u32 = 4; + pub const TestStruct = struct { + test_field: []const u8, + test_field2: u64, + + pub fn test_method(_: *TestStruct, a: u32, b: u32) u32 { + return a + b; + } + + fn inner_test_method(_: *TestStruct, a: u32, b: u32) u32 { + return a + b; + } + }; + const InnerTestStruct = struct { + test_field: []const u8, + test_field2: u64, + + pub fn test_method(_: *InnerTestStruct, a: u32, b: u32) u32 { + return a + b; + } + + fn inner_test_method(_: *InnerTestStruct, a: u32, b: u32) u32 { + return a + b; + } + }; + pub const TestEnum = enum { + TestEnumField1, + TestEnumField2, + }; + const InnerTestEnum = enum { + InnerTestEnumField1, + InnerTestEnumField2, + }; + + pub const TestUnion = union { + TestUnionField1: u32, + TestUnionField2: u64, + }; + + const InnerTestUnion = union { + InnerTestUnionField1: u32, + InnerTestUnionField2: u64, + }; + + pub fn test_fn(a: u32, b: u32) u32 { + const inner_var_in_func = 1; + const InnerStructInFunc = struct { + c: u32, + }; + _ = InnerStructInFunc; + return a + b + inner_var_in_func; + } + fn inner_test_fn(a: u32, b: u32) u32 { + return a + b; + } + "#; + + let definitions = extract_definitions("zig", source).unwrap(); + let stringified = stringify_definitions(&definitions); + println!("{stringified}"); + let expected = "var TEST_CONST:u32;var TEST_VAR:u32;func test_fn() -> void;class TestStruct{func test_method(_: *TestStruct, a: u32, b: u32) -> void;var test_field:[]const u8;var test_field2:u64;};enum TestEnum{TestEnumField1;TestEnumField2;};union TestUnion{TestUnionField1;TestUnionField2;};"; + assert_eq!(stringified, expected); + } + #[test] fn test_go() { let source = r#"