Interview Prep: The ML Grind
Important: don't skip this, don't reorder, don't try to be too creative, follow the sequence. It will save you weeks of struggling.
The Tricks
- Learn "einsum"
This eliminates 90% of the "why the hell is this reshape/view doing that?" confusion. It lets you write the math like the math, not like a puzzle of indexing tricks. - Use Noam's suffix notation
Print the suffix table or write it by hand. Keep it next to you. The first time you implement a transformer block from scratch, the dimension chaos will break you unless you have this. Suffixes make everything obvious and mechanical.
These two together are learning multipliers. I genuinely would not have made it far without being pointed at both. Treat them as required, not "nice to have".
Practice
You will have to do both
- Slow & painful
- Ask your favorite chatbot for a coding exercise, a set of test cases, and a .py scaffold with empty function definitions, using only NumPy and the standard Python library.
- Drop it into your editor and implement everything yourself.
- For the first couple of attempts ask for small nudges if needed. After that, no references. Grind it out.
- After full implementation send back to the chatbot and ask to validate and give feedback.
- Fast & painful
- Pen + paper, every morning. Rewrite the core reference material from memory.
- Mark mistakes, think about why you made them.
- Repeat next morning. This builds the recall you need in real interviews.
- Highly recommend: a multi-color pen23 and loose A4/US-letter sheets with clamps. Easier than a notebook to reorganize and faster to flip through.
Core - Pure NumPy, Implement FWD + BWD
- 2-layer MLP with ReLU
- Add batch dimension
- Row-sharded (tensor-parallel) MLP
- Sharded column-parallel MLP
- Data-parallel MLP
These lock in your intuition for matmul flow, activation flow, and gradient flow. If you can't do these cleanly, everything later will feel like black magic.
Do the implementation if your role is more ML-heavy
- Implement full multi-head attention with suffix notation + einsum
- Fit the whole thing on one sheet: QKV projections, attention scores, heavy ops, FLOPs estimates.
- Be able to reason clearly about:
- GQA
- MLA
- ROPE and positional embeddings in general
- Inference
- KV cache
- Understand speculative decoding and the perf improvement you can get out of it.
- Continuous Batching Pipelines
Infra-leaning roles usually don't require writing this from scratch but you need to understand their impact on e2e performance.
As you work through exercises above, make sure you can answer the actual interview question to tie everything together.
- Given model dimensions, how many tokens should you train on?
- Chinchilla optimal 20x of model param size.
- Figure out how to translate this to MoE params.
- Given model size + token count → how many GPUs + days to train?
- FLOPs of a single transformer block (FWD); multiply by the number of blocks. Know the relationships:
- BWD = 2×FWD
- Full step = FWD + BWD = 3×FWD
- Given real hardware constraints → how would you shard the model?
- Tensor Parallel
- FSDP (Zero 1 / 2 / 3)
- Pipeline Parallel
- Gradient Checkpointing
- Expert Parallel
- Context Parallel
- Communication math you to internalize:
- Which collectives each parallelism uses
- The exact tensor shapes passed through those collectives
- Bytes moved per rank
- The ability to estimate per-collective cost without looking anything up
- Explain trade offs of various parallelisms, why FSDP Zero 3 is rarely used in practice?
- Inference
- Latency for prefill vs decode given a model dimensions.
- How many H100s will you need to serve 10k users?
- RL training
- Reason through async vs sync training.
- Estimate throughput for both cases.
- More ML system design interview questions can be found here: huyenchip.com
One-Pager References - don't underestimate these.
Make lots of small, dense one-pagers - color-coded, quick to flip through:
- suffix notation
- per-layer FLOPs
- the entire attention block
- each parallelism mode and its collectives
- byte-movement formulas
- memory breakdowns (params, grads, activations)
This accelerates repetition and makes the whole grind way faster.
NumPy Broadcasting
- Align dimensions align right
- Fill missing dimensions with 1
- Max between the dims will be the resulting
Shape A (9, 1, 3): 9 x 1 x 3
Shape B (4, 1): 1 x 4 x 1
After broadcast: 9 x 4 x 3
Matrix Multiplication
"..." - additional dimensions which will follow the broadcast rule above
| Expression | Result |
|---|---|
| (..., M, K) @ (..., K, N) | (..., M, N) |
| K @ K | scalar |
| M × K @ K | M |
| K @ K × N | N |
| ... M K @ K ... | ... M |
| K @ ... K N | ... N |
Reshape/View
In the einsum section, I mentioned you don't have to deal with reshapes, but that's not entirely true. In many interviews, you'll encounter existing codebases full of reshapes, and you'll need to know how to reason through them.
view(): returns a view of a tensor with new dimensions, requires the tensor to be contiguous in memory errors if not.reshape(): returns a view or a copy if not contiguous.
Flatten
x = torch.randn(3, 4, 5)
x.reshape(-1) # Shape: [60]
Flatten keep batch dim
x = torch.randn(32, 3, 224, 224) # [B, C, H, W]
x.reshape(32, -1) # Shape: [32, 150528]
Add/Remove a dim
x = torch.randn(10, 20)
x.reshape(10, 20, 1) # Add dimension
x.reshape(10, 4, 5) # Split dimension
x.reshape(2, 5, 20) # Split first dim
