From 77cb54d9111d85be2fba9887127be44c7ed43f59 Mon Sep 17 00:00:00 2001 From: Guan-Ming Chiu Date: Mon, 12 Jan 2026 21:23:28 +0800 Subject: [PATCH] Add PyTorch shape validation --- qdp/qdp-python/src/lib.rs | 66 ++++++++++++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 14 deletions(-) diff --git a/qdp/qdp-python/src/lib.rs b/qdp/qdp-python/src/lib.rs index ba9bbd3ac..686795c5e 100644 --- a/qdp/qdp-python/src/lib.rs +++ b/qdp/qdp-python/src/lib.rs @@ -263,7 +263,7 @@ impl QdpEngine { match ndim { 1 => { - // 1D array: single sample encoding + // 1D array: single sample encoding (zero-copy if already contiguous) let array_1d = data.extract::>().map_err(|_| { PyRuntimeError::new_err( "Failed to extract 1D NumPy array. Ensure dtype is float64.", @@ -282,7 +282,7 @@ impl QdpEngine { }); } 2 => { - // 2D array: batch encoding + // 2D array: batch encoding (zero-copy if already contiguous) let array_2d = data.extract::>().map_err(|_| { PyRuntimeError::new_err( "Failed to extract 2D NumPy array. Ensure dtype is float64.", @@ -322,18 +322,56 @@ impl QdpEngine { // Check if it's a PyTorch tensor if is_pytorch_tensor(data)? { validate_tensor(data)?; - let vec_data: Vec = data - .call_method0("flatten")? - .call_method0("tolist")? - .extract()?; - let ptr = self - .engine - .encode(&vec_data, num_qubits, encoding_method) - .map_err(|e| PyRuntimeError::new_err(format!("Encoding failed: {}", e)))?; - return Ok(QuantumTensor { - ptr, - consumed: false, - }); + // NOTE(perf): `tolist()` + `extract()` makes extra copies (Tensor -> Python list -> Vec). + // TODO: Follow-up PR can use `numpy()`/buffer protocol (and possibly pinned host memory) + // to reduce copy overhead. + let ndim: usize = data.call_method0("dim")?.extract()?; + + match ndim { + 1 => { + // 1D tensor: single sample encoding + let vec_data: Vec = data.call_method0("tolist")?.extract()?; + let ptr = self + .engine + .encode(&vec_data, num_qubits, encoding_method) + .map_err(|e| PyRuntimeError::new_err(format!("Encoding failed: {}", e)))?; + return Ok(QuantumTensor { + ptr, + consumed: false, + }); + } + 2 => { + // 2D tensor: batch encoding + let shape: Vec = data.getattr("shape")?.extract()?; + let num_samples = shape[0] as usize; + let sample_size = shape[1] as usize; + let vec_data: Vec = data + .call_method0("flatten")? + .call_method0("tolist")? + .extract()?; + let ptr = self + .engine + .encode_batch( + &vec_data, + num_samples, + sample_size, + num_qubits, + encoding_method, + ) + .map_err(|e| PyRuntimeError::new_err(format!("Encoding failed: {}", e)))?; + return Ok(QuantumTensor { + ptr, + consumed: false, + }); + } + _ => { + return Err(PyRuntimeError::new_err(format!( + "Unsupported tensor shape: {}D. Expected 1D tensor for single sample \ + encoding or 2D tensor (batch_size, features) for batch encoding.", + ndim + ))); + } + } } // Fallback: try to extract as Vec (Python list)