ä»åã¯æ©æ¢°åŠç¿ã§ãã䜿ãããåŠç¿ã¢ã«ãŽãªãºã ã®äžã€ãã©ã³ãã ãã©ã¬ã¹ãããå®è£ ããŠãææãæåã®åé¡ã«ææŠããŠã¿ãŸãããããã®ã¢ã«ãŽãªãºã ã¯æ©æ¢°åŠç¿ã®åéã§ããå©çšããããã®ã§ãããå®éã«Rustã䜿ã£ãŠãŒãããå®è£ ããããšã§ãçè§£ãæ·±ããŸãããã
ã©ã³ãã ãã©ã¬ã¹ããšã¯ïŒ
ä»åã¯ãçè ã奜ããªæ©æ¢°åŠç¿ã¢ã«ãŽãªãºã ã®äžã€ãã©ã³ãã ãã©ã¬ã¹ã(Random Forest)ããåãäžããŠãRustã§ãŒãããå®è£ ããŠã¿ãŸãããã
ã©ã³ãã ãã©ã¬ã¹ãã¯ãã¢ã³ãµã³ãã«åŠç¿ã®äžã€ã§ããã¢ã³ãµã³ãã«åŠç¿ãšã¯ãè€æ°ã®ã¢ãã«ãçµã¿åãããŠãããé«ãäºæž¬ç²ŸåºŠãå®çŸããããã®æ©æ¢°åŠç¿ææ³ã§ãã
ã©ã³ãã ãã©ã¬ã¹ãã§ã¯ãè€æ°ã®ãæ±ºå®æš(Decision tree)ããå©çšããŸããæ¬¡ã®å³ã®ããã«ãè€æ°ã®æ±ºå®æšãããããäºæž¬ãè¡ã£ãåŸã§ã倿°æ±ºã«ãã£ãŠæçµçãªçããæ±ºå®ããŸãã
ææãæ°åã®ç»åããŒã¿ã»ããMNISTã«ã€ããŠ
ãããŠãåã«ã©ã³ãã ãã©ã¬ã¹ããã©ã€ãã©ãªãšããŠå®è£ ããŠãé¢çœããªãã®ã§ããããå®éã«æŽ»çšããŠã¿ãŸããããä»åã¯ã倧éã®ææãæ°åã®ç»åãåŠç¿ãããŠãæªç¥ã®ç»åããã©ããªæ°åãæãããŠããããå€å®ããŠãã©ã®ãããã®ç²ŸåºŠãã§ãã®ããè©äŸ¡ããŠã¿ãŸãã
倧éã®ææãæ°åã®ããŒã¿ãã€ã³ã¿ãŒãããäžã§å ¬éãããŠããŸããããããMNISTããšåŒã°ããæåãªããŒã¿ã»ããã§ããããã¯ã次ã®ç»åã®ããã«ã倧éã®ææãæ°åãåé²ãããããŒã¿ã»ããã§ãã
MNISTã®ç»åããŒã¿ã»ããã«ã¯ã6äžæã®åŠç¿çšããŒã¿ãš1äžæã®ãã¹ãçšããŒã¿ãå«ãŸããŠããŸããããããã28x28ã®ã°ã¬ã€ã¹ã±ãŒã«ç»åãšãã©ã®æ°åãæãããŠããã®ãã衚ãã©ãã«ããæãç«ã£ãŠããŸãã
ã¡ãªã¿ã«ãMNISTã®ããŒã¿ã»ããèªäœã¯ããã¡ã(https://github.com/fgnt/mnist)ã§é åžãããŠããã®ã§ãããç¬èªã®ããŒã¿åœ¢åŒã§é åžãããŠãããããããŒã¿ã®èªã¿èŸŒã¿ã倧å€ã§ããããã§ãçè ãæ¬é£èŒã®ããã«äœæãããmnist-readerãã¯ã¬ãŒããå©çšããŸãããããã®ã¯ã¬ãŒããå©çšãããšãMNISTã®ããŒã¿ãèªåã§ããŠã³ããŒãããŠãæè»œã«èªã¿èŸŒã¿ãè¡ãããšãã§ããŸãã
ãããžã§ã¯ããäœæããã
ããã§ã¯ããããžã§ã¯ããäœæããŠãå¿ èŠãªã©ã€ãã©ãªããããžã§ã¯ãã«è¿œå ããŸããããã¿ãŒããã«(Windowsã§ã¯PowerShellãmacOSã§ã¯ã¿ãŒããã«.app)ãèµ·åããŠãäžèšã®ã³ãã³ããå®è¡ããŸãã
# ãããžã§ã¯ãã®äœæ
mkdir rf_mnist
cd rf_mnist
cargo init
# ä»åå©çšããã©ã€ãã©ãªã远å
cargo add mnist_reader # MNISTã®ããŒã¿ã»ãããèªã¿èŸŒããã
cargo add lazyrand # ä¹±æ°ã䜿ããã
MNISTããŒã¿ã»ããã®èªã¿èŸŒã¿ãã¹ããããŠã¿ãã
ã¡ã€ã³ããã°ã©ã ãsrc/main.rsãã以äžã®ããã«æžãæããŠãMNISTããŒã¿ã»ããã®ããŠã³ããŒããšèªã¿èŸŒã¿ãæåããã確èªããŠã¿ãŸãããã
use mnist_reader::{MnistReader, print_image};
fn main() {
// MNISTã®ããŒã¿ã»ãããããŠã³ããŒã
let mut mnist = MnistReader::new("mnist-data");
mnist.load().unwrap();
// MNISTã®æåã®ããŒã¿ã衚瀺
let images: Vec<Vec<f32>> = mnist.train_data;
print_image(&images[0]);
let labels: Vec<u8> = mnist.train_labels;
println!("labels[0]={}", labels[0]);
}
ãããŠãã¿ãŒããã«ã§ä»¥äžã®ã³ãã³ããå®è¡ããŠã¿ãŸãããã
cargo run
ãããšãMNISTã®ããŒã¿ã»ãããããŠã³ããŒãããŠãæåã®ããŒã¿ã衚瀺ããŸãã
ãªããä»åã®ãããžã§ã¯ããrf_mnistããã©ã«ãã¯ãäžèšã®ãããªæ§é ã«ãªããŸããmnist-dataãã©ã«ãã¯ãããã°ã©ã ãååå®è¡ããæã«äœæãããŠãMNISTã®ããŒã¿ã»ãããèªåçã«ããŠã³ããŒããããŸãã
.
âââ Cargo.lock
âââ Cargo.toml
âââ <mnist-data>
â  âââ t10k-images.idx3-ubyte.gz
â  âââ t10k-labels.idx1-ubyte.gz
â  âââ train-images.idx3-ubyte.gz
â  âââ train-labels.idx1-ubyte.gz
âââ <src>
  âââ random_forest.rs
âââ main.rs
ã©ã³ãã ãã©ã¬ã¹ããå®è£ ããŠã¿ãã
ããã§ã¯ãã©ã³ãã ãã©ã¬ã¹ããå®è£ ããŠã¿ãŸãããããsrc/random_forest.rsããšãããã¡ã€ã«ãäœæããŸããããã©ã³ãã ãã©ã¬ã¹ãã®ã¢ã«ãŽãªãºã ã®ããã°ã©ã ãå°ãé·ãã®ã§ã以äžã«æç²ãããã®ã玹ä»ããŸãã
å®éã«ã¯ããã¡ãã®Gistã«ã¢ããããããã°ã©ã ( https://gist.github.com/kujirahand/b3c6a5b40310fb88964116fbe2665d9d )ãããã³ãŒããã³ããŒããŠè²Œãä»ããŠãã ããã
以äžã¯ãã©ã³ãã ãã©ã¬ã¹ãã§æ±ºå®æšãå®è£ ããããã®æ§é äœDecisionTreeã®å®çŸ©ã®æç²ã§ãã
/// æ±ºå®æšã®ããŒã
enum Node {
Leaf { prediction: u8 }, // ãµã³ãã«ã®äºæž¬ã©ãã«
Decision {
feature_index: usize, // ããŒã¿ãäºåããããã®æ¡ä»¶
threshold: f32, // åå²ã«äœ¿ããããå€
left: Box<Node>, // æ¡ä»¶ãæºããå Žåã®åããŒã
right: Box<Node>, // æ¡ä»¶ã«åèŽããªãå Žåã®åããŒã
},
}
/// åçŽãªæ±ºå®æšã¯ã©ã¹
pub struct DecisionTree {
max_depth: usize,
min_samples_split: usize,
max_features: usize,
root: Option<Box<Node>>,
}
泚ç®ãããã®ã¯ãåæåNodeã®Decisionã«ããããleftãšrightã§ããRust ã®åæåãæ§é äœã¯ãå
šãã£ãŒã«ãåã®ãµã€ãºãã³ã³ãã€ã«æã«ç¢ºå®ããªããã°ãªããŸããããã®ãããBox
ç¶ããŠãã©ã³ãã ãã©ã¬ã¹ãæ¬äœã®å®è£ ã確èªããŠã¿ãŸãããã
/// ã©ã³ãã ãã©ã¬ã¹ãæ¬äœ
pub struct RandomForest {
trees: Vec<DecisionTree>,
}
impl RandomForest {
/// readerã®train_dataãçšããŠåŠç¿ãã
pub fn train(reader: &MnistReader, n_trees: usize, max_depth: usize, min_samples_split: usize) -> Self {
// æ±ºå®æšãæå®æ°ã ãäœæããŠåŠç¿ãè¡ã
ãçç¥ã
let mut trees = Vec::with_capacity(n_trees);
for i in 0..n_trees {
println!("+ training tree...{}/{}", i+1, n_trees);
ãçç¥ã
tree.train(&data, &labels);
trees.push(tree);
}
RandomForest { trees }
}
/// 1ãµã³ãã«ãäºæž¬
pub fn predict(&self, sample: &[f32]) -> u8 {
let mut votes = HashMap::new();
for tree in &self.trees { // æ±ºå®æšã®äºæž¬çµæãéããŠæç¥š
let pred = tree.predict(sample);
*votes.entry(pred).or_insert(0) += 1;
}
votes.into_iter().max_by_key(|&(_, c)| c).map(|(cls, _)| cls).unwrap_or(0)
}
/// ãã¹ãããŒã¿ã§ç²ŸåºŠãèšç®
pub fn evaluate(&self, reader: &MnistReader) -> f32 { ⊠}
}
ã©ã³ãã ãã©ã¬ã¹ãã®å®è£ ã«ãããŠã¯ãå ¥åããŒã¿ããµã³ããªã³ã°ããŠè€æ°ã®æ±ºå®æšãäœæããŠããã®åŸã§æç¥šïŒå€æ°æ±ºïŒã§æçµçãªã¯ã©ã¹ã©ãã«ã決å®ããŸãããããŠãæ±ºå®æšã¯ãããŒã¿ããããç¹åŸŽéããããå€ããå°ããã倧ãããããšãã£ãïŒæã®å€å®ãç¹°ãè¿ãããšã§ã©ãã«ãäºæž¬ããŸãã
ããã§ãäžèšã®ããã°ã©ã ã§ããããŒã¿ã®åŠç¿ãè¡ãtrainã¡ãœããã§ã¯ãè€æ°ã®æ±ºå®æšãäœæããŠããŸããäºæž¬ãè¡ãpredictã¡ãœããã§ã¯ãçæããæ±ºå®æšã«äºæž¬ãããŠããã£ãŠãæåŸã«ã©ã®äºæž¬ãæ£ããããæç¥šã§æ±ºãããšããæµãã«ãªã£ãŠããŸãã
ã¡ã€ã³ããã°ã©ã ã宿ãããã
ããã§ã¯ãäžèšã®ããã°ã©ã (src/random_forest.rs)ãå©çšããŠãã¡ã€ã³ããã°ã©ã (src/main.rs)ã宿ãããŸããããä»åã¯ãã©ã³ãã ãã©ã¬ã¹ãã®å®è£ ãã¡ã€ã³ãªã®ã§ãå®éã«ææãæ°åãåé¡ããUIç»é¢ãªã©ã¯æãããããã¹ãçšã®ããŒã¿ãçšããŠãã©ã³ãã ãã©ã¬ã¹ãã®æ§èœãè©äŸ¡ããã ãã«çããŸããã
次ã®ãããªããã°ã©ã ã«ãªããŸãã
mod random_forest;
use mnist_reader::MnistReader;
use random_forest::RandomForest;
fn main() {
// MNISTã®ããŒã¿ã»ãããããŠã³ããŒã
let mut mnist = MnistReader::new("mnist-data");
println!("Loading MNIST dataset...");
mnist.load().unwrap();
// ã©ã³ãã ãã©ã¬ã¹ãã§ããŒã¿ã®åŠç¿ãè¡ã
println!("Training MNIST dataset...");
let rf = RandomForest::train(&mnist, 10, 30, 10);
println!("Evaluating MNIST dataset...");
// ãã¹ãããŒã¿ã§è©äŸ¡ãè¡ã£ãŠãå€å®ç²ŸåºŠã衚瀺
let result = rf.evaluate(&mnist);
println!("Accuracy: {:.2}%", result * 100.0);
}
ããã°ã©ã ã確èªããŠã¿ãŸããããã¢ãã«ã®åŠç¿ãè¡ãtrainã¡ãœããã§ã¯ãæ±ºå®æšãæå®æ°(n_trees)ã ãäœæããŠåŠç¿ãå®è¡ããŸãããããŠãäºæž¬ãè¡ãpredictã¡ãœããã§ã¯ãtrainã§äœæããæ±ºå®æšãå©çšããŠãäºæž¬ãè¡ã£ãŠãæç¥šçµæãå ã«æçµçãªå€å®çµæãè¿ããŸãã
ããã°ã©ã ãå®è¡ããã«ã¯ãã¿ãŒããã«ã§äžèšã®ã³ãã³ããå®è¡ããŸãã
cargo run
çè ã®Macbook Pro M4ãå©çšããŠæž¬å®ãããšãããå®è¡ã«48ç§ããããå€å®ç²ŸåºŠã¯94.47%ãšè¡šç€ºãããŸããããã ããä¹±æ°ã䜿ã£ãŠåŠç¿ãè¡ã£ãŠãããããå®è¡ãããã³ã«ç²ŸåºŠã¯å€ãããŸãã
-

èªäœã®ã©ã³ãã ãã©ã¬ã¹ãã§MNISTããŒã¿ã»ãããåŠç¿ããŠç²ŸåºŠããã¹ããããšãã
ãªããåŠç¿ã«æéã¯ããããŸããã RandomForest::trainã¡ãœããã®ãã©ã¡ãŒã¿ã倿Žãããšãããè¯ã粟床ãã§ãŸããããŒã¿ã®åŠç¿ãå®è¡ããéšåããäžèšã®ããã«æžãæãããšãå€å®ç²ŸåºŠã96.74%ãŸã§åäžããŸããåŠç¿çšã«äœã£ãèªäœã©ã€ãã©ãªãšããŠã¯ããã®çšåºŠã®ç²ŸåºŠãåºãã°ååã§ãããã
let rf = RandomForest::train(&mnist, 100, 30, 10);
ãŸãšã
以äžãä»åã¯ãŒãããã©ã³ãã ãã©ã¬ã¹ããå®è£ ããŠãææãæ°åã®ç»åããŒã¿ã»ããMNISTã§ã©ã®ãããã®ç²ŸåºŠãåºãã®ãã詊ããŠã¿ãŸãããå šè²ãçè§£ããã«ã¯ãæ©æ¢°åŠç¿ã®ç¥èãå¿ èŠãšãªãã®ã§ããããŒãããã¢ã«ãŽãªãºã ãå®è£ ããããšã§ããã®ä»çµã¿ããªããšãªãåãã£ãã®ã§ã¯ãªãã§ãããããäœåãããã°ãããã«ç²ŸåºŠãé«ãã工倫ãããããUIãäœæããŠç»åãå ¥åãããšçµæã衚瀺ããããããªä»çµã¿ãäœã£ãŠã¿ããšè¯ãã§ãããã
èªç±åããã°ã©ããŒããããã¯ãã©ã«ãŠãããã°ã©ãã³ã°ã®æ¥œãããäŒããæŽ»åãããŠããã代衚äœã«ãæ¥æ¬èªããã°ã©ãã³ã°èšèªããªã§ããã ãããã¹ã鳿¥œããµã¯ã©ããªã©ã2001幎ãªã³ã©ã€ã³ãœãã倧è³å ¥è³ã2004幎床æªèžãŠãŒã¹ ã¹ãŒããŒã¯ãªãšãŒã¿èªå®ã2010幎 OSSè²¢ç®è ç« åè³ããããŸã§50å以äžã®æè¡æžãå·çãããçŽè¿ã§ã¯ããå€§èŠæš¡èšèªã¢ãã«ã䜿ãããªãããã®ããã³ãããšã³ãžãã¢ãªã³ã°ã®æç§æž(ãã€ããåºç)ããPythonã§ã€ãããã¹ã¯ãããã¢ããª(ãœã·ã )ããå®è·µåã身ã«ã€ãã Pythonã®æç§æž 第2çããã·ãŽããã¯ãã©ã PythonèªååŠçã®æç§æž(ãã€ããåºç)ããªã©ã



