diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 7a10b1425..8837a46ab 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -29,6 +29,18 @@ jobs: uses: lunarmodules/luacheck@v1 with: args: ./lua/ + rust-tests: + name: Run Rust tests + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: Swatinem/rust-cache@v2 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + components: clippy, rustfmt + - name: Run rust tests + run: cargo test --features luajit rust: name: Check Rust style runs-on: ubuntu-latest diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 9dfbe03f4..2743836a6 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -74,6 +74,7 @@ jobs: fi cp target/release/libavante_templates.$EXT results/avante_templates.$EXT cp target/release/libavante_tokenizers.$EXT results/avante_tokenizers.$EXT + cp target/release/libavante_repo_map.$EXT results/avante_repo_map.$EXT cd results tar zcvf avante_lib-${{ matrix.os }}-${{ matrix.feature }}.tar.gz *.${EXT} @@ -85,6 +86,7 @@ jobs: Copy-Item -Path "target\release\avante_templates.dll" -Destination "results\avante_templates.dll" Copy-Item -Path "target\release\avante_tokenizers.dll" -Destination "results\avante_tokenizers.dll" + Copy-Item -Path "target\release\avante_repo_map.dll" -Destination "results\avante_repo_map.dll" Set-Location -Path results diff --git a/Cargo.lock b/Cargo.lock index 78470b2a3..9c2922a72 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -29,6 +29,27 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +[[package]] +name = "avante-repo-map" +version = "0.1.0" +dependencies = [ + "cc", + "minijinja", + "mlua", + "serde", + "tree-sitter", + "tree-sitter-c", + "tree-sitter-cpp", + "tree-sitter-go", + "tree-sitter-javascript", + "tree-sitter-language", + "tree-sitter-lua", + "tree-sitter-python", + "tree-sitter-ruby", + "tree-sitter-rust", + "tree-sitter-typescript", +] + [[package]] name = "avante-templates" version = "0.1.0" @@ -117,9 +138,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "cc" -version = "1.1.15" +version = "1.1.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57b6a275aa2903740dc87da01c62040406b8812552e97129a63ea8850a17c6e6" +checksum = "07b1695e2c7e8fc85310cde85aeaab7e3097f593c91d209d3f9df76c928100f0" dependencies = [ "shlex", ] @@ -1275,6 +1296,114 @@ dependencies = [ "unicode_categories", ] +[[package]] +name = "tree-sitter" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20f4cd3642c47a85052a887d86704f4eac272969f61b686bdd3f772122aabaff" +dependencies = [ + "cc", + "regex", + "regex-syntax", + "tree-sitter-language", +] + +[[package]] +name = "tree-sitter-c" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8b3fb515e498e258799a31d78e6603767cd6892770d9e2290ec00af5c3ad80b" +dependencies = [ + "cc", + "tree-sitter-language", +] + +[[package]] +name = "tree-sitter-cpp" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d67e862242878d6ee50e1e5814f267ee3eea0168aea2cdbd700ccfb4c74b6d3" +dependencies = [ + "cc", + "tree-sitter-language", +] + +[[package]] +name = "tree-sitter-go" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "caf57626e4c9b6d6efaf8a8d5ee1241c5f178ae7bfdf693713ae6a774f01424e" +dependencies = [ + "cc", + "tree-sitter-language", +] + +[[package]] +name = "tree-sitter-javascript" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59e1f62f8babb640b909f30675d1addeb1f17802f2a4d2af287569753b243977" +dependencies = [ + "cc", + "tree-sitter-language", +] + +[[package]] +name = "tree-sitter-language" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2545046bd1473dac6c626659cc2567c6c0ff302fc8b84a56c4243378276f7f57" + +[[package]] +name = "tree-sitter-lua" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cdb9adf0965fec58e7660cbb3a059dbb12ebeec9459e6dcbae3db004739641e" +dependencies = [ + "cc", + "tree-sitter-language", +] + +[[package]] +name = "tree-sitter-python" +version = "0.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65661b1a3e24139e2e54207e47d910ab07e28790d78efc7d5dc3a11ce2a110eb" +dependencies = [ + "cc", + "tree-sitter-language", +] + +[[package]] +name = "tree-sitter-ruby" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ec5ee842e27791e0adffa0b2a177614de51d2a26e5c7e84d014ed7f097e5ed0" +dependencies = [ + "cc", + "tree-sitter-language", +] + +[[package]] +name = "tree-sitter-rust" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cffbbcb780348fbae8395742ae5b34c1fd794e4085d43aac9f259387f9a84dc8" +dependencies = [ + "cc", + "tree-sitter-language", +] + +[[package]] +name = "tree-sitter-typescript" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aecf1585ae2a9dddc2b1d4c0e2140b2ec9876e2a25fd79de47fcf7dae0384685" +dependencies = [ + "cc", + "tree-sitter-language", +] + [[package]] name = "typeid" version = "1.0.2" diff --git a/Cargo.toml b/Cargo.toml index de6aead55..ced96b30e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ version = "0.1.0" [workspace.dependencies] avante-tokenizers = { path = "crates/avante-tokenizers" } avante-templates = { path = "crates/avante-templates" } +avante-repo-map = { path = "crates/avante-repo-map" } minijinja = { version = "2.2.0", features = [ "loader", "json", diff --git a/Makefile b/Makefile index b231f5d05..b80436482 100644 --- a/Makefile +++ b/Makefile @@ -22,13 +22,15 @@ all: luajit define make_definitions ifeq ($(BUILD_FROM_SOURCE),true) ifeq ($(TARGET_LIBRARY), all) -$1: $(BUILD_DIR)/libAvanteTokenizers-$1.$(EXT) $(BUILD_DIR)/libAvanteTemplates-$1.$(EXT) +$1: $(BUILD_DIR)/libAvanteTokenizers-$1.$(EXT) $(BUILD_DIR)/libAvanteTemplates-$1.$(EXT) $(BUILD_DIR)/libAvanteRepoMap-$1.$(EXT) else ifeq ($(TARGET_LIBRARY), tokenizers) $1: $(BUILD_DIR)/libAvanteTokenizers-$1.$(EXT) else ifeq ($(TARGET_LIBRARY), templates) $1: $(BUILD_DIR)/libAvanteTemplates-$1.$(EXT) +else ifeq ($(TARGET_LIBRARY), repo-map) +$1: $(BUILD_DIR)/libAvanteRepoMap-$1.$(EXT) else - $$(error TARGET_LIBRARY must be one of all, tokenizers, templates) + $$(error TARGET_LIBRARY must be one of all, tokenizers, templates, repo-map) endif else $1: @@ -41,16 +43,18 @@ $(foreach lua_version,$(LUA_VERSIONS),$(eval $(call make_definitions,$(lua_versi define build_package $1-$2: cargo build --release --features=$1 -p avante-$2 - cp target/release/libavante_$2.$(EXT) $(BUILD_DIR)/avante_$2.$(EXT) + cp target/release/libavante_$(shell echo $2 | tr - _).$(EXT) $(BUILD_DIR)/avante_$(shell echo $2 | tr - _).$(EXT) endef define build_targets $(BUILD_DIR)/libAvanteTokenizers-$1.$(EXT): $(BUILD_DIR) $1-tokenizers $(BUILD_DIR)/libAvanteTemplates-$1.$(EXT): $(BUILD_DIR) $1-templates +$(BUILD_DIR)/libAvanteRepoMap-$1.$(EXT): $(BUILD_DIR) $1-repo-map endef $(foreach lua_version,$(LUA_VERSIONS),$(eval $(call build_package,$(lua_version),tokenizers))) $(foreach lua_version,$(LUA_VERSIONS),$(eval $(call build_package,$(lua_version),templates))) +$(foreach lua_version,$(LUA_VERSIONS),$(eval $(call build_package,$(lua_version),repo-map))) $(foreach lua_version,$(LUA_VERSIONS),$(eval $(call build_targets,$(lua_version)))) $(BUILD_DIR): diff --git a/crates/avante-repo-map/Cargo.toml b/crates/avante-repo-map/Cargo.toml new file mode 100644 index 000000000..bcf40d7c3 --- /dev/null +++ b/crates/avante-repo-map/Cargo.toml @@ -0,0 +1,38 @@ +[lib] +crate-type = ["cdylib"] + +[package] +name = "avante-repo-map" +edition.workspace = true +rust-version.workspace = true +license.workspace = true +version.workspace = true + +[build-dependencies] +cc="*" + +[dependencies] +mlua = { workspace = true } +minijinja = { workspace = true } +serde = { workspace = true, features = ["derive"] } +tree-sitter = "0.23" +tree-sitter-language = "0.1" +tree-sitter-rust = "0.23" +tree-sitter-python = "0.23" +tree-sitter-javascript = "0.23" +tree-sitter-typescript = "0.23" +tree-sitter-go = "0.23" +tree-sitter-c = "0.23" +tree-sitter-cpp = "0.23" +tree-sitter-lua = "0.2" +tree-sitter-ruby = "0.23" + +[lints] +workspace = true + +[features] +lua51 = ["mlua/lua51"] +lua52 = ["mlua/lua52"] +lua53 = ["mlua/lua53"] +lua54 = ["mlua/lua54"] +luajit = ["mlua/luajit"] diff --git a/crates/avante-repo-map/queries/tree-sitter-c-defs.scm b/crates/avante-repo-map/queries/tree-sitter-c-defs.scm new file mode 100644 index 000000000..0e1a485aa --- /dev/null +++ b/crates/avante-repo-map/queries/tree-sitter-c-defs.scm @@ -0,0 +1,11 @@ +;; Capture extern functions, variables, public classes, and methods +(function_definition + (storage_class_specifier) @extern +) @function +(class_specifier + (public) @class + (function_definition) @method +) @class +(declaration + (storage_class_specifier) @extern +) @variable diff --git a/crates/avante-repo-map/queries/tree-sitter-cpp-defs.scm b/crates/avante-repo-map/queries/tree-sitter-cpp-defs.scm new file mode 100644 index 000000000..0e1a485aa --- /dev/null +++ b/crates/avante-repo-map/queries/tree-sitter-cpp-defs.scm @@ -0,0 +1,11 @@ +;; Capture extern functions, variables, public classes, and methods +(function_definition + (storage_class_specifier) @extern +) @function +(class_specifier + (public) @class + (function_definition) @method +) @class +(declaration + (storage_class_specifier) @extern +) @variable diff --git a/crates/avante-repo-map/queries/tree-sitter-go-defs.scm b/crates/avante-repo-map/queries/tree-sitter-go-defs.scm new file mode 100644 index 000000000..89f5a8b14 --- /dev/null +++ b/crates/avante-repo-map/queries/tree-sitter-go-defs.scm @@ -0,0 +1,18 @@ +;; Capture top-level functions and struct definitions +(var_declaration + (var_spec) @variable +) +(const_declaration + (const_spec) @variable +) +(function_declaration) @function +(type_declaration + (type_spec (struct_type)) @class +) +(type_declaration + (type_spec + (struct_type + (field_declaration_list + (field_declaration) @class_variable))) +) +(method_declaration) @method diff --git a/crates/avante-repo-map/queries/tree-sitter-javascript-defs.scm b/crates/avante-repo-map/queries/tree-sitter-javascript-defs.scm new file mode 100644 index 000000000..5698d96cf --- /dev/null +++ b/crates/avante-repo-map/queries/tree-sitter-javascript-defs.scm @@ -0,0 +1,23 @@ +;; Capture exported functions, arrow functions, variables, classes, and method definitions +(export_statement + declaration: (lexical_declaration + (variable_declarator) @variable + ) +) +(export_statement + declaration: (function_declaration) @function +) +(export_statement + declaration: (class_declaration + body: (class_body + (field_definition) @class_variable + ) + ) +) +(export_statement + declaration: (class_declaration + body: (class_body + (method_definition) @method + ) + ) +) diff --git a/crates/avante-repo-map/queries/tree-sitter-lua-defs.scm b/crates/avante-repo-map/queries/tree-sitter-lua-defs.scm new file mode 100644 index 000000000..73e27ca59 --- /dev/null +++ b/crates/avante-repo-map/queries/tree-sitter-lua-defs.scm @@ -0,0 +1,3 @@ +;; Capture function and method definitions +(variable_list) @variable +(function_declaration) @function diff --git a/crates/avante-repo-map/queries/tree-sitter-python-defs.scm b/crates/avante-repo-map/queries/tree-sitter-python-defs.scm new file mode 100644 index 000000000..40475abc0 --- /dev/null +++ b/crates/avante-repo-map/queries/tree-sitter-python-defs.scm @@ -0,0 +1,25 @@ +;; Capture top-level functions, class, and method definitions +(module + (expression_statement + (assignment) @assignment + ) +) +(module + (function_definition) @function +) +(module + (class_definition + body: (block + (expression_statement + (assignment) @class_assignment + ) + ) + ) +) +(module + (class_definition + body: (block + (function_definition) @method + ) + ) +) diff --git a/crates/avante-repo-map/queries/tree-sitter-ruby-defs.scm b/crates/avante-repo-map/queries/tree-sitter-ruby-defs.scm new file mode 100644 index 000000000..8f0117ae8 --- /dev/null +++ b/crates/avante-repo-map/queries/tree-sitter-ruby-defs.scm @@ -0,0 +1,16 @@ +;; Capture top-level methods, class definitions, and methods within classes +(program + (class + (body_statement + (call) @class_call + (assignment) @class_assignment + (method) @method + ) + ) @class +) +(program + (method) @function +) +(program + (assignment) @assignment +) diff --git a/crates/avante-repo-map/queries/tree-sitter-rust-defs.scm b/crates/avante-repo-map/queries/tree-sitter-rust-defs.scm new file mode 100644 index 000000000..ecf2f92e9 --- /dev/null +++ b/crates/avante-repo-map/queries/tree-sitter-rust-defs.scm @@ -0,0 +1,20 @@ +;; Capture public functions, structs, methods, and variable definitions +(function_item) @function +(impl_item + body: (declaration_list + (function_item) @method + ) +) +(struct_item) @class +(struct_item + body: (field_declaration_list + (field_declaration) @class_variable + ) +) +(enum_item + body: (enum_variant_list + (enum_variant) @enum_item + ) +) +(const_item) @variable +(static_item) @variable diff --git a/crates/avante-repo-map/queries/tree-sitter-typescript-defs.scm b/crates/avante-repo-map/queries/tree-sitter-typescript-defs.scm new file mode 100644 index 000000000..dee2e9520 --- /dev/null +++ b/crates/avante-repo-map/queries/tree-sitter-typescript-defs.scm @@ -0,0 +1,33 @@ +;; Capture exported functions, arrow functions, variables, classes, and method definitions +(export_statement + declaration: (lexical_declaration + (variable_declarator) @variable + ) +) +(export_statement + declaration: (function_declaration) @function +) +(export_statement + declaration: (class_declaration + body: (class_body + (public_field_definition) @class_variable + ) + ) +) +(interface_declaration + body: (interface_body + (property_signature) @class_variable + ) +) +(type_alias_declaration + value: (object_type + (property_signature) @class_variable + ) +) +(export_statement + declaration: (class_declaration + body: (class_body + (method_definition) @method + ) + ) +) diff --git a/crates/avante-repo-map/src/lib.rs b/crates/avante-repo-map/src/lib.rs new file mode 100644 index 000000000..86dd47495 --- /dev/null +++ b/crates/avante-repo-map/src/lib.rs @@ -0,0 +1,894 @@ +use mlua::prelude::*; +use std::cell::RefCell; +use std::collections::HashMap; +use tree_sitter::{Node, Parser, Query, QueryCursor}; +use tree_sitter_language::LanguageFn; + +#[derive(Debug, Clone)] +pub struct Func { + pub name: String, + pub params: String, + pub return_type: String, + pub accessibility_modifier: Option, +} + +#[derive(Debug, Clone)] +pub struct Class { + pub name: String, + pub methods: Vec, + pub properties: Vec, + pub visibility_modifier: Option, +} + +#[derive(Debug, Clone)] +pub struct Enum { + pub name: String, + pub items: Vec, +} + +#[derive(Debug, Clone)] +pub struct Variable { + pub name: String, + pub value_type: String, +} + +#[derive(Debug, Clone)] +pub enum Definition { + Func(Func), + Class(Class), + Enum(Enum), + Variable(Variable), +} + +fn get_ts_language(language: &str) -> Result { + match language { + "rust" => Ok(tree_sitter_rust::LANGUAGE), + "python" => Ok(tree_sitter_python::LANGUAGE), + "javascript" => Ok(tree_sitter_javascript::LANGUAGE), + "typescript" => Ok(tree_sitter_typescript::LANGUAGE_TSX), + "go" => Ok(tree_sitter_go::LANGUAGE), + "c" => Ok(tree_sitter_c::LANGUAGE), + "cpp" => Ok(tree_sitter_cpp::LANGUAGE), + "lua" => Ok(tree_sitter_lua::LANGUAGE), + "ruby" => Ok(tree_sitter_ruby::LANGUAGE), + _ => Err(format!("Unsupported language: {language}")), + } +} + +const C_QUERY: &str = include_str!("../queries/tree-sitter-c-defs.scm"); +const CPP_QUERY: &str = include_str!("../queries/tree-sitter-cpp-defs.scm"); +const GO_QUERY: &str = include_str!("../queries/tree-sitter-go-defs.scm"); +const JAVASCRIPT_QUERY: &str = include_str!("../queries/tree-sitter-javascript-defs.scm"); +const LUA_QUERY: &str = include_str!("../queries/tree-sitter-lua-defs.scm"); +const PYTHON_QUERY: &str = include_str!("../queries/tree-sitter-python-defs.scm"); +const RUST_QUERY: &str = include_str!("../queries/tree-sitter-rust-defs.scm"); +const TYPESCRIPT_QUERY: &str = include_str!("../queries/tree-sitter-typescript-defs.scm"); +const RUBY_QUERY: &str = include_str!("../queries/tree-sitter-ruby-defs.scm"); + +fn get_definitions_query(language: &str) -> Result { + let ts_language = get_ts_language(language)?; + let contents = match language { + "c" => C_QUERY, + "cpp" => CPP_QUERY, + "go" => GO_QUERY, + "javascript" => JAVASCRIPT_QUERY, + "lua" => LUA_QUERY, + "python" => PYTHON_QUERY, + "rust" => RUST_QUERY, + "typescript" => TYPESCRIPT_QUERY, + "ruby" => RUBY_QUERY, + _ => return Err(format!("Unsupported language: {language}")), + }; + let query = Query::new(&ts_language.into(), contents) + .unwrap_or_else(|_| panic!("Failed to parse query for {language}")); + Ok(query) +} + +fn get_closest_ancestor_name(node: &Node, source: &str) -> String { + let mut parent = node.parent(); + while let Some(parent_node) = parent { + let name_node = parent_node.child_by_field_name("name"); + if let Some(name_node) = name_node { + return get_node_text(&name_node, source.as_bytes()).to_string(); + } + parent = parent_node.parent(); + } + String::new() +} + +fn find_ancestor_by_type<'a>(node: &'a Node, parent_type: &str) -> Option> { + let mut parent = node.parent(); + while let Some(parent_node) = parent { + if parent_node.kind() == parent_type { + return Some(parent_node); + } + parent = parent_node.parent(); + } + None +} + +fn find_descendant_by_type<'a>(node: &'a Node, child_type: &str) -> Option> { + let mut cursor = node.walk(); + for i in 0..node.descendant_count() { + cursor.goto_descendant(i); + let node = cursor.node(); + if node.kind() == child_type { + return Some(node); + } + } + None +} + +fn find_child_by_type<'a>(node: &'a Node, child_type: &str) -> Option> { + node.children(&mut node.walk()) + .find(|child| child.kind() == child_type) +} + +fn get_node_text<'a>(node: &'a Node, source: &'a [u8]) -> String { + node.utf8_text(source).unwrap_or_default().to_string() +} + +fn get_node_type<'a>(node: &'a Node, source: &'a [u8]) -> String { + let predefined_type_node = find_descendant_by_type(node, "predefined_type"); + if let Some(type_node) = predefined_type_node { + return type_node.utf8_text(source).unwrap().to_string(); + } + let value_type_node = node.child_by_field_name("type"); + value_type_node + .map(|n| n.utf8_text(source).unwrap().to_string()) + .unwrap_or_default() +} + +fn is_first_letter_uppercase(name: &str) -> bool { + if name.is_empty() { + return false; + } + name.chars().next().unwrap().is_uppercase() +} + +// Given a language, parse the given source code and return exported definitions +fn extract_definitions(language: &str, source: &str) -> Result, String> { + let ts_language = get_ts_language(language)?; + + let mut definitions = Vec::new(); + let mut parser = Parser::new(); + parser + .set_language(&ts_language.into()) + .unwrap_or_else(|_| panic!("Failed to set language for {language}")); + let tree = parser + .parse(source, None) + .unwrap_or_else(|| panic!("Failed to parse source code for {language}")); + let root_node = tree.root_node(); + + let query = get_definitions_query(language)?; + let mut query_cursor = QueryCursor::new(); + let captures = query_cursor.captures(&query, root_node, source.as_bytes()); + + let mut class_def_map: HashMap> = HashMap::new(); + let mut enum_def_map: HashMap> = HashMap::new(); + + let ensure_class_def = |name: &str, class_def_map: &mut HashMap>| { + class_def_map.entry(name.to_string()).or_insert_with(|| { + RefCell::new(Class { + name: name.to_string(), + methods: vec![], + properties: vec![], + visibility_modifier: None, + }) + }); + }; + + let ensure_enum_def = |name: &str, enum_def_map: &mut HashMap>| { + enum_def_map.entry(name.to_string()).or_insert_with(|| { + RefCell::new(Enum { + name: name.to_string(), + items: vec![], + }) + }); + }; + + for (m, _) in captures { + for capture in m.captures { + let capture_name = &query.capture_names()[capture.index as usize]; + let node = capture.node; + let name_node = node.child_by_field_name("name"); + let name = name_node + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or(""); + match *capture_name { + "class" => { + if !name.is_empty() { + if language == "go" && !is_first_letter_uppercase(name) { + continue; + } + ensure_class_def(name, &mut class_def_map); + let visibility_modifier_node = + find_child_by_type(&node, "visibility_modifier"); + let visibility_modifier = visibility_modifier_node + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or(""); + let class_def = class_def_map.get_mut(name).unwrap(); + class_def.borrow_mut().visibility_modifier = + if visibility_modifier.is_empty() { + None + } else { + Some(visibility_modifier.to_string()) + }; + } + } + "enum_item" => { + let visibility_modifier_node = + find_descendant_by_type(&node, "visibility_modifier"); + let visibility_modifier = visibility_modifier_node + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or(""); + if language == "rust" && !visibility_modifier.contains("pub") { + continue; + } + let enum_name = get_closest_ancestor_name(&node, source); + if !enum_name.is_empty() + && language == "go" + && !is_first_letter_uppercase(&enum_name) + { + continue; + } + ensure_enum_def(&enum_name, &mut enum_def_map); + let enum_def = enum_def_map.get_mut(&enum_name).unwrap(); + let enum_type_node = find_descendant_by_type(&node, "type_identifier"); + let enum_type = enum_type_node + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or(""); + let variable = Variable { + name: name.to_string(), + value_type: enum_type.to_string(), + }; + enum_def.borrow_mut().items.push(variable); + } + "method" => { + let visibility_modifier_node = + find_descendant_by_type(&node, "visibility_modifier"); + let visibility_modifier = visibility_modifier_node + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or(""); + if language == "rust" && !visibility_modifier.contains("pub") { + continue; + } + if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) { + continue; + } + let params_node = node.child_by_field_name("parameters"); + let params = params_node + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or("()"); + let mut return_type_node = node.child_by_field_name("return_type"); + if return_type_node.is_none() { + return_type_node = node.child_by_field_name("result"); + } + let mut return_type = "void".to_string(); + if return_type_node.is_some() { + return_type = get_node_type(&return_type_node.unwrap(), source.as_bytes()); + if return_type.is_empty() { + return_type = return_type_node + .unwrap() + .utf8_text(source.as_bytes()) + .unwrap_or("void") + .to_string(); + } + } + + let impl_item_node = find_ancestor_by_type(&node, "impl_item"); + let receiver_node = node.child_by_field_name("receiver"); + let class_name = if let Some(impl_item) = impl_item_node { + let impl_type_node = impl_item.child_by_field_name("type"); + impl_type_node + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or("") + .to_string() + } else if let Some(receiver) = receiver_node { + let type_identifier_node = + find_descendant_by_type(&receiver, "type_identifier"); + type_identifier_node + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or("") + .to_string() + } else { + get_closest_ancestor_name(&node, source).to_string() + }; + + if language == "go" && !is_first_letter_uppercase(&class_name) { + continue; + } + + ensure_class_def(&class_name, &mut class_def_map); + let class_def = class_def_map.get_mut(&class_name).unwrap(); + + let accessibility_modifier_node = + find_descendant_by_type(&node, "accessibility_modifier"); + let accessibility_modifier = accessibility_modifier_node + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or(""); + + let func = Func { + name: name.to_string(), + params: params.to_string(), + return_type: return_type.to_string(), + accessibility_modifier: if accessibility_modifier.is_empty() { + None + } else { + Some(accessibility_modifier.to_string()) + }, + }; + class_def.borrow_mut().methods.push(func); + } + "class_assignment" => { + let visibility_modifier_node = + find_descendant_by_type(&node, "visibility_modifier"); + let visibility_modifier = visibility_modifier_node + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or(""); + if language == "rust" && !visibility_modifier.contains("pub") { + continue; + } + let left_node = node.child_by_field_name("left"); + let left = left_node + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or(""); + let value_type = get_node_type(&node, source.as_bytes()); + let class_name = get_closest_ancestor_name(&node, source); + if !class_name.is_empty() + && language == "go" + && !is_first_letter_uppercase(&class_name) + { + continue; + } + if class_name.is_empty() { + continue; + } + ensure_class_def(&class_name, &mut class_def_map); + let class_def = class_def_map.get_mut(&class_name).unwrap(); + let variable = Variable { + name: left.to_string(), + value_type: value_type.to_string(), + }; + class_def.borrow_mut().properties.push(variable); + } + "class_variable" => { + let visibility_modifier_node = + find_descendant_by_type(&node, "visibility_modifier"); + let visibility_modifier = visibility_modifier_node + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or(""); + if language == "rust" && !visibility_modifier.contains("pub") { + continue; + } + let value_type = get_node_type(&node, source.as_bytes()); + let class_name = get_closest_ancestor_name(&node, source); + if !class_name.is_empty() + && language == "go" + && !is_first_letter_uppercase(&class_name) + { + continue; + } + if class_name.is_empty() { + continue; + } + if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) { + continue; + } + ensure_class_def(&class_name, &mut class_def_map); + let class_def = class_def_map.get_mut(&class_name).unwrap(); + let variable = Variable { + name: name.to_string(), + value_type: value_type.to_string(), + }; + class_def.borrow_mut().properties.push(variable); + } + "function" | "arrow_function" => { + let visibility_modifier_node = + find_descendant_by_type(&node, "visibility_modifier"); + let visibility_modifier = visibility_modifier_node + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or(""); + if language == "rust" && !visibility_modifier.contains("pub") { + continue; + } + if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) { + continue; + } + let impl_item_node = find_ancestor_by_type(&node, "impl_item"); + if impl_item_node.is_some() { + continue; + } + let function_node = find_ancestor_by_type(&node, "function_declaration") + .or_else(|| find_ancestor_by_type(&node, "function_definition")); + if function_node.is_some() { + continue; + } + let params_node = node.child_by_field_name("parameters"); + let params = params_node + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or("()"); + let mut return_type_node = node.child_by_field_name("return_type"); + if return_type_node.is_none() { + return_type_node = node.child_by_field_name("result"); + } + let mut return_type = "void".to_string(); + if return_type_node.is_some() { + return_type = get_node_type(&return_type_node.unwrap(), source.as_bytes()); + if return_type.is_empty() { + return_type = return_type_node + .unwrap() + .utf8_text(source.as_bytes()) + .unwrap_or("void") + .to_string(); + } + } + + let accessibility_modifier_node = + find_descendant_by_type(&node, "accessibility_modifier"); + let accessibility_modifier = accessibility_modifier_node + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or(""); + + let func = Func { + name: name.to_string(), + params: params.to_string(), + return_type: return_type.to_string(), + accessibility_modifier: if accessibility_modifier.is_empty() { + None + } else { + Some(accessibility_modifier.to_string()) + }, + }; + definitions.push(Definition::Func(func)); + } + "assignment" => { + let visibility_modifier_node = + find_descendant_by_type(&node, "visibility_modifier"); + let visibility_modifier = visibility_modifier_node + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or(""); + if language == "rust" && !visibility_modifier.contains("pub") { + continue; + } + let impl_item_node = find_ancestor_by_type(&node, "impl_item") + .or_else(|| find_ancestor_by_type(&node, "class_declaration")) + .or_else(|| find_ancestor_by_type(&node, "class_definition")); + if impl_item_node.is_some() { + continue; + } + let function_node = find_ancestor_by_type(&node, "function_declaration") + .or_else(|| find_ancestor_by_type(&node, "function_definition")); + if function_node.is_some() { + continue; + } + let left_node = node.child_by_field_name("left"); + let left = left_node + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or(""); + if !left.is_empty() && language == "go" && !is_first_letter_uppercase(left) { + continue; + } + + let value_type = get_node_type(&node, source.as_bytes()); + let variable = Variable { + name: left.to_string(), + value_type: value_type.to_string(), + }; + definitions.push(Definition::Variable(variable)); + } + "variable" => { + let visibility_modifier_node = + find_descendant_by_type(&node, "visibility_modifier"); + let visibility_modifier = visibility_modifier_node + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or(""); + if language == "rust" && !visibility_modifier.contains("pub") { + continue; + } + let impl_item_node = find_ancestor_by_type(&node, "impl_item") + .or_else(|| find_ancestor_by_type(&node, "class_declaration")) + .or_else(|| find_ancestor_by_type(&node, "class_definition")); + if impl_item_node.is_some() { + continue; + } + let function_node = find_ancestor_by_type(&node, "function_declaration") + .or_else(|| find_ancestor_by_type(&node, "function_definition")); + if function_node.is_some() { + continue; + } + let value_node = node.child_by_field_name("value"); + if value_node.is_some() { + let value_type = value_node.unwrap().kind(); + if value_type == "arrow_function" { + let params_node = value_node.unwrap().child_by_field_name("parameters"); + let params = params_node + .map(|n| n.utf8_text(source.as_bytes()).unwrap()) + .unwrap_or("()"); + let mut return_type = "void".to_string(); + let return_type_node = + value_node.unwrap().child_by_field_name("return_type"); + if return_type_node.is_some() { + return_type = + get_node_type(&return_type_node.unwrap(), source.as_bytes()); + } + let func = Func { + name: name.to_string(), + params: params.to_string(), + return_type, + accessibility_modifier: None, + }; + definitions.push(Definition::Func(func)); + continue; + } + } + let value_type = get_node_type(&node, source.as_bytes()); + if !name.is_empty() && language == "go" && !is_first_letter_uppercase(name) { + continue; + } + let variable = Variable { + name: name.to_string(), + value_type: value_type.to_string(), + }; + definitions.push(Definition::Variable(variable)); + } + _ => {} + } + } + } + + for (_, def) in class_def_map { + let class_def = def.into_inner(); + if language == "rust" { + if let Some(visibility_modifier) = &class_def.visibility_modifier { + if visibility_modifier.contains("pub") { + definitions.push(Definition::Class(class_def)); + } + } + } else { + definitions.push(Definition::Class(class_def)); + } + } + + for (_, def) in enum_def_map { + definitions.push(Definition::Enum(def.into_inner())); + } + + Ok(definitions) +} + +fn stringify_function(func: &Func) -> String { + let mut res = format!("func {}", func.name); + if func.params.is_empty() { + res = format!("{res}()"); + } else { + res = format!("{res}{}", func.params); + } + if !func.return_type.is_empty() { + res = format!("{res} -> {}", func.return_type); + } + if let Some(modifier) = &func.accessibility_modifier { + res = format!("{modifier} {res}"); + } + format!("{res};") +} + +fn stringify_variable(variable: &Variable) -> String { + let mut res = format!("var {}", variable.name); + if !variable.value_type.is_empty() { + res = format!("{res}:{}", variable.value_type); + } + format!("{res};") +} + +fn stringify_enum_item(item: &Variable) -> String { + let mut res = item.name.clone(); + if !item.value_type.is_empty() { + res = format!("{res}:{}", item.value_type); + } + format!("{res};") +} + +fn stringify_class(class: &Class) -> String { + let mut res = format!("class {}{{", class.name); + for method in &class.methods { + let method_str = stringify_function(method); + res = format!("{res}{method_str}"); + } + for property in &class.properties { + let property_str = stringify_variable(property); + res = format!("{res}{property_str}"); + } + format!("{res}}};") +} + +fn stringify_enum(enum_def: &Enum) -> String { + let mut res = format!("enum {}{{", enum_def.name); + for item in &enum_def.items { + let item_str = stringify_enum_item(item); + res = format!("{res}{item_str}"); + } + format!("{res}}};") +} + +fn stringify_definitions(definitions: &Vec) -> String { + let mut res = String::new(); + for definition in definitions { + match definition { + Definition::Class(class) => res = format!("{res}{}", stringify_class(class)), + Definition::Enum(enum_def) => res = format!("{res}{}", stringify_enum(enum_def)), + Definition::Func(func) => res = format!("{res}{}", stringify_function(func)), + Definition::Variable(variable) => { + let variable_str = stringify_variable(variable); + res = format!("{res}{variable_str}"); + } + } + } + res +} + +pub fn get_definitions_string(language: &str, source: &str) -> LuaResult { + let definitions = + extract_definitions(language, source).map_err(|e| LuaError::RuntimeError(e.to_string()))?; + let stringified = stringify_definitions(&definitions); + Ok(stringified) +} + +#[mlua::lua_module] +fn avante_repo_map(lua: &Lua) -> LuaResult { + let exports = lua.create_table()?; + exports.set( + "stringify_definitions", + lua.create_function(move |_, (language, source): (String, String)| { + get_definitions_string(language.as_str(), source.as_str()) + })?, + )?; + Ok(exports) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_rust() { + let source = r#" + // This is a test comment + pub const TEST_CONST: u32 = 1; + pub static TEST_STATIC: u32 = 2; + const INNER_TEST_CONST: u32 = 3; + static INNER_TEST_STATIC: u32 = 4; + pub(crate) struct TestStruct { + pub test_field: String, + inner_test_field: String, + } + impl TestStruct { + pub fn test_method(&self, a: u32, b: u32) -> u32 { + a + b + } + fn inner_test_method(&self, a: u32, b: u32) -> u32 { + a + b + } + } + struct InnerTestStruct { + pub test_field: String, + inner_test_field: String, + } + impl InnerTestStruct { + pub fn test_method(&self, a: u32, b: u32) -> u32 { + a + b + } + fn inner_test_method(&self, a: u32, b: u32) -> u32 { + a + b + } + } + pub enum TestEnum { + TestEnumField1, + TestEnumField2, + } + enum InnerTestEnum { + InnerTestEnumField1, + InnerTestEnumField2, + } + pub fn test_fn(a: u32, b: u32) -> u32 { + a + b + } + fn inner_test_fn(a: u32, b: u32) -> u32 { + a + b + } + "#; + let definitions = extract_definitions("rust", source).unwrap(); + let stringified = stringify_definitions(&definitions); + println!("{stringified}"); + let expected = "var TEST_CONST:u32;var TEST_STATIC:u32;func test_fn(a: u32, b: u32) -> u32;class TestStruct{func test_method(&self, a: u32, b: u32) -> u32;var test_field:String;};"; + assert_eq!(stringified, expected); + } + + #[test] + fn test_go() { + let source = r#" + // This is a test comment + package main + import "fmt" + const TestConst string = "test" + const innerTestConst string = "test" + var TestVar string + var innerTestVar string + type TestStruct struct { + TestField string + innerTestField string + } + func (t *TestStruct) TestMethod(a int, b int) (int, error) { + return a + b, nil + } + func (t *TestStruct) innerTestMethod(a int, b int) (int, error) { + return a + b, nil + } + type innerTestStruct struct { + innerTestField string + } + func (t *innerTestStruct) testMethod(a int, b int) (int, error) { + return a + b, nil + } + func (t *innerTestStruct) innerTestMethod(a int, b int) (int, error) { + return a + b, nil + } + func TestFunc(a int, b int) (int, error) { + return a + b, nil + } + func innerTestFunc(a int, b int) (int, error) { + return a + b, nil + } + "#; + let definitions = extract_definitions("go", source).unwrap(); + let stringified = stringify_definitions(&definitions); + println!("{stringified}"); + let expected = "var TestConst:string;var TestVar:string;func TestFunc(a int, b int) -> (int, error);class TestStruct{func TestMethod(a int, b int) -> (int, error);var TestField:string;};"; + assert_eq!(stringified, expected); + } + + #[test] + fn test_python() { + let source = r#" + # This is a test comment + test_var: str = "test" + class TestClass: + def __init__(self, a, b): + self.a = a + self.b = b + def test_method(self, a: int, b: int) -> int: + return a + b + def test_func(a: int, b: int) -> int: + return a + b + "#; + let definitions = extract_definitions("python", source).unwrap(); + let stringified = stringify_definitions(&definitions); + println!("{stringified}"); + let expected = "var test_var:str;func test_func(a: int, b: int) -> int;class TestClass{func __init__(self, a, b) -> void;func test_method(self, a: int, b: int) -> int;};"; + assert_eq!(stringified, expected); + } + + #[test] + fn test_typescript() { + let source = r#" + // This is a test comment + export const testVar: string = "test"; + const innerTestVar: string = "test"; + export class TestClass { + a: number; + b: number; + constructor(a: number, b: number) { + this.a = a; + this.b = b; + } + testMethod(a: number, b: number): number { + return a + b; + } + } + class InnerTestClass { + a: number; + b: number; + } + export function testFunc(a: number, b: number) { + return a + b; + } + export const testFunc2 = (a: number, b: number) => { + return a + b; + } + export const testFunc3 = (a: number, b: number): number => { + return a + b; + } + function innerTestFunc(a: number, b: number) { + return a + b; + } + "#; + let definitions = extract_definitions("typescript", source).unwrap(); + let stringified = stringify_definitions(&definitions); + println!("{stringified}"); + let expected = "var testVar:string;func testFunc(a: number, b: number) -> void;func testFunc2(a: number, b: number) -> void;func testFunc3(a: number, b: number) -> number;class TestClass{func constructor(a: number, b: number) -> void;func testMethod(a: number, b: number) -> number;var a:number;var b:number;};" +; + assert_eq!(stringified, expected); + } + + #[test] + fn test_javascript() { + let source = r#" + // This is a test comment + export const testVar = "test"; + const innerTestVar = "test"; + export class TestClass { + constructor(a, b) { + this.a = a; + this.b = b; + } + testMethod(a, b) { + return a + b; + } + } + class InnerTestClass { + constructor(a, b) { + this.a = a; + this.b = b; + } + } + export const testFunc = function(a, b) { + return a + b; + } + export const testFunc2 = (a, b) => { + return a + b; + } + export const testFunc3 = (a, b) => a + b; + function innerTestFunc(a, b) { + return a + b; + } + "#; + let definitions = extract_definitions("javascript", source).unwrap(); + let stringified = stringify_definitions(&definitions); + println!("{stringified}"); + let expected = "var testVar;var testFunc;func testFunc2(a, b) -> void;func testFunc3(a, b) -> void;class TestClass{func constructor(a, b) -> void;func testMethod(a, b) -> void;};"; + assert_eq!(stringified, expected); + } + + #[test] + fn test_ruby() { + let source = r#" + # This is a test comment + test_var = "test" + def test_func(a, b) + return a + b + end + class TestClass + attr_accessor :a, :b + def initialize(a, b) + @a = a + @b = b + end + def test_method(a, b) + return a + b + end + end + "#; + let definitions = extract_definitions("ruby", source).unwrap(); + let stringified = stringify_definitions(&definitions); + println!("{stringified}"); + // FIXME: + let expected = "var test_var;func test_func(a, b) -> void;"; + assert_eq!(stringified, expected); + } + + #[test] + fn test_lua() { + let source = r#" + -- This is a test comment + local test_var = "test" + function test_func(a, b) + return a + b + end + "#; + let definitions = extract_definitions("lua", source).unwrap(); + let stringified = stringify_definitions(&definitions); + println!("{stringified}"); + let expected = "var test_var;func test_func(a, b) -> void;"; + assert_eq!(stringified, expected); + } +} diff --git a/lua/avante/init.lua b/lua/avante/init.lua index 08461cb9b..0af23eeda 100644 --- a/lua/avante/init.lua +++ b/lua/avante/init.lua @@ -397,6 +397,7 @@ function M.setup(opts) if M.did_setup then return end + require("avante.repo_map").setup() require("avante.path").setup() require("avante.highlights").setup() require("avante.diff").setup() diff --git a/lua/avante/llm.lua b/lua/avante/llm.lua index 434104d92..2aff73cf1 100644 --- a/lua/avante/llm.lua +++ b/lua/avante/llm.lua @@ -85,6 +85,7 @@ M.stream = function(opts) :totable() Utils.debug(user_prompts) + -- print(user_prompts[1]) ---@type AvantePromptOptions local code_opts = { diff --git a/lua/avante/repo_map.lua b/lua/avante/repo_map.lua new file mode 100644 index 000000000..f07dc731a --- /dev/null +++ b/lua/avante/repo_map.lua @@ -0,0 +1,147 @@ +local Utils = require("avante.utils") + +local filetype_map = { + ["javascriptreact"] = "javascript", + ["typescriptreact"] = "typescript", +} + +---@class AvanteRepoMap +---@field stringify_definitions fun(lang: string, source: string): string +local repo_map_lib = nil + +---@class avante.utils.repo_map +local RepoMap = {} + +function RepoMap.setup() + vim.defer_fn(function() + local ok, core = pcall(require, "avante_repo_map") + if not ok then + error("Failed to load avante_repo_map") + return + end + + if repo_map_lib == nil then repo_map_lib = core end + end, 1000) +end + +function RepoMap.get_ts_lang(filepath) + local filetype = vim.filetype.match({ filename = filepath }) + return filetype_map[filetype] or filetype +end + +function RepoMap.get_filetype(filepath) return vim.filetype.match({ filename = filepath }) end + +function RepoMap._build_repo_map(project_root, file_ext) + local output = {} + local gitignore_path = project_root .. "/.gitignore" + local ignore_patterns, negate_patterns = Utils.parse_gitignore(gitignore_path) + local filepaths = Utils.scan_directory(project_root, ignore_patterns, negate_patterns) + vim.iter(filepaths):each(function(filepath) + if not Utils.is_same_file_ext(file_ext, filepath) then return end + local definitions = + repo_map_lib.stringify_definitions(RepoMap.get_ts_lang(filepath), Utils.file.read_content(filepath) or "") + if definitions == "" then return end + table.insert(output, { + path = Utils.relative_path(filepath), + lang = RepoMap.get_filetype(filepath), + defs = definitions, + }) + end) + return output +end + +local cache = {} + +function RepoMap.get_repo_map(file_ext) + file_ext = file_ext or vim.fn.expand("%:e") + local project_root = Utils.root.get() + local cache_key = project_root .. "." .. file_ext + local cached = cache[cache_key] + if cached then return cached end + + local PPath = require("plenary.path") + local Path = require("avante.path") + local repo_map + + local function build_and_save() + repo_map = RepoMap._build_repo_map(project_root, file_ext) + cache[cache_key] = repo_map + Path.repo_map.save(project_root, file_ext, repo_map) + end + + repo_map = Path.repo_map.load(project_root, file_ext) + + if not repo_map or next(repo_map) == nil then + build_and_save() + if not repo_map then return end + else + local timer = vim.loop.new_timer() + + if timer then + timer:start( + 0, + 0, + vim.schedule_wrap(function() + build_and_save() + timer:close() + end) + ) + end + end + + local update_repo_map = vim.schedule_wrap(function(rel_filepath) + if rel_filepath and Utils.is_same_file_ext(file_ext, rel_filepath) then + local abs_filepath = PPath:new(project_root):joinpath(rel_filepath):absolute() + local definitions = repo_map_lib.stringify_definitions( + RepoMap.get_ts_lang(abs_filepath), + Utils.file.read_content(abs_filepath) or "" + ) + if definitions == "" then return end + local found = false + for _, m in ipairs(repo_map) do + if m.path == rel_filepath then + m.defs = definitions + found = true + break + end + end + if not found then + table.insert(repo_map, { + path = Utils.relative_path(abs_filepath), + lang = RepoMap.get_filetype(abs_filepath), + defs = definitions, + }) + end + cache[cache_key] = repo_map + Path.repo_map.save(project_root, file_ext, repo_map) + end + end) + + local handle = vim.loop.new_fs_event() + + if handle then + handle:start(project_root, { recursive = true }, function(err, rel_filepath) + if err then + print("Error watching directory " .. project_root .. ":", err) + return + end + + if rel_filepath then update_repo_map(rel_filepath) end + end) + end + + vim.api.nvim_create_autocmd({ "BufReadPost", "BufNewFile" }, { + callback = function(ev) + vim.defer_fn(function() + local filepath = vim.api.nvim_buf_get_name(ev.buf) + if not vim.startswith(filepath, project_root) then return end + local rel_filepath = Utils.relative_path(filepath) + update_repo_map(rel_filepath) + end, 0) + end, + }) + + return repo_map +end + +return RepoMap diff --git a/lua/avante/selection.lua b/lua/avante/selection.lua index fd2cf169b..5fc274123 100644 --- a/lua/avante/selection.lua +++ b/lua/avante/selection.lua @@ -2,6 +2,7 @@ local Utils = require("avante.utils") local Config = require("avante.config") local Llm = require("avante.llm") local Provider = require("avante.providers") +local RepoMap = require("avante.repo_map") local api = vim.api local fn = vim.fn @@ -394,7 +395,7 @@ function Selection:create_editing_input() local mentions = Utils.extract_mentions(input) input = mentions.new_content - local project_context = mentions.enable_project_context and Utils.repo_map.get_repo_map(file_ext) or nil + local project_context = mentions.enable_project_context and RepoMap.get_repo_map(file_ext) or nil Llm.stream({ bufnr = code_bufnr, diff --git a/lua/avante/sidebar.lua b/lua/avante/sidebar.lua index 04f0e377f..1524e6b02 100644 --- a/lua/avante/sidebar.lua +++ b/lua/avante/sidebar.lua @@ -11,6 +11,7 @@ local Diff = require("avante.diff") local Llm = require("avante.llm") local Utils = require("avante.utils") local Highlights = require("avante.highlights") +local RepoMap = require("avante.repo_map") local RESULT_BUF_NAME = "AVANTE_RESULT" local VIEW_BUFFER_UPDATED_PATTERN = "AvanteViewBufferUpdated" @@ -1295,7 +1296,7 @@ function Sidebar:create_input(opts) local file_ext = api.nvim_buf_get_name(self.code.bufnr):match("^.+%.(.+)$") - local project_context = mentions.enable_project_context and Utils.repo_map.get_repo_map(file_ext) or nil + local project_context = mentions.enable_project_context and RepoMap.get_repo_map(file_ext) or nil Llm.stream({ bufnr = self.code.bufnr, diff --git a/lua/avante/utils/init.lua b/lua/avante/utils/init.lua index 8af2fc7a8..d8d8ed8be 100644 --- a/lua/avante/utils/init.lua +++ b/lua/avante/utils/init.lua @@ -6,6 +6,7 @@ local lsp = vim.lsp ---@field tokens avante.utils.tokens ---@field root avante.utils.root ---@field repo_map avante.utils.repo_map +---@field file avante.utils.file local M = {} setmetatable(M, { diff --git a/lua/avante/utils/repo_map.lua b/lua/avante/utils/repo_map.lua deleted file mode 100644 index 40252dab0..000000000 --- a/lua/avante/utils/repo_map.lua +++ /dev/null @@ -1,730 +0,0 @@ -local parsers = require("nvim-treesitter.parsers") -local Config = require("avante.config") - -local get_node_text = vim.treesitter.get_node_text - ----@class avante.utils.repo_map -local RepoMap = {} - -local dependencies_queries = { - lua = [[ - (function_call - name: (identifier) @function_name - arguments: (arguments - (string) @required_file)) - ]], - - python = [[ - (import_from_statement - module_name: (dotted_name) @import_module) - (import_statement - (dotted_name) @import_module) - ]], - - javascript = [[ - (import_statement - source: (string) @import_module) - (call_expression - function: (identifier) @function_name - arguments: (arguments - (string) @required_file)) - ]], - - typescript = [[ - (import_statement - source: (string) @import_module) - (call_expression - function: (identifier) @function_name - arguments: (arguments - (string) @required_file)) - ]], - - go = [[ - (import_spec - path: (interpreted_string_literal) @import_module) - ]], - - rust = [[ - (use_declaration - (scoped_identifier) @import_module) - (use_declaration - (identifier) @import_module) - ]], - - c = [[ - (preproc_include - (string_literal) @import_module) - (preproc_include - (system_lib_string) @import_module) - ]], - - cpp = [[ - (preproc_include - (string_literal) @import_module) - (preproc_include - (system_lib_string) @import_module) - ]], -} - -local definitions_queries = { - python = [[ - ;; Capture top-level functions, class, and method definitions - (module - (expression_statement - (assignment) @assignment - ) - ) - (module - (function_definition) @function - ) - (module - (class_definition - body: (block - (expression_statement - (assignment) @class_assignment - ) - ) - ) - ) - (module - (class_definition - body: (block - (function_definition) @method - ) - ) - ) - ]], - javascript = [[ - ;; Capture exported functions, arrow functions, variables, classes, and method definitions - (export_statement - declaration: (lexical_declaration - (variable_declarator) @variable - ) - ) - (export_statement - declaration: (function_declaration) @function - ) - (export_statement - declaration: (class_declaration - body: (class_body - (field_definition) @class_variable - ) - ) - ) - (export_statement - declaration: (class_declaration - body: (class_body - (method_definition) @method - ) - ) - ) - ]], - typescript = [[ - ;; Capture exported functions, arrow functions, variables, classes, and method definitions - (export_statement - declaration: (lexical_declaration - (variable_declarator) @variable - ) - ) - (export_statement - declaration: (function_declaration) @function - ) - (export_statement - declaration: (class_declaration - body: (class_body - (public_field_definition) @class_variable - ) - ) - ) - (interface_declaration - body: (interface_body - (property_signature) @class_variable - ) - ) - (type_alias_declaration - value: (object_type - (property_signature) @class_variable - ) - ) - (export_statement - declaration: (class_declaration - body: (class_body - (method_definition) @method - ) - ) - ) - ]], - rust = [[ - ;; Capture public functions, structs, methods, and variable definitions - (function_item) @function - (impl_item - body: (declaration_list - (function_item) @method - ) - ) - (struct_item - body: (field_declaration_list - (field_declaration) @class_variable - ) - ) - (enum_item - body: (enum_variant_list - (enum_variant) @enum_item - ) - ) - (const_item) @variable - ]], - go = [[ - ;; Capture top-level functions and struct definitions - (var_declaration - (var_spec) @variable - ) - (const_declaration - (const_spec) @variable - ) - (function_declaration) @function - (type_declaration - (type_spec (struct_type)) @class - ) - (type_declaration - (type_spec - (struct_type - (field_declaration_list - (field_declaration) @class_variable))) - ) - (method_declaration) @method - ]], - c = [[ - ;; Capture extern functions, variables, public classes, and methods - (function_definition - (storage_class_specifier) @extern - ) @function - (class_specifier - (public) @class - (function_definition) @method - ) @class - (declaration - (storage_class_specifier) @extern - ) @variable - ]], - cpp = [[ - ;; Capture extern functions, variables, public classes, and methods - (function_definition - (storage_class_specifier) @extern - ) @function - (class_specifier - (public) @class - (function_definition) @method - ) @class - (declaration - (storage_class_specifier) @extern - ) @variable - ]], - lua = [[ - ;; Capture function and method definitions - (variable_list) @variable - (function_declaration) @function - ]], - ruby = [[ - ;; Capture top-level methods, class definitions, and methods within classes - (method) @function - (assignment) @assignment - (class - body: (body_statement - (assignment) @class_assignment - (method) @method - ) - ) - ]], -} - -local queries_filetype_map = { - ["javascriptreact"] = "javascript", - ["typescriptreact"] = "typescript", -} - -local function get_query(queries, filetype) - filetype = queries_filetype_map[filetype] or filetype - return queries[filetype] -end - -local function get_ts_lang(bufnr) - local lang = parsers.get_buf_lang(bufnr) - return lang -end - -function RepoMap.get_parser(bufnr) - local lang = get_ts_lang(bufnr) - if not lang then return end - local parser = parsers.get_parser(bufnr, lang) - return parser, lang -end - -function RepoMap.extract_dependencies(bufnr) - local parser, lang = RepoMap.get_parser(bufnr) - if not lang or not parser or not dependencies_queries[lang] then - print("No parser or query available for this buffer's language: " .. (lang or "unknown")) - return {} - end - - local dependencies = {} - local tree = parser:parse()[1] - local root = tree:root() - local filetype = vim.api.nvim_get_option_value("filetype", { buf = bufnr }) - - local query = get_query(dependencies_queries, filetype) - if not query then return dependencies end - - local query_obj = vim.treesitter.query.parse(lang, query) - - for _, node, _ in query_obj:iter_captures(root, bufnr, 0, -1) do - -- local name = query.captures[id] - local required_file = vim.treesitter.get_node_text(node, bufnr):gsub('"', ""):gsub("'", "") - table.insert(dependencies, required_file) - end - - return dependencies -end - -function RepoMap.get_filetype_by_filepath(filepath) return vim.filetype.match({ filename = filepath }) end - -function RepoMap.parse_file(filepath) - local File = require("avante.utils.file") - local source = File.read_content(filepath) - - local filetype = RepoMap.get_filetype_by_filepath(filepath) - local lang = parsers.ft_to_lang(filetype) - if lang then - local ok, parser = pcall(vim.treesitter.get_string_parser, source, lang) - if ok then - local tree = parser:parse()[1] - local node = tree:root() - return { node = node, source = source } - else - print("parser error", parser) - end - end -end - -local function get_closest_parent_name(node, source) - local parent = node:parent() - while parent do - local name = parent:field("name")[1] - if name then return get_node_text(name, source) end - parent = parent:parent() - end - return "" -end - -local function find_parent_by_type(node, type) - local parent = node:parent() - while parent do - if parent:type() == type then return parent end - parent = parent:parent() - end - return nil -end - -local function find_child_by_type(node, type) - for child in node:iter_children() do - if child:type() == type then return child end - local res = find_child_by_type(child, type) - if res then return res end - end - return nil -end - -local function get_node_type(node, source) - local node_type - local predefined_type_node = find_child_by_type(node, "predefined_type") - if predefined_type_node then - node_type = get_node_text(predefined_type_node, source) - else - local value_type_node = node:field("type")[1] - node_type = value_type_node and get_node_text(value_type_node, source) or "" - end - return node_type -end - --- Function to extract definitions from the file -function RepoMap.extract_definitions(filepath) - local Utils = require("avante.utils") - - local filetype = RepoMap.get_filetype_by_filepath(filepath) - - if not filetype then return {} end - - -- Get the corresponding query for the detected language - local query = get_query(definitions_queries, filetype) - if not query then return {} end - - local parsed = RepoMap.parse_file(filepath) - if not parsed then return {} end - - -- Get the current buffer's syntax tree - local root = parsed.node - - local lang = parsers.ft_to_lang(filetype) - - -- Parse the query - local query_obj = vim.treesitter.query.parse(lang, query) - - -- Store captured results - local definitions = {} - - local class_def_map = {} - local enum_def_map = {} - - local function get_class_def(name) - local def = class_def_map[name] - if def == nil then - def = { - type = "class", - name = name, - methods = {}, - properties = {}, - } - class_def_map[name] = def - end - return def - end - - local function get_enum_def(name) - local def = enum_def_map[name] - if def == nil then - def = { - type = "enum", - name = name, - items = {}, - } - enum_def_map[name] = def - end - return def - end - - for _, captures, _ in query_obj:iter_matches(root, parsed.source) do - for id, node in pairs(captures) do - local type = query_obj.captures[id] - local name_node = node:field("name")[1] - local name = name_node and get_node_text(name_node, parsed.source) or "" - - if type == "class" then - if name ~= "" then get_class_def(name) end - elseif type == "enum_item" then - local enum_name = get_closest_parent_name(node, parsed.source) - if enum_name and filetype == "go" and not Utils.is_first_letter_uppercase(enum_name) then goto continue end - local enum_def = get_enum_def(enum_name) - local enum_type_node = find_child_by_type(node, "type_identifier") - local enum_type = enum_type_node and get_node_text(enum_type_node, parsed.source) or "" - table.insert(enum_def.items, { - name = name, - type = enum_type, - }) - elseif type == "method" then - if name and filetype == "go" and not Utils.is_first_letter_uppercase(name) then goto continue end - local params_node = node:field("parameters")[1] - local params = params_node and get_node_text(params_node, parsed.source) or "()" - local return_type_node = node:field("return_type")[1] or node:field("result")[1] - local return_type = return_type_node and get_node_text(return_type_node, parsed.source) or "void" - - local class_name - local impl_item_node = find_parent_by_type(node, "impl_item") - local receiver_node = node:field("receiver")[1] - if impl_item_node then - local impl_type_node = impl_item_node:field("type")[1] - class_name = impl_type_node and get_node_text(impl_type_node, parsed.source) or "" - elseif receiver_node then - local type_identifier_node = find_child_by_type(receiver_node, "type_identifier") - class_name = type_identifier_node and get_node_text(type_identifier_node, parsed.source) or "" - else - class_name = get_closest_parent_name(node, parsed.source) - end - local class_def = get_class_def(class_name) - - local accessibility_modifier_node = find_child_by_type(node, "accessibility_modifier") - local accessibility_modifier = accessibility_modifier_node - and get_node_text(accessibility_modifier_node, parsed.source) - or "" - - table.insert(class_def.methods, { - type = "function", - name = name, - params = params, - return_type = return_type, - accessibility_modifier = accessibility_modifier, - }) - elseif type == "class_assignment" then - local left_node = node:field("left")[1] - local left = left_node and get_node_text(left_node, parsed.source) or "" - - local value_type = get_node_type(node, parsed.source) - - local class_name = get_closest_parent_name(node, parsed.source) - if class_name and filetype == "go" and not Utils.is_first_letter_uppercase(class_name) then goto continue end - - local class_def = get_class_def(class_name) - - table.insert(class_def.properties, { - type = "variable", - name = left, - value_type = value_type, - }) - elseif type == "class_variable" then - local value_type = get_node_type(node, parsed.source) - - local class_name = get_closest_parent_name(node, parsed.source) - if class_name and filetype == "go" and not Utils.is_first_letter_uppercase(class_name) then goto continue end - - local class_def = get_class_def(class_name) - - table.insert(class_def.properties, { - type = "variable", - name = name, - value_type = value_type, - }) - elseif type == "function" or type == "arrow_function" then - if name and filetype == "go" and not Utils.is_first_letter_uppercase(name) then goto continue end - local impl_item_node = find_parent_by_type(node, "impl_item") - if impl_item_node then goto continue end - local function_node = find_parent_by_type(node, "function_declaration") - or find_parent_by_type(node, "function_definition") - if function_node then goto continue end - -- Extract function parameters and return type - local params_node = node:field("parameters")[1] - local params = params_node and get_node_text(params_node, parsed.source) or "()" - local return_type_node = node:field("return_type")[1] or node:field("result")[1] - local return_type = return_type_node and get_node_text(return_type_node, parsed.source) or "void" - - local accessibility_modifier_node = find_child_by_type(node, "accessibility_modifier") - local accessibility_modifier = accessibility_modifier_node - and get_node_text(accessibility_modifier_node, parsed.source) - or "" - - local def = { - type = "function", - name = name, - params = params, - return_type = return_type, - accessibility_modifier = accessibility_modifier, - } - table.insert(definitions, def) - elseif type == "assignment" then - local impl_item_node = find_parent_by_type(node, "impl_item") - or find_parent_by_type(node, "class_declaration") - or find_parent_by_type(node, "class_definition") - if impl_item_node then goto continue end - local function_node = find_parent_by_type(node, "function_declaration") - or find_parent_by_type(node, "function_definition") - if function_node then goto continue end - - local left_node = node:field("left")[1] - local left = left_node and get_node_text(left_node, parsed.source) or "" - - if left and filetype == "go" and not Utils.is_first_letter_uppercase(left) then goto continue end - - local value_type = get_node_type(node, parsed.source) - - local def = { - type = "variable", - name = left, - value_type = value_type, - } - table.insert(definitions, def) - elseif type == "variable" then - local impl_item_node = find_parent_by_type(node, "impl_item") - or find_parent_by_type(node, "class_declaration") - or find_parent_by_type(node, "class_definition") - if impl_item_node then goto continue end - local function_node = find_parent_by_type(node, "function_declaration") - or find_parent_by_type(node, "function_definition") - if function_node then goto continue end - - local value_type = get_node_type(node, parsed.source) - - if name and filetype == "go" and not Utils.is_first_letter_uppercase(name) then goto continue end - - local def = { type = "variable", name = name, value_type = value_type } - table.insert(definitions, def) - end - ::continue:: - end - end - - for _, def in pairs(class_def_map) do - table.insert(definitions, def) - end - - for _, def in pairs(enum_def_map) do - table.insert(definitions, def) - end - - return definitions -end - -local function stringify_function(def) - local res = "func " .. def.name .. def.params .. ":" .. def.return_type .. ";" - if def.accessibility_modifier and def.accessibility_modifier ~= "" then - res = def.accessibility_modifier .. " " .. res - end - return res -end - -local function stringify_variable(def) - local res = "var " .. def.name - if def.value_type and def.value_type ~= "" then res = res .. ":" .. def.value_type end - return res .. ";" -end - -local function stringify_enum_item(def) - local res = def.name - if def.value_type and def.value_type ~= "" then res = res .. ":" .. def.value_type end - return res .. ";" -end - --- Function to load file content into a temporary buffer, process it, and then delete the buffer -function RepoMap.stringify_definitions(filepath) - if vim.endswith(filepath, "~") then return "" end - - -- Extract definitions - local definitions = RepoMap.extract_definitions(filepath) - - local output = "" - -- Print or process the definitions - for _, def in ipairs(definitions) do - if def.type == "class" then - output = output .. def.type .. " " .. def.name .. "{" - for _, property in ipairs(def.properties) do - output = output .. stringify_variable(property) - end - for _, method in ipairs(def.methods) do - output = output .. stringify_function(method) - end - output = output .. "}" - elseif def.type == "enum" then - output = output .. def.type .. " " .. def.name .. "{" - for _, item in ipairs(def.items) do - output = output .. stringify_enum_item(item) .. "" - end - output = output .. "}" - elseif def.type == "function" then - output = output .. stringify_function(def) - elseif def.type == "variable" then - output = output .. stringify_variable(def) - end - end - - return output -end - -function RepoMap._build_repo_map(project_root, file_ext) - local Utils = require("avante.utils") - local output = {} - local gitignore_path = project_root .. "/.gitignore" - local ignore_patterns, negate_patterns = Utils.parse_gitignore(gitignore_path) - local filepaths = Utils.scan_directory(project_root, ignore_patterns, negate_patterns) - vim.iter(filepaths):each(function(filepath) - if not Utils.is_same_file_ext(file_ext, filepath) then return end - local definitions = RepoMap.stringify_definitions(filepath) - if definitions == "" then return end - table.insert(output, { - path = Utils.relative_path(filepath), - lang = RepoMap.get_filetype_by_filepath(filepath), - defs = definitions, - }) - end) - return output -end - -local cache = {} - -function RepoMap.get_repo_map(file_ext) - file_ext = file_ext or vim.fn.expand("%:e") - local Utils = require("avante.utils") - local project_root = Utils.root.get() - local cache_key = project_root .. "." .. file_ext - local cached = cache[cache_key] - if cached then return cached end - - local PPath = require("plenary.path") - local Path = require("avante.path") - local repo_map - - local function build_and_save() - repo_map = RepoMap._build_repo_map(project_root, file_ext) - cache[cache_key] = repo_map - Path.repo_map.save(project_root, file_ext, repo_map) - end - - repo_map = Path.repo_map.load(project_root, file_ext) - - if not repo_map or next(repo_map) == nil then - build_and_save() - if not repo_map then return end - else - local timer = vim.loop.new_timer() - - if timer then - timer:start( - 0, - 0, - vim.schedule_wrap(function() - build_and_save() - timer:close() - end) - ) - end - end - - local update_repo_map = vim.schedule_wrap(function(rel_filepath) - if rel_filepath and Utils.is_same_file_ext(file_ext, rel_filepath) then - local abs_filepath = PPath:new(project_root):joinpath(rel_filepath):absolute() - local definitions = RepoMap.stringify_definitions(abs_filepath) - if definitions == "" then return end - local found = false - for _, m in ipairs(repo_map) do - if m.path == rel_filepath then - m.defs = definitions - found = true - break - end - end - if not found then - table.insert(repo_map, { - path = Utils.relative_path(abs_filepath), - lang = RepoMap.get_filetype_by_filepath(abs_filepath), - defs = definitions, - }) - end - cache[cache_key] = repo_map - Path.repo_map.save(project_root, file_ext, repo_map) - end - end) - - local handle = vim.loop.new_fs_event() - - if handle then - handle:start(project_root, { recursive = true }, function(err, rel_filepath) - if err then - print("Error watching directory " .. project_root .. ":", err) - return - end - - if rel_filepath then update_repo_map(rel_filepath) end - end) - end - - vim.api.nvim_create_autocmd({ "BufReadPost", "BufNewFile" }, { - callback = function(ev) - vim.defer_fn(function() - local filepath = vim.api.nvim_buf_get_name(ev.buf) - if not vim.startswith(filepath, project_root) then return end - local rel_filepath = Utils.relative_path(filepath) - update_repo_map(rel_filepath) - end, 0) - end, - }) - - return repo_map -end - -return RepoMap