autographRust 机器学习库
autograph 是一款 Rust 机器学习库。
在 crate 中使用 autograph 需要向 Cargo.toml 添加依赖:
[dependencies]
autograph = "0.1.0"
环境要求
- Rust https://www.rust-lang.org/
- A device (typically a gpu) with drivers for a supported API:
- 设备(通常是 GPU),附带受支持的 API 的驱动
- Vulkan (All platforms) https://www.vulkan.org/
- Metal (MacOS / iOS) https://developer.apple.com/metal/
- DX12 (Windows) https://docs.microsoft.com/windows/win32/directx
示例代码
机器学习
// Create the device.
let device = Device::new()?;
// Create the dataset.
let iris = Iris::new();
// The flower dimensions are the inputs to the model.
let x_array = iris.dimensions();
// Select only Petal Length + Petal Height
// These are the primary dimensions and it makes plotting easier.
let x_array = x_array.slice(&s![.., 2..]);
// Create the KMeans model.
let kmeans = KMeans::new(iris.class_names().len())
.into_device(device.clone())
.await?;
// For small datasets, we can load the entire dataset into the device.
// For larger datasets, the data can be streamed as an iterator.
let x = CowTensor::from(x_array.view())
.into_device(device)
// Note that despite the await this will resolve immediately.
// Host -> Device transfers are batched with other operations
// asynchronously on the device thread.
.await?;
// Construct a trainer.
let mut trainer = KMeansTrainer::from(kmeans);
// Intialize the model (KMeans++).
// Here we provide an iterator of n iterators, such that the trainer can
// visit the data n times. In this case, once for each centroid.
trainer.init(|n| std::iter::from_fn(|| Some(once(Ok(x.view().into())))).take(n))?;
// Train the model (1 epoch).
trainer.train(once(Ok(x.view().into())))?;
// Get the model back.
let kmeans = KMeans::from(trainer);
// Get the trained centroids.
// For multiple reads, batch them by getting the futures first.
let centroids_fut = kmeans.centroids()
// The centroids are in a FloatArcTensor, which can either be f32 or bf16.
// This will convert to f32 if necessary.
.cast_to::<f32>()?
.read();
// Get the predicted classes.
let pred = kmeans.predict(&x.view().into())?
.into_dimensionality()?
.read()
// Here we wait on all previous operations, including centroids_fut.
.await?;
// This will resolve immediately.
let centroids = centroids_fut.await?;
// Get the flower classes from the dataset.
let classes = iris.classes().map(|c| *c as u32);
// Plot the results to "plot.png".
// Note that since KMeans is an unsupervised method the predicted classes will be arbitrary and
// not align to the order of the true classes (ie the colors won't be the same in the plot).
plot(&x_array.view(), &classes.view(), &pred.as_array(), ¢roids.as_array())?;
评论