diff --git a/README.md b/README.md index cb81098..99ccaf9 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,8 @@ To more complete installation instructions and usage, please refer to the [Insta 4. **Optional Optimization**: Enable [custom kernels](docs/kernels.md) for faster inference and reduced memory usage +5. **Custom MSA Servers**: For enterprise or private deployments requiring authentication, see the [MSA Server Authentication](docs/usage.md#msa-server-authentication) section in our usage guide. + For comprehensive usage instructions and examples, refer to the [Usage Guide](docs/usage.md). diff --git a/docs/usage.md b/docs/usage.md index e44710f..bfa301c 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -26,6 +26,43 @@ intellifold predict ./examples/5S8I_A.yaml --out_dir ./output --only_run_data_pr ``` +## MSA Server Authentication + +When using custom MSA servers that require authentication (such as enterprise or private deployments), IntelliFold supports both basic authentication and API key authentication. + +### Quick Setup + +**For API Key Authentication (recommended):** +```bash +# Set environment variable (secure) +export MSA_API_KEY_VALUE=your-api-key + +# Run with custom MSA server +intellifold predict input.yaml --out_dir ./output --use_msa_server + --msa_server_url https://your-msa-server.com + --api_key_header X-API-Key +``` + +**For Basic Authentication:** +```bash +# Set environment variables (secure) +export MSA_USERNAME=your-username +export MSA_PASSWORD=your-password + +# Run with custom MSA server +intellifold predict input.yaml --out_dir ./output --use_msa_server + --msa_server_url https://your-msa-server.com +``` + +### Authentication Options + +- **Environment variables** (recommended for security): `MSA_API_KEY_VALUE`, `MSA_USERNAME`, `MSA_PASSWORD` +- **Custom headers**: Use `--api_key_header` for different API key header names (e.g., `X-Gravitee-Api-Key`) +- **Multiple auth types**: Cannot use both basic auth and API key simultaneously +- **Backward compatibility**: All authentication is optional - existing workflows continue to work + +The public ColabFold server (`https://api.colabfold.com`) requires no authentication and remains the default. + ### Run with Bash Script The aurguments is the same as `intellifold predict`, and you can set the parameters in the script. @@ -60,6 +97,14 @@ Common arguments of this `scripts`/`intellifold predict` are explained as follow Whether to use the MMSeqs2 server for MSA generation. * `--msa_server_url` (`str`, default: `https://api.colabfold.com`) MSA server url. Used only if `--use_msa_server` is set. +* `--msa_server_username` (`str`, default: `None`) + Username for basic authentication to MSA server. Can use environment variable `MSA_USERNAME`. +* `--msa_server_password` (`str`, default: `None`) + Password for basic authentication to MSA server. Can use environment variable `MSA_PASSWORD`. +* `--api_key_header` (`str`, default: `X-API-Key`) + Header name for API key authentication to MSA server. +* `--api_key_value` (`str`, default: `None`) + API key value for authentication to MSA server. Can use environment variable `MSA_API_KEY_VALUE`. * `--msa_pairing_strategy` (`str`, default: `complete`) Pairing strategy to use. Used only if `--use_msa_server` is set. Options are 'greedy' and 'complete'. * `--no_pairing` (`FLAG`, default: `False`) diff --git a/intellifold/data/inference/data_tools.py b/intellifold/data/inference/data_tools.py index 755fc7d..7ca05c0 100644 --- a/intellifold/data/inference/data_tools.py +++ b/intellifold/data/inference/data_tools.py @@ -136,6 +136,10 @@ def compute_msa( msa_server_url: str, msa_pairing_strategy: str, use_pairing=True, + msa_server_username: str = None, + msa_server_password: str = None, + api_key_header: str = "X-API-Key", + api_key_value: str = None, ) -> None: """Compute the MSA for the input data. @@ -151,8 +155,20 @@ def compute_msa( The MSA server URL. msa_pairing_strategy : str The MSA pairing strategy. + use_pairing : bool, optional + Whether to use pairing, by default True. + msa_server_username : str, optional + Username for basic authentication to MSA server. + msa_server_password : str, optional + Password for basic authentication to MSA server. + api_key_header : str, optional + Header name for API key authentication, by default "X-API-Key". + api_key_value : str, optional + API key value for authentication to MSA server. """ + logger.info(f"Starting MSA generation for target '{target_id}' with {len(data)} sequences") + if len(data) > 1 and use_pairing: paired_msas = run_mmseqs2( list(data.values()), @@ -161,6 +177,10 @@ def compute_msa( use_pairing=True, host_url=msa_server_url, pairing_strategy=msa_pairing_strategy, + msa_server_username=msa_server_username, + msa_server_password=msa_server_password, + api_key_header=api_key_header, + api_key_value=api_key_value, ) else: paired_msas = [""] * len(data) @@ -172,6 +192,10 @@ def compute_msa( use_pairing=False, host_url=msa_server_url, pairing_strategy=msa_pairing_strategy, + msa_server_username=msa_server_username, + msa_server_password=msa_server_password, + api_key_header=api_key_header, + api_key_value=api_key_value, ) for idx, name in enumerate(data): @@ -328,6 +352,10 @@ def process_inputs( # noqa: C901, PLR0912, PLR0915 max_msa_seqs: int = 4096, use_msa_server: bool = False, use_pairing: bool = True, + msa_server_username: str = None, + msa_server_password: str = None, + api_key_header: str = "X-API-Key", + api_key_value: str = None, ) -> None: """Process the input data and output directory. @@ -343,6 +371,14 @@ def process_inputs( # noqa: C901, PLR0912, PLR0915 Max number of MSA sequences, by default 4096. use_msa_server : bool, optional Whether to use the MMSeqs2 server for MSA generation, by default False. + msa_server_username : str, optional + Username for basic authentication to MSA server. + msa_server_password : str, optional + Password for basic authentication to MSA server. + api_key_header : str, optional + Header name for API key authentication, by default "X-API-Key". + api_key_value : str, optional + API key value for authentication to MSA server. Returns ------- @@ -485,6 +521,10 @@ def process_inputs( # noqa: C901, PLR0912, PLR0915 msa_server_url=msa_server_url, msa_pairing_strategy=msa_pairing_strategy, use_pairing=use_pairing, + msa_server_username=msa_server_username, + msa_server_password=msa_server_password, + api_key_header=api_key_header, + api_key_value=api_key_value, ) # Parse MSA data diff --git a/intellifold/data/msa/mmseqs2.py b/intellifold/data/msa/mmseqs2.py index 92c6843..a967898 100644 --- a/intellifold/data/msa/mmseqs2.py +++ b/intellifold/data/msa/mmseqs2.py @@ -46,12 +46,54 @@ def run_mmseqs2( # noqa: PLR0912, D103, C901, PLR0915 use_pairing: bool = False, pairing_strategy: str = "greedy", host_url: str = "https://api.colabfold.com", + msa_server_username: str = None, + msa_server_password: str = None, + api_key_header: str = "X-API-Key", + api_key_value: str = None, ) -> tuple[list[str], list[str]]: + """ + Run MMSeqs2 server query for MSA generation. + + Args: + x: Input sequence(s) as string or list of strings. + prefix: Prefix for temporary files. + use_env: Whether to use environmental databases. + use_filter: Whether to use filtering. + use_pairing: Whether to use pairing mode. + pairing_strategy: Strategy for pairing ('greedy' or 'complete'). + host_url: URL of the MSA server. + msa_server_username: Username for basic authentication. + msa_server_password: Password for basic authentication. + api_key_header: Header name for API key authentication. + api_key_value: API key value for authentication. + + Returns: + Tuple of MSA results as list of strings. + """ submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa" + # Log MSA server information + auth_method = "no authentication" + if msa_server_username and msa_server_password: + auth_method = f"basic authentication (user: {msa_server_username})" + elif api_key_value: + auth_method = f"API key authentication (header: {api_key_header})" + + logger.info(f"Connecting to MSA server: {host_url} with {auth_method}") + # Set header agent as intellifold headers = {} headers["User-Agent"] = "intellifold" + + # Configure authentication + auth = None + if msa_server_username and msa_server_password: + # Basic authentication + from requests.auth import HTTPBasicAuth + auth = HTTPBasicAuth(msa_server_username, msa_server_password) + elif api_key_value: + # API key authentication + headers[api_key_header] = api_key_value def submit(seqs, mode, N=101): n, query = N, "" @@ -69,10 +111,15 @@ def submit(seqs, mode, N=101): data={"q": query, "mode": mode}, timeout=6.02, headers=headers, + auth=auth, ) + res.raise_for_status() # Raises HTTPError for non-2xx status codes except requests.exceptions.Timeout: logger.warning("Timeout while submitting to MSA server. Retrying...") continue + except requests.exceptions.HTTPError as e: + logger.error(f"MSA server error {res.status_code}: {res.text} - {e}") + raise Exception(f"MSA server error {res.status_code}") from e except Exception as e: error_count += 1 logger.warning( @@ -97,13 +144,17 @@ def status(ID): error_count = 0 try: res = requests.get( - f"{host_url}/ticket/{ID}", timeout=6.02, headers=headers + f"{host_url}/ticket/{ID}", timeout=6.02, headers=headers, auth=auth ) + res.raise_for_status() # Raises HTTPError for non-2xx status codes except requests.exceptions.Timeout: logger.warning( "Timeout while fetching status from MSA server. Retrying..." ) continue + except requests.exceptions.HTTPError as e: + logger.error(f"MSA server error {res.status_code} while checking status: {e}") + raise Exception(f"MSA server error {res.status_code}") from e except Exception as e: error_count += 1 logger.warning( @@ -127,13 +178,17 @@ def download(ID, path): while True: try: res = requests.get( - f"{host_url}/result/download/{ID}", timeout=6.02, headers=headers + f"{host_url}/result/download/{ID}", timeout=6.02, headers=headers, auth=auth ) + res.raise_for_status() # Raises HTTPError for non-2xx status codes except requests.exceptions.Timeout: logger.warning( "Timeout while fetching result from MSA server. Retrying..." ) continue + except requests.exceptions.HTTPError as e: + logger.error(f"MSA server error {res.status_code} while downloading results: {e}") + raise Exception(f"MSA server error {res.status_code} while downloading results") from e except Exception as e: error_count += 1 logger.warning( @@ -217,7 +272,7 @@ def download(ID, path): pbar.set_description(out["status"]) while out["status"] in ["UNKNOWN", "RUNNING", "PENDING"]: t = 5 + random.randint(0, 5) - logger.error(f"Sleeping for {t}s. Reason: {out['status']}") + logger.info(f"MSA processing in progress, waiting {t}s. Status: {out.get('status', 'UNKNOWN')}") time.sleep(t) out = status(ID) pbar.set_description(out["status"]) diff --git a/run_intellifold.py b/run_intellifold.py index 74d50b2..d8946bc 100644 --- a/run_intellifold.py +++ b/run_intellifold.py @@ -151,6 +151,30 @@ def main(args): # # DO SOME INITIAL SETUP # #~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ init_logging() + + # Validate authentication arguments if MSA server is used + if args.use_msa_server: + # Get credentials from environment variables if not provided via CLI + msa_username = args.msa_server_username or os.getenv("MSA_USERNAME") + msa_password = args.msa_server_password or os.getenv("MSA_PASSWORD") + api_key_value = args.api_key_value or os.getenv("MSA_API_KEY_VALUE") + + # Check if both basic auth and API key are provided + has_basic_auth = msa_username and msa_password + has_api_key = api_key_value + + if has_basic_auth and has_api_key: + raise ValueError( + "Only one authentication method (basic or API key) can be used at a time. " + "Please provide either --msa_server_username/--msa_server_password OR " + "--api_key_value, but not both." + ) + + # Update args with environment variables if needed + args.msa_server_username = msa_username + args.msa_server_password = msa_password + args.api_key_value = api_key_value + seeds = args.seed seeds = list(map(int, seeds.split(","))) set_seed(seeds[0]) @@ -204,6 +228,10 @@ def main(args): msa_pairing_strategy=args.msa_pairing_strategy, max_msa_seqs=16384, use_pairing=not args.no_pairing, + msa_server_username=args.msa_server_username, + msa_server_password=args.msa_server_password, + api_key_header=args.api_key_header, + api_key_value=args.api_key_value, ) if args.return_similar_seq: compute_similar_sequence( @@ -463,6 +491,30 @@ def main(args): help="MSA server url. Used only if --use_msa_server is set.", default="https://api.colabfold.com", ) + parser.add_argument( + "--msa_server_username", + type=str, + help="Username for basic authentication to MSA server. Can also be set via MSA_USERNAME environment variable.", + default=None, + ) + parser.add_argument( + "--msa_server_password", + type=str, + help="Password for basic authentication to MSA server. Can also be set via MSA_PASSWORD environment variable (recommended).", + default=None, + ) + parser.add_argument( + "--api_key_header", + type=str, + help="Header name for API key authentication to MSA server.", + default="X-API-Key", + ) + parser.add_argument( + "--api_key_value", + type=str, + help="API key value for authentication to MSA server. Can also be set via MSA_API_KEY_VALUE environment variable (recommended).", + default=None, + ) parser.add_argument( "--msa_pairing_strategy", type=str, diff --git a/runner/intellifold_inference.py b/runner/intellifold_inference.py index fd9812f..5b29bc5 100644 --- a/runner/intellifold_inference.py +++ b/runner/intellifold_inference.py @@ -204,6 +204,10 @@ def main(args): msa_pairing_strategy=args.msa_pairing_strategy, max_msa_seqs=16384, use_pairing=not args.no_pairing, + msa_server_username=args.msa_server_username, + msa_server_password=args.msa_server_password, + api_key_header=args.api_key_header, + api_key_value=args.api_key_value, ) if args.return_similar_seq: compute_similar_sequence( @@ -456,6 +460,30 @@ def intellifold_cli(): help="MSA server url. Used only if --use_msa_server is set.", default="https://api.colabfold.com", ) +@click.option( + "--msa_server_username", + type=str, + help="Username for basic authentication to MSA server. Can also be set via MSA_USERNAME environment variable.", + default=None, +) +@click.option( + "--msa_server_password", + type=str, + help="Password for basic authentication to MSA server. Can also be set via MSA_PASSWORD environment variable (recommended).", + default=None, +) +@click.option( + "--api_key_header", + type=str, + help="Header name for API key authentication to MSA server.", + default="X-API-Key", +) +@click.option( + "--api_key_value", + type=str, + help="API key value for authentication to MSA server. Can also be set via MSA_API_KEY_VALUE environment variable (recommended).", + default=None, +) @click.option( "--msa_pairing_strategy", type=str, @@ -496,6 +524,10 @@ def predict( override: bool, use_msa_server: bool, msa_server_url: str, + msa_server_username: str, + msa_server_password: str, + api_key_header: str, + api_key_value: str, msa_pairing_strategy: str, no_pairing: bool, only_run_data_process: bool, @@ -517,11 +549,40 @@ def predict( override=override, use_msa_server=use_msa_server, msa_server_url=msa_server_url, + msa_server_username=msa_server_username, + msa_server_password=msa_server_password, + api_key_header=api_key_header, + api_key_value=api_key_value, msa_pairing_strategy=msa_pairing_strategy, no_pairing=no_pairing, only_run_data_process=only_run_data_process, return_similar_seq=return_similar_seq, ) + + # Validate authentication arguments if MSA server is used + if args.use_msa_server: + import os + + # Get credentials from environment variables if not provided via CLI + msa_username = args.msa_server_username or os.getenv("MSA_USERNAME") + msa_password = args.msa_server_password or os.getenv("MSA_PASSWORD") + api_key_value = args.api_key_value or os.getenv("MSA_API_KEY_VALUE") + + # Check if both basic auth and API key are provided + has_basic_auth = msa_username and msa_password + has_api_key = api_key_value + + if has_basic_auth and has_api_key: + raise ValueError( + "Only one authentication method (basic or API key) can be used at a time. " + "Please provide either --msa_server_username/--msa_server_password OR " + "--api_key_value, but not both." + ) + + # Update args with environment variables if needed + args.msa_server_username = msa_username + args.msa_server_password = msa_password + args.api_key_value = api_key_value main(args=args) intellifold_cli.add_command(predict)