Datarax
Data Pipeline Framework for JAX
An extensible data pipeline framework built for JAX-based ML workflows with JIT compilation, differentiable processing, and multi-device distribution.
Repository Coming Soon
This project is under active development and will be open-sourced soon
Overview
Datarax (Data + Array/JAX) is an extensible data pipeline framework built for JAX-based machine learning workflows. It leverages JAX's JIT compilation, automatic differentiation, and hardware acceleration to build data loading, preprocessing, and augmentation pipelines that run on CPUs, GPUs, and TPUs.
JAX has mature libraries for models (Flax), optimizers (Optax), and checkpointing (Orbax), but lacks a dedicated data pipeline framework that operates at the same level of abstraction. Existing options are either framework-agnostic loaders that return NumPy arrays (losing JIT/autodiff benefits) or wrappers around tf.data/PyTorch that introduce cross-framework overhead. Datarax fills this gap.
Every component — sources, operators, batchers, samplers, sharders — is a Flax NNX module. Pipeline state is managed through NNX's variable system, which means operators can hold learnable parameters, be serialized with Orbax, and participate in JAX transformations (jit, vmap, grad) without special handling. Because operators are NNX modules, gradients flow through the entire pipeline, enabling approaches like gradient-based augmentation search and task-optimized preprocessing that are not possible with standard data loaders.
Pipelines are directed acyclic graphs, not linear chains. The >> operator composes sequential steps, | creates parallel branches, and control-flow nodes (Branch, Merge, SplitField) handle conditional and multi-path logic. The DAG executor manages scheduling, caching, and rebatching across the graph.
Key Features
JAX-Native Design
All core components built on JAX with Flax NNX module system. JIT-compiled pipelines via XLA with built-in profiling and roofline analysis.
Differentiable Pipelines
Operators are NNX modules so gradients flow through the entire pipeline. Enables gradient-based augmentation search and task-optimized preprocessing.
DAG Execution Engine
Graph-based pipeline construction with branching, parallel execution, caching, and rebatching nodes. Use >> for sequential and | for parallel composition.
Multi-Device Distribution
Scale from laptop to cluster with device mesh sharding, ArraySharder, and JaxProcessSharder for multi-host training.
Deterministic Reproducibility
Reproducible pipelines by default using Grain's Feistel cipher shuffling (O(1) memory) with explicit RNG key threading through every stochastic operator.
Ecosystem Integration
Works with Flax, Optax, Orbax, HuggingFace Datasets, TensorFlow Datasets, and ArrayRecord. Built-in benchmarking against 12+ frameworks.
Use Cases
High-throughput genomics data processing and quality control
Single-cell RNA-seq analysis pipelines with millions of cells
Gradient-based augmentation search replacing RL-based methods like AutoAugment
Task-optimized preprocessing by backpropagating task loss through processing stages
Multi-omics data integration and harmonization
Imaging data preprocessing for microscopy and spatial transcriptomics
Feature engineering for ML models in drug discovery
Distributed training data pipelines across GPUs and TPUs
Installation
# Basic installation
pip install datarax
# With data loading support (HuggingFace, TFDS, audio/image libs)
pip install datarax[data]
# With GPU support (CUDA 12)
pip install datarax[gpu]
# Full development installation
pip install datarax[all]Quick Start
import jax
import jax.numpy as jnp
import numpy as np
from flax import nnx
from datarax import from_source
from datarax.dag.nodes import OperatorNode
from datarax.operators import ElementOperator, ElementOperatorConfig
from datarax.sources import MemorySource, MemorySourceConfig
# Create in-memory data source
data = {
"image": np.random.randint(0, 255, (1000, 28, 28, 1)).astype(np.float32),
"label": np.random.randint(0, 10, (1000,)).astype(np.int32),
}
source = MemorySource(MemorySourceConfig(), data=data, rngs=nnx.Rngs(0))
# Build pipeline with DAG-based API
pipeline = (
from_source(source, batch_size=32)
>> OperatorNode(normalizer)
>> OperatorNode(augmenter)
)
# Process batches
for i, batch in enumerate(pipeline):
print(f"Batch {i}: images {batch['image'].shape}")Built With
Ready to Get Started?
Explore the documentation, try examples, or contribute to the project.