diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8d02a9edc..7a10b1425 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -41,3 +41,15 @@ jobs: components: clippy, rustfmt - name: Run rustfmt run: make ruststylecheck + rustlint: + name: Lint Rust + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: Swatinem/rust-cache@v2 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + components: clippy, rustfmt + - name: Run rustfmt + run: make rustlint diff --git a/crates/avante-templates/src/lib.rs b/crates/avante-templates/src/lib.rs index 300aac9fd..81703fef0 100644 --- a/crates/avante-templates/src/lib.rs +++ b/crates/avante-templates/src/lib.rs @@ -15,7 +15,7 @@ impl<'a> State<'a> { } } -#[derive(Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] struct TemplateContext { use_xml_format: bool, ask: bool, @@ -29,16 +29,17 @@ struct TemplateContext { // Given the file name registered after add, the context table in Lua, resulted in a formatted // Lua string -fn render(state: &State, template: String, context: TemplateContext) -> LuaResult { +#[allow(clippy::needless_pass_by_value)] +fn render(state: &State, template: &str, context: TemplateContext) -> LuaResult { let environment = state.environment.lock().unwrap(); match environment.as_ref() { Some(environment) => { - let template = environment - .get_template(&template) + let jinja_template = environment + .get_template(template) .map_err(LuaError::external) .unwrap(); - Ok(template + Ok(jinja_template .render(context! { use_xml_format => context.use_xml_format, ask => context.ask, @@ -84,7 +85,7 @@ fn avante_templates(lua: &Lua) -> LuaResult { "render", lua.create_function_mut(move |lua, (template, context): (String, LuaValue)| { let ctx = lua.from_value(context)?; - render(&state_clone, template, ctx) + render(&state_clone, template.as_str(), ctx) })?, )?; Ok(exports) diff --git a/crates/avante-tokenizers/src/lib.rs b/crates/avante-tokenizers/src/lib.rs index 39016fa4b..533b162a0 100644 --- a/crates/avante-tokenizers/src/lib.rs +++ b/crates/avante-tokenizers/src/lib.rs @@ -8,13 +8,13 @@ struct Tiktoken { } impl Tiktoken { - fn new(model: String) -> Self { - let bpe = get_bpe_from_model(&model).unwrap(); + fn new(model: &str) -> Self { + let bpe = get_bpe_from_model(model).unwrap(); Tiktoken { bpe } } - fn encode(&self, text: String) -> (Vec, usize, usize) { - let tokens = self.bpe.encode_with_special_tokens(&text); + fn encode(&self, text: &str) -> (Vec, usize, usize) { + let tokens = self.bpe.encode_with_special_tokens(text); let num_tokens = tokens.len(); let num_chars = text.chars().count(); (tokens, num_tokens, num_chars) @@ -26,13 +26,17 @@ struct HuggingFaceTokenizer { } impl HuggingFaceTokenizer { - fn new(model: String) -> Self { + fn new(model: &str) -> Self { let tokenizer = Tokenizer::from_pretrained(model, None).unwrap(); HuggingFaceTokenizer { tokenizer } } - fn encode(&self, text: String) -> (Vec, usize, usize) { - let encoding = self.tokenizer.encode(text, false).unwrap(); + fn encode(&self, text: &str) -> (Vec, usize, usize) { + let encoding = self + .tokenizer + .encode(text, false) + .map_err(LuaError::external) + .unwrap(); let tokens: Vec = encoding.get_ids().iter().map(|x| *x as usize).collect(); let num_tokens = tokens.len(); let num_chars = encoding.get_offsets().last().unwrap().1; @@ -57,7 +61,7 @@ impl State { } } -fn encode(state: &State, text: String) -> LuaResult<(Vec, usize, usize)> { +fn encode(state: &State, text: &str) -> LuaResult<(Vec, usize, usize)> { let tokenizer = state.tokenizer.lock().unwrap(); match tokenizer.as_ref() { Some(TokenizerType::Tiktoken(tokenizer)) => Ok(tokenizer.encode(text)), @@ -68,9 +72,9 @@ fn encode(state: &State, text: String) -> LuaResult<(Vec, usize, usize)> } } -fn from_pretrained(state: &State, model: String) { +fn from_pretrained(state: &State, model: &str) { let mut tokenizer_mutex = state.tokenizer.lock().unwrap(); - *tokenizer_mutex = Some(match model.as_str() { + *tokenizer_mutex = Some(match model { "gpt-4o" => TokenizerType::Tiktoken(Tiktoken::new(model)), _ => TokenizerType::HuggingFace(HuggingFaceTokenizer::new(model)), }); @@ -86,13 +90,13 @@ fn avante_tokenizers(lua: &Lua) -> LuaResult { exports.set( "from_pretrained", lua.create_function(move |_, model: String| { - from_pretrained(&state, model); + from_pretrained(&state, model.as_str()); Ok(()) })?, )?; exports.set( "encode", - lua.create_function(move |_, text: String| encode(&state_clone, text))?, + lua.create_function(move |_, text: String| encode(&state_clone, text.as_str()))?, )?; Ok(exports) }