Skip to content

[Bug]: Too many values to unpack in dispatch_cpu_unquantized_gemm [LiquidAi/LMF2] #25771

@littlechicks

Description

@littlechicks

Your current environment

The output of python collect_env.py Collecting environment information... ============================== System Info ============================== OS : Debian GNU/Linux 12 (bookworm) (x86_64) GCC version : (Debian 12.2.0-14+deb12u1) 12.2.0 Clang version : Could not collect CMake version : version 4.1.0 Libc version : glibc-2.36

==============================
PyTorch Info

PyTorch version : 2.8.0+cpu
Is debug build : False
CUDA used to build PyTorch : None
ROCM used to build PyTorch : N/A

==============================
Python Environment

Python version : 3.12.11 (main, Jul 23 2025, 00:34:44) [Clang 20.1.4 ] (64-bit runtime)
Python platform : Linux-6.1.0-18-amd64-x86_64-with-glibc2.36

==============================
CUDA / GPU Info

Is CUDA available : False
CUDA runtime version : No CUDA
CUDA_MODULE_LOADING set to : N/A
GPU models and configuration : No CUDA
Nvidia driver version : No CUDA
cuDNN version : No CUDA
HIP runtime version : N/A
MIOpen runtime version : N/A
Is XNNPACK available : True

==============================
CPU Info

Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 45 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 16
On-line CPU(s) list: 0-15
Vendor ID: GenuineIntel
BIOS Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) CPU E5-2630 v3 @ 2.40GHz
BIOS Model name: Intel(R) Xeon(R) CPU E5-2630 v3 @ 2.40GHz CPU @ 2.4GHz
BIOS CPU family: 2
CPU family: 6
Model: 63
Thread(s) per core: 1
Core(s) per socket: 8
Socket(s): 2
Stepping: 0
BogoMIPS: 4794.44
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 invpcid xsaveopt arat md_clear flush_l1d arch_capabilities
Hypervisor vendor: VMware
Virtualization type: full
L1d cache: 512 KiB (16 instances)
L1i cache: 512 KiB (16 instances)
L2 cache: 4 MiB (16 instances)
L3 cache: 40 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-7
NUMA node1 CPU(s): 8-15
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: KVM: Mitigation: VMX unsupported
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Retbleed: Mitigation; IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; IBRS, IBPB conditional, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

==============================
Versions of relevant libraries

[pip3] intel_extension_for_pytorch==2.8.0
[pip3] numpy==2.2.6
[pip3] pyzmq==27.1.0
[pip3] torch==2.8.0+cpu
[pip3] torchaudio==2.8.0+cpu
[pip3] torchvision==0.23.0+cpu
[pip3] transformers==4.56.2
[pip3] triton==3.2.0
[conda] Could not collect

==============================
vLLM Info

ROCM Version : Could not collect
vLLM Version : 0.11.0rc2.dev153+gdb1e42f62 (git sha: db1e42f62)
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled
GPU Topology:
Could not collect

==============================
Environment Variables

VLLM_TARGET_DEVICE=cpu
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1

Your output of `python collect_env.py` here

🐛 Describe the bug

I have installed VLLM by following official documentation on how to install it with python on a CPU only machine. I also tried directly to serve the model with the CPU docker build.

Here is and example of testing code. (vllm cpu)

from vllm import LLM

def main():
    llm = LLM(model="LiquidAI/LFM2-2.6B", dtype="float32")
    prompt = "Bonjour, comment ça va ?"
    for output in llm.generate(prompt):
        print(output.text)

if __name__ == "__main__":
    main()

Here is the output.

python test.py
[W926 18:06:18.057966816 OperatorEntry.cpp:218] Warning: Warning only once for all operators,  other operators may also be overridden.
  Overriding a previously registered kernel for the same operator and the same dispatch key
  operator: aten::_addmm_activation(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False) -> Tensor
    registered at /pytorch/build/aten/src/ATen/RegisterSchema.cpp:6
  dispatch key: AutocastCPU
  previous kernel: registered at /pytorch/aten/src/ATen/autocast_mode.cpp:327
       new kernel: registered at /opt/workspace/ipex-cpu-dev/csrc/cpu/autocast/autocast_mode.cpp:112 (function operator())
INFO 09-26 18:06:21 [__init__.py:216] Automatically detected platform cpu.
INFO 09-26 18:06:23 [utils.py:233] non-default args: {'dtype': 'float32', 'disable_log_stats': True, 'model': 'LiquidAI/LFM2-1.2B'}
INFO 09-26 18:06:23 [model.py:544] Resolved architecture: Lfm2ForCausalLM
`torch_dtype` is deprecated! Use `dtype` instead!
INFO 09-26 18:06:23 [model.py:1724] Upcasting torch.bfloat16 to torch.float32.
INFO 09-26 18:06:23 [model.py:1507] Using max model len 128000
WARNING 09-26 18:06:23 [logger.py:72] Environment variable VLLM_CPU_KVCACHE_SPACE (GiB) for CPU backend is not set, using 4 by default.
INFO 09-26 18:06:23 [scheduler.py:205] Chunked prefill is enabled with max_num_batched_tokens=4096.
INFO 09-26 18:06:23 [config.py:297] Hybrid or mamba-based model detected: disabling prefix caching since it is not yet supported.
INFO 09-26 18:06:23 [importing.py:43] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors.
INFO 09-26 18:06:23 [importing.py:63] Triton not installed or not compatible; certain GPU-related functions will not be available.
INFO 09-26 18:06:23 [config.py:377] Setting attention block size to 16 tokens to ensure that attention page size is >= mamba page size.
INFO 09-26 18:06:23 [config.py:398] Padding mamba page size by 300.00% to ensure that mamba page size and attention page size are exactly equal.
[W926 18:06:28.717312259 OperatorEntry.cpp:218] Warning: Warning only once for all operators,  other operators may also be overridden.
  Overriding a previously registered kernel for the same operator and the same dispatch key
  operator: aten::_addmm_activation(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False) -> Tensor
    registered at /pytorch/build/aten/src/ATen/RegisterSchema.cpp:6
  dispatch key: AutocastCPU
  previous kernel: registered at /pytorch/aten/src/ATen/autocast_mode.cpp:327
       new kernel: registered at /opt/workspace/ipex-cpu-dev/csrc/cpu/autocast/autocast_mode.cpp:112 (function operator())
INFO 09-26 18:06:31 [__init__.py:216] Automatically detected platform cpu.
(EngineCore_DP0 pid=971069) INFO 09-26 18:06:33 [core.py:644] Waiting for init message from front-end.
(EngineCore_DP0 pid=971069) INFO 09-26 18:06:33 [core.py:77] Initializing a V1 LLM engine (v0.11.0rc2.dev153+gdb1e42f62) with config: model='LiquidAI/LFM2-1.2B', speculative_config=None, tokenizer='LiquidAI/LFM2-1.2B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float32, max_seq_len=128000, download_dir=None, load_format=auto, tensor_parallel_size=1, pipeline_parallel_size=1, data_parallel_size=1, disable_custom_all_reduce=True, quantization=None, enforce_eager=False, kv_cache_dtype=auto, device_config=cpu, structured_outputs_config=StructuredOutputsConfig(backend='auto', disable_fallback=False, disable_any_whitespace=False, disable_additional_properties=False, reasoning_parser=''), observability_config=ObservabilityConfig(show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None), seed=0, served_model_name=LiquidAI/LFM2-1.2B, enable_prefix_caching=False, chunked_prefill_enabled=True, pooler_config=None, compilation_config={"level":2,"debug_dump_path":"","cache_dir":"","backend":"inductor","custom_ops":["none"],"splitting_ops":null,"use_inductor":true,"compile_sizes":null,"inductor_compile_config":{"enable_auto_functionalized_v2":false,"dce":true,"size_asserts":false,"nan_asserts":false,"epilogue_fusion":true},"inductor_passes":{},"cudagraph_mode":0,"use_cudagraph":true,"cudagraph_num_of_warmups":0,"cudagraph_capture_sizes":[],"cudagraph_copy_inputs":false,"full_cuda_graph":false,"use_inductor_graph_partition":false,"pass_config":{},"max_capture_size":null,"local_cache_dir":null}
(EngineCore_DP0 pid=971069) INFO 09-26 18:06:33 [importing.py:43] Triton is installed but 0 active driver(s) found (expected 1). Disabling Triton to prevent runtime errors.
(EngineCore_DP0 pid=971069) INFO 09-26 18:06:33 [importing.py:63] Triton not installed or not compatible; certain GPU-related functions will not be available.
(EngineCore_DP0 pid=971069) WARNING 09-26 18:06:33 [_logger.py:72] Pin memory is not supported on CPU.
(EngineCore_DP0 pid=971069) INFO 09-26 18:06:33 [cpu_worker.py:154] auto thread-binding list (id, physical core): [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5), (6, 6), (7, 7)]
(EngineCore_DP0 pid=971069) INFO 09-26 18:06:34 [cpu_worker.py:66] OMP threads binding of Process 971069:
(EngineCore_DP0 pid=971069) INFO 09-26 18:06:34 [cpu_worker.py:66]      OMP tid: 971069, core 0
(EngineCore_DP0 pid=971069) INFO 09-26 18:06:34 [cpu_worker.py:66]      OMP tid: 971114, core 1
(EngineCore_DP0 pid=971069) INFO 09-26 18:06:34 [cpu_worker.py:66]      OMP tid: 971115, core 2
(EngineCore_DP0 pid=971069) INFO 09-26 18:06:34 [cpu_worker.py:66]      OMP tid: 971116, core 3
(EngineCore_DP0 pid=971069) INFO 09-26 18:06:34 [cpu_worker.py:66]      OMP tid: 971117, core 4
(EngineCore_DP0 pid=971069) INFO 09-26 18:06:34 [cpu_worker.py:66]      OMP tid: 971118, core 5
(EngineCore_DP0 pid=971069) INFO 09-26 18:06:34 [cpu_worker.py:66]      OMP tid: 971119, core 6
(EngineCore_DP0 pid=971069) INFO 09-26 18:06:34 [cpu_worker.py:66]      OMP tid: 971120, core 7
(EngineCore_DP0 pid=971069) INFO 09-26 18:06:34 [cpu_worker.py:66]
[W926 18:06:34.229932146 ProcessGroupGloo.cpp:514] Warning: Unable to resolve hostname to a (local) address. Using the loopback address as fallback. Manually set the network interface to bind to with GLOO_SOCKET_IFNAME. (function operator())
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
[Gloo] Rank 0 is connected to 0 peer ranks. Expected number of connected peer ranks is : 0
(EngineCore_DP0 pid=971069) INFO 09-26 18:06:34 [parallel_state.py:1201] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0, EP rank 0
(EngineCore_DP0 pid=971069) INFO 09-26 18:06:34 [cpu_model_runner.py:106] Starting to load model LiquidAI/LFM2-1.2B...
(EngineCore_DP0 pid=971069) INFO 09-26 18:06:34 [cpu.py:101] Using Torch SDPA backend.
(EngineCore_DP0 pid=971069) INFO 09-26 18:06:35 [weight_utils.py:392] Using model weights format ['*.safetensors']
(EngineCore_DP0 pid=971069) INFO 09-26 18:06:35 [weight_utils.py:450] No model.safetensors.index.json found in remote.
Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.84it/s]
Loading safetensors checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  1.84it/s]
(EngineCore_DP0 pid=971069)
(EngineCore_DP0 pid=971069) INFO 09-26 18:06:36 [default_loader.py:267] Loading weights took 0.77 seconds
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708] EngineCore failed to start.
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708] Traceback (most recent call last):
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/v1/engine/core.py", line 699, in run_engine_core
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]     engine_core = EngineCoreProc(*args, **kwargs)
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/v1/engine/core.py", line 498, in __init__
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]     super().__init__(vllm_config, executor_class, log_stats,
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/v1/engine/core.py", line 83, in __init__
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]     self.model_executor = executor_class(vllm_config)
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/executor/executor_base.py", line 54, in __init__
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]     self._init_executor()
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/executor/uniproc_executor.py", line 55, in _init_executor
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]     self.collective_rpc("load_model")
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/executor/uniproc_executor.py", line 83, in collective_rpc
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]     return [run_method(self.driver_worker, method, args, kwargs)]
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/utils/__init__.py", line 3120, in run_method
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]     return func(*args, **kwargs)
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/v1/worker/gpu_worker.py", line 213, in load_model
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]     self.model_runner.load_model(eep_scale_up=eep_scale_up)
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/v1/worker/cpu_model_runner.py", line 107, in load_model
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]     self.model = get_model(vllm_config=self.vllm_config)
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/model_executor/model_loader/__init__.py", line 119, in get_model
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]     return loader.load_model(vllm_config=vllm_config,
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/model_executor/model_loader/base_loader.py", line 51, in load_model
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]     process_weights_after_loading(model, model_config, target_device)
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/model_executor/model_loader/utils.py", line 112, in process_weights_after_loading
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]     quant_method.process_weights_after_loading(module)
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/model_executor/layers/linear.py", line 229, in process_weights_after_loading
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]     dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/model_executor/layers/utils.py", line 153, in dispatch_cpu_unquantized_gemm
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]     N, K = layer.weight.size()
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708]     ^^^^
(EngineCore_DP0 pid=971069) ERROR 09-26 18:06:36 [core.py:708] ValueError: too many values to unpack (expected 2)
(EngineCore_DP0 pid=971069) Process EngineCore_DP0:
(EngineCore_DP0 pid=971069) Traceback (most recent call last):
(EngineCore_DP0 pid=971069)   File "/root/.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
(EngineCore_DP0 pid=971069)     self.run()
(EngineCore_DP0 pid=971069)   File "/root/.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/multiprocessing/process.py", line 108, in run
(EngineCore_DP0 pid=971069)     self._target(*self._args, **self._kwargs)
(EngineCore_DP0 pid=971069)   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/v1/engine/core.py", line 712, in run_engine_core
(EngineCore_DP0 pid=971069)     raise e
(EngineCore_DP0 pid=971069)   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/v1/engine/core.py", line 699, in run_engine_core
(EngineCore_DP0 pid=971069)     engine_core = EngineCoreProc(*args, **kwargs)
(EngineCore_DP0 pid=971069)                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=971069)   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/v1/engine/core.py", line 498, in __init__
(EngineCore_DP0 pid=971069)     super().__init__(vllm_config, executor_class, log_stats,
(EngineCore_DP0 pid=971069)   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/v1/engine/core.py", line 83, in __init__
(EngineCore_DP0 pid=971069)     self.model_executor = executor_class(vllm_config)
(EngineCore_DP0 pid=971069)                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=971069)   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/executor/executor_base.py", line 54, in __init__
(EngineCore_DP0 pid=971069)     self._init_executor()
(EngineCore_DP0 pid=971069)   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/executor/uniproc_executor.py", line 55, in _init_executor
(EngineCore_DP0 pid=971069)     self.collective_rpc("load_model")
(EngineCore_DP0 pid=971069)   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/executor/uniproc_executor.py", line 83, in collective_rpc
(EngineCore_DP0 pid=971069)     return [run_method(self.driver_worker, method, args, kwargs)]
(EngineCore_DP0 pid=971069)             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=971069)   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/utils/__init__.py", line 3120, in run_method
(EngineCore_DP0 pid=971069)     return func(*args, **kwargs)
(EngineCore_DP0 pid=971069)            ^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=971069)   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/v1/worker/gpu_worker.py", line 213, in load_model
(EngineCore_DP0 pid=971069)     self.model_runner.load_model(eep_scale_up=eep_scale_up)
(EngineCore_DP0 pid=971069)   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/v1/worker/cpu_model_runner.py", line 107, in load_model
(EngineCore_DP0 pid=971069)     self.model = get_model(vllm_config=self.vllm_config)
(EngineCore_DP0 pid=971069)                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=971069)   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/model_executor/model_loader/__init__.py", line 119, in get_model
(EngineCore_DP0 pid=971069)     return loader.load_model(vllm_config=vllm_config,
(EngineCore_DP0 pid=971069)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=971069)   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/model_executor/model_loader/base_loader.py", line 51, in load_model
(EngineCore_DP0 pid=971069)     process_weights_after_loading(model, model_config, target_device)
(EngineCore_DP0 pid=971069)   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/model_executor/model_loader/utils.py", line 112, in process_weights_after_loading
(EngineCore_DP0 pid=971069)     quant_method.process_weights_after_loading(module)
(EngineCore_DP0 pid=971069)   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/model_executor/layers/linear.py", line 229, in process_weights_after_loading
(EngineCore_DP0 pid=971069)     dispatch_cpu_unquantized_gemm(layer, remove_weight=True)
(EngineCore_DP0 pid=971069)   File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/model_executor/layers/utils.py", line 153, in dispatch_cpu_unquantized_gemm
(EngineCore_DP0 pid=971069)     N, K = layer.weight.size()
(EngineCore_DP0 pid=971069)     ^^^^
(EngineCore_DP0 pid=971069) ValueError: too many values to unpack (expected 2)
Traceback (most recent call last):
  File "/opt/testingliquid/test.py", line 10, in <module>
    main()
  File "/opt/testingliquid/test.py", line 4, in main
    llm = LLM(model="LiquidAI/LFM2-1.2B", dtype="float32")
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/entrypoints/llm.py", line 293, in __init__
    self.llm_engine = LLMEngine.from_engine_args(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/v1/engine/llm_engine.py", line 177, in from_engine_args
    return cls(vllm_config=vllm_config,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/v1/engine/llm_engine.py", line 114, in __init__
    self.engine_core = EngineCoreClient.make_client(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/v1/engine/core_client.py", line 80, in make_client
    return SyncMPClient(vllm_config, executor_class, log_stats)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/v1/engine/core_client.py", line 602, in __init__
    super().__init__(
  File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/v1/engine/core_client.py", line 448, in __init__
    with launch_core_engines(vllm_config, executor_class,
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/.local/share/uv/python/cpython-3.12.11-linux-x86_64-gnu/lib/python3.12/contextlib.py", line 144, in __exit__
    next(self.gen)
  File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/v1/engine/utils.py", line 732, in launch_core_engines
    wait_for_engine_startup(
  File "/opt/liquid/.venv/lib/python3.12/site-packages/vllm-0.11.0rc2.dev153+gdb1e42f62.cpu-py3.12-linux-x86_64.egg/vllm/v1/engine/utils.py", line 785, in wait_for_engine_startup
    raise RuntimeError("Engine core initialization failed. "
RuntimeError: Engine core initialization failed. See root cause above. Failed core proc(s): {'EngineCore_DP0': 1}

I got the same issue with all variants models.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions