This demo explores how to achieve high-performance AI on Google's Tensor Processing Units (TPUs) using the JAX ecosystem, with a specific focus on image recognition workflows. We’ll begin with micro-benchmarks that showcase JAX's unique advantages for TPU-based computation, such as its single-program, multiple-data (SPMD) programming model, which is ideal for the TPU's systolic array architecture. This setup is designed for integration into larger, production-grade environments, such as those running on Kubernetes.

Ravi Mahendrakar
Ravi Mahendrakar is a Product Management leader at Google, focused on ML Frameworks & Ecosystems. With over 20 years of experience, including product roles at AWS, Aerospike, VAST Data, Pure Storage, Veritas, and IBM. Ravi specializes in bringing innovative data and enterprise software solutions to market. Ravi has an MBA from Chicago Booth and a Master's in Computer Science from CSU Chico.