Skip to content

Commit

Permalink
Add DistString impl to Uniform and Slice (#1315)
Browse files Browse the repository at this point in the history
* Add impl for `DistString` to `Uniform` and `Slice`

* Fix `DistString` impl.
  • Loading branch information
aobatact authored Jul 14, 2023
1 parent c354b6a commit ee80b41
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
34 changes: 34 additions & 0 deletions src/distributions/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
// except according to those terms.

use crate::distributions::{Distribution, Uniform};
#[cfg(feature = "alloc")]
use alloc::string::String;

/// A distribution to sample items uniformly from a slice.
///
Expand Down Expand Up @@ -115,3 +117,35 @@ impl core::fmt::Display for EmptySlice {

#[cfg(feature = "std")]
impl std::error::Error for EmptySlice {}

/// Note: the `String` is potentially left with excess capacity; optionally the
/// user may call `string.shrink_to_fit()` afterwards.
#[cfg(feature = "alloc")]
impl<'a> super::DistString for Slice<'a, char> {
fn append_string<R: crate::Rng + ?Sized>(&self, rng: &mut R, string: &mut String, len: usize) {
// Get the max char length to minimize extra space.
// Limit this check to avoid searching for long slice.
let max_char_len = if self.slice.len() < 200 {
self.slice
.iter()
.try_fold(1, |max_len, char| {
// When the current max_len is 4, the result max_char_len will be 4.
Some(max_len.max(char.len_utf8())).filter(|len| *len < 4)
})
.unwrap_or(4)
} else {
4
};

// Split the extension of string to reuse the unused capacities.
// Skip the split for small length or only ascii slice.
let mut extend_len = if max_char_len == 1 || len < 100 { len } else { len / 4 };
let mut remain_len = len;
while extend_len > 0 {
string.reserve(max_char_len * extend_len);
string.extend(self.sample_iter(&mut *rng).take(extend_len));
remain_len -= extend_len;
extend_len = extend_len.min(remain_len);
}
}
}
34 changes: 34 additions & 0 deletions src/distributions/uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,24 @@ impl UniformSampler for UniformChar {
}
}

/// Note: the `String` is potentially left with excess capacity if the range
/// includes non ascii chars; optionally the user may call
/// `string.shrink_to_fit()` afterwards.
#[cfg(feature = "alloc")]
impl super::DistString for Uniform<char>{
fn append_string<R: Rng + ?Sized>(&self, rng: &mut R, string: &mut alloc::string::String, len: usize) {
// Getting the hi value to assume the required length to reserve in string.
let mut hi = self.0.sampler.low + self.0.sampler.range - 1;
if hi >= CHAR_SURROGATE_START {
hi += CHAR_SURROGATE_LEN;
}
// Get the utf8 length of hi to minimize extra space.
let max_char_len = char::from_u32(hi).map(char::len_utf8).unwrap_or(4);
string.reserve(max_char_len * len);
string.extend(self.sample_iter(rng).take(len))
}
}

/// The back-end implementing [`UniformSampler`] for floating-point types.
///
/// Unless you are implementing [`UniformSampler`] for your own type, this type
Expand Down Expand Up @@ -1376,6 +1394,22 @@ mod tests {
let c = d.sample(&mut rng);
assert!((c as u32) < 0xD800 || (c as u32) > 0xDFFF);
}
#[cfg(feature = "alloc")]
{
use crate::distributions::DistString;
let string1 = d.sample_string(&mut rng, 100);
assert_eq!(string1.capacity(), 300);
let string2 = Uniform::new(
core::char::from_u32(0x0000).unwrap(),
core::char::from_u32(0x0080).unwrap(),
).unwrap().sample_string(&mut rng, 100);
assert_eq!(string2.capacity(), 100);
let string3 = Uniform::new_inclusive(
core::char::from_u32(0x0000).unwrap(),
core::char::from_u32(0x0080).unwrap(),
).unwrap().sample_string(&mut rng, 100);
assert_eq!(string3.capacity(), 200);
}
}

#[test]
Expand Down

0 comments on commit ee80b41

Please sign in to comment.