diff --git a/examples/benchmarks/LightGBM/multi_freq_handler.py b/examples/benchmarks/LightGBM/multi_freq_handler.py index 1d4ba2b82b..5b0569f136 100644 --- a/examples/benchmarks/LightGBM/multi_freq_handler.py +++ b/examples/benchmarks/LightGBM/multi_freq_handler.py @@ -4,18 +4,27 @@ import pandas as pd from qlib.data.dataset.loader import QlibDataLoader -from qlib.contrib.data.handler import DataHandlerLP, _DEFAULT_LEARN_PROCESSORS, check_transform_proc +from qlib.contrib.data.handler import ( + DataHandlerLP, + _DEFAULT_LEARN_PROCESSORS, + check_transform_proc, +) class Avg15minLoader(QlibDataLoader): def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: df = super(Avg15minLoader, self).load(instruments, start_time, end_time) if self.is_group: - # feature_day(day freq) and feature_15min(1min freq, Average every 15 minutes) renamed feature - df.columns = df.columns.map(lambda x: ("feature", x[1]) if x[0].startswith("feature") else x) + # Normalize feature_day (day freq) and feature_15min (1min freq) to unified "feature" group + df.columns = df.columns.map( + lambda x: ("feature", x[1]) + if isinstance(x, tuple) and len(x) >= 2 and isinstance(x[0], str) and x[0].startswith("feature") + else x + ) return df + class Avg15minHandler(DataHandlerLP): def __init__( self, @@ -32,10 +41,17 @@ def __init__( inst_processors=None, **kwargs, ): - infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) - learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) + infer_processors = check_transform_proc( + infer_processors, fit_start_time, fit_end_time + ) + learn_processors = check_transform_proc( + learn_processors, fit_start_time, fit_end_time + ) data_loader = Avg15minLoader( - config=self.loader_config(), filter_pipe=filter_pipe, freq=freq, inst_processors=inst_processors + config=self.loader_config(), + filter_pipe=filter_pipe, + freq=freq, + inst_processors=inst_processors, ) super().__init__( instruments=instruments, @@ -123,7 +139,10 @@ def loader_config(self): tmp_names = [] for i, _f in enumerate(fields): _fields = [f"Ref(Mean({_f}, 15), {j * 15})" for j in range(1, 240 // 15)] - _names = [f"{names[i][:-1]}{int(names[i][-1])+j}" for j in range(240 // 15 - 1, 0, -1)] + _names = [ + f"{names[i][:-1]}{int(names[i][-1])+j}" + for j in range(240 // 15 - 1, 0, -1) + ] _fields.append(f"Mean({_f}, 15)") _names.append(f"{names[i][:-1]}{int(names[i][-1])+240 // 15}") tmp_fields += _fields diff --git a/examples/benchmarks/TFT/data_formatters/base.py b/examples/benchmarks/TFT/data_formatters/base.py index 9cdce6382d..1bf3fb34b0 100644 --- a/examples/benchmarks/TFT/data_formatters/base.py +++ b/examples/benchmarks/TFT/data_formatters/base.py @@ -142,7 +142,11 @@ def _check_single_column(input_type): length = len([tup for tup in column_definition if tup[2] == input_type]) if length != 1: - raise ValueError("Illegal number of inputs ({}) of type {}".format(length, input_type)) + raise ValueError( + "Illegal number of inputs ({}) of type {}".format( + length, input_type + ) + ) _check_single_column(InputTypes.ID) _check_single_column(InputTypes.TIME) @@ -152,45 +156,66 @@ def _check_single_column(input_type): real_inputs = [ tup for tup in column_definition - if tup[1] == DataTypes.REAL_VALUED and tup[2] not in {InputTypes.ID, InputTypes.TIME} + if tup[1] == DataTypes.REAL_VALUED + and tup[2] not in {InputTypes.ID, InputTypes.TIME} ] categorical_inputs = [ tup for tup in column_definition - if tup[1] == DataTypes.CATEGORICAL and tup[2] not in {InputTypes.ID, InputTypes.TIME} + if tup[1] == DataTypes.CATEGORICAL + and tup[2] not in {InputTypes.ID, InputTypes.TIME} ] return identifier + time + real_inputs + categorical_inputs def _get_input_columns(self): """Returns names of all input columns.""" - return [tup[0] for tup in self.get_column_definition() if tup[2] not in {InputTypes.ID, InputTypes.TIME}] + return [ + tup[0] + for tup in self.get_column_definition() + if tup[2] not in {InputTypes.ID, InputTypes.TIME} + ] def _get_tft_input_indices(self): """Returns the relevant indexes and input sizes required by TFT.""" # Functions def _extract_tuples_from_data_type(data_type, defn): - return [tup for tup in defn if tup[1] == data_type and tup[2] not in {InputTypes.ID, InputTypes.TIME}] + return [ + tup + for tup in defn + if tup[1] == data_type + and tup[2] not in {InputTypes.ID, InputTypes.TIME} + ] def _get_locations(input_types, defn): return [i for i, tup in enumerate(defn) if tup[2] in input_types] # Start extraction column_definition = [ - tup for tup in self.get_column_definition() if tup[2] not in {InputTypes.ID, InputTypes.TIME} + tup + for tup in self.get_column_definition() + if tup[2] not in {InputTypes.ID, InputTypes.TIME} ] - categorical_inputs = _extract_tuples_from_data_type(DataTypes.CATEGORICAL, column_definition) - real_inputs = _extract_tuples_from_data_type(DataTypes.REAL_VALUED, column_definition) + categorical_inputs = _extract_tuples_from_data_type( + DataTypes.CATEGORICAL, column_definition + ) + real_inputs = _extract_tuples_from_data_type( + DataTypes.REAL_VALUED, column_definition + ) locations = { "input_size": len(self._get_input_columns()), "output_size": len(_get_locations({InputTypes.TARGET}, column_definition)), "category_counts": self.num_classes_per_cat_input, "input_obs_loc": _get_locations({InputTypes.TARGET}, column_definition), - "static_input_loc": _get_locations({InputTypes.STATIC_INPUT}, column_definition), - "known_regular_inputs": _get_locations({InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT}, real_inputs), + "static_input_loc": _get_locations( + {InputTypes.STATIC_INPUT}, column_definition + ), + "known_regular_inputs": _get_locations( + {InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT}, real_inputs + ), "known_categorical_inputs": _get_locations( {InputTypes.STATIC_INPUT, InputTypes.KNOWN_INPUT}, categorical_inputs ), @@ -213,7 +238,9 @@ def get_experiment_params(self): for k in required_keys: if k not in fixed_params: - raise ValueError("Field {}".format(k) + " missing from fixed parameter definitions!") + raise ValueError( + "Field {}".format(k) + " missing from fixed parameter definitions!" + ) fixed_params["column_definition"] = self.get_column_definition() diff --git a/examples/benchmarks/TFT/data_formatters/qlib_Alpha158.py b/examples/benchmarks/TFT/data_formatters/qlib_Alpha158.py index a2afcc8142..9032d6b1c3 100644 --- a/examples/benchmarks/TFT/data_formatters/qlib_Alpha158.py +++ b/examples/benchmarks/TFT/data_formatters/qlib_Alpha158.py @@ -110,8 +110,12 @@ def set_scalers(self, df): print("Setting scalers with training data...") column_definitions = self.get_column_definition() - id_column = utils.get_single_col_by_input_type(InputTypes.ID, column_definitions) - target_column = utils.get_single_col_by_input_type(InputTypes.TARGET, column_definitions) + id_column = utils.get_single_col_by_input_type( + InputTypes.ID, column_definitions + ) + target_column = utils.get_single_col_by_input_type( + InputTypes.TARGET, column_definitions + ) # Extract identifiers in case required self.identifiers = list(df[id_column].unique()) @@ -137,7 +141,9 @@ def set_scalers(self, df): for col in categorical_inputs: # Set all to str so that we don't have mixed integer/string columns srs = df[col].apply(str) - categorical_scalers[col] = sklearn.preprocessing.LabelEncoder().fit(srs.values) + categorical_scalers[col] = sklearn.preprocessing.LabelEncoder().fit( + srs.values + ) num_classes.append(srs.nunique()) # Set categorical scaler outputs diff --git a/examples/benchmarks/TFT/expt_settings/configs.py b/examples/benchmarks/TFT/expt_settings/configs.py index 55eb32a0b1..20c426d58f 100644 --- a/examples/benchmarks/TFT/expt_settings/configs.py +++ b/examples/benchmarks/TFT/expt_settings/configs.py @@ -54,7 +54,9 @@ def __init__(self, experiment="volatility", root_folder=None): # Defines all relevant paths if root_folder is None: - root_folder = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "outputs") + root_folder = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "..", "outputs" + ) print("Using root folder {}".format(root_folder)) self.root_folder = root_folder @@ -64,7 +66,12 @@ def __init__(self, experiment="volatility", root_folder=None): self.results_folder = os.path.join(root_folder, "results", experiment) # Creates folders if they don't exist - for relevant_directory in [self.root_folder, self.data_folder, self.model_folder, self.results_folder]: + for relevant_directory in [ + self.root_folder, + self.data_folder, + self.model_folder, + self.results_folder, + ]: if not os.path.exists(relevant_directory): os.makedirs(relevant_directory) diff --git a/examples/benchmarks/TFT/libs/hyperparam_opt.py b/examples/benchmarks/TFT/libs/hyperparam_opt.py index 86f587d7db..1dfa6dcbe7 100644 --- a/examples/benchmarks/TFT/libs/hyperparam_opt.py +++ b/examples/benchmarks/TFT/libs/hyperparam_opt.py @@ -48,7 +48,9 @@ class HyperparamOptManager: hyperparam_folder: Where to save optimisation outputs. """ - def __init__(self, param_ranges, fixed_params, model_folder, override_w_fixed_params=True): + def __init__( + self, param_ranges, fixed_params, model_folder, override_w_fixed_params=True + ): """Instantiates model. Args: @@ -136,9 +138,17 @@ def _check_params(self, params): missing_fields = [k for k in valid_fields if k not in params] if invalid_fields: - raise ValueError("Invalid Fields Found {} - Valid ones are {}".format(invalid_fields, valid_fields)) + raise ValueError( + "Invalid Fields Found {} - Valid ones are {}".format( + invalid_fields, valid_fields + ) + ) if missing_fields: - raise ValueError("Missing Fields Found {} - Valid ones are {}".format(missing_fields, valid_fields)) + raise ValueError( + "Missing Fields Found {} - Valid ones are {}".format( + missing_fields, valid_fields + ) + ) def _get_name(self, params): """Returns a unique key for the supplied set of params.""" @@ -168,7 +178,9 @@ def get_next_parameters(self, ranges_to_skip=None): def _get_next(): """Returns next hyperparameter set per try.""" - parameters = {k: np.random.choice(self.param_ranges[k]) for k in param_range_keys} + parameters = { + k: np.random.choice(self.param_ranges[k]) for k in param_range_keys + } # Adds fixed params for k in self.fixed_params: @@ -265,7 +277,9 @@ def __init__( # Sanity checks if worker_number > max_workers: raise ValueError( - "Worker number ({}) cannot be larger than the total number of workers!".format(max_workers) + "Worker number ({}) cannot be larger than the total number of workers!".format( + max_workers + ) ) if worker_number > search_iterations: raise ValueError( @@ -274,10 +288,16 @@ def __init__( ) ) - print("*** Creating hyperparameter manager for worker {} ***".format(worker_number)) + print( + "*** Creating hyperparameter manager for worker {} ***".format( + worker_number + ) + ) hyperparam_folder = os.path.join(root_model_folder, str(worker_number)) - super().__init__(param_ranges, fixed_params, hyperparam_folder, override_w_fixed_params=True) + super().__init__( + param_ranges, fixed_params, hyperparam_folder, override_w_fixed_params=True + ) serialised_ranges_folder = os.path.join(root_model_folder, "hyperparams") if clear_serialised_params: @@ -287,7 +307,9 @@ def __init__( utils.create_folder_if_not_exist(serialised_ranges_folder) - self.serialised_ranges_path = os.path.join(serialised_ranges_folder, "ranges_{}.csv".format(search_iterations)) + self.serialised_ranges_path = os.path.join( + serialised_ranges_folder, "ranges_{}.csv".format(search_iterations) + ) self.hyperparam_folder = hyperparam_folder # override self.worker_num = worker_number self.total_search_iterations = search_iterations @@ -421,7 +443,12 @@ def assign_worker_numbers(self, df): max_worker_num = int(np.ceil(n / batch_size)) - worker_idx = np.concatenate([np.tile(i + 1, self.num_iterations_per_worker) for i in range(max_worker_num)]) + worker_idx = np.concatenate( + [ + np.tile(i + 1, self.num_iterations_per_worker) + for i in range(max_worker_num) + ] + ) output["worker"] = worker_idx[: len(output)] diff --git a/examples/benchmarks/TFT/libs/tft_model.py b/examples/benchmarks/TFT/libs/tft_model.py index f3b6dda34f..65624455d2 100644 --- a/examples/benchmarks/TFT/libs/tft_model.py +++ b/examples/benchmarks/TFT/libs/tft_model.py @@ -68,7 +68,12 @@ def linear_layer(size, activation=None, use_time_distributed=False, use_bias=Tru def apply_mlp( - inputs, hidden_size, output_size, output_activation=None, hidden_activation="tanh", use_time_distributed=False + inputs, + hidden_size, + output_size, + output_activation=None, + hidden_activation="tanh", + use_time_distributed=False, ): """Applies simple feed-forward network to an input. @@ -84,16 +89,22 @@ def apply_mlp( Tensor for MLP outputs. """ if use_time_distributed: - hidden = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(hidden_size, activation=hidden_activation))( + hidden = tf.keras.layers.TimeDistributed( + tf.keras.layers.Dense(hidden_size, activation=hidden_activation) + )(inputs) + return tf.keras.layers.TimeDistributed( + tf.keras.layers.Dense(output_size, activation=output_activation) + )(hidden) + else: + hidden = tf.keras.layers.Dense(hidden_size, activation=hidden_activation)( inputs ) - return tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(output_size, activation=output_activation))(hidden) - else: - hidden = tf.keras.layers.Dense(hidden_size, activation=hidden_activation)(inputs) return tf.keras.layers.Dense(output_size, activation=output_activation)(hidden) -def apply_gating_layer(x, hidden_layer_size, dropout_rate=None, use_time_distributed=True, activation=None): +def apply_gating_layer( + x, hidden_layer_size, dropout_rate=None, use_time_distributed=True, activation=None +): """Applies a Gated Linear Unit (GLU) to an input. Args: @@ -115,9 +126,13 @@ def apply_gating_layer(x, hidden_layer_size, dropout_rate=None, use_time_distrib activation_layer = tf.keras.layers.TimeDistributed( tf.keras.layers.Dense(hidden_layer_size, activation=activation) )(x) - gated_layer = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(hidden_layer_size, activation="sigmoid"))(x) + gated_layer = tf.keras.layers.TimeDistributed( + tf.keras.layers.Dense(hidden_layer_size, activation="sigmoid") + )(x) else: - activation_layer = tf.keras.layers.Dense(hidden_layer_size, activation=activation)(x) + activation_layer = tf.keras.layers.Dense( + hidden_layer_size, activation=activation + )(x) gated_layer = tf.keras.layers.Dense(hidden_layer_size, activation="sigmoid")(x) return tf.keras.layers.Multiply()([activation_layer, gated_layer]), gated_layer @@ -172,16 +187,27 @@ def gated_residual_network( skip = linear(x) # Apply feedforward network - hidden = linear_layer(hidden_layer_size, activation=None, use_time_distributed=use_time_distributed)(x) + hidden = linear_layer( + hidden_layer_size, activation=None, use_time_distributed=use_time_distributed + )(x) if additional_context is not None: hidden = hidden + linear_layer( - hidden_layer_size, activation=None, use_time_distributed=use_time_distributed, use_bias=False + hidden_layer_size, + activation=None, + use_time_distributed=use_time_distributed, + use_bias=False, )(additional_context) hidden = tf.keras.layers.Activation("elu")(hidden) - hidden = linear_layer(hidden_layer_size, activation=None, use_time_distributed=use_time_distributed)(hidden) + hidden = linear_layer( + hidden_layer_size, activation=None, use_time_distributed=use_time_distributed + )(hidden) gating_layer, gate = apply_gating_layer( - hidden, output_size, dropout_rate=dropout_rate, use_time_distributed=use_time_distributed, activation=None + hidden, + output_size, + dropout_rate=dropout_rate, + use_time_distributed=use_time_distributed, + activation=None, ) if return_gate: @@ -229,9 +255,13 @@ def __call__(self, q, k, v, mask): Tuple of (layer outputs, attention weights) """ temper = tf.sqrt(tf.cast(tf.shape(k)[-1], dtype="float32")) - attn = Lambda(lambda x: K.batch_dot(x[0], x[1], axes=[2, 2]) / temper)([q, k]) # shape=(batch, q, k) + attn = Lambda(lambda x: K.batch_dot(x[0], x[1], axes=[2, 2]) / temper)( + [q, k] + ) # shape=(batch, q, k) if mask is not None: - mmask = Lambda(lambda x: (-1e9) * (1.0 - K.cast(x, "float32")))(mask) # setting to infinity + mmask = Lambda(lambda x: (-1e9) * (1.0 - K.cast(x, "float32")))( + mask + ) # setting to infinity attn = Add()([attn, mmask]) attn = self.activation(attn) attn = self.dropout(attn) @@ -403,7 +433,9 @@ def __init__(self, raw_params, use_cudnn=False): self._input_obs_loc = json.loads(str(params["input_obs_loc"])) self._static_input_loc = json.loads(str(params["static_input_loc"])) self._known_regular_input_idx = json.loads(str(params["known_regular_inputs"])) - self._known_categorical_input_idx = json.loads(str(params["known_categorical_inputs"])) + self._known_categorical_input_idx = json.loads( + str(params["known_categorical_inputs"]) + ) self.column_definition = params["column_definition"] @@ -471,7 +503,9 @@ def get_tft_embeddings(self, all_inputs): num_categorical_variables = len(self.category_counts) num_regular_variables = self.input_size - num_categorical_variables - embedding_sizes = [self.hidden_layer_size for i, size in enumerate(self.category_counts)] + embedding_sizes = [ + self.hidden_layer_size for i, size in enumerate(self.category_counts) + ] embeddings = [] for i in range(num_categorical_variables): @@ -479,7 +513,10 @@ def get_tft_embeddings(self, all_inputs): [ tf.keras.layers.InputLayer([time_steps]), tf.keras.layers.Embedding( - self.category_counts[i], embedding_sizes[i], input_length=time_steps, dtype=tf.float32 + self.category_counts[i], + embedding_sizes[i], + input_length=time_steps, + dtype=tf.float32, ), ] ) @@ -490,12 +527,17 @@ def get_tft_embeddings(self, all_inputs): all_inputs[:, :, num_regular_variables:], ) - embedded_inputs = [embeddings[i](categorical_inputs[Ellipsis, i]) for i in range(num_categorical_variables)] + embedded_inputs = [ + embeddings[i](categorical_inputs[Ellipsis, i]) + for i in range(num_categorical_variables) + ] # Static inputs if self._static_input_loc: static_inputs = [ - tf.keras.layers.Dense(self.hidden_layer_size)(regular_inputs[:, 0, i : i + 1]) + tf.keras.layers.Dense(self.hidden_layer_size)( + regular_inputs[:, 0, i : i + 1] + ) for i in range(num_regular_variables) if i in self._static_input_loc ] + [ @@ -510,17 +552,26 @@ def get_tft_embeddings(self, all_inputs): def convert_real_to_embedding(x): """Applies linear transformation for time-varying inputs.""" - return tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(self.hidden_layer_size))(x) + return tf.keras.layers.TimeDistributed( + tf.keras.layers.Dense(self.hidden_layer_size) + )(x) # Targets obs_inputs = tf.keras.backend.stack( - [convert_real_to_embedding(regular_inputs[Ellipsis, i : i + 1]) for i in self._input_obs_loc], axis=-1 + [ + convert_real_to_embedding(regular_inputs[Ellipsis, i : i + 1]) + for i in self._input_obs_loc + ], + axis=-1, ) # Observed (a prioir unknown) inputs wired_embeddings = [] for i in range(num_categorical_variables): - if i not in self._known_categorical_input_idx and i + num_regular_variables not in self._input_obs_loc: + if ( + i not in self._known_categorical_input_idx + and i + num_regular_variables not in self._input_obs_loc + ): e = embeddings[i](categorical_inputs[:, :, i]) wired_embeddings.append(e) @@ -531,7 +582,9 @@ def convert_real_to_embedding(x): unknown_inputs.append(e) if unknown_inputs + wired_embeddings: - unknown_inputs = tf.keras.backend.stack(unknown_inputs + wired_embeddings, axis=-1) + unknown_inputs = tf.keras.backend.stack( + unknown_inputs + wired_embeddings, axis=-1 + ) else: unknown_inputs = None @@ -547,7 +600,9 @@ def convert_real_to_embedding(x): if i + num_regular_variables not in self._static_input_loc ] - known_combined_layer = tf.keras.backend.stack(known_regular_inputs + known_categorical_inputs, axis=-1) + known_combined_layer = tf.keras.backend.stack( + known_regular_inputs + known_categorical_inputs, axis=-1 + ) return unknown_inputs, known_combined_layer, obs_inputs, static_inputs @@ -571,7 +626,9 @@ def cache_batched_data(self, data, cache_key, num_samples=-1): """ if num_samples > 0: - TFTDataCache.update(self._batch_sampled_data(data, max_samples=num_samples), cache_key) + TFTDataCache.update( + self._batch_sampled_data(data, max_samples=num_samples), cache_key + ) else: TFTDataCache.update(self._batch_data(data), cache_key) @@ -589,7 +646,9 @@ def _batch_sampled_data(self, data, max_samples): """ if max_samples < 1: - raise ValueError("Illegal number of samples specified! samples={}".format(max_samples)) + raise ValueError( + "Illegal number of samples specified! samples={}".format(max_samples) + ) id_col = self._get_single_col_by_type(InputTypes.ID) time_col = self._get_single_col_by_type(InputTypes.TIME) @@ -604,7 +663,8 @@ def _batch_sampled_data(self, data, max_samples): num_entries = len(df) if num_entries >= self.time_steps: valid_sampling_locations += [ - (identifier, self.time_steps + i) for i in range(num_entries - self.time_steps + 1) + (identifier, self.time_steps + i) + for i in range(num_entries - self.time_steps + 1) ] split_data_map[identifier] = df @@ -617,22 +677,34 @@ def _batch_sampled_data(self, data, max_samples): print("Extracting {} samples...".format(max_samples)) ranges = [ valid_sampling_locations[i] - for i in np.random.choice(len(valid_sampling_locations), max_samples, replace=False) + for i in np.random.choice( + len(valid_sampling_locations), max_samples, replace=False + ) ] else: - print("Max samples={} exceeds # available segments={}".format(max_samples, len(valid_sampling_locations))) + print( + "Max samples={} exceeds # available segments={}".format( + max_samples, len(valid_sampling_locations) + ) + ) ranges = valid_sampling_locations id_col = self._get_single_col_by_type(InputTypes.ID) time_col = self._get_single_col_by_type(InputTypes.TIME) target_col = self._get_single_col_by_type(InputTypes.TARGET) - input_cols = [tup[0] for tup in self.column_definition if tup[2] not in {InputTypes.ID, InputTypes.TIME}] + input_cols = [ + tup[0] + for tup in self.column_definition + if tup[2] not in {InputTypes.ID, InputTypes.TIME} + ] for i, tup in enumerate(ranges): if (i + 1 % 1000) == 0: print(i + 1, "of", max_samples, "samples done...") identifier, start_idx = tup - sliced = split_data_map[identifier].iloc[start_idx - self.time_steps : start_idx] + sliced = split_data_map[identifier].iloc[ + start_idx - self.time_steps : start_idx + ] inputs[i, :, :] = sliced[input_cols] outputs[i, :, :] = sliced[[target_col]] time[i, :, 0] = sliced[time_col] @@ -667,7 +739,9 @@ def _batch_single_entity(input_data): lags = self.time_steps x = input_data.values if time_steps >= lags: - return np.stack([x[i : time_steps - (lags - 1) + i, :] for i in range(lags)], axis=1) + return np.stack( + [x[i : time_steps - (lags - 1) + i, :] for i in range(lags)], axis=1 + ) else: return None @@ -675,11 +749,20 @@ def _batch_single_entity(input_data): id_col = self._get_single_col_by_type(InputTypes.ID) time_col = self._get_single_col_by_type(InputTypes.TIME) target_col = self._get_single_col_by_type(InputTypes.TARGET) - input_cols = [tup[0] for tup in self.column_definition if tup[2] not in {InputTypes.ID, InputTypes.TIME}] + input_cols = [ + tup[0] + for tup in self.column_definition + if tup[2] not in {InputTypes.ID, InputTypes.TIME} + ] data_map = {} for _, sliced in data.groupby(id_col, group_keys=False): - col_mappings = {"identifier": [id_col], "time": [time_col], "outputs": [target_col], "inputs": input_cols} + col_mappings = { + "identifier": [id_col], + "time": [time_col], + "outputs": [target_col], + "inputs": input_cols, + } for k in col_mappings: cols = col_mappings[k] @@ -693,7 +776,9 @@ def _batch_single_entity(input_data): # Combine all data for k in data_map: # Wendi: Avoid returning None when the length is not enough - data_map[k] = np.concatenate([i for i in data_map[k] if i is not None], axis=0) + data_map[k] = np.concatenate( + [i for i in data_map[k] if i is not None], axis=0 + ) # Shorten target so we only get decoder steps data_map["outputs"] = data_map["outputs"][:, self.num_encoder_steps :, :] @@ -726,7 +811,9 @@ def _build_base_graph(self): ) ) - unknown_inputs, known_combined_layer, obs_inputs, static_inputs = self.get_tft_embeddings(all_inputs) + unknown_inputs, known_combined_layer, obs_inputs, static_inputs = ( + self.get_tft_embeddings(all_inputs) + ) # Isolate known and observed historical inputs. if unknown_inputs is not None: @@ -740,7 +827,11 @@ def _build_base_graph(self): ) else: historical_inputs = concat( - [known_combined_layer[:, :encoder_steps, :], obs_inputs[:, :encoder_steps, :]], axis=-1 + [ + known_combined_layer[:, :encoder_steps, :], + obs_inputs[:, :encoder_steps, :], + ], + axis=-1, ) # Isolate only known future inputs. @@ -786,7 +877,9 @@ def static_combine_and_mask(embedding): transformed_embedding = concat(trans_emb_list, axis=1) - combined = tf.keras.layers.Multiply()([sparse_weights, transformed_embedding]) + combined = tf.keras.layers.Multiply()( + [sparse_weights, transformed_embedding] + ) static_vec = K.sum(combined, axis=1) @@ -795,16 +888,28 @@ def static_combine_and_mask(embedding): static_encoder, static_weights = static_combine_and_mask(static_inputs) static_context_variable_selection = gated_residual_network( - static_encoder, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=False + static_encoder, + self.hidden_layer_size, + dropout_rate=self.dropout_rate, + use_time_distributed=False, ) static_context_enrichment = gated_residual_network( - static_encoder, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=False + static_encoder, + self.hidden_layer_size, + dropout_rate=self.dropout_rate, + use_time_distributed=False, ) static_context_state_h = gated_residual_network( - static_encoder, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=False + static_encoder, + self.hidden_layer_size, + dropout_rate=self.dropout_rate, + use_time_distributed=False, ) static_context_state_c = gated_residual_network( - static_encoder, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=False + static_encoder, + self.hidden_layer_size, + dropout_rate=self.dropout_rate, + use_time_distributed=False, ) def lstm_combine_and_mask(embedding): @@ -822,7 +927,9 @@ def lstm_combine_and_mask(embedding): flatten = K.reshape(embedding, [-1, time_steps, embedding_dim * num_inputs]) - expanded_static_context = K.expand_dims(static_context_variable_selection, axis=1) + expanded_static_context = K.expand_dims( + static_context_variable_selection, axis=1 + ) # Variable selection weights mlp_outputs, static_gate = gated_residual_network( @@ -851,12 +958,16 @@ def lstm_combine_and_mask(embedding): transformed_embedding = stack(trans_emb_list, axis=-1) - combined = tf.keras.layers.Multiply()([sparse_weights, transformed_embedding]) + combined = tf.keras.layers.Multiply()( + [sparse_weights, transformed_embedding] + ) temporal_ctx = K.sum(combined, axis=-1) return temporal_ctx, sparse_weights, static_gate - historical_features, historical_flags, _ = lstm_combine_and_mask(historical_inputs) + historical_features, historical_flags, _ = lstm_combine_and_mask( + historical_inputs + ) future_features, future_flags, _ = lstm_combine_and_mask(future_inputs) # LSTM layer @@ -886,17 +997,22 @@ def get_lstm(return_state): return lstm history_lstm, state_h, state_c = get_lstm(return_state=True)( - historical_features, initial_state=[static_context_state_h, static_context_state_c] + historical_features, + initial_state=[static_context_state_h, static_context_state_c], ) - future_lstm = get_lstm(return_state=False)(future_features, initial_state=[state_h, state_c]) + future_lstm = get_lstm(return_state=False)( + future_features, initial_state=[state_h, state_c] + ) lstm_layer = concat([history_lstm, future_lstm], axis=1) # Apply gated skip connection input_embeddings = concat([historical_features, future_features], axis=1) - lstm_layer, _ = apply_gating_layer(lstm_layer, self.hidden_layer_size, self.dropout_rate, activation=None) + lstm_layer, _ = apply_gating_layer( + lstm_layer, self.hidden_layer_size, self.dropout_rate, activation=None + ) temporal_feature_layer = add_and_norm([lstm_layer, input_embeddings]) # Static enrichment layers @@ -918,16 +1034,23 @@ def get_lstm(return_state): mask = get_decoder_mask(enriched) x, self_att = self_attn_layer(enriched, enriched, enriched, mask=mask) - x, _ = apply_gating_layer(x, self.hidden_layer_size, dropout_rate=self.dropout_rate, activation=None) + x, _ = apply_gating_layer( + x, self.hidden_layer_size, dropout_rate=self.dropout_rate, activation=None + ) x = add_and_norm([x, enriched]) # Nonlinear processing on outputs decoder = gated_residual_network( - x, self.hidden_layer_size, dropout_rate=self.dropout_rate, use_time_distributed=True + x, + self.hidden_layer_size, + dropout_rate=self.dropout_rate, + use_time_distributed=True, ) # Final skip connection - decoder, _ = apply_gating_layer(decoder, self.hidden_layer_size, activation=None) + decoder, _ = apply_gating_layer( + decoder, self.hidden_layer_size, activation=None + ) transformer_layer = add_and_norm([decoder, temporal_feature_layer]) # Attention components for explainability @@ -952,15 +1075,19 @@ def build_model(self): """ with tf.variable_scope(self.name): - transformer_layer, all_inputs, attention_components = self._build_base_graph() - - outputs = tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(self.output_size * len(self.quantiles)))( - transformer_layer[Ellipsis, self.num_encoder_steps :, :] + transformer_layer, all_inputs, attention_components = ( + self._build_base_graph() ) + outputs = tf.keras.layers.TimeDistributed( + tf.keras.layers.Dense(self.output_size * len(self.quantiles)) + )(transformer_layer[Ellipsis, self.num_encoder_steps :, :]) + self._attention_components = attention_components - adam = tf.keras.optimizers.Adam(lr=self.learning_rate, clipnorm=self.max_gradient_norm) + adam = tf.keras.optimizers.Adam( + lr=self.learning_rate, clipnorm=self.max_gradient_norm + ) model = tf.keras.Model(inputs=all_inputs, outputs=outputs) @@ -1005,7 +1132,9 @@ def quantile_loss(self, a, b): quantile_loss = QuantileLossCalculator(valid_quantiles).quantile_loss - model.compile(loss=quantile_loss, optimizer=adam, sample_weight_mode="temporal") + model.compile( + loss=quantile_loss, optimizer=adam, sample_weight_mode="temporal" + ) self._input_placeholder = all_inputs @@ -1023,7 +1152,11 @@ def fit(self, train_df=None, valid_df=None): # Add relevant callbacks callbacks = [ - tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=self.early_stopping_patience, min_delta=1e-4), + tf.keras.callbacks.EarlyStopping( + monitor="val_loss", + patience=self.early_stopping_patience, + min_delta=1e-4, + ), tf.keras.callbacks.ModelCheckpoint( filepath=self.get_keras_saved_path(self._temp_folder), monitor="val_loss", @@ -1049,7 +1182,11 @@ def fit(self, train_df=None, valid_df=None): print("Using keras standard fit") def _unpack(data): - return data["inputs"], data["outputs"], self._get_active_locations(data["active_entries"]) + return ( + data["inputs"], + data["outputs"], + self._get_active_locations(data["active_entries"]), + ) # Unpack without sample weights data, labels, active_flags = _unpack(train_data) @@ -1063,7 +1200,11 @@ def _unpack(data): sample_weight=active_flags, epochs=self.num_epochs, batch_size=self.minibatch_size, - validation_data=(val_data, np.concatenate([val_labels, val_labels, val_labels], axis=-1), val_flags), + validation_data=( + val_data, + np.concatenate([val_labels, val_labels, val_labels], axis=-1), + val_flags, + ), callbacks=all_callbacks, shuffle=True, use_multiprocessing=True, @@ -1130,7 +1271,9 @@ def predict(self, df, return_targets=False): identifier = data["identifier"] outputs = data["outputs"] - combined = self.model.predict(inputs, workers=16, use_multiprocessing=True, batch_size=self.minibatch_size) + combined = self.model.predict( + inputs, workers=16, use_multiprocessing=True, batch_size=self.minibatch_size + ) # Format output_csv if self.output_size != 1: @@ -1140,7 +1283,11 @@ def format_outputs(prediction): """Returns formatted dataframes for prediction.""" flat_prediction = pd.DataFrame( - prediction[:, :, 0], columns=["t+{}".format(i) for i in range(self.time_steps - self.num_encoder_steps)] + prediction[:, :, 0], + columns=[ + "t+{}".format(i) + for i in range(self.time_steps - self.num_encoder_steps) + ], ) cols = list(flat_prediction.columns) flat_prediction["forecast_time"] = time[:, self.num_encoder_steps - 1, 0] @@ -1151,7 +1298,9 @@ def format_outputs(prediction): # Extract predictions for each quantile into different entries process_map = { - "p{}".format(int(q * 100)): combined[Ellipsis, i * self.output_size : (i + 1) * self.output_size] + "p{}".format(int(q * 100)): combined[ + Ellipsis, i * self.output_size : (i + 1) * self.output_size + ] for i, q in enumerate(self.quantiles) } @@ -1183,7 +1332,8 @@ def get_batch_attention_weights(input_batch): attention_weights = {} for k in self._attention_components: attention_weight = tf.keras.backend.get_session().run( - self._attention_components[k], {input_placeholder: input_batch.astype(np.float32)} + self._attention_components[k], + {input_placeholder: input_batch.astype(np.float32)}, ) attention_weights[k] = attention_weight return attention_weights @@ -1196,10 +1346,15 @@ def get_batch_attention_weights(input_batch): num_batches += 1 # Split up inputs into batches - batched_inputs = [inputs[i * batch_size : (i + 1) * batch_size, Ellipsis] for i in range(num_batches)] + batched_inputs = [ + inputs[i * batch_size : (i + 1) * batch_size, Ellipsis] + for i in range(num_batches) + ] # Get attention weights, while avoiding large memory increases - attention_by_batch = [get_batch_attention_weights(batch) for batch in batched_inputs] + attention_by_batch = [ + get_batch_attention_weights(batch) for batch in batched_inputs + ] attention_weights = {} for k in self._attention_components: attention_weights[k] = [] @@ -1242,7 +1397,12 @@ def save(self, model_folder): # issue with Keras that leads to different performance evaluation results # when model is reloaded (https://github.com/keras-team/keras/issues/4875). - utils.save(tf.keras.backend.get_session(), model_folder, cp_name=self.name, scope=self.name) + utils.save( + tf.keras.backend.get_session(), + model_folder, + cp_name=self.name, + scope=self.name, + ) def load(self, model_folder, use_keras_loadings=False): """Loads TFT weights. @@ -1261,7 +1421,12 @@ def load(self, model_folder, use_keras_loadings=False): self.model.load_weights(serialisation_path) else: # Loads tensorflow graph for optimal models. - utils.load(tf.keras.backend.get_session(), model_folder, cp_name=self.name, scope=self.name) + utils.load( + tf.keras.backend.get_session(), + model_folder, + cp_name=self.name, + scope=self.name, + ) @classmethod def get_hyperparm_choices(cls): diff --git a/examples/benchmarks/TFT/libs/utils.py b/examples/benchmarks/TFT/libs/utils.py index 4682434d63..8252b68f04 100644 --- a/examples/benchmarks/TFT/libs/utils.py +++ b/examples/benchmarks/TFT/libs/utils.py @@ -52,7 +52,11 @@ def extract_cols_from_data_type(data_type, column_definition, excluded_input_typ Returns: List of names for columns with data type specified. """ - return [tup[0] for tup in column_definition if tup[1] == data_type and tup[2] not in excluded_input_types] + return [ + tup[0] + for tup in column_definition + if tup[1] == data_type and tup[2] not in excluded_input_types + ] # Loss functions. @@ -73,12 +77,16 @@ def tensorflow_quantile_loss(y, y_pred, quantile): # Checks quantile if quantile < 0 or quantile > 1: - raise ValueError("Illegal quantile value={}! Values should be between 0 and 1.".format(quantile)) + raise ValueError( + "Illegal quantile value={}! Values should be between 0 and 1.".format( + quantile + ) + ) prediction_underflow = y - y_pred - q_loss = quantile * tf.maximum(prediction_underflow, 0.0) + (1.0 - quantile) * tf.maximum( - -prediction_underflow, 0.0 - ) + q_loss = quantile * tf.maximum(prediction_underflow, 0.0) + ( + 1.0 - quantile + ) * tf.maximum(-prediction_underflow, 0.0) return tf.reduce_sum(q_loss, axis=-1) @@ -98,9 +106,9 @@ def numpy_normalised_quantile_loss(y, y_pred, quantile): Float for normalised quantile loss. """ prediction_underflow = y - y_pred - weighted_errors = quantile * np.maximum(prediction_underflow, 0.0) + (1.0 - quantile) * np.maximum( - -prediction_underflow, 0.0 - ) + weighted_errors = quantile * np.maximum(prediction_underflow, 0.0) + ( + 1.0 - quantile + ) * np.maximum(-prediction_underflow, 0.0) quantile_loss = weighted_errors.mean() normaliser = y.abs().mean() @@ -168,7 +176,9 @@ def save(tf_session, model_folder, cp_name, scope=None): var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=scope) saver = tf.train.Saver(var_list=var_list, max_to_keep=100000) - save_path = saver.save(tf_session, os.path.join(model_folder, "{0}.ckpt".format(cp_name))) + save_path = saver.save( + tf_session, os.path.join(model_folder, "{0}.ckpt".format(cp_name)) + ) print("Model saved to: {0}".format(save_path)) @@ -221,4 +231,6 @@ def print_weights_in_checkpoint(model_folder, cp_name): """ load_path = os.path.join(model_folder, "{0}.ckpt".format(cp_name)) - print_tensors_in_checkpoint_file(file_name=load_path, tensor_name="", all_tensors=True, all_tensor_names=True) + print_tensors_in_checkpoint_file( + file_name=load_path, tensor_name="", all_tensors=True, all_tensor_names=True + ) diff --git a/examples/benchmarks/TFT/tft.py b/examples/benchmarks/TFT/tft.py index 633a875c0f..ef76fc529d 100644 --- a/examples/benchmarks/TFT/tft.py +++ b/examples/benchmarks/TFT/tft.py @@ -78,14 +78,20 @@ def get_shifted_label(data_df, shifts=5, col_shift="LABEL0"): - return data_df[[col_shift]].groupby("instrument", group_keys=False).apply(lambda df: df.shift(shifts)) + return ( + data_df[[col_shift]] + .groupby("instrument", group_keys=False) + .apply(lambda df: df.shift(shifts)) + ) def fill_test_na(test_df): test_df_res = test_df.copy() feature_cols = ~test_df_res.columns.str.contains("label", case=False) test_feature_fna = ( - test_df_res.loc[:, feature_cols].groupby("datetime", group_keys=False).apply(lambda df: df.fillna(df.mean())) + test_df_res.loc[:, feature_cols] + .groupby("datetime", group_keys=False) + .apply(lambda df: df.fillna(df.mean())) ) test_df_res.loc[:, feature_cols] = test_feature_fna return test_df_res @@ -132,7 +138,13 @@ def process_predicted(df, col_name): """ df_res = df.copy() - df_res = df_res.rename(columns={"forecast_time": "datetime", "identifier": "instrument", "t+4": col_name}) + df_res = df_res.rename( + columns={ + "forecast_time": "datetime", + "identifier": "instrument", + "t+4": col_name, + } + ) df_res = df_res.set_index(["datetime", "instrument"]).sort_index() df_res = df_res[[col_name]] return df_res @@ -161,21 +173,31 @@ def __init__(self, **kwargs): def _prepare_data(self, dataset: DatasetH): df_train, df_valid = dataset.prepare( - ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ["train", "valid"], + col_set=["feature", "label"], + data_key=DataHandlerLP.DK_L, ) return transform_df(df_train), transform_df(df_valid) - def fit(self, dataset: DatasetH, MODEL_FOLDER="qlib_tft_model", USE_GPU_ID=0, **kwargs): + def fit( + self, dataset: DatasetH, MODEL_FOLDER="qlib_tft_model", USE_GPU_ID=0, **kwargs + ): DATASET = self.params["DATASET"] LABEL_SHIFT = self.params["label_shift"] LABEL_COL = DATASET_SETTING[DATASET]["label_col"] if DATASET not in ALLOW_DATASET: - raise AssertionError("The dataset is not supported, please make a new formatter to fit this dataset") + raise AssertionError( + "The dataset is not supported, please make a new formatter to fit this dataset" + ) dtrain, dvalid = self._prepare_data(dataset) - dtrain.loc[:, LABEL_COL] = get_shifted_label(dtrain, shifts=LABEL_SHIFT, col_shift=LABEL_COL) - dvalid.loc[:, LABEL_COL] = get_shifted_label(dvalid, shifts=LABEL_SHIFT, col_shift=LABEL_COL) + dtrain.loc[:, LABEL_COL] = get_shifted_label( + dtrain, shifts=LABEL_SHIFT, col_shift=LABEL_COL + ) + dvalid.loc[:, LABEL_COL] = get_shifted_label( + dvalid, shifts=LABEL_SHIFT, col_shift=LABEL_COL + ) train = process_qlib_data(dtrain, DATASET, fillna=True).dropna() valid = process_qlib_data(dvalid, DATASET, fillna=True).dropna() @@ -192,7 +214,9 @@ def fit(self, dataset: DatasetH, MODEL_FOLDER="qlib_tft_model", USE_GPU_ID=0, ** use_gpu = (True, self.gpu_id) # ===========================Training Process=========================== ModelClass = libs.tft_model.TemporalFusionTransformer - if not isinstance(self.data_formatter, data_formatters.base.GenericDataFormatter): + if not isinstance( + self.data_formatter, data_formatters.base.GenericDataFormatter + ): raise ValueError( "Data formatters should inherit from" + "AbstractDataFormatter! Type={}".format(type(self.data_formatter)) @@ -201,7 +225,9 @@ def fit(self, dataset: DatasetH, MODEL_FOLDER="qlib_tft_model", USE_GPU_ID=0, ** default_keras_session = tf.keras.backend.get_session() if use_gpu[0]: - self.tf_config = utils.get_default_tensorflow_config(tf_device="gpu", gpu_id=use_gpu[1]) + self.tf_config = utils.get_default_tensorflow_config( + tf_device="gpu", gpu_id=use_gpu[1] + ) else: self.tf_config = utils.get_default_tensorflow_config(tf_device="cpu") @@ -237,7 +263,13 @@ def fit(self, dataset: DatasetH, MODEL_FOLDER="qlib_tft_model", USE_GPU_ID=0, ** def extract_numerical_data(data): """Strips out forecast time and identifier columns.""" - return data[[col for col in data.columns if col not in {"forecast_time", "identifier"}]] + return data[ + [ + col + for col in data.columns + if col not in {"forecast_time", "identifier"} + ] + ] # p50_loss = utils.numpy_normalised_quantile_loss( # extract_numerical_data(targets), extract_numerical_data(p50_forecast), @@ -254,7 +286,9 @@ def predict(self, dataset): raise ValueError("model is not fitted yet!") d_test = dataset.prepare("test", col_set=["feature", "label"]) d_test = transform_df(d_test) - d_test.loc[:, self.label_col] = get_shifted_label(d_test, shifts=self.label_shift, col_shift=self.label_col) + d_test.loc[:, self.label_col] = get_shifted_label( + d_test, shifts=self.label_shift, col_shift=self.label_col + ) test = process_qlib_data(d_test, self.expt_name, fillna=True).dropna() use_gpu = (True, self.gpu_id) diff --git a/examples/benchmarks/TRA/example.py b/examples/benchmarks/TRA/example.py index f7e16ddee4..de12a674cc 100644 --- a/examples/benchmarks/TRA/example.py +++ b/examples/benchmarks/TRA/example.py @@ -14,7 +14,10 @@ def main(seed, config_file="configs/config_alstm.yaml"): # seed_suffix = "/seed1000" if "init" in config_file else f"/seed{seed}" seed_suffix = "" config["task"]["model"]["kwargs"].update( - {"seed": seed, "logdir": config["task"]["model"]["kwargs"]["logdir"] + seed_suffix} + { + "seed": seed, + "logdir": config["task"]["model"]["kwargs"]["logdir"] + seed_suffix, + } ) # initialize workflow @@ -33,6 +36,11 @@ def main(seed, config_file="configs/config_alstm.yaml"): # set params from cmd parser = argparse.ArgumentParser(allow_abbrev=False) parser.add_argument("--seed", type=int, default=1000, help="random seed") - parser.add_argument("--config_file", type=str, default="configs/config_alstm.yaml", help="config file") + parser.add_argument( + "--config_file", + type=str, + default="configs/config_alstm.yaml", + help="config file", + ) args = parser.parse_args() main(**vars(args)) diff --git a/examples/benchmarks/TRA/src/dataset.py b/examples/benchmarks/TRA/src/dataset.py index 47cde9f3fd..c16bf746e3 100644 --- a/examples/benchmarks/TRA/src/dataset.py +++ b/examples/benchmarks/TRA/src/dataset.py @@ -29,7 +29,9 @@ def _create_ts_slices(index, seq_len): assert index.is_lexsorted(), "index should be sorted" # number of dates for each code - sample_count_by_codes = pd.Series(0, index=index).groupby(level=0, group_keys=False).size().values + sample_count_by_codes = ( + pd.Series(0, index=index).groupby(level=0, group_keys=False).size().values + ) # start_index for each code start_index_of_codes = np.roll(np.cumsum(sample_count_by_codes), 1) @@ -123,7 +125,9 @@ def setup_data(self, handler_kwargs: dict = None, **kwargs): self._index = df.index # add memory to feature - self._data = np.c_[self._data, np.zeros((len(self._data), self.num_states), dtype=np.float32)] + self._data = np.c_[ + self._data, np.zeros((len(self._data), self.num_states), dtype=np.float32) + ] # padding tensor self.zeros = np.zeros((self.seq_len, self._data.shape[1]), dtype=np.float32) @@ -234,12 +238,20 @@ def __iter__(self): label = [] index = [] for slc in slices_subset: - _data = self._data[slc].clone() if self.pin_memory else self._data[slc].copy() + _data = ( + self._data[slc].clone() + if self.pin_memory + else self._data[slc].copy() + ) if len(_data) != self.seq_len: if self.pin_memory: - _data = torch.cat([self.zeros[: self.seq_len - len(_data)], _data], axis=0) + _data = torch.cat( + [self.zeros[: self.seq_len - len(_data)], _data], axis=0 + ) else: - _data = np.concatenate([self.zeros[: self.seq_len - len(_data)], _data], axis=0) + _data = np.concatenate( + [self.zeros[: self.seq_len - len(_data)], _data], axis=0 + ) if self.num_states > 0: _data[-self.horizon :, -self.num_states :] = 0 data.append(_data) diff --git a/examples/benchmarks/TRA/src/model.py b/examples/benchmarks/TRA/src/model.py index ebafd6a521..f2ec529a2c 100644 --- a/examples/benchmarks/TRA/src/model.py +++ b/examples/benchmarks/TRA/src/model.py @@ -53,17 +53,25 @@ def __init__( self.model = eval(model_type)(**model_config).to(device) if model_init_state: - self.model.load_state_dict(torch.load(model_init_state, map_location="cpu")["model"]) + self.model.load_state_dict( + torch.load(model_init_state, map_location="cpu")["model"] + ) if freeze_model: for param in self.model.parameters(): param.requires_grad_(False) else: - self.logger.info("# model params: %d" % sum([p.numel() for p in self.model.parameters()])) + self.logger.info( + "# model params: %d" % sum([p.numel() for p in self.model.parameters()]) + ) self.tra = TRA(self.model.output_size, **tra_config).to(device) - self.logger.info("# tra params: %d" % sum([p.numel() for p in self.tra.parameters()])) + self.logger.info( + "# tra params: %d" % sum([p.numel() for p in self.tra.parameters()]) + ) - self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=lr) + self.optimizer = optim.Adam( + list(self.model.parameters()) + list(self.tra.parameters()), lr=lr + ) self.model_config = model_config self.tra_config = tra_config @@ -283,9 +291,10 @@ def fit(self, dataset, evals_result=dict()): if self.logdir: self.logger.info("save model & pred to local directory") - pd.concat({name: pd.DataFrame(evals_result[name]) for name in evals_result}, axis=1).to_csv( - self.logdir + "/logs.csv", index=False - ) + pd.concat( + {name: pd.DataFrame(evals_result[name]) for name in evals_result}, + axis=1, + ).to_csv(self.logdir + "/logs.csv", index=False) torch.save(best_params, self.logdir + "/model.bin") @@ -401,7 +410,9 @@ def __init__(self, d_model, dropout=0.1, max_len=5000): pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) + ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) @@ -450,7 +461,10 @@ def __init__( self.pe = PositionalEncoding(input_size, dropout) layer = nn.TransformerEncoderLayer( - nhead=num_heads, dropout=dropout, d_model=hidden_size, dim_feedforward=hidden_size * 4 + nhead=num_heads, + dropout=dropout, + d_model=hidden_size, + dim_feedforward=hidden_size * 4, ) self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers) @@ -486,7 +500,9 @@ class TRA(nn.Module): tau (float): gumbel softmax temperature """ - def __init__(self, input_size, num_states=1, hidden_size=8, tau=1.0, src_info="LR_TPE"): + def __init__( + self, input_size, num_states=1, hidden_size=8, tau=1.0, src_info="LR_TPE" + ): super().__init__() self.num_states = num_states diff --git a/examples/benchmarks_dynamic/DDG-DA/workflow.py b/examples/benchmarks_dynamic/DDG-DA/workflow.py index 8209e0e906..c226def724 100644 --- a/examples/benchmarks_dynamic/DDG-DA/workflow.py +++ b/examples/benchmarks_dynamic/DDG-DA/workflow.py @@ -23,10 +23,14 @@ class DDGDABench(DDGDA): DEFAULT_CONF = CONF_LIST[0] # Linear by default due to efficiency - def __init__(self, conf_path: Union[str, Path] = DEFAULT_CONF, horizon=20, **kwargs) -> None: + def __init__( + self, conf_path: Union[str, Path] = DEFAULT_CONF, horizon=20, **kwargs + ) -> None: # This code is for being compatible with the previous old code conf_path = Path(conf_path) - super().__init__(conf_path=conf_path, horizon=horizon, working_dir=DIRNAME, **kwargs) + super().__init__( + conf_path=conf_path, horizon=horizon, working_dir=DIRNAME, **kwargs + ) for f in self.CONF_LIST: if conf_path.samefile(f): diff --git a/examples/benchmarks_dynamic/baseline/rolling_benchmark.py b/examples/benchmarks_dynamic/baseline/rolling_benchmark.py index 02b7ed4650..9a3de543ad 100644 --- a/examples/benchmarks_dynamic/baseline/rolling_benchmark.py +++ b/examples/benchmarks_dynamic/baseline/rolling_benchmark.py @@ -15,11 +15,16 @@ class RollingBenchmark(Rolling): # The config in the README.md - CONF_LIST = [DIRNAME / "workflow_config_linear_Alpha158.yaml", DIRNAME / "workflow_config_lightgbm_Alpha158.yaml"] + CONF_LIST = [ + DIRNAME / "workflow_config_linear_Alpha158.yaml", + DIRNAME / "workflow_config_lightgbm_Alpha158.yaml", + ] DEFAULT_CONF = CONF_LIST[0] - def __init__(self, conf_path: Union[str, Path] = DEFAULT_CONF, horizon=20, **kwargs) -> None: + def __init__( + self, conf_path: Union[str, Path] = DEFAULT_CONF, horizon=20, **kwargs + ) -> None: # This code is for being compatible with the previous old code conf_path = Path(conf_path) super().__init__(conf_path=conf_path, horizon=horizon, **kwargs) diff --git a/examples/data_demo/data_cache_demo.py b/examples/data_demo/data_cache_demo.py index dd65e3168e..307a7e625d 100644 --- a/examples/data_demo/data_cache_demo.py +++ b/examples/data_demo/data_cache_demo.py @@ -23,7 +23,9 @@ if __name__ == "__main__": init() - config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml" + config_path = ( + DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml" + ) # 1) show original time with TimeInspector.logt("The original time without handler cache:"): diff --git a/examples/data_demo/data_mem_resuse_demo.py b/examples/data_demo/data_mem_resuse_demo.py index 9853fe4ae3..f8f7822467 100644 --- a/examples/data_demo/data_mem_resuse_demo.py +++ b/examples/data_demo/data_mem_resuse_demo.py @@ -27,12 +27,16 @@ repeat = 2 exp_name = "data_mem_reuse_demo" - config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml" + config_path = ( + DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml" + ) yaml = YAML(typ="safe", pure=True) task_config = yaml.load(config_path.open()) # 1) without using processed data in memory - with TimeInspector.logt("The original time without reusing processed data in memory:"): + with TimeInspector.logt( + "The original time without reusing processed data in memory:" + ): for i in range(repeat): task_train(task_config["task"], experiment_name=exp_name) diff --git a/examples/highfreq/highfreq_handler.py b/examples/highfreq/highfreq_handler.py index 7df564b7b9..5f310434bc 100644 --- a/examples/highfreq/highfreq_handler.py +++ b/examples/highfreq/highfreq_handler.py @@ -14,8 +14,12 @@ def __init__( fit_end_time=None, drop_raw=True, ): - infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) - learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) + infer_processors = check_transform_proc( + infer_processors, fit_start_time, fit_end_time + ) + learn_processors = check_transform_proc( + learn_processors, fit_start_time, fit_end_time + ) data_loader = { "class": "QlibDataLoader", @@ -50,7 +54,11 @@ def get_normalized_price_feature(price_field, shift=0): if shift == 0: template_norm = "Cut({0}/Ref(DayLast({1}), 240), 240, None)" else: - template_norm = "Cut(Ref({0}, " + str(shift) + ")/Ref(DayLast({1}), 240), 240, None)" + template_norm = ( + "Cut(Ref({0}, " + + str(shift) + + ")/Ref(DayLast({1}), 240), 240, None)" + ) feature_ops = template_norm.format( template_if.format( @@ -133,7 +141,9 @@ def get_feature_config(self): # Because there is no vwap field in the yahoo data, a method similar to Simpson integration is used to approximate vwap simpson_vwap = "($open + 2*$high + 2*$low + $close)/6" fields += [ - "Cut({0}, 240, None)".format(template_fillnan.format(template_paused.format("$close"))), + "Cut({0}, 240, None)".format( + template_fillnan.format(template_paused.format("$close")) + ), ] names += ["$close0"] fields += [ diff --git a/examples/highfreq/highfreq_ops.py b/examples/highfreq/highfreq_ops.py index 36e15a3755..95714e3e30 100644 --- a/examples/highfreq/highfreq_ops.py +++ b/examples/highfreq/highfreq_ops.py @@ -25,7 +25,9 @@ class DayLast(ElemOperator): def _load_internal(self, instrument, start_index, end_index, freq): _calendar = get_calendar_day(freq=freq) series = self.feature.load(instrument, start_index, end_index, freq) - return series.groupby(_calendar[series.index], group_keys=False).transform("last") + return series.groupby(_calendar[series.index], group_keys=False).transform( + "last" + ) class FFillNan(ElemOperator): @@ -104,8 +106,12 @@ class Select(PairOperator): """ def _load_internal(self, instrument, start_index, end_index, freq): - series_condition = self.feature_left.load(instrument, start_index, end_index, freq) - series_feature = self.feature_right.load(instrument, start_index, end_index, freq) + series_condition = self.feature_left.load( + instrument, start_index, end_index, freq + ) + series_feature = self.feature_right.load( + instrument, start_index, end_index, freq + ) return series_feature.loc[series_condition] diff --git a/examples/highfreq/highfreq_processor.py b/examples/highfreq/highfreq_processor.py index 26e0fdd0f5..18a21ebf5e 100644 --- a/examples/highfreq/highfreq_processor.py +++ b/examples/highfreq/highfreq_processor.py @@ -11,7 +11,9 @@ def __init__(self, fit_start_time, fit_end_time): self.fit_end_time = fit_end_time def fit(self, df_features): - fetch_df = fetch_df_by_index(df_features, slice(self.fit_start_time, self.fit_end_time), level="datetime") + fetch_df = fetch_df_by_index( + df_features, slice(self.fit_start_time, self.fit_end_time), level="datetime" + ) del df_features df_values = fetch_df.values names = { @@ -28,14 +30,18 @@ def fit(self, df_features): part_values = np.log1p(part_values) self.feature_med[name] = np.nanmedian(part_values) part_values = part_values - self.feature_med[name] - self.feature_std[name] = np.nanmedian(np.absolute(part_values)) * 1.4826 + EPS + self.feature_std[name] = ( + np.nanmedian(np.absolute(part_values)) * 1.4826 + EPS + ) part_values = part_values / self.feature_std[name] self.feature_vmax[name] = np.nanmax(part_values) self.feature_vmin[name] = np.nanmin(part_values) def __call__(self, df_features): df_features["date"] = pd.to_datetime( - df_features.index.get_level_values(level="datetime").to_series().dt.date.values + df_features.index.get_level_values(level="datetime") + .to_series() + .dt.date.values ) df_features.set_index("date", append=True, drop=True, inplace=True) df_values = df_features.values @@ -55,11 +61,17 @@ def __call__(self, df_features): slice3 = df_values[:, name_val] < -3.5 df_values[:, name_val][slice0] = ( - 3.0 + (df_values[:, name_val][slice0] - 3.0) / (self.feature_vmax[name] - 3) * 0.5 + 3.0 + + (df_values[:, name_val][slice0] - 3.0) + / (self.feature_vmax[name] - 3) + * 0.5 ) df_values[:, name_val][slice1] = 3.5 df_values[:, name_val][slice2] = ( - -3.0 - (df_values[:, name_val][slice2] + 3.0) / (self.feature_vmin[name] + 3) * 0.5 + -3.0 + - (df_values[:, name_val][slice2] + 3.0) + / (self.feature_vmin[name] + 3) + * 0.5 ) df_values[:, name_val][slice3] = -3.5 idx = df_features.index.droplevel("datetime").drop_duplicates() diff --git a/examples/highfreq/workflow.py b/examples/highfreq/workflow.py index 02948c5a12..bb06598cab 100644 --- a/examples/highfreq/workflow.py +++ b/examples/highfreq/workflow.py @@ -14,11 +14,23 @@ from qlib.data.data import Cal from qlib.tests.data import GetData -from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut +from highfreq_ops import ( + get_calendar_day, + DayLast, + FFillNan, + BFillNan, + Date, + Select, + IsNull, + Cut, +) class HighfreqWorkflow: - SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None} + SPEC_CONF = { + "custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], + "expression_cache": None, + } MARKET = "all" @@ -33,7 +45,9 @@ class HighfreqWorkflow: "fit_start_time": start_time, "fit_end_time": train_end_time, "instruments": MARKET, - "infer_processors": [{"class": "HighFreqNorm", "module_path": "highfreq_processor"}], + "infer_processors": [ + {"class": "HighFreqNorm", "module_path": "highfreq_processor"} + ], } DATA_HANDLER_CONFIG1 = { "start_time": start_time, @@ -85,7 +99,9 @@ def _init_qlib(self): # use cn_data_1min data QLIB_INIT_CONFIG = {**HIGH_FREQ_CONFIG, **self.SPEC_CONF} provider_uri = QLIB_INIT_CONFIG.get("provider_uri") - GetData().qlib_data(target_dir=provider_uri, interval="1min", region=REG_CN, exists_skip=True) + GetData().qlib_data( + target_dir=provider_uri, interval="1min", region=REG_CN, exists_skip=True + ) qlib.init(**QLIB_INIT_CONFIG) def _prepare_calender_cache(self): diff --git a/examples/hyperparameter/LightGBM/hyperparameter_360.py b/examples/hyperparameter/LightGBM/hyperparameter_360.py index 7ba28c78fe..87bfe7ff78 100644 --- a/examples/hyperparameter/LightGBM/hyperparameter_360.py +++ b/examples/hyperparameter/LightGBM/hyperparameter_360.py @@ -5,7 +5,9 @@ from qlib.tests.data import GetData from qlib.tests.config import get_dataset_config, CSI300_MARKET, DATASET_ALPHA360_CLASS -DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA360_CLASS) +DATASET_CONFIG = get_dataset_config( + market=CSI300_MARKET, dataset_class=DATASET_ALPHA360_CLASS +) def objective(trial): diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py index 2fb7c85b56..4acbd11ff2 100644 --- a/examples/model_rolling/task_manager_rolling.py +++ b/examples/model_rolling/task_manager_rolling.py @@ -18,7 +18,10 @@ from qlib.workflow.task.collect import RecorderCollector from qlib.model.ens.group import RollingGroup from qlib.model.trainer import TrainerR, TrainerRM, task_train -from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG +from qlib.tests.config import ( + CSI100_RECORD_LGB_TASK_CONFIG, + CSI100_RECORD_XGBOOST_TASK_CONFIG, +) class RollingTaskExample: @@ -36,7 +39,10 @@ def __init__( ): # TaskManager config if task_config is None: - task_config = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG] + task_config = [ + CSI100_RECORD_XGBOOST_TASK_CONFIG, + CSI100_RECORD_LGB_TASK_CONFIG, + ] mongo_conf = { "task_url": task_url, "task_db_name": task_db_name, diff --git a/examples/nested_decision_execution/workflow.py b/examples/nested_decision_execution/workflow.py index 1d602d7fe0..0185185513 100644 --- a/examples/nested_decision_execution/workflow.py +++ b/examples/nested_decision_execution/workflow.py @@ -223,13 +223,21 @@ class NestedDecisionExecutionWorkflow: def _init_qlib(self): """initialize qlib""" provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir - GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN, version="v2", exists_skip=True) + GetData().qlib_data( + target_dir=provider_uri_day, region=REG_CN, version="v2", exists_skip=True + ) provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri") GetData().qlib_data( - target_dir=provider_uri_1min, interval="1min", region=REG_CN, version="v2", exists_skip=True + target_dir=provider_uri_1min, + interval="1min", + region=REG_CN, + version="v2", + exists_skip=True, ) provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day} - qlib.init(provider_uri=provider_uri_map, dataset_cache=None, expression_cache=None) + qlib.init( + provider_uri=provider_uri_map, dataset_cache=None, expression_cache=None + ) def _train_model(self, model, dataset): with R.start(experiment_name=self.exp_name): @@ -290,7 +298,9 @@ def collect_data(self): "n_drop": 5, }, } - data_generator = collect_data(executor=executor_config, strategy=strategy_config, **backtest_config) + data_generator = collect_data( + executor=executor_config, strategy=strategy_config, **backtest_config + ) for trade_decision in data_generator: print(trade_decision) @@ -306,13 +316,17 @@ def collect_data(self): def check_diff_freq(self): self._init_qlib() exp = R.get_exp(experiment_name="backtest") - rec = next(iter(exp.list_recorders().values())) # assuming this will get the latest recorder + rec = next( + iter(exp.list_recorders().values()) + ) # assuming this will get the latest recorder for check_key in "account", "total_turnover", "total_cost": check_key = "total_cost" acc_dict = {} for freq in ["30minute", "5minute", "1day"]: - acc_dict[freq] = rec.load_object(f"portfolio_analysis/report_normal_{freq}.pkl")[check_key] + acc_dict[freq] = rec.load_object( + f"portfolio_analysis/report_normal_{freq}.pkl" + )[check_key] acc_df = pd.DataFrame(acc_dict) acc_resam = acc_df.resample("1d").last().dropna() assert (acc_resam["30minute"] == acc_resam["1day"]).all() diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py index dccc56b682..90006dca8e 100644 --- a/examples/online_srv/online_management_simulate.py +++ b/examples/online_srv/online_management_simulate.py @@ -14,7 +14,10 @@ from qlib.workflow.online.strategy import RollingStrategy from qlib.workflow.task.gen import RollingGen from qlib.workflow.task.manage import TaskManager -from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE +from qlib.tests.config import ( + CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, + CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, +) import pandas as pd from qlib.contrib.evaluate import backtest_daily from qlib.contrib.evaluate import risk_analysis @@ -52,7 +55,10 @@ def __init__( tasks (dict or list[dict]): a set of the task config waiting for rolling and training """ if tasks is None: - tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE] + tasks = [ + CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, + CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, + ] self.exp_name = exp_name self.task_pool = task_pool self.start_time = start_time @@ -73,7 +79,9 @@ def __init__( # TODO: support all the trainers: TrainerR, TrainerRM, DelayTrainerR raise NotImplementedError(f"This type of input is not supported") self.rolling_online_manager = OnlineManager( - RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen), + RollingStrategy( + exp_name, task_template=tasks, rolling_gen=self.rolling_gen + ), trainer=self.trainer, begin_time=self.start_time, ) @@ -113,7 +121,9 @@ def main(self): strategy=strategy_obj, ) analysis = dict() - analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"]) + analysis["excess_return_without_cost"] = risk_analysis( + report_normal["return"] - report_normal["bench"] + ) analysis["excess_return_with_cost"] = risk_analysis( report_normal["return"] - report_normal["bench"] - report_normal["cost"] ) diff --git a/examples/online_srv/rolling_online_management.py b/examples/online_srv/rolling_online_management.py index 6abbbfb0e8..6170780077 100644 --- a/examples/online_srv/rolling_online_management.py +++ b/examples/online_srv/rolling_online_management.py @@ -13,12 +13,22 @@ import os import fire import qlib -from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM, end_task_train, task_train +from qlib.model.trainer import ( + DelayTrainerR, + DelayTrainerRM, + TrainerR, + TrainerRM, + end_task_train, + task_train, +) from qlib.workflow import R from qlib.workflow.online.strategy import RollingStrategy from qlib.workflow.task.gen import RollingGen from qlib.workflow.online.manager import OnlineManager -from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING, CSI100_RECORD_LGB_TASK_CONFIG_ROLLING +from qlib.tests.config import ( + CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING, + CSI100_RECORD_LGB_TASK_CONFIG_ROLLING, +) from qlib.workflow.task.manage import TaskManager @@ -48,7 +58,9 @@ def __init__( self.rolling_step = rolling_step strategies = [] for task in tasks: - name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy + name_id = task["model"][ + "class" + ] # NOTE: Assumption: The model class can specify only one strategy strategies.append( RollingStrategy( name_id, @@ -59,9 +71,7 @@ def __init__( self.trainer = trainer self.rolling_online_manager = OnlineManager(strategies, trainer=self.trainer) - _ROLLING_MANAGER_PATH = ( - ".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine. - ) + _ROLLING_MANAGER_PATH = ".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine. def worker(self): # train tasks by other progress or machines for multiprocessing @@ -113,7 +123,9 @@ def add_strategy(self): print("========== add strategy ==========") strategies = [] for task in self.add_tasks: - name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy + name_id = task["model"][ + "class" + ] # NOTE: Assumption: The model class can specify only one strategy strategies.append( RollingStrategy( name_id, diff --git a/examples/online_srv/update_online_pred.py b/examples/online_srv/update_online_pred.py index faeec24da7..87bcca994b 100644 --- a/examples/online_srv/update_online_pred.py +++ b/examples/online_srv/update_online_pred.py @@ -25,7 +25,11 @@ class UpdatePredExample: def __init__( - self, provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv", task_config=task + self, + provider_uri="~/.qlib/qlib_data/cn_data", + region=REG_CN, + experiment_name="online_srv", + task_config=task, ): qlib.init(provider_uri=provider_uri, region=region) self.experiment_name = experiment_name diff --git a/examples/orderbook_data/create_dataset.py b/examples/orderbook_data/create_dataset.py index b94e76c430..9cbaca1428 100755 --- a/examples/orderbook_data/create_dataset.py +++ b/examples/orderbook_data/create_dataset.py @@ -77,11 +77,15 @@ def format_time(day, hms): hms = str(hms) if hms[0] == "1": # >=10, return ( - "-".join([day[0:4], day[4:6], day[6:8]]) + " " + ":".join([hms[:2], hms[2:4], hms[4:6] + "." + hms[6:]]) + "-".join([day[0:4], day[4:6], day[6:8]]) + + " " + + ":".join([hms[:2], hms[2:4], hms[4:6] + "." + hms[6:]]) ) else: return ( - "-".join([day[0:4], day[4:6], day[6:8]]) + " " + ":".join([hms[:1], hms[1:3], hms[3:5] + "." + hms[5:]]) + "-".join([day[0:4], day[4:6], day[6:8]]) + + " " + + ":".join([hms[:1], hms[1:3], hms[3:5] + "." + hms[5:]]) ) ## Discard the entire row if wrong data timestamp encoutered. @@ -102,7 +106,10 @@ def format_time(day, hms): timestamp = list(zip(list(df["date"]), list(df["time"]))) ## The cleaned timestamp # generate timestamp pd_timestamp = pd.DatetimeIndex( - [pd.Timestamp(format_time(timestamp[i][0], timestamp[i][1])) for i in range(len(df["date"]))] + [ + pd.Timestamp(format_time(timestamp[i][0], timestamp[i][1])) + for i in range(len(df["date"])) + ] ) df = df.drop(columns=["date", "time", "name", "code", "wind_code"]) # df = pd.DataFrame(data=df.to_dict("list"), index=pd_timestamp) @@ -112,7 +119,9 @@ def format_time(day, hms): if str.lower(type) == "orderqueue": ## extract ab1~ab50 df["ab"] = [ - ",".join([str(int(row["ab" + str(i + 1)])) for i in range(0, row["ab_items"])]) + ",".join( + [str(int(row["ab" + str(i + 1)])) for i in range(0, row["ab_items"])] + ) for timestamp, row in df.iterrows() ] df = df.drop(columns=["ab" + str(i) for i in range(1, 51)]) @@ -140,16 +149,33 @@ def add_one_stock_daily_data_wrapper(filepath, type, exchange_place, index, date try: if index % 100 == 0: print("index = {}, filepath = {}".format(index, filepath)) - error_index_list = add_one_stock_daily_data(filepath, type, exchange_place, arc, date) + error_index_list = add_one_stock_daily_data( + filepath, type, exchange_place, arc, date + ) if error_index_list is not None and len(error_index_list) > 0: - f = open(os.path.join(LOG_FILE_PATH, "temp_timestamp_error_{0}_{1}_{2}.txt".format(pid, date, type)), "a+") - f.write("{}, {}, {}\n".format(filepath, error_index_list, exchange_place + "_" + code)) + f = open( + os.path.join( + LOG_FILE_PATH, + "temp_timestamp_error_{0}_{1}_{2}.txt".format(pid, date, type), + ), + "a+", + ) + f.write( + "{}, {}, {}\n".format( + filepath, error_index_list, exchange_place + "_" + code + ) + ) f.close() except Exception as e: info = traceback.format_exc() print("error:" + str(e)) - f = open(os.path.join(LOG_FILE_PATH, "temp_fail_{0}_{1}_{2}.txt".format(pid, date, type)), "a+") + f = open( + os.path.join( + LOG_FILE_PATH, "temp_fail_{0}_{1}_{2}.txt".format(pid, date, type) + ), + "a+", + ) f.write("fail:" + str(filepath) + "\n" + str(e) + "\n" + str(info) + "\n") f.close() @@ -165,7 +191,9 @@ def add_data(tick_date, doc_type, stock_name_dict): return try: begin_time = time.time() - os.system(f"cp {DATABASE_PATH}/{tick_date + '_{}.tar.gz'.format(doc_type)} {DATA_PATH}/") + os.system( + f"cp {DATABASE_PATH}/{tick_date + '_{}.tar.gz'.format(doc_type)} {DATA_PATH}/" + ) os.system( f"tar -xvzf {DATA_PATH}/{tick_date + '_{}.tar.gz'.format(doc_type)} -C {DATA_PATH}/ {tick_date + '_' + doc_type}/SH" @@ -182,13 +210,26 @@ def add_data(tick_date, doc_type, stock_name_dict): print("tick_date={}".format(tick_date)) - temp_data_path_sh = os.path.join(DATA_PATH, tick_date + "_" + doc_type, "SH", tick_date) - temp_data_path_sz = os.path.join(DATA_PATH, tick_date + "_" + doc_type, "SZ", tick_date) - is_files_exist = {"sh": os.path.exists(temp_data_path_sh), "sz": os.path.exists(temp_data_path_sz)} + temp_data_path_sh = os.path.join( + DATA_PATH, tick_date + "_" + doc_type, "SH", tick_date + ) + temp_data_path_sz = os.path.join( + DATA_PATH, tick_date + "_" + doc_type, "SZ", tick_date + ) + is_files_exist = { + "sh": os.path.exists(temp_data_path_sh), + "sz": os.path.exists(temp_data_path_sz), + } sz_files = ( ( - set([i.split(".csv")[0] for i in os.listdir(temp_data_path_sz) if i[:2] == "30" or i[0] == "0"]) + set( + [ + i.split(".csv")[0] + for i in os.listdir(temp_data_path_sz) + if i[:2] == "30" or i[0] == "0" + ] + ) & set(stock_name_dict["SZ"]) ) if is_files_exist["sz"] @@ -197,7 +238,13 @@ def add_data(tick_date, doc_type, stock_name_dict): sz_file_nums = len(sz_files) if is_files_exist["sz"] else 0 sh_files = ( ( - set([i.split(".csv")[0] for i in os.listdir(temp_data_path_sh) if i[0] == "6"]) + set( + [ + i.split(".csv")[0] + for i in os.listdir(temp_data_path_sh) + if i[0] == "6" + ] + ) & set(stock_name_dict["SH"]) ) if is_files_exist["sh"] @@ -206,15 +253,24 @@ def add_data(tick_date, doc_type, stock_name_dict): sh_file_nums = len(sh_files) if is_files_exist["sh"] else 0 print("sz_file_nums:{}, sh_file_nums:{}".format(sz_file_nums, sh_file_nums)) - f = (DATA_INFO_PATH / "data_info_log_{}_{}".format(doc_type, tick_date)).open("w+") - f.write("sz:{}, sh:{}, date:{}:".format(sz_file_nums, sh_file_nums, tick_date) + "\n") + f = (DATA_INFO_PATH / "data_info_log_{}_{}".format(doc_type, tick_date)).open( + "w+" + ) + f.write( + "sz:{}, sh:{}, date:{}:".format(sz_file_nums, sh_file_nums, tick_date) + + "\n" + ) f.close() if sh_file_nums > 0: # write is not thread-safe, update may be thread-safe Parallel(n_jobs=N_JOBS)( delayed(add_one_stock_daily_data_wrapper)( - os.path.join(temp_data_path_sh, name + ".csv"), doc_type, "SH", index, tick_date + os.path.join(temp_data_path_sh, name + ".csv"), + doc_type, + "SH", + index, + tick_date, ) for index, name in enumerate(list(sh_files)) ) @@ -222,7 +278,11 @@ def add_data(tick_date, doc_type, stock_name_dict): # write is not thread-safe, update may be thread-safe Parallel(n_jobs=N_JOBS)( delayed(add_one_stock_daily_data_wrapper)( - os.path.join(temp_data_path_sz, name + ".csv"), doc_type, "SZ", index, tick_date + os.path.join(temp_data_path_sz, name + ".csv"), + doc_type, + "SZ", + index, + tick_date, ) for index, name in enumerate(list(sz_files)) ) @@ -230,14 +290,28 @@ def add_data(tick_date, doc_type, stock_name_dict): os.system(f"rm -f {DATA_PATH}/{tick_date + '_{}.tar.gz'.format(doc_type)}") os.system(f"rm -rf {DATA_PATH}/{tick_date + '_' + doc_type}") total_time = time.time() - begin_time - f = (DATA_FINISH_INFO_PATH / "data_info_finish_log_{}_{}".format(doc_type, tick_date)).open("w+") - f.write("finish: date:{}, consume_time:{}, end_time: {}".format(tick_date, total_time, time.time()) + "\n") + f = ( + DATA_FINISH_INFO_PATH + / "data_info_finish_log_{}_{}".format(doc_type, tick_date) + ).open("w+") + f.write( + "finish: date:{}, consume_time:{}, end_time: {}".format( + tick_date, total_time, time.time() + ) + + "\n" + ) f.close() except Exception as e: info = traceback.format_exc() print("date error:" + str(e)) - f = open(os.path.join(LOG_FILE_PATH, "temp_fail_{0}_{1}_{2}.txt".format(pid, tick_date, doc_type)), "a+") + f = open( + os.path.join( + LOG_FILE_PATH, + "temp_fail_{0}_{1}_{2}.txt".format(pid, tick_date, doc_type), + ), + "a+", + ) f.write("fail:" + str(tick_date) + "\n" + str(e) + "\n" + str(info) + "\n") f.close() @@ -273,7 +347,15 @@ def import_data(self, doc_type_l=["Tick", "Transaction", "Order"]): # doc_type = 'Day' for doc_type in doc_type_l: - date_list = list(set([int(path.split("_")[0]) for path in os.listdir(DATABASE_PATH) if doc_type in path])) + date_list = list( + set( + [ + int(path.split("_")[0]) + for path in os.listdir(DATABASE_PATH) + if doc_type in path + ] + ) + ) date_list.sort() date_list = [str(date) for date in date_list] @@ -281,8 +363,16 @@ def import_data(self, doc_type_l=["Tick", "Transaction", "Order"]): stock_name_list = [lines.split("\t")[0] for lines in f.readlines()] f.close() stock_name_dict = { - "SH": [stock_name[2:] for stock_name in stock_name_list if "SH" in stock_name], - "SZ": [stock_name[2:] for stock_name in stock_name_list if "SZ" in stock_name], + "SH": [ + stock_name[2:] + for stock_name in stock_name_list + if "SH" in stock_name + ], + "SZ": [ + stock_name[2:] + for stock_name in stock_name_list + if "SZ" in stock_name + ], } lib_name = get_library_name(doc_type) diff --git a/examples/orderbook_data/example.py b/examples/orderbook_data/example.py index f8bd84ea78..2febb680ec 100644 --- a/examples/orderbook_data/example.py +++ b/examples/orderbook_data/example.py @@ -24,7 +24,10 @@ def setUp(self): mem_cache_size_limit=1024**3 * 2, mem_cache_type="sizeof", kernels=1, - expression_provider={"class": "LocalExpressionProvider", "kwargs": {"time2idx": False}}, + expression_provider={ + "class": "LocalExpressionProvider", + "kwargs": {"time2idx": False}, + }, feature_provider={ "class": "ArcticFeatureProvider", "module_path": "qlib.contrib.data.data", @@ -87,7 +90,9 @@ def test_basic03(self): # Here are some popular expressions for high-frequency # 1) some shared expression - expr_sum_buy_ask_1 = "(TResample($ask1, '1min', 'last') + TResample($bid1, '1min', 'last'))" + expr_sum_buy_ask_1 = ( + "(TResample($ask1, '1min', 'last') + TResample($bid1, '1min', 'last'))" + ) total_volume = ( "TResample(" + "+".join([f"${name}{i}" for i in range(1, 11) for name in ["asize", "bsize"]]) @@ -96,14 +101,20 @@ def test_basic03(self): @staticmethod def total_func(name, method): - return "TResample(" + "+".join([f"${name}{i}" for i in range(1, 11)]) + ",'1min', '{}')".format(method) + return ( + "TResample(" + + "+".join([f"${name}{i}" for i in range(1, 11)]) + + ",'1min', '{}')".format(method) + ) def test_exp_01(self): exprs = [] names = [] for name in ["asize", "bsize"]: for i in range(1, 11): - exprs.append(f"TResample(${name}{i}, '1min', 'mean') / ({self.total_volume})") + exprs.append( + f"TResample(${name}{i}, '1min', 'mean') / ({self.total_volume})" + ) names.append(f"v_{name}_{i}") df = D.features(self.stocks_list, fields=exprs, freq="ticks") df.columns = names @@ -145,7 +156,9 @@ def test_exp_04(self): exprs = [] names = [] for name in ["asize", "bsize"]: - exprs.append(f"(({ self.total_func(name, 'mean')}) / 10) / {self.total_volume}") + exprs.append( + f"(({ self.total_func(name, 'mean')}) / 10) / {self.total_volume}" + ) names.append(f"v_avg_{name}") df = D.features(self.stocks_list, fields=exprs, freq="ticks") @@ -180,7 +193,9 @@ def test_exp_06(self): for i in range(1, 11): for name in ["asize", "bsize"]: - exprs.append(f"TResample({expr6_price_func(name, i, 'mean')}, '1min', 'mean') / {self.total_volume}") + exprs.append( + f"TResample({expr6_price_func(name, i, 'mean')}, '1min', 'mean') / {self.total_volume}" + ) names.append(f"v_diff_{name}{i}_{t}s") df = D.features(self.stocks_list, fields=exprs, freq="ticks") @@ -227,7 +242,11 @@ def test_exp_07_2(self): for funccode in ["B", "S"]: for ordercode in ["0", "1"]: exprs.append(expr7(funccode, ordercode, "3")) - names.append(self.trans_dict[ordercode] + self.trans_dict[funccode] + "_intensity_3s") + names.append( + self.trans_dict[ordercode] + + self.trans_dict[funccode] + + "_intensity_3s" + ) df = D.features(self.stocks_list, fields=exprs, freq="transaction") df.columns = names print(df) @@ -248,7 +267,11 @@ def test_exp_08_1(self): for funccode in ["B", "S"]: for ordercode in ["0", "1"]: exprs.append(expr8_1(funccode, ordercode, "10", "900")) - names.append(self.trans_dict[ordercode] + self.trans_dict[funccode] + "_relative_intensity_10s_900s") + names.append( + self.trans_dict[ordercode] + + self.trans_dict[funccode] + + "_relative_intensity_10s_900s" + ) df = D.features(self.stocks_list, fields=exprs, freq="order") df.columns = names @@ -290,7 +313,11 @@ def test_exp_09_order(self): exprs.append( f'TResample(Div(Sub(TResample({self.expr7_init(funccode, ordercode, "3")}, "3s", "last"), Ref(TResample({self.expr7_init(funccode, ordercode, "3")},"3s", "last"), 1)), 3) ,"1min", "mean")' ) - names.append(self.trans_dict[ordercode] + self.trans_dict[funccode] + "_diff_intensity_3s_3s") + names.append( + self.trans_dict[ordercode] + + self.trans_dict[funccode] + + "_diff_intensity_3s_3s" + ) df = D.features(self.stocks_list, fields=exprs, freq="order") df.columns = names print(df) diff --git a/examples/portfolio/prepare_riskdata.py b/examples/portfolio/prepare_riskdata.py index e502a1ff78..41e08d85b6 100644 --- a/examples/portfolio/prepare_riskdata.py +++ b/examples/portfolio/prepare_riskdata.py @@ -9,10 +9,16 @@ def prepare_data(riskdata_root="./riskdata", T=240, start_time="2016-01-01"): - universe = D.features(D.instruments("csi300"), ["$close"], start_time=start_time).swaplevel().sort_index() + universe = ( + D.features(D.instruments("csi300"), ["$close"], start_time=start_time) + .swaplevel() + .sort_index() + ) price_all = ( - D.features(D.instruments("all"), ["$close"], start_time=start_time).squeeze().unstack(level="instrument") + D.features(D.instruments("all"), ["$close"], start_time=start_time) + .squeeze() + .unstack(level="instrument") ) # StructuredCovEstimator is a statistical risk model @@ -32,7 +38,9 @@ def prepare_data(riskdata_root="./riskdata", T=240, start_time="2016-01-01"): ret.clip(ret.quantile(0.025), ret.quantile(0.975), axis=1, inplace=True) # run risk model - F, cov_b, var_u = riskmodel.predict(ret, is_price=False, return_decomposed_components=True) + F, cov_b, var_u = riskmodel.predict( + ret, is_price=False, return_decomposed_components=True + ) # save risk data root = riskdata_root + "/" + date.strftime("%Y%m%d") diff --git a/examples/rl_order_execution/scripts/gen_pickle_data.py b/examples/rl_order_execution/scripts/gen_pickle_data.py index 75810bddc0..9782de2af3 100755 --- a/examples/rl_order_execution/scripts/gen_pickle_data.py +++ b/examples/rl_order_execution/scripts/gen_pickle_data.py @@ -15,7 +15,13 @@ parser = argparse.ArgumentParser() parser.add_argument("-c", "--config", type=str, default="config.yml") parser.add_argument("-d", "--dest", type=str, default=".") - parser.add_argument("-s", "--split", type=str, choices=["none", "date", "stock", "both"], default="stock") + parser.add_argument( + "-s", + "--split", + type=str, + choices=["none", "date", "stock", "both"], + default="stock", + ) args = parser.parse_args() conf = yaml.load(open(args.config), Loader=loader) @@ -31,8 +37,12 @@ if "backtest_conf" in conf: backtest = provider._gen_dataframe(deepcopy(provider.backtest_conf)) - provider.feature_conf["path"] = os.path.splitext(provider.feature_conf["path"])[0] + "/" - provider.backtest_conf["path"] = os.path.splitext(provider.backtest_conf["path"])[0] + "/" + provider.feature_conf["path"] = ( + os.path.splitext(provider.feature_conf["path"])[0] + "/" + ) + provider.backtest_conf["path"] = ( + os.path.splitext(provider.backtest_conf["path"])[0] + "/" + ) # Split by date if args.split == "date" or args.split == "both": provider._gen_day_dataset(deepcopy(provider.feature_conf), "feature") diff --git a/examples/rl_order_execution/scripts/gen_training_orders.py b/examples/rl_order_execution/scripts/gen_training_orders.py index b03ce6e5a8..7bf76b0f8a 100755 --- a/examples/rl_order_execution/scripts/gen_training_orders.py +++ b/examples/rl_order_execution/scripts/gen_training_orders.py @@ -19,7 +19,11 @@ def generate_order(stock: str, start_idx: int, end_idx: int) -> bool: df["date"] = df["datetime"].dt.date.astype("datetime64") df = df.set_index(["instrument", "datetime", "date"]) - df = df.groupby("date", group_keys=False).take(range(start_idx, end_idx)).droplevel(level=0) + df = ( + df.groupby("date", group_keys=False) + .take(range(start_idx, end_idx)) + .droplevel(level=0) + ) order_all = pd.DataFrame(df.groupby(level=(2, 0), group_keys=False).mean().dropna()) order_all["amount"] = np.random.lognormal(-3.28, 1.14) * order_all["$volume0"] @@ -27,12 +31,23 @@ def generate_order(stock: str, start_idx: int, end_idx: int) -> bool: order_all["order_type"] = 0 order_all = order_all.drop(columns=["$volume0"]) - order_train = order_all[order_all.index.get_level_values(0) <= pd.Timestamp("2021-06-30")] - order_test = order_all[order_all.index.get_level_values(0) > pd.Timestamp("2021-06-30")] - order_valid = order_test[order_test.index.get_level_values(0) <= pd.Timestamp("2021-09-30")] - order_test = order_test[order_test.index.get_level_values(0) > pd.Timestamp("2021-09-30")] + order_train = order_all[ + order_all.index.get_level_values(0) <= pd.Timestamp("2021-06-30") + ] + order_test = order_all[ + order_all.index.get_level_values(0) > pd.Timestamp("2021-06-30") + ] + order_valid = order_test[ + order_test.index.get_level_values(0) <= pd.Timestamp("2021-09-30") + ] + order_test = order_test[ + order_test.index.get_level_values(0) > pd.Timestamp("2021-09-30") + ] - for order, tag in zip((order_train, order_valid, order_test, order_all), ("train", "valid", "test", "all")): + for order, tag in zip( + (order_train, order_valid, order_test, order_all), + ("train", "valid", "test", "all"), + ): path = OUTPUT_PATH / tag os.makedirs(path, exist_ok=True) if len(order) > 0: diff --git a/examples/rl_order_execution/scripts/merge_orders.py b/examples/rl_order_execution/scripts/merge_orders.py index 64a684e07b..4edf451912 100755 --- a/examples/rl_order_execution/scripts/merge_orders.py +++ b/examples/rl_order_execution/scripts/merge_orders.py @@ -12,4 +12,6 @@ dfs.append(df) total_df = pd.concat(dfs) - pickle.dump(total_df, open(os.path.join("data", "orders", f"{tag}_orders.pkl"), "wb")) + pickle.dump( + total_df, open(os.path.join("data", "orders", f"{tag}_orders.pkl"), "wb") + ) diff --git a/examples/rolling_process_data/rolling_handler.py b/examples/rolling_process_data/rolling_handler.py index 13b399afd8..d5dffcc73c 100644 --- a/examples/rolling_process_data/rolling_handler.py +++ b/examples/rolling_process_data/rolling_handler.py @@ -14,8 +14,12 @@ def __init__( fit_end_time=None, data_loader_kwargs={}, ): - infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) - learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) + infer_processors = check_transform_proc( + infer_processors, fit_start_time, fit_end_time + ) + learn_processors = check_transform_proc( + learn_processors, fit_start_time, fit_end_time + ) data_loader = { "class": "DataLoaderDH", diff --git a/examples/rolling_process_data/workflow.py b/examples/rolling_process_data/workflow.py index d1c03866a4..3051158b7f 100644 --- a/examples/rolling_process_data/workflow.py +++ b/examples/rolling_process_data/workflow.py @@ -70,11 +70,17 @@ def rolling_process(self): "fit_start_time": datetime(*train_start_time), "fit_end_time": datetime(*train_end_time), "infer_processors": [ - {"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature"}}, + { + "class": "RobustZScoreNorm", + "kwargs": {"fields_group": "feature"}, + }, ], "learn_processors": [ {"class": "DropnaLabel"}, - {"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}}, + { + "class": "CSZScoreNorm", + "kwargs": {"fields_group": "label"}, + }, ], "data_loader_kwargs": { "handler_config": pre_handler, @@ -96,25 +102,49 @@ def rolling_process(self): if rolling_offset: dataset.config( handler_kwargs={ - "start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]), - "end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]), + "start_time": datetime( + train_start_time[0] + rolling_offset, *train_start_time[1:] + ), + "end_time": datetime( + test_end_time[0] + rolling_offset, *test_end_time[1:] + ), "processor_kwargs": { - "fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]), - "fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]), + "fit_start_time": datetime( + train_start_time[0] + rolling_offset, + *train_start_time[1:], + ), + "fit_end_time": datetime( + train_end_time[0] + rolling_offset, *train_end_time[1:] + ), }, }, segments={ "train": ( - datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]), - datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]), + datetime( + train_start_time[0] + rolling_offset, + *train_start_time[1:], + ), + datetime( + train_end_time[0] + rolling_offset, *train_end_time[1:] + ), ), "valid": ( - datetime(valid_start_time[0] + rolling_offset, *valid_start_time[1:]), - datetime(valid_end_time[0] + rolling_offset, *valid_end_time[1:]), + datetime( + valid_start_time[0] + rolling_offset, + *valid_start_time[1:], + ), + datetime( + valid_end_time[0] + rolling_offset, *valid_end_time[1:] + ), ), "test": ( - datetime(test_start_time[0] + rolling_offset, *test_start_time[1:]), - datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]), + datetime( + test_start_time[0] + rolling_offset, + *test_start_time[1:], + ), + datetime( + test_end_time[0] + rolling_offset, *test_end_time[1:] + ), ), }, ) diff --git a/examples/run_all_model.py b/examples/run_all_model.py index 70571556b1..53220284f6 100644 --- a/examples/run_all_model.py +++ b/examples/run_all_model.py @@ -35,7 +35,10 @@ def _return_wrapped(*args, **kwargs): valid_names.remove("self") for arg_name in kwargs: if arg_name not in valid_names: - raise ValueError("Unknown argument seen '%s', expected: [%s]" % (arg_name, ", ".join(valid_names))) + raise ValueError( + "Unknown argument seen '%s', expected: [%s]" + % (arg_name, ", ".join(valid_names)) + ) return function_to_decorate(*args, **kwargs) return _return_wrapped @@ -55,8 +58,16 @@ def cal_mean_std(results) -> dict: for fn in results: mean_std[fn] = dict() for metric in results[fn]: - mean = statistics.mean(results[fn][metric]) if len(results[fn][metric]) > 1 else results[fn][metric][0] - std = statistics.stdev(results[fn][metric]) if len(results[fn][metric]) > 1 else 0 + mean = ( + statistics.mean(results[fn][metric]) + if len(results[fn][metric]) > 1 + else results[fn][metric][0] + ) + std = ( + statistics.stdev(results[fn][metric]) + if len(results[fn][metric]) > 1 + else 0 + ) mean_std[fn][metric] = [mean, std] return mean_std @@ -71,14 +82,18 @@ def create_env(): python_path = env_path / "bin" / "python" # TODO: FIX ME! sys.stderr.write("\n") # get anaconda activate path - conda_activate = Path(os.environ["CONDA_PREFIX"]) / "bin" / "activate" # TODO: FIX ME! + conda_activate = ( + Path(os.environ["CONDA_PREFIX"]) / "bin" / "activate" + ) # TODO: FIX ME! return temp_dir, env_path, python_path, conda_activate # function to execute the cmd def execute(cmd, wait_when_err=False, raise_err=True): print("Running CMD:", cmd) - with subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True, shell=True) as p: + with subprocess.Popen( + cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True, shell=True + ) as p: for line in p.stdout: sys.stdout.write(line.split("\b")[0]) if "\b" in line: @@ -107,7 +122,9 @@ def get_all_folders(models, exclude) -> dict: elif models is None: models = [f.name.lower() for f in os.scandir("benchmarks")] else: - raise ValueError("Input models type is not supported. Please provide str or list without space.") + raise ValueError( + "Input models type is not supported. Please provide str or list without space." + ) for f in os.scandir("benchmarks"): add = xor(bool(f.name.lower() in models), bool(exclude)) if add: @@ -155,9 +172,15 @@ def get_all_results(folders) -> dict: if "1day.excess_return_with_cost.annualized_return" not in metrics: print(f"{recorder_id} is skipped due to incomplete result") continue - result["annualized_return_with_cost"].append(metrics["1day.excess_return_with_cost.annualized_return"]) - result["information_ratio_with_cost"].append(metrics["1day.excess_return_with_cost.information_ratio"]) - result["max_drawdown_with_cost"].append(metrics["1day.excess_return_with_cost.max_drawdown"]) + result["annualized_return_with_cost"].append( + metrics["1day.excess_return_with_cost.annualized_return"] + ) + result["information_ratio_with_cost"].append( + metrics["1day.excess_return_with_cost.information_ratio"] + ) + result["max_drawdown_with_cost"].append( + metrics["1day.excess_return_with_cost.max_drawdown"] + ) result["ic"].append(metrics["IC"]) result["icir"].append(metrics["ICIR"]) result["rank_ic"].append(metrics["Rank IC"]) @@ -324,14 +347,18 @@ def run( if "torch" in content: # automatically install pytorch according to nvidia's version execute( - f"{python_path} -m pip install light-the-torch", wait_when_err=wait_when_err + f"{python_path} -m pip install light-the-torch", + wait_when_err=wait_when_err, ) # for automatically installing torch according to the nvidia driver execute( f"{env_path / 'bin' / 'ltt'} install --install-cmd '{python_path} -m pip install {{packages}}' -- -r {req_path}", wait_when_err=wait_when_err, ) else: - execute(f"{python_path} -m pip install -r {req_path}", wait_when_err=wait_when_err) + execute( + f"{python_path} -m pip install -r {req_path}", + wait_when_err=wait_when_err, + ) sys.stderr.write("\n") # read yaml, remove seed kwargs of model, and then save file in the temp_dir @@ -345,8 +372,14 @@ def run( sys.stderr.write("\n") # install qlib sys.stderr.write("Installing qlib...\n") - execute(f"{python_path} -m pip install --upgrade pip", wait_when_err=wait_when_err) # TODO: FIX ME! - execute(f"{python_path} -m pip install --upgrade cython", wait_when_err=wait_when_err) # TODO: FIX ME! + execute( + f"{python_path} -m pip install --upgrade pip", + wait_when_err=wait_when_err, + ) # TODO: FIX ME! + execute( + f"{python_path} -m pip install --upgrade cython", + wait_when_err=wait_when_err, + ) # TODO: FIX ME! if fn == "TFT": execute( f"cd {env_path} && {python_path} -m pip install --upgrade --force-reinstall --ignore-installed PyYAML -e {qlib_uri}", @@ -395,8 +428,15 @@ def _collect_results(self, exp_folder_name, dataset): sys.stderr.write("\n") sys.stderr.write("\n") # move results folder - shutil.move(exp_folder_name, exp_folder_name + f"_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}") - shutil.move("table.md", f"table_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}.md") + shutil.move( + exp_folder_name, + exp_folder_name + + f"_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}", + ) + shutil.move( + "table.md", + f"table_{dataset}_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}.md", + ) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index fe48d090c0..2d43a9a98a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ dependencies = [ "pymongo", "loguru", "lightgbm", - "gym", + "gymnasium", "cvxpy", "joblib", "matplotlib", diff --git a/qlib/__init__.py b/qlib/__init__.py index 687e317ced..22e7893037 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -66,7 +66,9 @@ def init(default_conf="client", **kwargs): f"Invalid provider uri: {provider_uri}, please check if a valid provider uri has been set. This path does not exist." ) else: - logger.warning(f"auto_path is False, please make sure {mount_path} is mounted") + logger.warning( + f"auto_path is False, please make sure {mount_path} is mounted" + ) elif uri_type == C.NFS_URI: _mount_nfs_uri(provider_uri, C.dpm.get_data_uri(_freq), C["auto_mount"]) else: @@ -77,7 +79,9 @@ def init(default_conf="client", **kwargs): if "flask_server" in C: logger.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}") logger.info("qlib successfully initialized based on %s settings." % default_conf) - data_path = {_freq: C.dpm.get_data_uri(_freq) for _freq in C.dpm.provider_uri.keys()} + data_path = { + _freq: C.dpm.get_data_uri(_freq) for _freq in C.dpm.provider_uri.keys() + } logger.info(f"data_path={data_path}") @@ -123,7 +127,9 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False): else: # system: linux/Unix/Mac # check mount - _remote_uri = provider_uri[:-1] if provider_uri.endswith("/") else provider_uri + _remote_uri = ( + provider_uri[:-1] if provider_uri.endswith("/") else provider_uri + ) # `mount a /b/c` is different from `mount a /b/c/`. So we convert it into string to make sure handling it accurately mount_path = str(mount_path) _mount_path = mount_path[:-1] if mount_path.endswith("/") else mount_path @@ -137,14 +143,20 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False): stderr=subprocess.STDOUT, ) as shell_r: _command_log = shell_r.stdout.readlines() - _command_log = [line for line in _command_log if _remote_uri in line] + _command_log = [ + line for line in _command_log if _remote_uri in line + ] if len(_command_log) > 0: for _c in _command_log: if isinstance(_c, str): _temp_mount = _c.split(" ")[2] else: _temp_mount = _c.decode("utf-8").split(" ")[2] - _temp_mount = _temp_mount[:-1] if _temp_mount.endswith("/") else _temp_mount + _temp_mount = ( + _temp_mount[:-1] + if _temp_mount.endswith("/") + else _temp_mount + ) if _temp_mount == _mount_path: _is_mount = True break @@ -166,16 +178,24 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False): command_res = os.popen("dpkg -l | grep nfs-common") command_res = command_res.readlines() if not command_res: - raise OSError("nfs-common is not found, please install it by execute: sudo apt install nfs-common") + raise OSError( + "nfs-common is not found, please install it by execute: sudo apt install nfs-common" + ) # manually mount try: - subprocess.run(mount_command, check=True, capture_output=True, text=True) + subprocess.run( + mount_command, check=True, capture_output=True, text=True + ) LOG.info("Mount finished.") except subprocess.CalledProcessError as e: if e.returncode == 256: - raise OSError("Mount failed: requires sudo or permission denied") from e + raise OSError( + "Mount failed: requires sudo or permission denied" + ) from e elif e.returncode == 32512: - raise OSError(f"mount {provider_uri} on {mount_path} error! Command error") from e + raise OSError( + f"mount {provider_uri} on {mount_path} error! Command error" + ) from e else: raise OSError(f"Mount failed: {e.stderr}") from e else: @@ -199,7 +219,9 @@ def init_from_yaml_conf(conf_path, **kwargs): init(default_conf, **config) -def get_project_path(config_name="config.yaml", cur_path: Union[Path, str, None] = None) -> Path: +def get_project_path( + config_name="config.yaml", cur_path: Union[Path, str, None] = None +) -> Path: """ If users are building a project follow the following pattern. - Qlib is a sub folder in project path @@ -307,7 +329,9 @@ def auto_init(**kwargs): qlib_conf_update = conf.get("qlib_cfg_update", {}) for k, v in kwargs.items(): if k in qlib_conf_update: - logger.warning(f"`qlib_conf_update` from conf_pp is override by `kwargs` on key '{k}'") + logger.warning( + f"`qlib_conf_update` from conf_pp is override by `kwargs` on key '{k}'" + ) qlib_conf_update.update(kwargs) init_from_yaml_conf(qlib_conf_path, **qlib_conf_update) diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index 9daba91153..038174b908 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -205,7 +205,9 @@ def get_strategy_executor( exchange_kwargs["end_time"] = end_time trade_exchange = get_exchange(**exchange_kwargs) - common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange) + common_infra = CommonInfrastructure( + trade_account=trade_account, trade_exchange=trade_exchange + ) trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy) trade_strategy.reset_common_infra(common_infra) trade_executor = init_instance_by_config(executor, accept_types=BaseExecutor) @@ -306,7 +308,9 @@ def collect_data( exchange_kwargs, pos_type=pos_type, ) - yield from collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value=return_value) + yield from collect_data_loop( + start_time, end_time, trade_strategy, trade_executor, return_value=return_value + ) def format_decisions( @@ -340,9 +344,16 @@ def format_decisions( last_dec_idx = 0 for i, dec in enumerate(decisions[1:], 1): if dec.strategy.trade_calendar.get_freq() == cur_freq: - res[1].append((decisions[last_dec_idx], format_decisions(decisions[last_dec_idx + 1 : i]))) + res[1].append( + ( + decisions[last_dec_idx], + format_decisions(decisions[last_dec_idx + 1 : i]), + ) + ) last_dec_idx = i - res[1].append((decisions[last_dec_idx], format_decisions(decisions[last_dec_idx + 1 :]))) + res[1].append( + (decisions[last_dec_idx], format_decisions(decisions[last_dec_idx + 1 :])) + ) return res diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py index b0e416f8f4..4dd0294b9a 100644 --- a/qlib/backtest/account.py +++ b/qlib/backtest/account.py @@ -108,7 +108,9 @@ def __init__( self.benchmark_config: dict = {} # avoid no attribute error self.init_vars(init_cash, position_dict, freq, benchmark_config) - def init_vars(self, init_cash: float, position_dict: dict, freq: str, benchmark_config: dict) -> None: + def init_vars( + self, init_cash: float, position_dict: dict, freq: str, benchmark_config: dict + ) -> None: # 1) the following variables are shared by multiple layers # - you will see a shallow copy instead of deepcopy in the NestedExecutor; self.init_cash = init_cash @@ -146,14 +148,22 @@ def reset_report(self, freq: str, benchmark_config: dict) -> None: # fill stock value # The frequency of account may not align with the trading frequency. # This may result in obscure bugs when data quality is low. - if isinstance(self.benchmark_config, dict) and "start_time" in self.benchmark_config: - self.current_position.fill_stock_value(self.benchmark_config["start_time"], self.freq) + if ( + isinstance(self.benchmark_config, dict) + and "start_time" in self.benchmark_config + ): + self.current_position.fill_stock_value( + self.benchmark_config["start_time"], self.freq + ) # trading related metrics(e.g. high-frequency trading) self.indicator = Indicator() def reset( - self, freq: str | None = None, benchmark_config: dict | None = None, port_metr_enabled: bool | None = None + self, + freq: str | None = None, + benchmark_config: dict | None = None, + port_metr_enabled: bool | None = None, ) -> None: """reset freq and report of account @@ -180,7 +190,9 @@ def get_hist_positions(self) -> Dict[pd.Timestamp, BasePosition]: def get_cash(self) -> float: return self.current_position.get_cash() - def _update_state_from_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None: + def _update_state_from_order( + self, order: Order, trade_val: float, cost: float, trade_price: float + ) -> None: if self.is_port_metr_enabled(): # update turnover self.accum_info.add_turnover(trade_val) @@ -191,16 +203,29 @@ def _update_state_from_order(self, order: Order, trade_val: float, cost: float, trade_amount = trade_val / trade_price if order.direction == Order.SELL: # 0 for sell # when sell stock, get profit from price change - profit = trade_val - self.current_position.get_stock_price(order.stock_id) * trade_amount - self.accum_info.add_return_value(profit) # note here do not consider cost + profit = ( + trade_val + - self.current_position.get_stock_price(order.stock_id) + * trade_amount + ) + self.accum_info.add_return_value( + profit + ) # note here do not consider cost elif order.direction == Order.BUY: # 1 for buy # when buy stock, we get return for the rtn computing method # profit in buy order is to make rtn is consistent with earning at the end of bar - profit = self.current_position.get_stock_price(order.stock_id) * trade_amount - trade_val - self.accum_info.add_return_value(profit) # note here do not consider cost - - def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None: + profit = ( + self.current_position.get_stock_price(order.stock_id) * trade_amount + - trade_val + ) + self.accum_info.add_return_value( + profit + ) # note here do not consider cost + + def update_order( + self, order: Order, trade_val: float, cost: float, trade_price: float + ) -> None: if self.current_position.skip_update(): # TODO: supporting polymorphism for account # updating order for infinite position is meaningless @@ -239,15 +264,22 @@ def update_current_position( stock_list = self.current_position.get_stock_list() for code in stock_list: # if suspended, no new price to be updated, profit is 0 - if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time): + if trade_exchange.check_stock_suspended( + code, trade_start_time, trade_end_time + ): continue - bar_close = cast(float, trade_exchange.get_close(code, trade_start_time, trade_end_time)) + bar_close = cast( + float, + trade_exchange.get_close(code, trade_start_time, trade_end_time), + ) self.current_position.update_stock_price(stock_id=code, price=bar_close) # update holding day count # NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy self.current_position.add_count_all(bar=self.freq) - def update_portfolio_metrics(self, trade_start_time: pd.Timestamp, trade_end_time: pd.Timestamp) -> None: + def update_portfolio_metrics( + self, trade_start_time: pd.Timestamp, trade_end_time: pd.Timestamp + ) -> None: """update portfolio_metrics""" # calculate earning # account_value - last_account_value @@ -330,7 +362,9 @@ def update_indicator( ) # aggregate all the order metrics a single step - self.indicator.cal_trade_indicators(trade_start_time, self.freq, indicator_config) + self.indicator.cal_trade_indicators( + trade_start_time, self.freq, indicator_config + ) # record the metrics self.indicator.record(trade_start_time) @@ -380,7 +414,9 @@ def update_bar_end( if atomic is True and trade_info is None: raise ValueError("trade_info is necessary in atomic executor") elif atomic is False and inner_order_indicators is None: - raise ValueError("inner_order_indicators is necessary in un-atomic executor") + raise ValueError( + "inner_order_indicators is necessary in un-atomic executor" + ) # update current position and hold bar count in each bar end self.update_current_position(trade_start_time, trade_end_time, trade_exchange) @@ -406,11 +442,15 @@ def get_portfolio_metrics(self) -> Tuple[pd.DataFrame, dict]: """get the history portfolio_metrics and positions instance""" if self.is_port_metr_enabled(): assert self.portfolio_metrics is not None - _portfolio_metrics = self.portfolio_metrics.generate_portfolio_metrics_dataframe() + _portfolio_metrics = ( + self.portfolio_metrics.generate_portfolio_metrics_dataframe() + ) _positions = self.get_hist_positions() return _portfolio_metrics, _positions else: - raise ValueError("generate_portfolio_metrics should be True if you want to generate portfolio_metrics") + raise ValueError( + "generate_portfolio_metrics should be True if you want to generate portfolio_metrics" + ) def get_trade_indicator(self) -> Indicator: """get the trade indicator instance, which has pa/pos/ffr info.""" diff --git a/qlib/backtest/backtest.py b/qlib/backtest/backtest.py index 5e5edacafd..4d4f060875 100644 --- a/qlib/backtest/backtest.py +++ b/qlib/backtest/backtest.py @@ -41,7 +41,9 @@ def backtest_loop( it computes the trading indicator """ return_value: dict = {} - for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value): + for _decision in collect_data_loop( + start_time, end_time, trade_strategy, trade_executor, return_value + ): pass portfolio_dict = cast(PORT_METRIC, return_value.get("portfolio_dict")) @@ -83,11 +85,17 @@ def collect_data_loop( trade_executor.reset(start_time=start_time, end_time=end_time) trade_strategy.reset(level_infra=trade_executor.get_level_infra()) - with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop") as bar: + with tqdm( + total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop" + ) as bar: _execute_result = None while not trade_executor.finished(): - _trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result) - _execute_result = yield from trade_executor.collect_data(_trade_decision, level=0) + _trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision( + _execute_result + ) + _execute_result = yield from trade_executor.collect_data( + _trade_decision, level=0 + ) trade_strategy.post_exe_step(_execute_result) bar.update(1) trade_strategy.post_upper_level_exe_step() @@ -103,8 +111,12 @@ def collect_data_loop( if executor.trade_account.is_port_metr_enabled(): portfolio_dict[key] = executor.trade_account.get_portfolio_metrics() - indicator_df = executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe() + indicator_df = ( + executor.trade_account.get_trade_indicator().generate_trade_indicators_dataframe() + ) indicator_obj = executor.trade_account.get_trade_indicator() indicator_dict[key] = (indicator_df, indicator_obj) - return_value.update({"portfolio_dict": portfolio_dict, "indicator_dict": indicator_dict}) + return_value.update( + {"portfolio_dict": portfolio_dict, "indicator_dict": indicator_dict} + ) diff --git a/qlib/backtest/decision.py b/qlib/backtest/decision.py index 7188bec7a5..311eb4e4e7 100644 --- a/qlib/backtest/decision.py +++ b/qlib/backtest/decision.py @@ -8,7 +8,18 @@ from enum import IntEnum # try to fix circular imports when enabling type hints -from typing import TYPE_CHECKING, Any, ClassVar, Generic, List, Optional, Tuple, TypeVar, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Generic, + List, + Optional, + Tuple, + TypeVar, + Union, + cast, +) from qlib.backtest.utils import TradeCalendarManager from qlib.data.data import Cal @@ -82,7 +93,9 @@ class Order: def __post_init__(self) -> None: if self.direction not in {Order.SELL, Order.BUY}: - raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy") + raise NotImplementedError( + "direction not supported, `Order.SELL` for sell, `Order.BUY` for buy" + ) self.deal_amount = 0.0 self.factor = None @@ -114,7 +127,9 @@ def sign(self) -> int: return self.direction * 2 - 1 @staticmethod - def parse_dir(direction: Union[str, int, np.integer, OrderDir, np.ndarray]) -> Union[OrderDir, np.ndarray]: + def parse_dir( + direction: Union[str, int, np.integer, OrderDir, np.ndarray], + ) -> Union[OrderDir, np.ndarray]: if isinstance(direction, OrderDir): return direction elif isinstance(direction, (int, float, np.integer, np.floating)): @@ -232,7 +247,9 @@ def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]: raise NotImplementedError(f"Please implement the `__call__` method") @abstractmethod - def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]: + def clip_time_range( + self, start_time: pd.Timestamp, end_time: pd.Timestamp + ) -> Tuple[pd.Timestamp, pd.Timestamp]: """ Parameters ---------- @@ -254,10 +271,14 @@ def __init__(self, start_idx: int, end_idx: int) -> None: self._start_idx = start_idx self._end_idx = end_idx - def __call__(self, trade_calendar: TradeCalendarManager | None = None) -> Tuple[int, int]: + def __call__( + self, trade_calendar: TradeCalendarManager | None = None + ) -> Tuple[int, int]: return self._start_idx, self._end_idx - def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]: + def clip_time_range( + self, start_time: pd.Timestamp, end_time: pd.Timestamp + ) -> Tuple[pd.Timestamp, pd.Timestamp]: raise NotImplementedError @@ -279,21 +300,35 @@ def __init__(self, start_time: str | time, end_time: str | time) -> None: end_time : str | time e.g. "14:30" """ - self.start_time = pd.Timestamp(start_time).time() if isinstance(start_time, str) else start_time - self.end_time = pd.Timestamp(end_time).time() if isinstance(end_time, str) else end_time + self.start_time = ( + pd.Timestamp(start_time).time() + if isinstance(start_time, str) + else start_time + ) + self.end_time = ( + pd.Timestamp(end_time).time() if isinstance(end_time, str) else end_time + ) assert self.start_time < self.end_time def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]: if trade_calendar is None: - raise NotImplementedError("trade_calendar is necessary for getting TradeRangeByTime.") + raise NotImplementedError( + "trade_calendar is necessary for getting TradeRangeByTime." + ) start_date = trade_calendar.start_time.date() - val_start, val_end = concat_date_time(start_date, self.start_time), concat_date_time(start_date, self.end_time) + val_start, val_end = concat_date_time( + start_date, self.start_time + ), concat_date_time(start_date, self.end_time) return trade_calendar.get_range_idx(val_start, val_end) - def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]: + def clip_time_range( + self, start_time: pd.Timestamp, end_time: pd.Timestamp + ) -> Tuple[pd.Timestamp, pd.Timestamp]: start_date = start_time.date() - val_start, val_end = concat_date_time(start_date, self.start_time), concat_date_time(start_date, self.end_time) + val_start, val_end = concat_date_time( + start_date, self.start_time + ), concat_date_time(start_date, self.end_time) # NOTE: `end_date` should not be used. Because the `end_date` is for slicing. It may be in the next day # Assumption: start_time and end_time is for intra-day trading. So it is OK for only using start_date return max(val_start, start_time), min(val_end, end_time) @@ -315,7 +350,11 @@ class BaseTradeDecision(Generic[DecisionType]): 2. Same as `case 1.3` """ - def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange, None] = None) -> None: + def __init__( + self, + strategy: BaseStrategy, + trade_range: Union[Tuple[int, int], TradeRange, None] = None, + ) -> None: """ Parameters ---------- @@ -358,7 +397,9 @@ def get_decision(self) -> List[DecisionType]: """ raise NotImplementedError(f"This type of input is not supported") - def update(self, trade_calendar: TradeCalendarManager) -> Optional[BaseTradeDecision]: + def update( + self, trade_calendar: TradeCalendarManager + ) -> Optional[BaseTradeDecision]: """ Be called at the **start** of each step. @@ -384,7 +425,9 @@ def update(self, trade_calendar: TradeCalendarManager) -> Optional[BaseTradeDeci def _get_range_limit(self, **kwargs: Any) -> Tuple[int, int]: if self.trade_range is not None: - return self.trade_range(trade_calendar=cast(TradeCalendarManager, kwargs.get("inner_calendar"))) + return self.trade_range( + trade_calendar=cast(TradeCalendarManager, kwargs.get("inner_calendar")) + ) else: raise NotImplementedError("The decision didn't provide an index range") @@ -434,7 +477,9 @@ def get_range_limit(self, **kwargs: Any) -> Tuple[int, int]: return kwargs["default_value"] else: # Default to get full index - raise NotImplementedError(f"The decision didn't provide an index range") from e + raise NotImplementedError( + f"The decision didn't provide an index range" + ) from e # clip index if getattr(self, "total_step", None) is not None: @@ -446,10 +491,14 @@ def get_range_limit(self, **kwargs: Any) -> Tuple[int, int]: logger.warning( f"[{_start_idx},{_end_idx}] go beyond the total_step({self.total_step}), it will be clipped.", ) - _start_idx, _end_idx = max(0, _start_idx), min(self.total_step - 1, _end_idx) + _start_idx, _end_idx = max(0, _start_idx), min( + self.total_step - 1, _end_idx + ) return _start_idx, _end_idx - def get_data_cal_range_limit(self, rtype: str = "full", raise_error: bool = False) -> Tuple[int, int]: + def get_data_cal_range_limit( + self, rtype: str = "full", raise_error: bool = False + ) -> Tuple[int, int]: """ get the range limit based on data calendar @@ -489,7 +538,9 @@ def get_data_cal_range_limit(self, rtype: str = "full", raise_error: bool = Fals day_start = pd.Timestamp(self.start_time.date()) day_end = epsilon_change(day_start + pd.Timedelta(days=1)) freq = self.strategy.trade_exchange.freq - _, _, day_start_idx, day_end_idx = Cal.locate_index(day_start, day_end, freq=freq) + _, _, day_start_idx, day_end_idx = Cal.locate_index( + day_start, day_end, freq=freq + ) if self.trade_range is None: if raise_error: raise NotImplementedError(f"There is no trade_range in this case") @@ -497,9 +548,13 @@ def get_data_cal_range_limit(self, rtype: str = "full", raise_error: bool = Fals return 0, day_end_idx - day_start_idx else: if rtype == "full": - val_start, val_end = self.trade_range.clip_time_range(day_start, day_end) + val_start, val_end = self.trade_range.clip_time_range( + day_start, day_end + ) elif rtype == "step": - val_start, val_end = self.trade_range.clip_time_range(self.start_time, self.end_time) + val_start, val_end = self.trade_range.clip_time_range( + self.start_time, self.end_time + ) else: raise ValueError(f"This type of input {rtype} is not supported") _, _, start_idx, end_index = Cal.locate_index(val_start, val_end, freq=freq) diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 69262fcbba..553315e359 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -149,10 +149,14 @@ def __init__( self.limit_type = self._get_limit_type(limit_threshold) if limit_threshold is None: if C.region in [REG_CN, REG_TW]: - self.logger.warning(f"limit_threshold not set. The stocks hit the limit may be bought/sold") + self.logger.warning( + f"limit_threshold not set. The stocks hit the limit may be bought/sold" + ) elif self.limit_type == self.LT_FLT and abs(cast(float, limit_threshold)) > 0.1: if C.region in [REG_CN, REG_TW]: - self.logger.warning(f"limit_threshold may not be set to a reasonable value") + self.logger.warning( + f"limit_threshold may not be set to a reasonable value" + ) if isinstance(deal_price, str): if deal_price[0] != "$": @@ -173,9 +177,18 @@ def __init__( # $change is for calculating the limit of the stock #  get volume limit from kwargs - self.buy_vol_limit, self.sell_vol_limit, vol_lt_fields = self._get_vol_limit(volume_threshold) + self.buy_vol_limit, self.sell_vol_limit, vol_lt_fields = self._get_vol_limit( + volume_threshold + ) - necessary_fields = {self.buy_price, self.sell_price, "$close", "$change", "$factor", "$volume"} + necessary_fields = { + self.buy_price, + self.sell_price, + "$close", + "$change", + "$factor", + "$volume", + } if self.limit_type == self.LT_TP_EXP: assert isinstance(limit_threshold, tuple) for exp in limit_threshold: @@ -223,9 +236,13 @@ def get_quote_from_qlib(self) -> None: # The 'factor.day.bin' file not exists, and `factor` field contains `nan` # Use adjusted price self.trade_w_adj_price = True - self.logger.warning("factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price.") + self.logger.warning( + "factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price." + ) if self.trade_unit is not None: - self.logger.warning(f"trade unit {self.trade_unit} is not supported in adjusted_price mode.") + self.logger.warning( + f"trade unit {self.trade_unit} is not supported in adjusted_price mode." + ) else: # The `factor.day.bin` file exists and all data `close` and `factor` are not `nan` # Use normal price @@ -242,20 +259,34 @@ def get_quote_from_qlib(self) -> None: pstr = getattr(self, attr) # price string if pstr not in self.extra_quote.columns: self.extra_quote[pstr] = self.extra_quote["$close"] - self.logger.warning(f"No {pstr} set for extra_quote. Use $close as {pstr}.") + self.logger.warning( + f"No {pstr} set for extra_quote. Use $close as {pstr}." + ) if "$factor" not in self.extra_quote.columns: self.extra_quote["$factor"] = 1.0 - self.logger.warning("No $factor set for extra_quote. Use 1.0 as $factor.") + self.logger.warning( + "No $factor set for extra_quote. Use 1.0 as $factor." + ) if "limit_sell" not in self.extra_quote.columns: self.extra_quote["limit_sell"] = False - self.logger.warning("No limit_sell set for extra_quote. All stock will be able to be sold.") + self.logger.warning( + "No limit_sell set for extra_quote. All stock will be able to be sold." + ) if "limit_buy" not in self.extra_quote.columns: self.extra_quote["limit_buy"] = False - self.logger.warning("No limit_buy set for extra_quote. All stock will be able to be bought.") - assert set(self.extra_quote.columns) == set(self.quote_df.columns) - {"$change"} - self.quote_df = pd.concat([self.quote_df, self.extra_quote], sort=False, axis=0) + self.logger.warning( + "No limit_buy set for extra_quote. All stock will be able to be bought." + ) + assert set(self.extra_quote.columns) == set(self.quote_df.columns) - { + "$change" + } + self.quote_df = pd.concat( + [self.quote_df, self.extra_quote], sort=False, axis=0 + ) - LT_TP_EXP = "(exp)" # Tuple[str, str]: the limitation is calculated by a Qlib expression. + LT_TP_EXP = ( + "(exp)" # Tuple[str, str]: the limitation is calculated by a Qlib expression. + ) LT_FLT = "float" # float: the trading limitation is based on `abs($change) < limit_threshold` LT_NONE = "none" # none: there is no trading limitation @@ -268,7 +299,9 @@ def _get_limit_type(self, limit_threshold: Union[tuple, float, None]) -> str: elif limit_threshold is None: return self.LT_NONE else: - raise NotImplementedError(f"This type of `limit_threshold` is not supported") + raise NotImplementedError( + f"This type of `limit_threshold` is not supported" + ) def _update_limit(self, limit_threshold: Union[Tuple, float, None]) -> None: # $close may contain NaN, the nan indicates that the stock is not tradable at that timestamp @@ -282,17 +315,25 @@ def _update_limit(self, limit_threshold: Union[Tuple, float, None]) -> None: # set limit limit_threshold = cast(tuple, limit_threshold) # astype bool is necessary, because quote_df is an expression and could be float - self.quote_df["limit_buy"] = self.quote_df[limit_threshold[0]].astype("bool") | suspended - self.quote_df["limit_sell"] = self.quote_df[limit_threshold[1]].astype("bool") | suspended + self.quote_df["limit_buy"] = ( + self.quote_df[limit_threshold[0]].astype("bool") | suspended + ) + self.quote_df["limit_sell"] = ( + self.quote_df[limit_threshold[1]].astype("bool") | suspended + ) elif limit_type == self.LT_FLT: limit_threshold = cast(float, limit_threshold) - self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold) | suspended + self.quote_df["limit_buy"] = ( + self.quote_df["$change"].ge(limit_threshold) | suspended + ) self.quote_df["limit_sell"] = ( self.quote_df["$change"].le(-limit_threshold) | suspended ) # pylint: disable=E1130 @staticmethod - def _get_vol_limit(volume_threshold: Union[tuple, dict, None]) -> Tuple[Optional[list], Optional[list], set]: + def _get_vol_limit( + volume_threshold: Union[tuple, dict, None], + ) -> Tuple[Optional[list], Optional[list], set]: """ preprocess the volume limit. get the fields need to get from qlib. @@ -365,13 +406,27 @@ def check_stock_limit( if direction is None: # The trading limitation is related to the trading direction # if the direction is not provided, then any limitation from buy or sell will result in trading limitation - buy_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all") - sell_limit = self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all") + buy_limit = self.quote.get_data( + stock_id, start_time, end_time, field="limit_buy", method="all" + ) + sell_limit = self.quote.get_data( + stock_id, start_time, end_time, field="limit_sell", method="all" + ) return bool(buy_limit or sell_limit) elif direction == Order.BUY: - return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field="limit_buy", method="all")) + return cast( + bool, + self.quote.get_data( + stock_id, start_time, end_time, field="limit_buy", method="all" + ), + ) elif direction == Order.SELL: - return cast(bool, self.quote.get_data(stock_id, start_time, end_time, field="limit_sell", method="all")) + return cast( + bool, + self.quote.get_data( + stock_id, start_time, end_time, field="limit_sell", method="all" + ), + ) else: raise ValueError(f"direction {direction} is not supported!") @@ -416,7 +471,9 @@ def is_stock_tradable( def check_order(self, order: Order) -> bool: # check limit and suspended - return self.is_stock_tradable(order.stock_id, order.start_time, order.end_time, order.direction) + return self.is_stock_tradable( + order.stock_id, order.start_time, order.end_time, order.direction + ) def deal_order( self, @@ -456,9 +513,19 @@ def deal_order( # 1) some stock with 0 value in the position # 2) `trade_unit` of trade_cost will be lost in user account if trade_account: - trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price) + trade_account.update_order( + order=order, + trade_val=trade_val, + cost=trade_cost, + trade_price=trade_price, + ) elif position: - position.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price) + position.update_order( + order=order, + trade_val=trade_val, + cost=trade_cost, + trade_price=trade_price, + ) return trade_val, trade_cost, trade_price @@ -470,7 +537,9 @@ def get_quote_info( field: str, method: str = "ts_data_last", ) -> Union[None, int, float, bool, IndexData]: - return self.quote.get_data(stock_id, start_time, end_time, field=field, method=method) + return self.quote.get_data( + stock_id, start_time, end_time, field=field, method=method + ) def get_close( self, @@ -479,7 +548,9 @@ def get_close( end_time: pd.Timestamp, method: str = "ts_data_last", ) -> Union[None, int, float, bool, IndexData]: - return self.quote.get_data(stock_id, start_time, end_time, field="$close", method=method) + return self.quote.get_data( + stock_id, start_time, end_time, field="$close", method=method + ) def get_volume( self, @@ -489,7 +560,9 @@ def get_volume( method: Optional[str] = "sum", ) -> Union[None, int, float, bool, IndexData]: """get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)""" - return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method) + return self.quote.get_data( + stock_id, start_time, end_time, field="$volume", method=method + ) def get_deal_price( self, @@ -506,9 +579,15 @@ def get_deal_price( else: raise NotImplementedError(f"This type of input is not supported") - deal_price = self.quote.get_data(stock_id, start_time, end_time, field=pstr, method=method) - if method is not None and (deal_price is None or np.isnan(deal_price) or deal_price <= 1e-08): - self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!") + deal_price = self.quote.get_data( + stock_id, start_time, end_time, field=pstr, method=method + ) + if method is not None and ( + deal_price is None or np.isnan(deal_price) or deal_price <= 1e-08 + ): + self.logger.warning( + f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!" + ) self.logger.warning(f"setting deal_price to close price") deal_price = self.get_close(stock_id, start_time, end_time, method) return deal_price @@ -526,10 +605,14 @@ def get_factor( `None`: if the stock is suspended `None` may be returned `float`: return factor if the factor exists """ - assert start_time is not None and end_time is not None, "the time range must be given" + assert ( + start_time is not None and end_time is not None + ), "the time range must be given" if stock_id not in self.quote.get_all_stock(): return None - return self.quote.get_data(stock_id, start_time, end_time, field="$factor", method="ts_data_last") + return self.quote.get_data( + stock_id, start_time, end_time, field="$factor", method="ts_data_last" + ) def generate_amount_position_from_weight_position( self, @@ -555,16 +638,21 @@ def generate_amount_position_from_weight_position( # calculate the total weight of tradable value tradable_weight = 0.0 for stock_id, wp in weight_position.items(): - if self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time): + if self.is_stock_tradable( + stock_id=stock_id, start_time=start_time, end_time=end_time + ): # weight_position must be greater than 0 and less than 1 if wp < 0 or wp > 1: raise ValueError( - "weight_position is {}, " "weight_position is not in the range of (0, 1).".format(wp), + "weight_position is {}, " + "weight_position is not in the range of (0, 1).".format(wp), ) tradable_weight += wp if tradable_weight - 1.0 >= 1e-5: - raise ValueError("tradable_weight is {}, can not greater than 1.".format(tradable_weight)) + raise ValueError( + "tradable_weight is {}, can not greater than 1.".format(tradable_weight) + ) amount_dict = {} for stock_id in weight_position: @@ -586,7 +674,9 @@ def generate_amount_position_from_weight_position( ) return amount_dict - def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float | None = None) -> float: + def get_real_deal_amount( + self, current_amount: float, target_amount: float, factor: float | None = None + ) -> float: """ Calculate the real adjust deal amount when considering the trading unit :param current_amount: @@ -634,19 +724,25 @@ def generate_order_for_target_amount_position( # results of the same parameter are different; # so here we sort stock_id, and then randomly shuffle the order of stock_id # because the same random seed is used, the final stock_id order is fixed - sorted_ids = sorted(set(list(current_position.keys()) + list(target_position.keys()))) + sorted_ids = sorted( + set(list(current_position.keys()) + list(target_position.keys())) + ) random.seed(0) random.shuffle(sorted_ids) for stock_id in sorted_ids: # Do not generate order for the non-tradable stocks - if not self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time): + if not self.is_stock_tradable( + stock_id=stock_id, start_time=start_time, end_time=end_time + ): continue target_amount = target_position.get(stock_id, 0) current_amount = current_position.get(stock_id, 0) factor = self.get_factor(stock_id, start_time=start_time, end_time=end_time) - deal_amount = self.get_real_deal_amount(current_amount, target_amount, factor) + deal_amount = self.get_real_deal_amount( + current_amount, target_amount, factor + ) if deal_amount == 0: continue if deal_amount > 0: @@ -695,8 +791,12 @@ def calculate_amount_position_value( value = 0 for stock_id in amount_dict: if not only_tradable or ( - not self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time) - and not self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) + not self.check_stock_suspended( + stock_id=stock_id, start_time=start_time, end_time=end_time + ) + and not self.check_stock_limit( + stock_id=stock_id, start_time=start_time, end_time=end_time + ) ): value += ( self.get_deal_price( @@ -719,9 +819,13 @@ def _get_factor_or_raise_error( """Please refer to the docs of get_amount_of_trade_unit""" if factor is None: if stock_id is not None and start_time is not None and end_time is not None: - factor = self.get_factor(stock_id=stock_id, start_time=start_time, end_time=end_time) + factor = self.get_factor( + stock_id=stock_id, start_time=start_time, end_time=end_time + ) else: - raise ValueError(f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None") + raise ValueError( + f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None" + ) assert factor is not None return factor @@ -780,10 +884,17 @@ def round_amount_by_trade_unit( start_time=start_time, end_time=end_time, ) - return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor + return ( + (deal_amount * factor + 0.1) + // self.trade_unit + * self.trade_unit + / factor + ) return deal_amount - def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> Optional[float]: + def _clip_amount_by_volume( + self, order: Order, dealt_order_amount: dict + ) -> Optional[float]: """parse the capacity limit string and return the actual amount of orders that can be executed. NOTE: this function will change the order.deal_amount **inplace** @@ -795,7 +906,9 @@ def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> Opti dealt_order_amount : dict :param dealt_order_amount: the dealt order amount dict with the format of {stock_id: float} """ - vol_limit = self.buy_vol_limit if order.direction == Order.BUY else self.sell_vol_limit + vol_limit = ( + self.buy_vol_limit if order.direction == Order.BUY else self.sell_vol_limit + ) if vol_limit is None: return order.deal_amount @@ -827,11 +940,15 @@ def _clip_amount_by_volume(self, order: Order, dealt_order_amount: dict) -> Opti orig_deal_amount = order.deal_amount order.deal_amount = max(min(vol_limit_min, orig_deal_amount), 0) if vol_limit_min < orig_deal_amount: - self.logger.debug(f"Order clipped due to volume limitation: {order}, {list(zip(vol_limit_num, vol_limit))}") + self.logger.debug( + f"Order clipped due to volume limitation: {order}, {list(zip(vol_limit_num, vol_limit))}" + ) return None - def _get_buy_amount_by_cash_limit(self, trade_price: float, cash: float, cost_ratio: float) -> float: + def _get_buy_amount_by_cash_limit( + self, trade_price: float, cash: float, cost_ratio: float + ) -> float: """return the real order amount after cash limit for buying. Parameters ---------- @@ -872,9 +989,19 @@ def _calc_trade_info_by_order( """ trade_price = cast( float, - self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction), + self.get_deal_price( + order.stock_id, + order.start_time, + order.end_time, + direction=order.direction, + ), + ) + total_trade_val = ( + cast( + float, self.get_volume(order.stock_id, order.start_time, order.end_time) + ) + * trade_price ) - total_trade_val = cast(float, self.get_volume(order.stock_id, order.start_time, order.end_time)) * trade_price order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time) order.deal_amount = order.amount # set to full amount and clip it step by step # Clipping amount first @@ -899,7 +1026,9 @@ def _calc_trade_info_by_order( if position is not None: # TODO: make the trading shortable current_amount = ( - position.get_stock_amount(order.stock_id) if position.check_stock(order.stock_id) else 0 + position.get_stock_amount(order.stock_id) + if position.check_stock(order.stock_id) + else 0 ) if not np.isclose(order.deal_amount, current_amount): # when not selling last stock. rounding is necessary @@ -925,10 +1054,14 @@ def _calc_trade_info_by_order( if cash < max(trade_val * cost_ratio, self.min_cost): # cash cannot cover cost order.deal_amount = 0 - self.logger.debug(f"Order clipped due to cost higher than cash: {order}") + self.logger.debug( + f"Order clipped due to cost higher than cash: {order}" + ) elif cash < trade_val + max(trade_val * cost_ratio, self.min_cost): # The money is not enough - max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash, cost_ratio) + max_buy_amount = self._get_buy_amount_by_cash_limit( + trade_price, cash, cost_ratio + ) order.deal_amount = self.round_amount_by_trade_unit( min(max_buy_amount, order.deal_amount), order.factor, @@ -936,13 +1069,19 @@ def _calc_trade_info_by_order( self.logger.debug(f"Order clipped due to cash limitation: {order}") else: # The money is enough - order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor) + order.deal_amount = self.round_amount_by_trade_unit( + order.deal_amount, order.factor + ) else: # Unknown amount of money. Just round the amount - order.deal_amount = self.round_amount_by_trade_unit(order.deal_amount, order.factor) + order.deal_amount = self.round_amount_by_trade_unit( + order.deal_amount, order.factor + ) else: - raise NotImplementedError("order direction {} error".format(order.direction)) + raise NotImplementedError( + "order direction {} error".format(order.direction) + ) trade_val = order.deal_amount * trade_price trade_cost = max(trade_val * cost_ratio, self.min_cost) diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index b5d4326a71..673d6cc3cf 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -16,7 +16,12 @@ from ..utils import init_instance_by_config from .decision import BaseTradeDecision, Order from .exchange import Exchange -from .utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager, get_start_end_idx +from .utils import ( + CommonInfrastructure, + LevelInfrastructure, + TradeCalendarManager, + get_start_end_idx, +) class BaseExecutor: @@ -118,13 +123,17 @@ def __init__( self._settle_type = settle_type self.reset(start_time=start_time, end_time=end_time, common_infra=common_infra) if common_infra is None: - get_module_logger("BaseExecutor").warning(f"`common_infra` is not set for {self}") + get_module_logger("BaseExecutor").warning( + f"`common_infra` is not set for {self}" + ) # record deal order amount in one day self.dealt_order_amount: Dict[str, float] = defaultdict(float) self.deal_day = None - def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_account: bool = False) -> None: + def reset_common_infra( + self, common_infra: CommonInfrastructure, copy_trade_account: bool = False + ) -> None: """ reset infrastructure for trading - reset trade_account @@ -146,12 +155,17 @@ def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_acco if copy_trade_account else common_infra.get("trade_account") ) - self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics) + self.trade_account.reset( + freq=self.time_per_step, + port_metr_enabled=self.generate_portfolio_metrics, + ) @property def trade_exchange(self) -> Exchange: """get trade exchange in a prioritized order""" - return getattr(self, "_trade_exchange", None) or self.common_infra.get("trade_exchange") + return getattr(self, "_trade_exchange", None) or self.common_infra.get( + "trade_exchange" + ) @property def trade_calendar(self) -> TradeCalendarManager: @@ -161,7 +175,9 @@ def trade_calendar(self) -> TradeCalendarManager: """ return self.level_infra.get("trade_calendar") - def reset(self, common_infra: CommonInfrastructure | None = None, **kwargs: Any) -> None: + def reset( + self, common_infra: CommonInfrastructure | None = None, **kwargs: Any + ) -> None: """ - reset `start_time` and `end_time`, used in trade calendar - reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc @@ -170,7 +186,9 @@ def reset(self, common_infra: CommonInfrastructure | None = None, **kwargs: Any) if "start_time" in kwargs or "end_time" in kwargs: start_time = kwargs.get("start_time") end_time = kwargs.get("end_time") - self.level_infra.reset_cal(freq=self.time_per_step, start_time=start_time, end_time=end_time) + self.level_infra.reset_cal( + freq=self.time_per_step, start_time=start_time, end_time=end_time + ) if common_infra is not None: self.reset_common_infra(common_infra) @@ -180,7 +198,9 @@ def get_level_infra(self) -> LevelInfrastructure: def finished(self) -> bool: return self.trade_calendar.finished() - def execute(self, trade_decision: BaseTradeDecision, level: int = 0) -> List[object]: + def execute( + self, trade_decision: BaseTradeDecision, level: int = 0 + ) -> List[object]: """execute the trade decision and return the executed result NOTE: this function is never used directly in the framework. Should we delete it? @@ -198,7 +218,9 @@ def execute(self, trade_decision: BaseTradeDecision, level: int = 0) -> List[obj the executed result for trade decision """ return_value: dict = {} - for _decision in self.collect_data(trade_decision, return_value=return_value, level=level): + for _decision in self.collect_data( + trade_decision, return_value=return_value, level=level + ): pass return cast(list, return_value.get("execute_result")) @@ -207,7 +229,9 @@ def _collect_data( self, trade_decision: BaseTradeDecision, level: int = 0, - ) -> Union[Generator[Any, Any, Tuple[List[object], dict]], Tuple[List[object], dict]]: + ) -> Union[ + Generator[Any, Any, Tuple[List[object], dict]], Tuple[List[object], dict] + ]: """ Please refer to the doc of collect_data The only difference between `_collect_data` and `collect_data` is that some common steps are moved into @@ -262,7 +286,9 @@ def collect_data( if self.track_data: yield trade_decision - atomic = not issubclass(self.__class__, NestedExecutor) # issubclass(A, A) is True + atomic = not issubclass( + self.__class__, NestedExecutor + ) # issubclass(A, A) is True if atomic and trade_decision.get_range_limit(default_value=None) is not None: raise ValueError("atomic executor doesn't support specify `range_limit`") @@ -372,7 +398,9 @@ def __init__( **kwargs, ) - def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_account: bool = False) -> None: + def reset_common_infra( + self, common_infra: CommonInfrastructure, copy_trade_account: bool = False + ) -> None: """ reset infrastructure for trading - reset inner_strategy and inner_executor common infra @@ -380,7 +408,9 @@ def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_acco # NOTE: please refer to the docs of BaseExecutor.reset_common_infra for the meaning of `copy_trade_account` # The first level follow the `copy_trade_account` from the upper level - super(NestedExecutor, self).reset_common_infra(common_infra, copy_trade_account=copy_trade_account) + super(NestedExecutor, self).reset_common_infra( + common_infra, copy_trade_account=copy_trade_account + ) # The lower level have to copy the trade_account self.inner_executor.reset_common_infra(common_infra, copy_trade_account=True) @@ -391,16 +421,24 @@ def _init_sub_trading(self, trade_decision: BaseTradeDecision) -> None: self.inner_executor.reset(start_time=trade_start_time, end_time=trade_end_time) sub_level_infra = self.inner_executor.get_level_infra() self.level_infra.set_sub_level_infra(sub_level_infra) - self.inner_strategy.reset(level_infra=sub_level_infra, outer_trade_decision=trade_decision) + self.inner_strategy.reset( + level_infra=sub_level_infra, outer_trade_decision=trade_decision + ) - def _update_trade_decision(self, trade_decision: BaseTradeDecision) -> BaseTradeDecision: + def _update_trade_decision( + self, trade_decision: BaseTradeDecision + ) -> BaseTradeDecision: # outer strategy have chance to update decision each iterator - updated_trade_decision = trade_decision.update(self.inner_executor.trade_calendar) + updated_trade_decision = trade_decision.update( + self.inner_executor.trade_calendar + ) if updated_trade_decision is not None: # TODO: always is None for now? trade_decision = updated_trade_decision # NEW UPDATE # create a hook for inner strategy to update outer decision - trade_decision = self.inner_strategy.alter_outer_trade_decision(trade_decision) + trade_decision = self.inner_strategy.alter_outer_trade_decision( + trade_decision + ) return trade_decision def _collect_data( @@ -430,7 +468,10 @@ def _collect_data( # NOTE: make sure get_start_end_idx is after `self._update_trade_decision` start_idx, end_idx = get_start_end_idx(sub_cal, trade_decision) - if not self._align_range_limit or start_idx <= sub_cal.get_trade_step() <= end_idx: + if ( + not self._align_range_limit + or start_idx <= sub_cal.get_trade_step() <= end_idx + ): # if force align the range limit, skip the steps outside the decision range limit res = self.inner_strategy.generate_trade_decision(_inner_execute_result) @@ -456,7 +497,9 @@ def _collect_data( _inner_trade_decision: BaseTradeDecision = res - trade_decision.mod_inner_decision(_inner_trade_decision) # propagate part of decision information + trade_decision.mod_inner_decision( + _inner_trade_decision + ) # propagate part of decision information # NOTE sub_cal.get_step_time() must be called before collect_data in case of step shifting decision_list.append((_inner_trade_decision, *sub_cal.get_step_time())) @@ -471,7 +514,9 @@ def _collect_data( execute_result.extend(_inner_execute_result) inner_order_indicators.append( - self.inner_executor.trade_account.get_trade_indicator().get_order_indicator(raw=True), + self.inner_executor.trade_account.get_trade_indicator().get_order_indicator( + raw=True + ), ) else: # do nothing and just step forward @@ -480,7 +525,10 @@ def _collect_data( # Let inner strategy know that the outer level execution is done. self.inner_strategy.post_upper_level_exe_step() - return execute_result, {"inner_order_indicators": inner_order_indicators, "decision_list": decision_list} + return execute_result, { + "inner_order_indicators": inner_order_indicators, + "decision_list": decision_list, + } def post_inner_exe_step(self, inner_exe_res: List[object]) -> None: """ @@ -587,7 +635,9 @@ def _get_order_iterator(self, trade_decision: BaseTradeDecision) -> List[Order]: raise NotImplementedError(f"This type of input is not supported") return order_it - def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]: + def _collect_data( + self, trade_decision: BaseTradeDecision, level: int = 0 + ) -> Tuple[List[object], dict]: trade_start_time, _ = self.trade_calendar.get_step_time() execute_result: list = [] diff --git a/qlib/backtest/high_performance_ds.py b/qlib/backtest/high_performance_ds.py index f149f13dd5..dc90b341b1 100644 --- a/qlib/backtest/high_performance_ds.py +++ b/qlib/backtest/high_performance_ds.py @@ -104,7 +104,9 @@ class PandasQuote(BaseQuote): def __init__(self, quote_df: pd.DataFrame, freq: str) -> None: super().__init__(quote_df=quote_df, freq=freq) quote_dict = {} - for stock_id, stock_val in quote_df.groupby(level="instrument", group_keys=False): + for stock_id, stock_val in quote_df.groupby( + level="instrument", group_keys=False + ): quote_dict[stock_id] = stock_val.droplevel(level="instrument") self.data = quote_dict @@ -114,7 +116,9 @@ def get_all_stock(self): def get_data(self, stock_id, start_time, end_time, field, method=None): if method == "ts_data_last": method = ts_data_last - stock_data = resam_ts_data(self.data[stock_id][field], start_time, end_time, method=method) + stock_data = resam_ts_data( + self.data[stock_id][field], start_time, end_time, method=method + ) if stock_data is None: return None elif isinstance(stock_data, (bool, np.bool_, int, float, np.number)): @@ -122,7 +126,9 @@ def get_data(self, stock_id, start_time, end_time, field, method=None): elif isinstance(stock_data, pd.Series): return idd.SingleData(stock_data) else: - raise ValueError(f"stock data from resam_ts_data must be a number, pd.Series or pd.DataFrame") + raise ValueError( + f"stock data from resam_ts_data must be a number, pd.Series or pd.DataFrame" + ) class NumpyQuote(BaseQuote): @@ -137,9 +143,15 @@ def __init__(self, quote_df: pd.DataFrame, freq: str, region: str = "cn") -> Non """ super().__init__(quote_df=quote_df, freq=freq) quote_dict = {} - for stock_id, stock_val in quote_df.groupby(level="instrument", group_keys=False): - quote_dict[stock_id] = idd.MultiData(stock_val.droplevel(level="instrument")) - quote_dict[stock_id].sort_index() # To support more flexible slicing, we must sort data first + for stock_id, stock_val in quote_df.groupby( + level="instrument", group_keys=False + ): + quote_dict[stock_id] = idd.MultiData( + stock_val.droplevel(level="instrument") + ) + quote_dict[ + stock_id + ].sort_index() # To support more flexible slicing, we must sort data first self.data = quote_dict n, unit = Freq.parse(freq) @@ -242,7 +254,9 @@ def __rsub__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetr def __mul__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric: raise NotImplementedError(f"Please implement the `__mul__` method") - def __truediv__(self, other: Union[BaseSingleMetric, int, float]) -> BaseSingleMetric: + def __truediv__( + self, other: Union[BaseSingleMetric, int, float] + ) -> BaseSingleMetric: raise NotImplementedError(f"Please implement the `__truediv__` method") def __eq__(self, other: object) -> BaseSingleMetric: @@ -277,7 +291,9 @@ def empty(self) -> bool: raise NotImplementedError(f"Please implement the `empty` method") - def add(self, other: BaseSingleMetric, fill_value: float = None) -> BaseSingleMetric: + def add( + self, other: BaseSingleMetric, fill_value: float = None + ) -> BaseSingleMetric: """Replace np.nan with fill_value in two metrics and add them.""" raise NotImplementedError(f"Please implement the `add` method") @@ -331,7 +347,9 @@ def assign(self, col: str, metric: Union[dict, pd.Series]) -> None: raise NotImplementedError(f"Please implement the 'assign' method") - def transfer(self, func: Callable, new_col: str = None) -> Optional[BaseSingleMetric]: + def transfer( + self, func: Callable, new_col: str = None + ) -> Optional[BaseSingleMetric]: """compute new metric with existing metrics. Parameters @@ -536,7 +554,9 @@ def empty(self): def index(self): return list(self.metric.index) - def add(self, other: BaseSingleMetric, fill_value: float = None) -> PandasSingleMetric: + def add( + self, other: BaseSingleMetric, fill_value: float = None + ) -> PandasSingleMetric: other = cast(PandasSingleMetric, other) return self.__class__(self.metric.add(other.metric, fill_value=fill_value)) diff --git a/qlib/backtest/position.py b/qlib/backtest/position.py index e6f46279f3..b93e6b1457 100644 --- a/qlib/backtest/position.py +++ b/qlib/backtest/position.py @@ -23,7 +23,9 @@ def __init__(self, *args: Any, cash: float = 0.0, **kwargs: Any) -> None: self._settle_type = self.ST_NO self.position: dict = {} - def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None: + def fill_stock_value( + self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30 + ) -> None: pass def skip_update(self) -> bool: @@ -54,7 +56,9 @@ def check_stock(self, stock_id: str) -> bool: """ raise NotImplementedError(f"Please implement the `check_stock` method") - def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None: + def update_order( + self, order: Order, trade_val: float, cost: float, trade_price: float + ) -> None: """ Parameters ---------- @@ -92,7 +96,9 @@ def calculate_stock_value(self) -> float: float: the value(money) of all the stock """ - raise NotImplementedError(f"Please implement the `calculate_stock_value` method") + raise NotImplementedError( + f"Please implement the `calculate_stock_value` method" + ) def calculate_value(self) -> float: raise NotImplementedError(f"Please implement the `calculate_value` method") @@ -154,7 +160,9 @@ def get_stock_amount_dict(self) -> dict: Dict: {stock_id : amount of stock} """ - raise NotImplementedError(f"Please implement the `get_stock_amount_dict` method") + raise NotImplementedError( + f"Please implement the `get_stock_amount_dict` method" + ) def get_stock_weight_dict(self, only_stock: bool = False) -> dict: """ @@ -173,7 +181,9 @@ def get_stock_weight_dict(self, only_stock: bool = False) -> dict: Dict: {stock_id : value weight of stock in the position} """ - raise NotImplementedError(f"Please implement the `get_stock_weight_dict` method") + raise NotImplementedError( + f"Please implement the `get_stock_weight_dict` method" + ) def add_count_all(self, bar: str) -> None: """ @@ -242,7 +252,11 @@ class Position(BasePosition): } """ - def __init__(self, cash: float = 0, position_dict: Dict[str, Union[Dict[str, float], float]] = {}) -> None: + def __init__( + self, + cash: float = 0, + position_dict: Dict[str, Union[Dict[str, float], float]] = {}, + ) -> None: """Init position by cash and position_dict. Parameters @@ -277,7 +291,9 @@ def __init__(self, cash: float = 0, position_dict: Dict[str, Union[Dict[str, flo except KeyError: pass - def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None: + def fill_stock_value( + self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30 + ) -> None: """fill the stock value by the close price of latest last_days from qlib. Parameters @@ -311,17 +327,25 @@ def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last freq=freq, disk_cache=True, ).dropna() - price_dict = price_df.groupby(["instrument"], group_keys=False).tail(1)["$close"].to_dict() + price_dict = ( + price_df.groupby(["instrument"], group_keys=False) + .tail(1)["$close"] + .to_dict() + ) if len(price_dict) < len(stock_list): lack_stock = set(stock_list) - set(price_dict) - raise ValueError(f"{lack_stock} doesn't have close price in qlib in the latest {last_days} days") + raise ValueError( + f"{lack_stock} doesn't have close price in qlib in the latest {last_days} days" + ) for stock in stock_list: self.position[stock]["price"] = price_dict[stock] self.position["now_account_value"] = self.calculate_value() - def _init_stock(self, stock_id: str, amount: float, price: float | None = None) -> None: + def _init_stock( + self, stock_id: str, amount: float, price: float | None = None + ) -> None: """ initialization the stock in current position @@ -337,9 +361,13 @@ def _init_stock(self, stock_id: str, amount: float, price: float | None = None) self.position[stock_id] = {} self.position[stock_id]["amount"] = amount self.position[stock_id]["price"] = price - self.position[stock_id]["weight"] = 0 # update the weight in the end of the trade date + self.position[stock_id][ + "weight" + ] = 0 # update the weight in the end of the trade date - def _buy_stock(self, stock_id: str, trade_val: float, cost: float, trade_price: float) -> None: + def _buy_stock( + self, stock_id: str, trade_val: float, cost: float, trade_price: float + ) -> None: trade_amount = trade_val / trade_price if stock_id not in self.position: self._init_stock(stock_id=stock_id, amount=trade_amount, price=trade_price) @@ -349,7 +377,9 @@ def _buy_stock(self, stock_id: str, trade_val: float, cost: float, trade_price: self.position["cash"] -= trade_val + cost - def _sell_stock(self, stock_id: str, trade_val: float, cost: float, trade_price: float) -> None: + def _sell_stock( + self, stock_id: str, trade_val: float, cost: float, trade_price: float + ) -> None: trade_amount = trade_val / trade_price if stock_id not in self.position: raise KeyError("{} not in current position".format(stock_id)) @@ -387,7 +417,9 @@ def _del_stock(self, stock_id: str) -> None: def check_stock(self, stock_id: str) -> bool: return stock_id in self.position - def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None: + def update_order( + self, order: Order, trade_val: float, cost: float, trade_price: float + ) -> None: # handle order, order is a order class, defined in exchange.py if order.direction == Order.BUY: # BUY @@ -396,12 +428,16 @@ def update_order(self, order: Order, trade_val: float, cost: float, trade_price: # SELL self._sell_stock(order.stock_id, trade_val, cost, trade_price) else: - raise NotImplementedError("do not support order direction {}".format(order.direction)) + raise NotImplementedError( + "do not support order direction {}".format(order.direction) + ) def update_stock_price(self, stock_id: str, price: float) -> None: self.position[stock_id]["price"] = price - def update_stock_count(self, stock_id: str, bar: str, count: float) -> None: # TODO: check type of `bar` + def update_stock_count( + self, stock_id: str, bar: str, count: float + ) -> None: # TODO: check type of `bar` self.position[stock_id][f"count_{bar}"] = count def update_stock_weight(self, stock_id: str, weight: float) -> None: @@ -411,7 +447,9 @@ def calculate_stock_value(self) -> float: stock_list = self.get_stock_list() value = 0 for stock_id in stock_list: - value += self.position[stock_id]["amount"] * self.position[stock_id]["price"] + value += ( + self.position[stock_id]["amount"] * self.position[stock_id]["price"] + ) return value def calculate_value(self) -> float: @@ -420,7 +458,9 @@ def calculate_value(self) -> float: return value def get_stock_list(self) -> List[str]: - stock_list = list(set(self.position.keys()) - {"cash", "now_account_value", "cash_delay"}) + stock_list = list( + set(self.position.keys()) - {"cash", "now_account_value", "cash_delay"} + ) return stock_list def get_stock_price(self, code: str) -> float: @@ -468,7 +508,11 @@ def get_stock_weight_dict(self, only_stock: bool = False) -> dict: d = {} stock_list = self.get_stock_list() for stock_code in stock_list: - d[stock_code] = self.position[stock_code]["amount"] * self.position[stock_code]["price"] / position_value + d[stock_code] = ( + self.position[stock_code]["amount"] + * self.position[stock_code]["price"] + / position_value + ) return d def add_count_all(self, bar: str) -> None: @@ -485,7 +529,9 @@ def update_weight_all(self) -> None: self.update_stock_weight(stock_code, weight) def settle_start(self, settle_type: str) -> None: - assert self._settle_type == self.ST_NO, "Currently, settlement can't be nested!!!!!" + assert ( + self._settle_type == self.ST_NO + ), "Currently, settlement can't be nested!!!!!" self._settle_type = settle_type if settle_type == self.ST_CASH: self.position["cash_delay"] = 0.0 @@ -515,7 +561,9 @@ def check_stock(self, stock_id: str) -> bool: # InfPosition always have any stocks return True - def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None: + def update_order( + self, order: Order, trade_val: float, cost: float, trade_price: float + ) -> None: pass def update_stock_price(self, stock_id: str, price: float) -> None: diff --git a/qlib/backtest/profit_attribution.py b/qlib/backtest/profit_attribution.py index 05ca867065..ac6635bd56 100644 --- a/qlib/backtest/profit_attribution.py +++ b/qlib/backtest/profit_attribution.py @@ -39,7 +39,12 @@ def get_benchmark_weight( """ if not path: - path = Path(C.dpm.get_data_uri(freq)).expanduser() / "raw" / "AIndexMembers" / "weights.csv" + path = ( + Path(C.dpm.get_data_uri(freq)).expanduser() + / "raw" + / "AIndexMembers" + / "weights.csv" + ) # TODO: the storage of weights should be implemented in a more elegent way # TODO: The benchmark is not consistent with the filename in instruments. bench_weight_df = pd.read_csv(path, usecols=["code", "date", "index", "weight"]) @@ -49,7 +54,10 @@ def get_benchmark_weight( bench_weight_df = bench_weight_df[bench_weight_df.date >= start_date] if end_date is not None: bench_weight_df = bench_weight_df[bench_weight_df.date <= end_date] - bench_stock_weight = bench_weight_df.pivot_table(index="date", columns="code", values="weight") / 100.0 + bench_stock_weight = ( + bench_weight_df.pivot_table(index="date", columns="code", values="weight") + / 100.0 + ) return bench_stock_weight @@ -103,7 +111,9 @@ def decompose_portofolio_weight(stock_weight_df, stock_group_df): for group_key in all_group: group_mask = stock_group_df == group_key group_weight[group_key] = stock_weight_df[group_mask].sum(axis=1) - stock_weight_in_group[group_key] = stock_weight_df[group_mask].divide(group_weight[group_key], axis=0) + stock_weight_in_group[group_key] = stock_weight_df[group_mask].divide( + group_weight[group_key], axis=0 + ) return group_weight, stock_weight_in_group @@ -155,7 +165,9 @@ def decompose_portofolio(stock_weight_df, stock_group_df, stock_ret_df): all_group = np.unique(stock_group_df.values.flatten()) all_group = all_group[~np.isnan(all_group)] - group_weight, stock_weight_in_group = decompose_portofolio_weight(stock_weight_df, stock_group_df) + group_weight, stock_weight_in_group = decompose_portofolio_weight( + stock_weight_df, stock_group_df + ) group_ret = {} for group_key, val in stock_weight_in_group.items(): @@ -194,15 +206,21 @@ def get_daily_bin_group(bench_values, stock_values, group_n): stock_group = stock_values.copy() # get the bin split points based on the daily proportion of benchmark - split_points = np.percentile(bench_values[~bench_values.isna()], np.linspace(0, 100, group_n + 1)) + split_points = np.percentile( + bench_values[~bench_values.isna()], np.linspace(0, 100, group_n + 1) + ) # Modify the biggest uppper bound and smallest lowerbound split_points[0], split_points[-1] = -np.inf, np.inf for i, (lb, up) in enumerate(zip(split_points, split_points[1:])): - stock_group.loc[stock_values[(stock_values >= lb) & (stock_values < up)].index] = group_n - i + stock_group.loc[ + stock_values[(stock_values >= lb) & (stock_values < up)].index + ] = (group_n - i) return stock_group -def get_stock_group(stock_group_field_df, bench_stock_weight_df, group_method, group_n=None): +def get_stock_group( + stock_group_field_df, bench_stock_weight_df, group_method, group_n=None +): if group_method == "category": # use the value of the benchmark as the category return stock_group_field_df @@ -284,7 +302,9 @@ def brinson_pa( stock_group_field = stock_group_field.ffill() stock_group_field = stock_group_field.loc[start_date:end_date] - stock_group = get_stock_group(stock_group_field, bench_stock_weight, group_method, group_n) + stock_group = get_stock_group( + stock_group_field, bench_stock_weight, group_method, group_n + ) deal_price_df = stock_df["deal_price"].unstack().T deal_price_df = deal_price_df.ffill() @@ -298,8 +318,12 @@ def brinson_pa( port_stock_weight_df = get_stock_weight_df(positions) # decomposing the portofolio - port_group_weight_df, port_group_ret_df = decompose_portofolio(port_stock_weight_df, stock_group, stock_ret) - bench_group_weight_df, bench_group_ret_df = decompose_portofolio(bench_stock_weight, stock_group, stock_ret) + port_group_weight_df, port_group_ret_df = decompose_portofolio( + port_stock_weight_df, stock_group, stock_ret + ) + bench_group_weight_df, bench_group_ret_df = decompose_portofolio( + bench_stock_weight, stock_group, stock_ret + ) # if the group return of the portofolio is NaN, replace it with the market # value diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index f1016e24e2..f1ca42b078 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -16,7 +16,11 @@ from ..tests.config import CSI300_BENCH from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data -from .high_performance_ds import BaseOrderIndicator, BaseSingleMetric, NumpyOrderIndicator +from .high_performance_ds import ( + BaseOrderIndicator, + BaseSingleMetric, + NumpyOrderIndicator, +) class PortfolioMetrics: @@ -76,7 +80,9 @@ def __init__(self, freq: str = "day", benchmark_config: dict = {}) -> None: self.init_bench(freq=freq, benchmark_config=benchmark_config) def init_vars(self) -> None: - self.accounts: dict = OrderedDict() # account position value for each trade time + self.accounts: dict = ( + OrderedDict() + ) # account position value for each trade time self.returns: dict = OrderedDict() # daily return rate for each trade time self.total_turnovers: dict = OrderedDict() # total turnover for each trade time self.turnovers: dict = OrderedDict() # turnover for each trade time @@ -87,14 +93,18 @@ def init_vars(self) -> None: self.benches: dict = OrderedDict() self.latest_pm_time: Optional[pd.TimeStamp] = None - def init_bench(self, freq: str | None = None, benchmark_config: dict | None = None) -> None: + def init_bench( + self, freq: str | None = None, benchmark_config: dict | None = None + ) -> None: if freq is not None: self.freq = freq self.benchmark_config = benchmark_config self.bench = self._cal_benchmark(self.benchmark_config, self.freq) @staticmethod - def _cal_benchmark(benchmark_config: Optional[dict], freq: str) -> Optional[pd.Series]: + def _cal_benchmark( + benchmark_config: Optional[dict], freq: str + ) -> Optional[pd.Series]: if benchmark_config is None: return None benchmark = benchmark_config.get("benchmark", CSI300_BENCH) @@ -111,11 +121,17 @@ def _cal_benchmark(benchmark_config: Optional[dict], freq: str) -> Optional[pd.S raise ValueError("benchmark freq can't be None!") _codes = benchmark if isinstance(benchmark, (list, dict)) else [benchmark] fields = ["$close/Ref($close,1)-1"] - _temp_result, _ = get_higher_eq_freq_feature(_codes, fields, start_time, end_time, freq=freq) + _temp_result, _ = get_higher_eq_freq_feature( + _codes, fields, start_time, end_time, freq=freq + ) if len(_temp_result) == 0: - raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark") + raise ValueError( + f"The benchmark {_codes} does not exist. Please provide the right benchmark" + ) return ( - _temp_result.groupby(level="datetime", group_keys=False)[_temp_result.columns.tolist()[0]] + _temp_result.groupby(level="datetime", group_keys=False)[ + _temp_result.columns.tolist()[0] + ] .mean() .fillna(0) ) @@ -182,9 +198,13 @@ def update_portfolio_metrics_record( ) if trade_end_time is None and bench_value is None: - raise ValueError("Both trade_end_time and bench_value is None, benchmark is not usable.") + raise ValueError( + "Both trade_end_time and bench_value is None, benchmark is not usable." + ) elif bench_value is None: - bench_value = self._sample_benchmark(self.bench, trade_start_time, trade_end_time) + bench_value = self._sample_benchmark( + self.bench, trade_start_time, trade_end_time + ) # update pm data self.accounts[trade_start_time] = account_value @@ -275,7 +295,9 @@ class Indicator: """ - def __init__(self, order_indicator_cls: Type[BaseOrderIndicator] = NumpyOrderIndicator) -> None: + def __init__( + self, order_indicator_cls: Type[BaseOrderIndicator] = NumpyOrderIndicator + ) -> None: self.order_indicator_cls = order_indicator_cls # order indicator is metrics for a single order for a specific step @@ -298,7 +320,9 @@ def record(self, trade_start_time: Union[str, pd.Timestamp]) -> None: self.order_indicator_his[trade_start_time] = self.get_order_indicator() self.trade_indicator_his[trade_start_time] = self.get_trade_indicator() - def _update_order_trade_info(self, trade_info: List[Tuple[Order, float, float, float]]) -> None: + def _update_order_trade_info( + self, trade_info: List[Tuple[Order, float, float, float]] + ) -> None: amount = dict() deal_amount = dict() trade_price = dict() @@ -336,11 +360,15 @@ def func(deal_amount, amount): self.order_indicator.transfer(func, "ffr") - def update_order_indicators(self, trade_info: List[Tuple[Order, float, float, float]]) -> None: + def update_order_indicators( + self, trade_info: List[Tuple[Order, float, float, float]] + ) -> None: self._update_order_trade_info(trade_info=trade_info) self._update_order_fulfill_rate() - def _agg_order_trade_info(self, inner_order_indicators: List[BaseOrderIndicator]) -> None: + def _agg_order_trade_info( + self, inner_order_indicators: List[BaseOrderIndicator] + ) -> None: # calculate total trade amount with each inner order indicator. def trade_amount_func(deal_amount, trade_price): return deal_amount * trade_price @@ -349,7 +377,14 @@ def trade_amount_func(deal_amount, trade_price): indicator.transfer(trade_amount_func, "trade_price") # sum inner order indicators with same metric. - all_metric = ["inner_amount", "deal_amount", "trade_price", "trade_value", "trade_cost", "trade_dir"] + all_metric = [ + "inner_amount", + "deal_amount", + "trade_price", + "trade_value", + "trade_cost", + "trade_dir", + ] self.order_indicator_cls.sum_all_indicators( self.order_indicator, inner_order_indicators, @@ -375,7 +410,9 @@ def _update_trade_amount(self, outer_trade_decision: BaseTradeDecision) -> None: if len(decision) == 0: self.order_indicator.assign("amount", {}) else: - self.order_indicator.assign("amount", {order.stock_id: order.amount_delta for order in decision}) + self.order_indicator.assign( + "amount", {order.stock_id: order.amount_delta for order in decision} + ) def _get_base_vol_pri( self, @@ -437,7 +474,9 @@ def _get_base_vol_pri( assert isinstance(price_s, idd.SingleData) if agg == "vwap": - volume_s = trade_exchange.get_volume(inst, trade_start_time, trade_end_time, method=None) + volume_s = trade_exchange.get_volume( + inst, trade_start_time, trade_end_time, method=None + ) if isinstance(volume_s, (int, float, np.number)): volume_s = idd.SingleData(volume_s, [trade_start_time]) assert isinstance(volume_s, idd.SingleData) @@ -491,7 +530,9 @@ def _agg_base_price( bv_s = oi.get_index_data("base_volume").reindex(trade_dir.index) bp_new, bv_new = {}, {} - for pr, v, (inst, direction) in zip(bp_s.data, bv_s.data, zip(trade_dir.index, trade_dir.data)): + for pr, v, (inst, direction) in zip( + bp_s.data, bv_s.data, zip(trade_dir.index, trade_dir.data) + ): if np.isnan(pr): bp_tmp, bv_tmp = self._get_base_vol_pri( inst, @@ -518,7 +559,9 @@ def _agg_base_price( self.order_indicator.assign("base_volume", base_volume.to_dict()) self.order_indicator.assign( "base_price", - ((bp_all_multi_data * bv_all_multi_data).sum(axis=1) / base_volume).to_dict(), + ( + (bp_all_multi_data * bv_all_multi_data).sum(axis=1) / base_volume + ).to_dict(), ) def _agg_order_price_advantage(self) -> None: @@ -548,35 +591,45 @@ def agg_order_indicators( self._update_trade_amount(outer_trade_decision) self._update_order_fulfill_rate() pa_config = indicator_config.get("pa_config", {}) - self._agg_base_price(inner_order_indicators, decision_list, trade_exchange, pa_config=pa_config) # TODO + self._agg_base_price( + inner_order_indicators, decision_list, trade_exchange, pa_config=pa_config + ) # TODO self._agg_order_price_advantage() - def _cal_trade_fulfill_rate(self, method: str = "mean") -> Optional[BaseSingleMetric]: + def _cal_trade_fulfill_rate( + self, method: str = "mean" + ) -> Optional[BaseSingleMetric]: if method == "mean": return self.order_indicator.transfer( lambda ffr: ffr.mean(), ) elif method == "amount_weighted": return self.order_indicator.transfer( - lambda ffr, deal_amount: (ffr * deal_amount.abs()).sum() / (deal_amount.abs().sum()), + lambda ffr, deal_amount: (ffr * deal_amount.abs()).sum() + / (deal_amount.abs().sum()), ) elif method == "value_weighted": return self.order_indicator.transfer( - lambda ffr, trade_value: (ffr * trade_value.abs()).sum() / (trade_value.abs().sum()), + lambda ffr, trade_value: (ffr * trade_value.abs()).sum() + / (trade_value.abs().sum()), ) else: raise ValueError(f"method {method} is not supported!") - def _cal_trade_price_advantage(self, method: str = "mean") -> Optional[BaseSingleMetric]: + def _cal_trade_price_advantage( + self, method: str = "mean" + ) -> Optional[BaseSingleMetric]: if method == "mean": return self.order_indicator.transfer(lambda pa: pa.mean()) elif method == "amount_weighted": return self.order_indicator.transfer( - lambda pa, deal_amount: (pa * deal_amount.abs()).sum() / (deal_amount.abs().sum()), + lambda pa, deal_amount: (pa * deal_amount.abs()).sum() + / (deal_amount.abs().sum()), ) elif method == "value_weighted": return self.order_indicator.transfer( - lambda pa, trade_value: (pa * trade_value.abs()).sum() / (trade_value.abs().sum()), + lambda pa, trade_value: (pa * trade_value.abs()).sum() + / (trade_value.abs().sum()), ) else: raise ValueError(f"method {method} is not supported!") @@ -614,8 +667,12 @@ def cal_trade_indicators( show_indicator = indicator_config.get("show_indicator", False) ffr_config = indicator_config.get("ffr_config", {}) pa_config = indicator_config.get("pa_config", {}) - fulfill_rate = self._cal_trade_fulfill_rate(method=ffr_config.get("weight_method", "mean")) - price_advantage = self._cal_trade_price_advantage(method=pa_config.get("weight_method", "mean")) + fulfill_rate = self._cal_trade_fulfill_rate( + method=ffr_config.get("weight_method", "mean") + ) + price_advantage = self._cal_trade_price_advantage( + method=pa_config.get("weight_method", "mean") + ) positive_rate = self._cal_trade_positive_rate() deal_amount = self._cal_deal_amount() trade_value = self._cal_trade_value() @@ -641,7 +698,9 @@ def cal_trade_indicators( ), ) - def get_order_indicator(self, raw: bool = True) -> Union[BaseOrderIndicator, Dict[Text, pd.Series]]: + def get_order_indicator( + self, raw: bool = True + ) -> Union[BaseOrderIndicator, Dict[Text, pd.Series]]: return self.order_indicator if raw else self.order_indicator.to_series() def get_trade_indicator(self) -> Dict[str, Optional[BaseSingleMetric]]: diff --git a/qlib/backtest/signal.py b/qlib/backtest/signal.py index cedc9bb175..ce20024398 100644 --- a/qlib/backtest/signal.py +++ b/qlib/backtest/signal.py @@ -22,7 +22,9 @@ class Signal(metaclass=abc.ABCMeta): """ @abc.abstractmethod - def get_signal(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Union[pd.Series, pd.DataFrame, None]: + def get_signal( + self, start_time: pd.Timestamp, end_time: pd.Timestamp + ) -> Union[pd.Series, pd.DataFrame, None]: """ get the signal at the end of the decision step(from `start_time` to `end_time`) @@ -57,11 +59,15 @@ def __init__(self, signal: Union[pd.Series, pd.DataFrame]) -> None: """ self.signal_cache = convert_index_format(signal, level="datetime") - def get_signal(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Union[pd.Series, pd.DataFrame]: + def get_signal( + self, start_time: pd.Timestamp, end_time: pd.Timestamp + ) -> Union[pd.Series, pd.DataFrame]: # the frequency of the signal may not align with the decision frequency of strategy # so resampling from the data is necessary # the latest signal leverage more recent data and therefore is used in trading. - signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last") + signal = resam_ts_data( + self.signal_cache, start_time=start_time, end_time=end_time, method="last" + ) return signal @@ -86,7 +92,9 @@ def _update_model(self) -> None: def create_signal_from( - obj: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame], + obj: Union[ + Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame + ], ) -> Signal: """ create signal from diverse information diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py index 4210c9548a..841ca9b742 100644 --- a/qlib/backtest/utils.py +++ b/qlib/backtest/utils.py @@ -69,7 +69,9 @@ def reset( _calendar = Cal.calendar(freq=freq, future=True) assert isinstance(_calendar, np.ndarray) self._calendar = _calendar - _, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, future=True) + _, _, _start_index, _end_index = Cal.locate_index( + start_time, end_time, freq=freq, future=True + ) self.start_index = _start_index self.end_index = _end_index self.trade_len = _end_index - _start_index + 1 @@ -86,7 +88,9 @@ def finished(self) -> bool: def step(self) -> None: if self.finished(): - raise RuntimeError(f"The calendar is finished, please reset it if you want to call it!") + raise RuntimeError( + f"The calendar is finished, please reset it if you want to call it!" + ) self.trade_step += 1 def get_freq(self) -> str: @@ -99,7 +103,9 @@ def get_trade_len(self) -> int: def get_trade_step(self) -> int: return self.trade_step - def get_step_time(self, trade_step: int | None = None, shift: int = 0) -> Tuple[pd.Timestamp, pd.Timestamp]: + def get_step_time( + self, trade_step: int | None = None, shift: int = 0 + ) -> Tuple[pd.Timestamp, pd.Timestamp]: """ Get the left and right endpoints of the trade_step'th trading interval @@ -128,7 +134,9 @@ def get_step_time(self, trade_step: int | None = None, shift: int = 0) -> Tuple[ if trade_step is None: trade_step = self.get_trade_step() calendar_index = self.start_index + trade_step - shift - return self._calendar[calendar_index], epsilon_change(self._calendar[calendar_index + 1]) + return self._calendar[calendar_index], epsilon_change( + self._calendar[calendar_index + 1] + ) def get_data_cal_range(self, rtype: str = "full") -> Tuple[int, int]: """ @@ -156,9 +164,13 @@ def get_data_cal_range(self, rtype: str = "full") -> Tuple[int, int]: _, _, day_start_idx, _ = Cal.locate_index(day_start, day_end, freq=freq) if rtype == "full": - _, _, start_idx, end_index = Cal.locate_index(self.start_time, self.end_time, freq=freq) + _, _, start_idx, end_index = Cal.locate_index( + self.start_time, self.end_time, freq=freq + ) elif rtype == "step": - _, _, start_idx, end_index = Cal.locate_index(*self.get_step_time(), freq=freq) + _, _, start_idx, end_index = Cal.locate_index( + *self.get_step_time(), freq=freq + ) else: raise ValueError(f"This type of input {rtype} is not supported") @@ -169,7 +181,9 @@ def get_all_time(self) -> Tuple[pd.Timestamp, pd.Timestamp]: return self.start_time, self.end_time # helper functions - def get_range_idx(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[int, int]: + def get_range_idx( + self, start_time: pd.Timestamp, end_time: pd.Timestamp + ) -> Tuple[int, int]: """ get the range index which involve start_time~end_time (both sides are closed) @@ -228,7 +242,11 @@ def has(self, infra_name: str) -> bool: def update(self, other: BaseInfrastructure) -> None: support_infra = other.get_support_infra() - infra_dict = {_infra: getattr(other, _infra) for _infra in support_infra if hasattr(other, _infra)} + infra_dict = { + _infra: getattr(other, _infra) + for _infra in support_infra + if hasattr(other, _infra) + } self.reset_infra(**infra_dict) @@ -257,10 +275,14 @@ def reset_cal( ) -> None: """reset trade calendar manager""" if self.has("trade_calendar"): - self.get("trade_calendar").reset(freq, start_time=start_time, end_time=end_time) + self.get("trade_calendar").reset( + freq, start_time=start_time, end_time=end_time + ) else: self.reset_infra( - trade_calendar=TradeCalendarManager(freq, start_time=start_time, end_time=end_time, level_infra=self), + trade_calendar=TradeCalendarManager( + freq, start_time=start_time, end_time=end_time, level_infra=self + ), ) def set_sub_level_infra(self, sub_level_infra: LevelInfrastructure) -> None: @@ -268,7 +290,9 @@ def set_sub_level_infra(self, sub_level_infra: LevelInfrastructure) -> None: self.reset_infra(sub_level_infra=sub_level_infra) -def get_start_end_idx(trade_calendar: TradeCalendarManager, outer_trade_decision: BaseTradeDecision) -> Tuple[int, int]: +def get_start_end_idx( + trade_calendar: TradeCalendarManager, outer_trade_decision: BaseTradeDecision +) -> Tuple[int, int]: """ A helper function for getting the decision-level index range limitation for inner strategy - NOTE: this function is not applicable to order-level diff --git a/qlib/cli/run.py b/qlib/cli/run.py index eb6ed971be..f12b2b2906 100644 --- a/qlib/cli/run.py +++ b/qlib/cli/run.py @@ -120,11 +120,15 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"): f"Can't find BASE_CONFIG_PATH base on: {Path.cwd()}, " f"try using relative path to config path: {Path(config_path).absolute()}" ) - relative_path = Path(config_path).absolute().parent.joinpath(base_config_path) + relative_path = ( + Path(config_path).absolute().parent.joinpath(base_config_path) + ) if relative_path.exists(): path = relative_path else: - raise FileNotFoundError(f"Can't find the BASE_CONFIG file: {base_config_path}") + raise FileNotFoundError( + f"Can't find the BASE_CONFIG file: {base_config_path}" + ) with open(path) as fp: yaml = YAML(typ="safe", pure=True) @@ -139,7 +143,9 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"): qlib.init(**config.get("qlib_init")) else: exp_manager = C["exp_manager"] - exp_manager["kwargs"]["uri"] = "file:" + str(Path(os.getcwd()).resolve() / uri_folder) + exp_manager["kwargs"]["uri"] = "file:" + str( + Path(os.getcwd()).resolve() / uri_folder + ) qlib.init(**config.get("qlib_init"), exp_manager=exp_manager) if "experiment_name" in config: diff --git a/qlib/config.py b/qlib/config.py index 4e5d62564f..13e54f238c 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -62,7 +62,9 @@ class QSettings(BaseSettings): class Config: def __init__(self, default_conf): - self.__dict__["_default_config"] = copy.deepcopy(default_conf) # avoiding conflicts with __getattr__ + self.__dict__["_default_config"] = copy.deepcopy( + default_conf + ) # avoiding conflicts with __getattr__ self.reset() def __getitem__(self, key): @@ -209,7 +211,13 @@ def register_from_C(config, skip_register=True): # However, due to bug in pytest, it requires log message to propagate to root logger to be captured by `caplog` [2]. # [1] https://github.com/microsoft/qlib/pull/1661 # [2] https://github.com/pytest-dev/pytest/issues/3697 - "loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"], "propagate": False}}, + "loggers": { + "qlib": { + "level": logging.DEBUG, + "handlers": ["console"], + "propagate": False, + } + }, # To let qlib work with other packages, we shouldn't disable existing loggers. # Note that this param is default to True according to the documentation of logging. "disable_existing_loggers": False, @@ -328,7 +336,11 @@ class DataPathManager: - some helper functions to process uri. """ - def __init__(self, provider_uri: Union[str, Path, dict], mount_path: Union[str, Path, dict]): + def __init__( + self, + provider_uri: Union[str, Path, dict], + mount_path: Union[str, Path, dict], + ): """ The relation of `provider_uri` and `mount_path` - `mount_path` is used only if provider_uri is an NFS path @@ -347,14 +359,19 @@ def format_provider_uri(provider_uri: Union[str, dict, Path]) -> dict: else: raise TypeError(f"provider_uri does not support {type(provider_uri)}") for freq, _uri in provider_uri.items(): - if QlibConfig.DataPathManager.get_uri_type(_uri) == QlibConfig.LOCAL_URI: + if ( + QlibConfig.DataPathManager.get_uri_type(_uri) + == QlibConfig.LOCAL_URI + ): provider_uri[freq] = str(Path(_uri).expanduser().resolve()) return provider_uri @staticmethod def get_uri_type(uri: Union[str, Path]): uri = uri if isinstance(uri, str) else str(uri.expanduser().resolve()) - is_win = re.match("^[a-zA-Z]:.*", uri) is not None # such as 'C:\\data', 'D:' + is_win = ( + re.match("^[a-zA-Z]:.*", uri) is not None + ) # such as 'C:\\data', 'D:' # such as 'host:/data/' (User may define short hostname by themselves or use localhost) is_nfs_or_win = re.match("^[^/]+:.+", uri) is not None @@ -415,7 +432,9 @@ def resolve_path(self): for _freq in _provider_uri.keys(): # mount_path _mount_path[_freq] = ( - _mount_path[_freq] if _mount_path[_freq] is None else str(Path(_mount_path[_freq]).expanduser()) + _mount_path[_freq] + if _mount_path[_freq] is None + else str(Path(_mount_path[_freq]).expanduser()) ) self["provider_uri"] = _provider_uri self["mount_path"] = _mount_path @@ -438,7 +457,11 @@ def set(self, default_conf: str = "client", **kwargs): default_conf : str the default config template chosen by user: "server", "client" """ - from .utils import set_log_with_config, get_module_logger, can_use_cache # pylint: disable=C0415 + from .utils import ( + set_log_with_config, + get_module_logger, + can_use_cache, + ) # pylint: disable=C0415 self.reset() @@ -448,11 +471,15 @@ def set(self, default_conf: str = "client", **kwargs): if _logging_config: set_log_with_config(_logging_config) - logger = get_module_logger("Initialization", kwargs.get("logging_level", self.logging_level)) + logger = get_module_logger( + "Initialization", kwargs.get("logging_level", self.logging_level) + ) logger.info(f"default_conf: {default_conf}.") self.set_mode(default_conf) - self.set_region(kwargs.get("region", self["region"] if "region" in self else REG_CN)) + self.set_region( + kwargs.get("region", self["region"] if "region" in self else REG_CN) + ) for k, v in kwargs.items(): if k not in self: @@ -471,7 +498,11 @@ def set(self, default_conf: str = "client", **kwargs): self["expression_cache"] = None # check dataset cache if self.is_depend_redis(self["dataset_cache"]): - log_str += f" and {self['dataset_cache']}" if log_str else self["dataset_cache"] + log_str += ( + f" and {self['dataset_cache']}" + if log_str + else self["dataset_cache"] + ) self["dataset_cache"] = None if log_str: logger.warning( diff --git a/qlib/contrib/data/data.py b/qlib/contrib/data/data.py index c153cfb8f6..8404792128 100644 --- a/qlib/contrib/data/data.py +++ b/qlib/contrib/data/data.py @@ -18,7 +18,10 @@ class ArcticFeatureProvider(FeatureProvider): def __init__( - self, uri="127.0.0.1", retry_time=0, market_transaction_time_list=[("09:15", "11:30"), ("13:00", "15:00")] + self, + uri="127.0.0.1", + retry_time=0, + market_transaction_time_list=[("09:15", "11:30"), ("13:00", "15:00")], ): super().__init__() self.uri = uri @@ -42,7 +45,9 @@ def feature(self, instrument, field, start_index, end_index, freq): # instruments does not exist return pd.Series() else: - df = arctic[freq].read(instrument, columns=[field], chunk_range=(start_index, end_index)) + df = arctic[freq].read( + instrument, columns=[field], chunk_range=(start_index, end_index) + ) s = df[field] if not s.empty: diff --git a/qlib/contrib/data/dataset.py b/qlib/contrib/data/dataset.py index 812e2cc713..b5195170e0 100644 --- a/qlib/contrib/data/dataset.py +++ b/qlib/contrib/data/dataset.py @@ -17,7 +17,9 @@ def _to_tensor(x): if not isinstance(x, torch.Tensor): - return torch.tensor(x, dtype=torch.float, device=device) # pylint: disable=E1101 + return torch.tensor( + x, dtype=torch.float, device=device + ) # pylint: disable=E1101 return x @@ -34,7 +36,9 @@ def _create_ts_slices(index, seq_len): assert index.is_monotonic_increasing, "index should be sorted" # number of dates for each instrument - sample_count_by_insts = index.to_series().groupby(level=0, group_keys=False).size().values + sample_count_by_insts = ( + index.to_series().groupby(level=0, group_keys=False).size().values + ) # start index for each instrument start_index_of_insts = np.roll(np.cumsum(sample_count_by_insts), 1) @@ -140,13 +144,19 @@ def __init__( label = handler.data_loader.fields["label"][0][0] horizon = guess_horizon([label]) - assert num_states == 0 or horizon > 0, "please specify `horizon` to avoid data leakage" + assert ( + num_states == 0 or horizon > 0 + ), "please specify `horizon` to avoid data leakage" assert memory_mode in ["sample", "daily"], "unsupported memory mode" - assert memory_mode == "sample" or batch_size < 0, "daily memory requires daily sampling (`batch_size < 0`)" + assert ( + memory_mode == "sample" or batch_size < 0 + ), "daily memory requires daily sampling (`batch_size < 0`)" assert batch_size != 0, "invalid batch size" if batch_size > 0 and n_samples is not None: - warnings.warn("`n_samples` can only be used for daily sampling (`batch_size < 0`)") + warnings.warn( + "`n_samples` can only be used for daily sampling (`batch_size < 0`)" + ) self.seq_len = seq_len self.horizon = horizon @@ -157,7 +167,12 @@ def __init__( self.shuffle = shuffle self.drop_last = drop_last self.input_size = input_size - self.params = (batch_size, n_samples, drop_last, shuffle) # for train/eval switch + self.params = ( + batch_size, + n_samples, + drop_last, + shuffle, + ) # for train/eval switch super().__init__(handler, segments, **kwargs) @@ -180,34 +195,50 @@ def setup_data(self, handler_kwargs: dict = None, **kwargs): # convert to numpy self._data = df["feature"].values.astype("float32") - np.nan_to_num(self._data, copy=False) # NOTE: fillna in case users forget using the fillna processor + np.nan_to_num( + self._data, copy=False + ) # NOTE: fillna in case users forget using the fillna processor self._label = df["label"].squeeze().values.astype("float32") self._index = df.index if self.input_size is not None and self.input_size != self._data.shape[1]: - warnings.warn("the data has different shape from input_size and the data will be reshaped") - assert self._data.shape[1] % self.input_size == 0, "data mismatch, please check `input_size`" + warnings.warn( + "the data has different shape from input_size and the data will be reshaped" + ) + assert ( + self._data.shape[1] % self.input_size == 0 + ), "data mismatch, please check `input_size`" # create batch slices self._batch_slices = _create_ts_slices(self._index, self.seq_len) # create daily slices - daily_slices = {date: [] for date in sorted(self._index.unique(level=1))} # sorted by date + daily_slices = { + date: [] for date in sorted(self._index.unique(level=1)) + } # sorted by date for i, (code, date) in enumerate(self._index): daily_slices[date].append(self._batch_slices[i]) self._daily_slices = np.array(list(daily_slices.values()), dtype="object") - self._daily_index = pd.Series(list(daily_slices.keys())) # index is the original date index + self._daily_index = pd.Series( + list(daily_slices.keys()) + ) # index is the original date index # add memory (sample wise and daily) if self.memory_mode == "sample": - self._memory = np.zeros((len(self._data), self.num_states), dtype=np.float32) + self._memory = np.zeros( + (len(self._data), self.num_states), dtype=np.float32 + ) elif self.memory_mode == "daily": - self._memory = np.zeros((len(self._daily_index), self.num_states), dtype=np.float32) + self._memory = np.zeros( + (len(self._daily_index), self.num_states), dtype=np.float32 + ) else: raise ValueError(f"invalid memory_mode `{self.memory_mode}`") # padding tensor - self._zeros = np.zeros((self.seq_len, max(self.num_states, self._data.shape[1])), dtype=np.float32) + self._zeros = np.zeros( + (self.seq_len, max(self.num_states, self._data.shape[1])), dtype=np.float32 + ) def _prepare_seg(self, slc, **kwargs): fn = _get_date_parse_fn(self._index[0][1]) @@ -228,8 +259,12 @@ def _prepare_seg(self, slc, **kwargs): obj._zeros = self._zeros # update index for this batch date_index = self._index.get_level_values(1) - obj._batch_slices = self._batch_slices[(date_index >= start_date) & (date_index <= end_date)] - mask = (self._daily_index.values >= start_date) & (self._daily_index.values <= end_date) + obj._batch_slices = self._batch_slices[ + (date_index >= start_date) & (date_index <= end_date) + ] + mask = (self._daily_index.values >= start_date) & ( + self._daily_index.values <= end_date + ) obj._daily_slices = self._daily_slices[mask] obj._daily_index = self._daily_index[mask] return obj @@ -305,18 +340,29 @@ def __iter__(self): # NOTE: daily sampling is used in 1) eval mode, 2) train mode with self.batch_size < 0 if self.batch_size < 0: # store daily index - idx = self._daily_index.index[j] # daily_index.index is the index of the original data + idx = self._daily_index.index[ + j + ] # daily_index.index is the index of the original data daily_index.append(idx) # store daily memory if specified # NOTE: daily memory always requires daily sampling (self.batch_size < 0) if self.memory_mode == "daily": - slc = slice(max(idx - self.seq_len - self.horizon, 0), max(idx - self.horizon, 0)) - state.append(_maybe_padding(self._memory[slc], self.seq_len, self._zeros)) + slc = slice( + max(idx - self.seq_len - self.horizon, 0), + max(idx - self.horizon, 0), + ) + state.append( + _maybe_padding(self._memory[slc], self.seq_len, self._zeros) + ) # down-sample stocks and store count - if self.n_samples and 0 < self.n_samples < len(slices_subset): # intraday subsample - slices_subset = np.random.choice(slices_subset, self.n_samples, replace=False) + if self.n_samples and 0 < self.n_samples < len( + slices_subset + ): # intraday subsample + slices_subset = np.random.choice( + slices_subset, self.n_samples, replace=False + ) daily_count.append(len(slices_subset)) # normal sampling @@ -328,12 +374,20 @@ def __iter__(self): for slc in slices_subset: # legacy support for Alpha360 data by `input_size` if self.input_size: - data.append(self._data[slc.stop - 1].reshape(self.input_size, -1).T) + data.append( + self._data[slc.stop - 1].reshape(self.input_size, -1).T + ) else: - data.append(_maybe_padding(self._data[slc], self.seq_len, self._zeros)) + data.append( + _maybe_padding(self._data[slc], self.seq_len, self._zeros) + ) if self.memory_mode == "sample": - state.append(_maybe_padding(self._memory[slc], self.seq_len, self._zeros)[: -self.horizon]) + state.append( + _maybe_padding( + self._memory[slc], self.seq_len, self._zeros + )[: -self.horizon] + ) label.append(self._label[slc.stop - 1]) index.append(slc.stop - 1) diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py index 2fe5258daa..4a06597142 100644 --- a/qlib/contrib/data/handler.py +++ b/qlib/contrib/data/handler.py @@ -60,8 +60,12 @@ def __init__( inst_processors=None, **kwargs, ): - infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) - learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) + infer_processors = check_transform_proc( + infer_processors, fit_start_time, fit_end_time + ) + learn_processors = check_transform_proc( + learn_processors, fit_start_time, fit_end_time + ) data_loader = { "class": "QlibDataLoader", @@ -111,8 +115,12 @@ def __init__( inst_processors=None, **kwargs, ): - infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) - learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) + infer_processors = check_transform_proc( + infer_processors, fit_start_time, fit_end_time + ) + learn_processors = check_transform_proc( + learn_processors, fit_start_time, fit_end_time + ) data_loader = { "class": "QlibDataLoader", diff --git a/qlib/contrib/data/highfreq_handler.py b/qlib/contrib/data/highfreq_handler.py index 8eed4814f2..84a2d8710b 100644 --- a/qlib/contrib/data/highfreq_handler.py +++ b/qlib/contrib/data/highfreq_handler.py @@ -17,8 +17,12 @@ def __init__( fit_end_time=None, drop_raw=True, ): - infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) - learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) + infer_processors = check_transform_proc( + infer_processors, fit_start_time, fit_end_time + ) + learn_processors = check_transform_proc( + learn_processors, fit_start_time, fit_end_time + ) data_loader = { "class": "QlibDataLoader", @@ -56,7 +60,10 @@ def get_normalized_price_feature(price_field, shift=0): # calculate -> ffill -> remove paused feature_ops = template_paused.format( template_fillnan.format( - template_norm.format(template_if.format("$close", price_field), template_fillnan.format("$close")) + template_norm.format( + template_if.format("$close", price_field), + template_fillnan.format("$close"), + ) ) ) return feature_ops @@ -80,7 +87,9 @@ def get_normalized_price_feature(price_field, shift=0): fields += [ template_gzero.format( template_paused.format( - "If(IsNull({0}), 0, {0})".format("{0}/Ref(DayLast(Mean({0}, 7200)), 240)".format("$volume")) + "If(IsNull({0}), 0, {0})".format( + "{0}/Ref(DayLast(Mean({0}, 7200)), 240)".format("$volume") + ) ) ) ] @@ -90,7 +99,9 @@ def get_normalized_price_feature(price_field, shift=0): template_gzero.format( template_paused.format( "If(IsNull({0}), 0, {0})".format( - "Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240)".format("$volume") + "Ref({0}, 240)/Ref(DayLast(Mean({0}, 7200)), 240)".format( + "$volume" + ) ) ) ) @@ -119,8 +130,12 @@ def __init__( self.day_length = day_length self.columns = columns - infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) - learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) + infer_processors = check_transform_proc( + infer_processors, fit_start_time, fit_end_time + ) + learn_processors = check_transform_proc( + learn_processors, fit_start_time, fit_end_time + ) data_loader = { "class": "QlibDataLoader", @@ -153,13 +168,20 @@ def get_normalized_price_feature(price_field, shift=0): if shift == 0: template_norm = f"{{0}}/DayLast(Ref({{1}}, {self.day_length * 2}))" else: - template_norm = f"Ref({{0}}, " + str(shift) + f")/DayLast(Ref({{1}}, {self.day_length}))" + template_norm = ( + f"Ref({{0}}, " + + str(shift) + + f")/DayLast(Ref({{1}}, {self.day_length}))" + ) template_fillnan = "FFillNan({0})" # calculate -> ffill -> remove paused feature_ops = template_paused.format( template_fillnan.format( - template_norm.format(template_if.format("$close", price_field), template_fillnan.format("$close")) + template_norm.format( + template_if.format("$close", price_field), + template_fillnan.format("$close"), + ) ) ) return feature_ops @@ -176,7 +198,9 @@ def get_normalized_price_feature(price_field, shift=0): fields += [ template_paused.format( "If(IsNull({0}), 0, {0})".format( - f"{{0}}/Ref(DayLast(Mean({{0}}, {self.day_length * 30})), {self.day_length})".format("$volume") + f"{{0}}/Ref(DayLast(Mean({{0}}, {self.day_length * 30})), {self.day_length})".format( + "$volume" + ) ) ) ] @@ -293,12 +317,16 @@ def get_feature_config(self): if "$vwap" in self.columns: fields += [ - template_paused.format(template_if.format(template_fillnan.format("$close"), "$vwap")), + template_paused.format( + template_if.format(template_fillnan.format("$close"), "$vwap") + ), ] names += ["$vwap0"] if "$volume" in self.columns: - fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$volume"))] + fields += [ + template_paused.format("If(IsNull({0}), 0, {0})".format("$volume")) + ] names += ["$volume0"] return fields, names @@ -317,8 +345,12 @@ def __init__( inst_processors=None, drop_raw=True, ): - infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) - learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) + infer_processors = check_transform_proc( + infer_processors, fit_start_time, fit_end_time + ) + learn_processors = check_transform_proc( + learn_processors, fit_start_time, fit_end_time + ) data_loader = { "class": "QlibDataLoader", @@ -358,7 +390,10 @@ def get_normalized_price_feature(price_field, shift=0): # calculate -> ffill -> remove paused feature_ops = template_paused.format( template_fillnan.format( - template_norm.format(template_if.format("$close", price_field), template_fillnan.format("$close")) + template_norm.format( + template_if.format("$close", price_field), + template_fillnan.format("$close"), + ) ) ) return feature_ops @@ -375,7 +410,9 @@ def get_normalized_vwap_price_feature(price_field, shift=0): feature_ops = template_paused.format( template_fillnan.format( template_norm.format( - template_if.format("$close", template_ifinf.format("$close", price_field)), + template_if.format( + "$close", template_ifinf.format("$close", price_field) + ), template_fillnan.format("$close"), ) ) @@ -413,7 +450,9 @@ def get_volume_feature(volume_field, shift=0): template_paused.format( "If(IsInf({0}), 0, {0})".format( "If(IsNull({0}), 0, {0})".format( - "{0}/Ref(DayLast(Mean({0}, 7200)), 240)".format(volume_field) + "{0}/Ref(DayLast(Mean({0}, 7200)), 240)".format( + volume_field + ) ) ) ) @@ -423,7 +462,9 @@ def get_volume_feature(volume_field, shift=0): template_paused.format( "If(IsInf({0}), 0, {0})".format( "If(IsNull({0}), 0, {0})".format( - f"Ref({{0}}, {shift})/Ref(DayLast(Mean({{0}}, 7200)), 240)".format(volume_field) + f"Ref({{0}}, {shift})/Ref(DayLast(Mean({{0}}, 7200)), 240)".format( + volume_field + ) ) ) ) @@ -444,7 +485,16 @@ def get_volume_feature(volume_field, shift=0): fields += [get_volume_feature("$askV1", 0)] fields += [get_volume_feature("$askV3", 0)] fields += [get_volume_feature("$askV5", 0)] - names += ["$bidV", "$bidV1", "$bidV3", "$bidV5", "$askV", "$askV1", "$askV3", "$askV5"] + names += [ + "$bidV", + "$bidV1", + "$bidV3", + "$bidV5", + "$askV", + "$askV1", + "$askV3", + "$askV5", + ] fields += [get_volume_feature("$bidV", 240)] fields += [get_volume_feature("$bidV1", 240)] @@ -454,7 +504,16 @@ def get_volume_feature(volume_field, shift=0): fields += [get_volume_feature("$askV1", 240)] fields += [get_volume_feature("$askV3", 240)] fields += [get_volume_feature("$askV5", 240)] - names += ["$bidV_1", "$bidV1_1", "$bidV3_1", "$bidV5_1", "$askV_1", "$askV1_1", "$askV3_1", "$askV5_1"] + names += [ + "$bidV_1", + "$bidV1_1", + "$bidV3_1", + "$bidV5_1", + "$askV_1", + "$askV1_1", + "$askV3_1", + "$askV5_1", + ] return fields, names @@ -518,22 +577,34 @@ def get_feature_config(self): fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$askV"))] names += ["$askV0"] - fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("($bid + $ask) / 2"))] + fields += [ + template_paused.format( + "If(IsNull({0}), 0, {0})".format("($bid + $ask) / 2") + ) + ] names += ["$median0"] fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$factor"))] names += ["$factor0"] - fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$downlimitmarket"))] + fields += [ + template_paused.format("If(IsNull({0}), 0, {0})".format("$downlimitmarket")) + ] names += ["$downlimitmarket0"] - fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$uplimitmarket"))] + fields += [ + template_paused.format("If(IsNull({0}), 0, {0})".format("$uplimitmarket")) + ] names += ["$uplimitmarket0"] - fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$highmarket"))] + fields += [ + template_paused.format("If(IsNull({0}), 0, {0})".format("$highmarket")) + ] names += ["$highmarket0"] - fields += [template_paused.format("If(IsNull({0}), 0, {0})".format("$lowmarket"))] + fields += [ + template_paused.format("If(IsNull({0}), 0, {0})".format("$lowmarket")) + ] names += ["$lowmarket0"] return fields, names diff --git a/qlib/contrib/data/highfreq_processor.py b/qlib/contrib/data/highfreq_processor.py index db2a6e39b4..2fffdd963c 100644 --- a/qlib/contrib/data/highfreq_processor.py +++ b/qlib/contrib/data/highfreq_processor.py @@ -35,10 +35,15 @@ def __init__( self.norm_groups = norm_groups def fit(self, df_features) -> None: - if os.path.exists(self.feature_save_dir) and len(os.listdir(self.feature_save_dir)) != 0: + if ( + os.path.exists(self.feature_save_dir) + and len(os.listdir(self.feature_save_dir)) != 0 + ): return os.makedirs(self.feature_save_dir) - fetch_df = fetch_df_by_index(df_features, slice(self.fit_start_time, self.fit_end_time), level="datetime") + fetch_df = fetch_df_by_index( + df_features, slice(self.fit_start_time, self.fit_end_time), level="datetime" + ) del df_features index = 0 names = {} @@ -76,5 +81,7 @@ def __call__(self, df_features): df_values[:, name_val] = np.log1p(df_values[:, name_val]) df_values[:, name_val] -= feature_mean df_values[:, name_val] /= feature_std - df_features = pd.DataFrame(data=df_values, index=df_features.index, columns=df_features.columns) + df_features = pd.DataFrame( + data=df_values, index=df_features.index, columns=df_features.columns + ) return df_features.fillna(0) diff --git a/qlib/contrib/data/highfreq_provider.py b/qlib/contrib/data/highfreq_provider.py index 611e30d861..b857240f2b 100644 --- a/qlib/contrib/data/highfreq_provider.py +++ b/qlib/contrib/data/highfreq_provider.py @@ -10,7 +10,17 @@ from qlib.utils import init_instance_by_config from qlib.data.dataset.handler import DataHandlerLP from qlib.data.data import Cal -from qlib.contrib.ops.high_freq import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut +from qlib.contrib.ops.high_freq import ( + get_calendar_day, + DayLast, + FFillNan, + BFillNan, + Date, + Select, + IsNull, + IsInf, + Cut, +) import pickle as pkl from joblib import Parallel, delayed @@ -137,7 +147,9 @@ def _gen_dataframe(self, config, datasets=["train", "valid", "test"]): res = [data[i] for i in datasets] else: res = data.prepare(datasets) - self.logger.info(f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}") + self.logger.info( + f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}" + ) else: if not os.path.exists(os.path.dirname(path)): os.makedirs(os.path.dirname(path)) @@ -160,7 +172,9 @@ def _gen_dataframe(self, config, datasets=["train", "valid", "test"]): with open(path[:-4] + "test.pkl", "wb") as f: pkl.dump(testset, f) res = [data[i] for i in datasets] - self.logger.info(f"[{__name__}]Data generated, time cost: {(time.time() - start_time):.2f}") + self.logger.info( + f"[{__name__}]Data generated, time cost: {(time.time() - start_time):.2f}" + ) return res def _gen_data(self, config, datasets=["train", "valid", "test"]): @@ -179,7 +193,9 @@ def _gen_data(self, config, datasets=["train", "valid", "test"]): res = [data[i] for i in datasets] else: res = data.prepare(datasets) - self.logger.info(f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}") + self.logger.info( + f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}" + ) else: if not os.path.exists(os.path.dirname(path)): os.makedirs(os.path.dirname(path)) @@ -190,7 +206,9 @@ def _gen_data(self, config, datasets=["train", "valid", "test"]): dataset.config(dump_all=True, recursive=True) dataset.to_pickle(path) res = dataset.prepare(datasets) - self.logger.info(f"[{__name__}]Data generated, time cost: {(time.time() - start_time):.2f}") + self.logger.info( + f"[{__name__}]Data generated, time cost: {(time.time() - start_time):.2f}" + ) return res def _gen_dataset(self, config): @@ -204,7 +222,9 @@ def _gen_dataset(self, config): with open(path, "rb") as f: dataset = pkl.load(f) - self.logger.info(f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}") + self.logger.info( + f"[{__name__}]Data loaded, time cost: {time.time() - start:.2f}" + ) else: start = time.time() if not os.path.exists(os.path.dirname(path)): @@ -212,9 +232,13 @@ def _gen_dataset(self, config): self.logger.info(f"[{__name__}]Generating dataset") self._prepare_calender_cache() dataset = init_instance_by_config(config) - self.logger.info(f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}") + self.logger.info( + f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}" + ) dataset.prepare(["train", "valid", "test"]) - self.logger.info(f"[{__name__}]Dataset prepared, time cost: {time.time() - start:.2f}") + self.logger.info( + f"[{__name__}]Dataset prepared, time cost: {time.time() - start:.2f}" + ) dataset.config(dump_all=True, recursive=True) dataset.to_pickle(path) return dataset @@ -235,14 +259,18 @@ def _gen_day_dataset(self, config, conf_type): self.logger.info(f"[{__name__}]Generating dataset") self._prepare_calender_cache() dataset = init_instance_by_config(config) - self.logger.info(f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}") + self.logger.info( + f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}" + ) dataset.config(dump_all=False, recursive=True) dataset.to_pickle(path + "tmp_dataset.pkl") with open(path + "tmp_dataset.pkl", "rb") as f: new_dataset = pkl.load(f) - time_list = D.calendar(start_time=self.start_time, end_time=self.end_time, freq=self.freq)[::240] + time_list = D.calendar( + start_time=self.start_time, end_time=self.end_time, freq=self.freq + )[::240] def generate_dataset(times): if os.path.isfile(path + times.strftime("%Y-%m-%d") + ".pkl"): @@ -276,7 +304,9 @@ def _gen_stock_dataset(self, config, conf_type): self.logger.info(f"[{__name__}]Generating dataset") self._prepare_calender_cache() dataset = init_instance_by_config(config) - self.logger.info(f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}") + self.logger.info( + f"[{__name__}]Dataset init, time cost: {time.time() - start:.2f}" + ) dataset.config(dump_all=False, recursive=True) dataset.to_pickle(path + "tmp_dataset.pkl") @@ -285,7 +315,11 @@ def _gen_stock_dataset(self, config, conf_type): instruments = D.instruments(market="all") stock_list = D.list_instruments( - instruments=instruments, start_time=self.start_time, end_time=self.end_time, freq=self.freq, as_list=True + instruments=instruments, + start_time=self.start_time, + end_time=self.end_time, + freq=self.freq, + as_list=True, ) def generate_dataset(stock): diff --git a/qlib/contrib/data/loader.py b/qlib/contrib/data/loader.py index 4d11f3a34c..2acecd398a 100644 --- a/qlib/contrib/data/loader.py +++ b/qlib/contrib/data/loader.py @@ -126,14 +126,30 @@ def get_feature_config( ] if "price" in config: windows = config["price"].get("windows", range(5)) - feature = config["price"].get("feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"]) + feature = config["price"].get( + "feature", ["OPEN", "HIGH", "LOW", "CLOSE", "VWAP"] + ) for field in feature: field = field.lower() - fields += ["Ref($%s, %d)/$close" % (field, d) if d != 0 else "$%s/$close" % field for d in windows] + fields += [ + ( + "Ref($%s, %d)/$close" % (field, d) + if d != 0 + else "$%s/$close" % field + ) + for d in windows + ] names += [field.upper() + str(d) for d in windows] if "volume" in config: windows = config["volume"].get("windows", range(5)) - fields += ["Ref($volume, %d)/($volume+1e-12)" % d if d != 0 else "$volume/($volume+1e-12)" for d in windows] + fields += [ + ( + "Ref($volume, %d)/($volume+1e-12)" % d + if d != 0 + else "$volume/($volume+1e-12)" + ) + for d in windows + ] names += ["VOLUME" + str(d) for d in windows] if "rolling" in config: windows = config["rolling"].get("windows", [5, 10, 20, 30, 60]) @@ -197,7 +213,11 @@ def use(x): names += ["RANK%d" % d for d in windows] if use("RSV"): # Represent the price position between upper and lower resistent price for past d days. - fields += ["($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" % (d, d, d) for d in windows] + fields += [ + "($close-Min($low, %d))/(Max($high, %d)-Min($low, %d)+1e-12)" + % (d, d, d) + for d in windows + ] names += ["RSV%d" % d for d in windows] if use("IMAX"): # The number of days between current date and previous highest price date. @@ -216,7 +236,10 @@ def use(x): if use("IMXD"): # The time period between previous lowest-price date occur after highest price date. # Large value suggest downward momemtum. - fields += ["(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) for d in windows] + fields += [ + "(IdxMax($high, %d)-IdxMin($low, %d))/%d" % (d, d, d) + for d in windows + ] names += ["IMXD%d" % d for d in windows] if use("CORR"): # The correlation between absolute close price and log scaled trading volume @@ -224,7 +247,10 @@ def use(x): names += ["CORR%d" % d for d in windows] if use("CORD"): # The correlation between price change ratio and volume change ratio - fields += ["Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d for d in windows] + fields += [ + "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), %d)" % d + for d in windows + ] names += ["CORD%d" % d for d in windows] if use("CNTP"): # The percentage of days in past d days that price go up. @@ -236,13 +262,18 @@ def use(x): names += ["CNTN%d" % d for d in windows] if use("CNTD"): # The diff between past up day and past down day - fields += ["Mean($close>Ref($close, 1), %d)-Mean($closeRef($close, 1), %d)-Mean($close Tuple[pd.Series, pd.Series]: """ calculate the precision for long and short operation @@ -131,10 +136,14 @@ def pred_autocorr(pred: pd.Series, lag=1, inst_col="instrument", date_col="datet """ if isinstance(pred, pd.DataFrame): pred = pred.iloc[:, 0] - get_module_logger("pred_autocorr").warning(f"Only the first column in {pred.columns} of `pred` is kept") + get_module_logger("pred_autocorr").warning( + f"Only the first column in {pred.columns} of `pred` is kept" + ) pred_ustk = pred.sort_index().unstack(inst_col) corr_s = {} - for (idx, cur), (_, prev) in zip(pred_ustk.iterrows(), pred_ustk.shift(lag).iterrows()): + for (idx, cur), (_, prev) in zip( + pred_ustk.iterrows(), pred_ustk.shift(lag).iterrows() + ): corr_s[idx] = cur.corr(prev) corr_s = pd.Series(corr_s).sort_index() return corr_s @@ -157,7 +166,9 @@ def pred_autocorr_all(pred_dict, n_jobs=-1, **kwargs): return complex_parallel(Parallel(n_jobs=n_jobs, verbose=10), ac_dict) -def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> (pd.Series, pd.Series): +def calc_ic( + pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False +) -> (pd.Series, pd.Series): """calc_ic. Parameters @@ -175,8 +186,12 @@ def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False ic and rank ic """ df = pd.DataFrame({"pred": pred, "label": label}) - ic = df.groupby(date_col, group_keys=False).apply(lambda df: df["pred"].corr(df["label"])) - ric = df.groupby(date_col, group_keys=False).apply(lambda df: df["pred"].corr(df["label"], method="spearman")) + ic = df.groupby(date_col, group_keys=False).apply( + lambda df: df["pred"].corr(df["label"]) + ) + ric = df.groupby(date_col, group_keys=False).apply( + lambda df: df["pred"].corr(df["label"], method="spearman") + ) if dropna: return ic.dropna(), ric.dropna() else: @@ -210,6 +225,9 @@ def calc_all_ic(pred_dict_all, label, date_col="datetime", dropna=False, n_jobs= """ pred_all_ics = {} for k, pred in pred_dict_all.items(): - pred_all_ics[k] = DelayedDict(["ic", "ric"], delayed(calc_ic)(pred, label, date_col=date_col, dropna=dropna)) + pred_all_ics[k] = DelayedDict( + ["ic", "ric"], + delayed(calc_ic)(pred, label, date_col=date_col, dropna=dropna), + ) pred_all_ics = complex_parallel(Parallel(n_jobs=n_jobs, verbose=10), pred_all_ics) return pred_all_ics diff --git a/qlib/contrib/evaluate.py b/qlib/contrib/evaluate.py index e0bacfca85..2ad296e296 100644 --- a/qlib/contrib/evaluate.py +++ b/qlib/contrib/evaluate.py @@ -13,7 +13,12 @@ from ..utils import get_date_range from ..utils.resam import Freq from ..strategy.base import BaseStrategy -from ..backtest import get_exchange, position, backtest as backtest_func, executor as _executor +from ..backtest import ( + get_exchange, + position, + backtest as backtest_func, + executor as _executor, +) from ..data import D @@ -24,7 +29,9 @@ logger = get_module_logger("Evaluate") -def risk_analysis(r, N: int = None, freq: str = "day", mode: Literal["sum", "product"] = "sum"): +def risk_analysis( + r, N: int = None, freq: str = "day", mode: Literal["sum", "product"] = "sum" +): """Risk Analysis NOTE: The calculation of annualized return is different from the definition of annualized return. @@ -80,7 +87,9 @@ def cal_risk_analysis_scaler(freq): # max percentage drawdown from peak cumulative product max_drawdown = (cumulative_curve / cumulative_curve.cummax() - 1).min() else: - raise ValueError(f"risk_analysis accumulation mode {mode} is not supported. Expected `sum` or `product`.") + raise ValueError( + f"risk_analysis accumulation mode {mode} is not supported. Expected `sum` or `product`." + ) information_ratio = mean / std * np.sqrt(N) data = { @@ -339,7 +348,10 @@ def long_short_backtest( _pred_dates = pred.index.get_level_values(level="datetime") predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max()) - trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift)) + trade_dates = np.append( + predict_dates[shift:], + get_date_range(predict_dates[-1], left_shift=1, right_shift=shift), + ) long_returns = {} short_returns = {} @@ -361,7 +373,9 @@ def long_short_backtest( for stock in long_stocks: if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date): continue - profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str) + profit = trade_exchange.get_quote_info( + stock_id=stock, start_time=date, end_time=date, field=profit_str + ) if np.isnan(profit): long_profit.append(0) else: @@ -370,7 +384,9 @@ def long_short_backtest( for stock in short_stocks: if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date): continue - profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str) + profit = trade_exchange.get_quote_info( + stock_id=stock, start_time=date, end_time=date, field=profit_str + ) if np.isnan(profit): short_profit.append(0) else: @@ -380,7 +396,9 @@ def long_short_backtest( # exclude the suspend stock if trade_exchange.check_stock_suspended(stock_id=stock, trade_date=date): continue - profit = trade_exchange.get_quote_info(stock_id=stock, start_time=date, end_time=date, field=profit_str) + profit = trade_exchange.get_quote_info( + stock_id=stock, start_time=date, end_time=date, field=profit_str + ) if np.isnan(profit): all_profit.append(0) else: @@ -409,7 +427,9 @@ def t_run(): "n_drop": 5, "signal": pred, } - report_df, positions = backtest_daily(start_time="2017-01-01", end_time="2020-08-01", strategy=strategy_config) + report_df, positions = backtest_daily( + start_time="2017-01-01", end_time="2020-08-01", strategy=strategy_config + ) print(report_df.head()) print(positions.keys()) print(positions[list(positions.keys())[0]]) diff --git a/qlib/contrib/evaluate_portfolio.py b/qlib/contrib/evaluate_portfolio.py index 0c598e2fa8..12640756e3 100644 --- a/qlib/contrib/evaluate_portfolio.py +++ b/qlib/contrib/evaluate_portfolio.py @@ -97,7 +97,9 @@ def get_position_list_value(positions): # return dict for time:position_value value_dict = OrderedDict() for day, position in positions.items(): - value = _get_position_value_from_df(evaluate_date=day, position=position, close_data_df=close_data_df) + value = _get_position_value_from_df( + evaluate_date=day, position=position, close_data_df=close_data_df + ) value_dict[day] = value return value_dict @@ -187,7 +189,9 @@ def get_max_drawdown_from_series(r): """ # mdd = ((r.cumsum() - r.cumsum().cummax()) / (1 + r.cumsum().cummax())).min() - mdd = (((1 + r).cumprod() - (1 + r).cumprod().cummax()) / ((1 + r).cumprod().cummax())).min() + mdd = ( + ((1 + r).cumprod() - (1 + r).cumprod().cummax()) / ((1 + r).cumprod().cummax()) + ).min() return mdd diff --git a/qlib/contrib/meta/data_selection/dataset.py b/qlib/contrib/meta/data_selection/dataset.py index 61efdd63cf..2647364bbe 100644 --- a/qlib/contrib/meta/data_selection/dataset.py +++ b/qlib/contrib/meta/data_selection/dataset.py @@ -48,19 +48,29 @@ def setup(self, trainer=TrainerR, trainer_kwargs={}): """ # 1) prepare the prediction of proxy models - perf_task_tpl = deepcopy(self.task_tpl) # this task is supposed to contains no complicated objects + perf_task_tpl = deepcopy( + self.task_tpl + ) # this task is supposed to contains no complicated objects # The only thing we want to save is the prediction perf_task_tpl["record"] = ["qlib.workflow.record_temp.SignalRecord"] - trainer = auto_filter_kwargs(trainer)(experiment_name=self.exp_name, **trainer_kwargs) + trainer = auto_filter_kwargs(trainer)( + experiment_name=self.exp_name, **trainer_kwargs + ) # NOTE: # The handler is initialized for only once. if not trainer.has_worker(): self.dh = init_task_handler(perf_task_tpl) - self.dh.config(dump_all=False) # in some cases, the data handler are saved to disk with `dump_all=True` + self.dh.config( + dump_all=False + ) # in some cases, the data handler are saved to disk with `dump_all=True` else: - self.dh = init_instance_by_config(perf_task_tpl["dataset"]["kwargs"]["handler"]) - assert self.dh.dump_all is False # otherwise, it will save all the detailed data + self.dh = init_instance_by_config( + perf_task_tpl["dataset"]["kwargs"]["handler"] + ) + assert ( + self.dh.dump_all is False + ) # otherwise, it will save all the detailed data seg = perf_task_tpl["dataset"]["kwargs"]["segments"] @@ -73,7 +83,12 @@ def setup(self, trainer=TrainerR, trainer_kwargs={}): # NOTE: # we play a trick here # treat the training segments as test to create the rolling tasks - rg = RollingGen(step=self.step, test_key="train", train_key=None, task_copy_func=deepcopy_basic_type) + rg = RollingGen( + step=self.step, + test_key="train", + train_key=None, + task_copy_func=deepcopy_basic_type, + ) gen_task = task_generator(perf_task_tpl, [rg]) recorders = R.list_recorders(experiment_name=self.exp_name) @@ -81,7 +96,9 @@ def setup(self, trainer=TrainerR, trainer_kwargs={}): get_module_logger("Internal Data").info("the data has been initialized") else: # train new models - assert 0 == len(recorders), "An empty experiment is required for setup `InternalData`" + assert 0 == len( + recorders + ), "An empty experiment is required for setup `InternalData`" trainer.train(gen_task) # 2) extract the similarity matrix @@ -121,7 +138,13 @@ def update(self): class MetaTaskDS(MetaTask): """Meta Task for Data Selection""" - def __init__(self, task: dict, meta_info: pd.DataFrame, mode: str = MetaTask.PROC_MODE_FULL, fill_method="max"): + def __init__( + self, + task: dict, + meta_info: pd.DataFrame, + mode: str = MetaTask.PROC_MODE_FULL, + fill_method="max", + ): """ The description of the processed data @@ -153,12 +176,16 @@ def __init__(self, task: dict, meta_info: pd.DataFrame, mode: str = MetaTask.PRO ds = self.get_dataset() # these three lines occupied 70% of the time of initializing MetaTaskDS - d_train, d_test = ds.prepare(["train", "test"], col_set=["feature", "label"]) + d_train, d_test = ds.prepare( + ["train", "test"], col_set=["feature", "label"] + ) prev_size = d_test.shape[0] d_train = d_train.dropna(axis=0) d_test = d_test.dropna(axis=0) if prev_size == 0 or d_test.shape[0] / prev_size <= 0.1: - raise ValueError(f"Most of samples are dropped. Please check this task: {task}") + raise ValueError( + f"Most of samples are dropped. Please check this task: {task}" + ) assert ( d_test.groupby("datetime", group_keys=False).size().shape[0] >= 5 @@ -195,7 +222,12 @@ def _get_processed_meta_info(self): if suffix == "seg": fill_value = {} for col in meta_info_norm.columns: - fill_value[col] = meta_info_norm.loc[meta_info_norm[col].isna(), :].dropna(axis=1).mean().max() + fill_value[col] = ( + meta_info_norm.loc[meta_info_norm[col].isna(), :] + .dropna(axis=1) + .mean() + .max() + ) fill_value = pd.Series(fill_value).sort_index() # The NaN Values are filled segment-wise. Below is an exampleof fill_value # 2009-01-05 2009-02-06 0.145809 @@ -290,7 +322,9 @@ def __init__( else: self.internal_data = InternalData(task_tpl, step=step, exp_name=exp_name) self.internal_data.setup() - self.task_tpl = deepcopy(task_tpl) # FIXME: if the handler is shared, how to avoid the explosion of the memroy. + self.task_tpl = deepcopy( + task_tpl + ) # FIXME: if the handler is shared, how to avoid the explosion of the memroy. self.trunc_days = trunc_days self.hist_step_n = hist_step_n self.step = step @@ -304,7 +338,9 @@ def __init__( self.ta = TimeAdjuster(future=True) for t in task_iter: t["dataset"]["kwargs"]["segments"]["test"] = self.ta.shift( - t["dataset"]["kwargs"]["segments"]["test"], step=rolling_ext_days, rtype=RollingGen.ROLL_EX + t["dataset"]["kwargs"]["segments"]["test"], + step=rolling_ext_days, + rtype=RollingGen.ROLL_EX, ) if task_mode == MetaTask.PROC_MODE_FULL: # Only pre initializing the task when full task is req @@ -321,12 +357,19 @@ def __init__( for t in tqdm(task_iter, desc="creating meta tasks"): try: self.meta_task_l.append( - MetaTaskDS(t, meta_info=self._prepare_meta_ipt(t), mode=task_mode, fill_method=fill_method) + MetaTaskDS( + t, + meta_info=self._prepare_meta_ipt(t), + mode=task_mode, + fill_method=fill_method, + ) ) self.task_list.append(t) except ValueError as e: logger.warning(f"ValueError: {e}") - assert len(self.meta_task_l) > 0, "No meta tasks found. Please check the data and setting" + assert ( + len(self.meta_task_l) > 0 + ), "No meta tasks found. Please check the data and setting" def _prepare_meta_ipt(self, task) -> pd.DataFrame: """ @@ -371,7 +414,9 @@ def mask_overlap(s): Approximately the diagnal + horizon length of data are masked. """ start, end = s.name - end = get_date_by_shift(trading_date=end, shift=self.trunc_days - 1, future=True) + end = get_date_by_shift( + trading_date=end, shift=self.trunc_days - 1, future=True + ) return s.mask((s.index >= start) & (s.index <= end)) ic_df_avail = ic_df_avail.apply(mask_overlap) # apply to each col @@ -388,11 +433,15 @@ def _prepare_seg(self, segment: Text) -> List[MetaTask]: train_task_n = int(len(self.meta_task_l) * self.segments) if segment == "train": train_tasks = self.meta_task_l[:train_task_n] - get_module_logger("MetaDatasetDS").info(f"The first train meta task: {train_tasks[0]}") + get_module_logger("MetaDatasetDS").info( + f"The first train meta task: {train_tasks[0]}" + ) return train_tasks elif segment == "test": test_tasks = self.meta_task_l[train_task_n:] - get_module_logger("MetaDatasetDS").info(f"The first test meta task: {test_tasks[0]}") + get_module_logger("MetaDatasetDS").info( + f"The first test meta task: {test_tasks[0]}" + ) return test_tasks else: raise NotImplementedError(f"This type of input is not supported") @@ -401,12 +450,18 @@ def _prepare_seg(self, segment: Text) -> List[MetaTask]: test_tasks = [] for t in self.meta_task_l: test_end = t.task["dataset"]["kwargs"]["segments"]["test"][1] - if test_end is None or pd.Timestamp(test_end) < pd.Timestamp(self.segments): + if test_end is None or pd.Timestamp(test_end) < pd.Timestamp( + self.segments + ): train_tasks.append(t) else: test_tasks.append(t) - get_module_logger("MetaDatasetDS").info(f"The first train meta task: {train_tasks[0]}") - get_module_logger("MetaDatasetDS").info(f"The first test meta task: {test_tasks[0]}") + get_module_logger("MetaDatasetDS").info( + f"The first train meta task: {train_tasks[0]}" + ) + get_module_logger("MetaDatasetDS").info( + f"The first test meta task: {test_tasks[0]}" + ) if segment == "train": return train_tasks elif segment == "test": diff --git a/qlib/contrib/meta/data_selection/model.py b/qlib/contrib/meta/data_selection/model.py index ed3ff9397e..a765af3315 100644 --- a/qlib/contrib/meta/data_selection/model.py +++ b/qlib/contrib/meta/data_selection/model.py @@ -98,7 +98,9 @@ def run_epoch(self, phase, task_list, epoch, opt, loss_l, ignore_weight=False): try: loss = criterion(pred, meta_input["y_test"], meta_input["test_idx"]) except ValueError as e: - get_module_logger("MetaModelDS").warning(f"Exception `{e}` when calculating IC loss") + get_module_logger("MetaModelDS").warning( + f"Exception `{e}` when calculating IC loss" + ) continue else: raise ValueError(f"Unknown criterion: {self.criterion}") @@ -115,8 +117,13 @@ def run_epoch(self, phase, task_list, epoch, opt, loss_l, ignore_weight=False): pred_y_all.append( pd.DataFrame( { - "pred": pd.Series(pred.detach().cpu().numpy(), index=meta_input["test_idx"]), - "label": pd.Series(meta_input["y_test"].detach().cpu().numpy(), index=meta_input["test_idx"]), + "pred": pd.Series( + pred.detach().cpu().numpy(), index=meta_input["test_idx"] + ), + "label": pd.Series( + meta_input["y_test"].detach().cpu().numpy(), + index=meta_input["test_idx"], + ), } ) ) @@ -145,7 +152,17 @@ def fit(self, meta_dataset: MetaDatasetDS): """ if not self.fitted: - for k in set(["lr", "step", "hist_step_n", "clip_method", "clip_weight", "criterion", "max_epoch"]): + for k in set( + [ + "lr", + "step", + "hist_step_n", + "clip_method", + "clip_weight", + "criterion", + "max_epoch", + ] + ): R.log_params(**{k: getattr(self, k)}) # FIXME: get test tasks for just checking the performance @@ -154,7 +171,11 @@ def fit(self, meta_dataset: MetaDatasetDS): if len(meta_tasks_l[1]): R.log_params( - **dict(proxy_test_begin=meta_tasks_l[1][0].task["dataset"]["kwargs"]["segments"]["test"]) + **dict( + proxy_test_begin=meta_tasks_l[1][0].task["dataset"]["kwargs"][ + "segments" + ]["test"] + ) ) # debug: record when the test phase starts self.tn = PredNet( @@ -169,7 +190,9 @@ def fit(self, meta_dataset: MetaDatasetDS): # run weight with no weight for phase, task_list in zip(phases, meta_tasks_l): - self.run_epoch(f"{phase}_noweight", task_list, 0, opt, {}, ignore_weight=True) + self.run_epoch( + f"{phase}_noweight", task_list, 0, opt, {}, ignore_weight=True + ) self.run_epoch(f"{phase}_init", task_list, 0, opt, {}) # run training @@ -184,7 +207,9 @@ def _prepare_task(self, task: MetaTask) -> dict: meta_ipt = task.get_meta_input() weights = self.tn.twm(meta_ipt["time_perf"]) - weight_s = pd.Series(weights.detach().cpu().numpy(), index=task.meta_info.columns) + weight_s = pd.Series( + weights.detach().cpu().numpy(), index=task.meta_info.columns + ) task = copy.copy(task.task) # NOTE: this is a shallow copy. task["reweighter"] = TimeReweighter(weight_s) return task diff --git a/qlib/contrib/meta/data_selection/net.py b/qlib/contrib/meta/data_selection/net.py index fce19df3e2..a56d564121 100644 --- a/qlib/contrib/meta/data_selection/net.py +++ b/qlib/contrib/meta/data_selection/net.py @@ -18,7 +18,9 @@ def __init__(self, hist_step_n, clip_weight=None, clip_method="clamp"): def forward(self, time_perf, time_belong=None, return_preds=False): hist_step_n = self.linear.in_features # NOTE: the reshape order is very important - time_perf = time_perf.reshape(hist_step_n, time_perf.shape[0] // hist_step_n, *time_perf.shape[1:]) + time_perf = time_perf.reshape( + hist_step_n, time_perf.shape[0] // hist_step_n, *time_perf.shape[1:] + ) time_perf = torch.mean(time_perf, dim=1, keepdim=False) preds = [] @@ -33,7 +35,9 @@ def forward(self, time_perf, time_belong=None, return_preds=False): else: return time_belong @ preds else: - weights = preds_to_weight_with_clamp(preds, self.clip_weight, self.clip_method) + weights = preds_to_weight_with_clamp( + preds, self.clip_weight, self.clip_method + ) if time_belong is None: return weights else: @@ -41,7 +45,14 @@ def forward(self, time_perf, time_belong=None, return_preds=False): class PredNet(nn.Module): - def __init__(self, step, hist_step_n, clip_weight=None, clip_method="tanh", alpha: float = 0.0): + def __init__( + self, + step, + hist_step_n, + clip_weight=None, + clip_method="tanh", + alpha: float = 0.0, + ): """ Parameters ---------- @@ -50,7 +61,9 @@ def __init__(self, step, hist_step_n, clip_weight=None, clip_method="tanh", alph """ super().__init__() self.step = step - self.twm = TimeWeightMeta(hist_step_n=hist_step_n, clip_weight=clip_weight, clip_method=clip_method) + self.twm = TimeWeightMeta( + hist_step_n=hist_step_n, clip_weight=clip_weight, clip_method=clip_method + ) self.init_paramters(hist_step_n) self.alpha = alpha @@ -64,11 +77,15 @@ def get_sample_weights(self, X, time_perf, time_belong, ignore_weight=False): def forward(self, X, y, time_perf, time_belong, X_test, ignore_weight=False): """Please refer to the docs of MetaTaskDS for the description of the variables""" - weights = self.get_sample_weights(X, time_perf, time_belong, ignore_weight=ignore_weight) + weights = self.get_sample_weights( + X, time_perf, time_belong, ignore_weight=ignore_weight + ) X_w = X.T * weights.view(1, -1) theta = torch.inverse(X_w @ X + self.alpha * torch.eye(X_w.shape[0])) @ X_w @ y return X_test @ theta, weights def init_paramters(self, hist_step_n): - self.twm.linear.weight.data = 1.0 / hist_step_n + self.twm.linear.weight.data * 0.01 + self.twm.linear.weight.data = ( + 1.0 / hist_step_n + self.twm.linear.weight.data * 0.01 + ) self.twm.linear.bias.data.fill_(0.0) diff --git a/qlib/contrib/meta/data_selection/utils.py b/qlib/contrib/meta/data_selection/utils.py index 2fddb00963..df7b57b56b 100644 --- a/qlib/contrib/meta/data_selection/utils.py +++ b/qlib/contrib/meta/data_selection/utils.py @@ -49,7 +49,9 @@ def forward(self, pred, y, idx): continue ic_day = torch.dot( - (pred_focus - pred_focus.mean()) / np.sqrt(pred_focus.shape[0]) / pred_focus.std(), + (pred_focus - pred_focus.mean()) + / np.sqrt(pred_focus.shape[0]) + / pred_focus.std(), (y_focus - y_focus.mean()) / np.sqrt(y_focus.shape[0]) / y_focus.std(), ) ic_all += ic_day @@ -87,7 +89,9 @@ def preds_to_weight_with_clamp(preds, clip_weight=None, clip_method="tanh"): weights = torch.ones_like(preds) else: sm = nn.Sigmoid() - weights = sm(preds) * clip_weight # TODO: The clip_weight is useless here. + weights = ( + sm(preds) * clip_weight + ) # TODO: The clip_weight is useless here. weights = weights / torch.sum(weights) * weights.numel() else: raise ValueError("Unknown clip_method") diff --git a/qlib/contrib/model/__init__.py b/qlib/contrib/model/__init__.py index 5d4d5f2e69..836f3ebab8 100644 --- a/qlib/contrib/model/__init__.py +++ b/qlib/contrib/model/__init__.py @@ -4,7 +4,9 @@ from .catboost_model import CatBoostModel except ModuleNotFoundError: CatBoostModel = None - print("ModuleNotFoundError. CatBoostModel are skipped. (optional: maybe installing CatBoostModel can fix it.)") + print( + "ModuleNotFoundError. CatBoostModel are skipped. (optional: maybe installing CatBoostModel can fix it.)" + ) try: from .double_ensemble import DEnsembleModel from .gbdt import LGBModel @@ -17,12 +19,16 @@ from .xgboost import XGBModel except ModuleNotFoundError: XGBModel = None - print("ModuleNotFoundError. XGBModel is skipped(optional: maybe installing xgboost can fix it).") + print( + "ModuleNotFoundError. XGBModel is skipped(optional: maybe installing xgboost can fix it)." + ) try: from .linear import LinearModel except ModuleNotFoundError: LinearModel = None - print("ModuleNotFoundError. LinearModel is skipped(optional: maybe installing scipy and sklearn can fix it).") + print( + "ModuleNotFoundError. LinearModel is skipped(optional: maybe installing scipy and sklearn can fix it)." + ) # import pytorch models try: from .pytorch_alstm import ALSTM @@ -35,9 +41,27 @@ from .pytorch_tcn import TCN from .pytorch_add import ADD - pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model, TCN, ADD) + pytorch_classes = ( + ALSTM, + GATs, + GRU, + LSTM, + DNNModelPytorch, + TabnetModel, + SFM_Model, + TCN, + ADD, + ) except ModuleNotFoundError: pytorch_classes = () - print("ModuleNotFoundError. PyTorch models are skipped (optional: maybe installing pytorch can fix it).") + print( + "ModuleNotFoundError. PyTorch models are skipped (optional: maybe installing pytorch can fix it)." + ) -all_model_classes = (CatBoostModel, DEnsembleModel, LGBModel, XGBModel, LinearModel) + pytorch_classes +all_model_classes = ( + CatBoostModel, + DEnsembleModel, + LGBModel, + XGBModel, + LinearModel, +) + pytorch_classes diff --git a/qlib/contrib/model/catboost_model.py b/qlib/contrib/model/catboost_model.py index 4fc1c6f893..cfac627d2c 100644 --- a/qlib/contrib/model/catboost_model.py +++ b/qlib/contrib/model/catboost_model.py @@ -41,13 +41,17 @@ def fit( data_key=DataHandlerLP.DK_L, ) if df_train.empty or df_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) x_train, y_train = df_train["feature"], df_train["label"] x_valid, y_valid = df_valid["feature"], df_valid["label"] # CatBoost needs 1D array as its label if y_train.values.ndim == 2 and y_train.values.shape[1] == 1: - y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze(y_valid.values) + y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze( + y_valid.values + ) else: raise ValueError("CatBoost doesn't support multi-label training") @@ -80,7 +84,9 @@ def fit( def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if self.model is None: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare( + segment, col_set="feature", data_key=DataHandlerLP.DK_I + ) return pd.Series(self.model.predict(x_test.values), index=x_test.index) def get_feature_importance(self, *args, **kwargs) -> pd.Series: @@ -92,7 +98,8 @@ def get_feature_importance(self, *args, **kwargs) -> pd.Series: https://catboost.ai/docs/concepts/python-reference_catboost_get_feature_importance.html#python-reference_catboost_get_feature_importance """ return pd.Series( - data=self.model.get_feature_importance(*args, **kwargs), index=self.model.feature_names_ + data=self.model.get_feature_importance(*args, **kwargs), + index=self.model.feature_names_, ).sort_values(ascending=False) diff --git a/qlib/contrib/model/double_ensemble.py b/qlib/contrib/model/double_ensemble.py index 85d4418f4d..d002ab281e 100644 --- a/qlib/contrib/model/double_ensemble.py +++ b/qlib/contrib/model/double_ensemble.py @@ -33,7 +33,9 @@ def __init__( early_stopping_rounds=None, **kwargs, ): - self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm" + self.base_model = ( + base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm" + ) self.num_models = num_models # the number of sub-models self.enable_sr = enable_sr self.enable_fs = enable_fs @@ -55,8 +57,12 @@ def __init__( self.epochs = epochs self.logger = get_module_logger("DEnsembleModel") self.logger.info("Double Ensemble Model...") - self.ensemble = [] # the current ensemble model, a list contains all the sub-models - self.sub_features = [] # the features for each sub model in the form of pandas.Index + self.ensemble = ( + [] + ) # the current ensemble model, a list contains all the sub-models + self.sub_features = ( + [] + ) # the features for each sub model in the form of pandas.Index self.params = {"objective": loss} self.params.update(kwargs) self.loss = loss @@ -64,21 +70,29 @@ def __init__( def fit(self, dataset: DatasetH): df_train, df_valid = dataset.prepare( - ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ["train", "valid"], + col_set=["feature", "label"], + data_key=DataHandlerLP.DK_L, ) if df_train.empty or df_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) x_train, y_train = df_train["feature"], df_train["label"] # initialize the sample weights N, F = x_train.shape weights = pd.Series(np.ones(N, dtype=float)) # initialize the features features = x_train.columns - pred_sub = pd.DataFrame(np.zeros((N, self.num_models), dtype=float), index=x_train.index) + pred_sub = pd.DataFrame( + np.zeros((N, self.num_models), dtype=float), index=x_train.index + ) # train sub-models for k in range(self.num_models): self.sub_features.append(features) - self.logger.info("Training sub-model: ({}/{})".format(k + 1, self.num_models)) + self.logger.info( + "Training sub-model: ({}/{})".format(k + 1, self.num_models) + ) model_k = self.train_submodel(df_train, df_valid, weights, features) self.ensemble.append(model_k) # no further sample re-weight and feature selection needed for the last sub-model @@ -89,10 +103,12 @@ def fit(self, dataset: DatasetH): loss_curve = self.retrieve_loss_curve(model_k, df_train, features) pred_k = self.predict_sub(model_k, df_train, features) pred_sub.iloc[:, k] = pred_k - pred_ensemble = (pred_sub.iloc[:, : k + 1] * self.sub_weights[0 : k + 1]).sum(axis=1) / np.sum( - self.sub_weights[0 : k + 1] + pred_ensemble = ( + pred_sub.iloc[:, : k + 1] * self.sub_weights[0 : k + 1] + ).sum(axis=1) / np.sum(self.sub_weights[0 : k + 1]) + loss_values = pd.Series( + self.get_loss(y_train.values.squeeze(), pred_ensemble.values) ) - loss_values = pd.Series(self.get_loss(y_train.values.squeeze(), pred_ensemble.values)) if self.enable_sr: self.logger.info("Sample re-weighting...") @@ -190,17 +206,24 @@ def feature_selection(self, df_train, loss_values): # shuffle specific columns and calculate g-value for each feature x_train_tmp = x_train.copy() for i_f, feat in enumerate(features): - x_train_tmp.loc[:, feat] = np.random.permutation(x_train_tmp.loc[:, feat].values) + x_train_tmp.loc[:, feat] = np.random.permutation( + x_train_tmp.loc[:, feat].values + ) pred = pd.Series(np.zeros(N), index=x_train_tmp.index) for i_s, submodel in enumerate(self.ensemble): pred += ( pd.Series( - submodel.predict(x_train_tmp.loc[:, self.sub_features[i_s]].values), index=x_train_tmp.index + submodel.predict( + x_train_tmp.loc[:, self.sub_features[i_s]].values + ), + index=x_train_tmp.index, ) / M ) loss_feat = self.get_loss(y_train.values.squeeze(), pred.values) - g.loc[i_f, "g_value"] = np.mean(loss_feat - loss_values) / (np.std(loss_feat - loss_values) + 1e-7) + g.loc[i_f, "g_value"] = np.mean(loss_feat - loss_values) / ( + np.std(loss_feat - loss_values) + 1e-7 + ) x_train_tmp.loc[:, feat] = x_train.loc[:, feat].copy() # one column in train features is all-nan # if g['g_value'].isna().any() @@ -215,7 +238,10 @@ def feature_selection(self, df_train, loss_values): for i_b, b in enumerate(sorted_bins): b_feat = features[g["bins"] == b] num_feat = int(np.ceil(self.sample_ratios[i_b] * len(b_feat))) - res_feat = res_feat + np.random.choice(b_feat, size=num_feat, replace=False).tolist() + res_feat = ( + res_feat + + np.random.choice(b_feat, size=num_feat, replace=False).tolist() + ) return pd.Index(set(res_feat)) def get_loss(self, label, pred): @@ -238,7 +264,9 @@ def retrieve_loss_curve(self, model, df_train, features): loss_curve = pd.DataFrame(np.zeros((N, num_trees))) pred_tree = np.zeros(N, dtype=float) for i_tree in range(num_trees): - pred_tree += model.predict(x_train.values, start_iteration=i_tree, num_iteration=1) + pred_tree += model.predict( + x_train.values, start_iteration=i_tree, num_iteration=1 + ) loss_curve.iloc[:, i_tree] = self.get_loss(y_train, pred_tree) else: raise ValueError("not implemented yet") @@ -247,12 +275,16 @@ def retrieve_loss_curve(self, model, df_train, features): def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if self.ensemble is None: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare( + segment, col_set="feature", data_key=DataHandlerLP.DK_I + ) pred = pd.Series(np.zeros(x_test.shape[0]), index=x_test.index) for i_sub, submodel in enumerate(self.ensemble): feat_sub = self.sub_features[i_sub] pred += ( - pd.Series(submodel.predict(x_test.loc[:, feat_sub].values), index=x_test.index) + pd.Series( + submodel.predict(x_test.loc[:, feat_sub].values), index=x_test.index + ) * self.sub_weights[i_sub] ) pred = pred / np.sum(self.sub_weights) @@ -273,5 +305,13 @@ def get_feature_importance(self, *args, **kwargs) -> pd.Series: """ res = [] for _model, _weight in zip(self.ensemble, self.sub_weights): - res.append(pd.Series(_model.feature_importance(*args, **kwargs), index=_model.feature_name()) * _weight) - return pd.concat(res, axis=1, sort=False).sum(axis=1).sort_values(ascending=False) + res.append( + pd.Series( + _model.feature_importance(*args, **kwargs), + index=_model.feature_name(), + ) + * _weight + ) + return ( + pd.concat(res, axis=1, sort=False).sum(axis=1).sort_values(ascending=False) + ) diff --git a/qlib/contrib/model/gbdt.py b/qlib/contrib/model/gbdt.py index f14205f888..006b77ac4c 100644 --- a/qlib/contrib/model/gbdt.py +++ b/qlib/contrib/model/gbdt.py @@ -16,7 +16,9 @@ class LGBModel(ModelFT, LightGBMFInt): """LightGBM Model""" - def __init__(self, loss="mse", early_stopping_rounds=50, num_boost_round=1000, **kwargs): + def __init__( + self, loss="mse", early_stopping_rounds=50, num_boost_round=1000, **kwargs + ): if loss not in {"mse", "binary"}: raise NotImplementedError self.params = {"objective": loss, "verbosity": -1} @@ -25,7 +27,9 @@ def __init__(self, loss="mse", early_stopping_rounds=50, num_boost_round=1000, * self.num_boost_round = num_boost_round self.model = None - def _prepare_data(self, dataset: DatasetH, reweighter=None) -> List[Tuple[lgb.Dataset, str]]: + def _prepare_data( + self, dataset: DatasetH, reweighter=None + ) -> List[Tuple[lgb.Dataset, str]]: """ The motivation of current version is to make validation optional - train segment is necessary; @@ -34,9 +38,13 @@ def _prepare_data(self, dataset: DatasetH, reweighter=None) -> List[Tuple[lgb.Da assert "train" in dataset.segments for key in ["train", "valid"]: if key in dataset.segments: - df = dataset.prepare(key, col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + df = dataset.prepare( + key, col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) if df.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) x, y = df["feature"], df["label"] # Lightgbm need 1D array as its label @@ -69,7 +77,9 @@ def fit( ds_l = self._prepare_data(dataset, reweighter) ds, names = list(zip(*ds_l)) early_stopping_callback = lgb.early_stopping( - self.early_stopping_rounds if early_stopping_rounds is None else early_stopping_rounds + self.early_stopping_rounds + if early_stopping_rounds is None + else early_stopping_rounds ) # NOTE: if you encounter error here. Please upgrade your lightgbm verbose_eval_callback = lgb.log_evaluation(period=verbose_eval) @@ -77,10 +87,16 @@ def fit( self.model = lgb.train( self.params, ds[0], # training dataset - num_boost_round=self.num_boost_round if num_boost_round is None else num_boost_round, + num_boost_round=( + self.num_boost_round if num_boost_round is None else num_boost_round + ), valid_sets=ds, valid_names=names, - callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback], + callbacks=[ + early_stopping_callback, + verbose_eval_callback, + evals_result_callback, + ], **kwargs, ) for k in names: @@ -92,10 +108,14 @@ def fit( def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if self.model is None: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare( + segment, col_set="feature", data_key=DataHandlerLP.DK_I + ) return pd.Series(self.model.predict(x_test.values), index=x_test.index) - def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20, reweighter=None): + def finetune( + self, dataset: DatasetH, num_boost_round=10, verbose_eval=20, reweighter=None + ): """ finetune model @@ -111,7 +131,9 @@ def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20, rewei # Based on existing model and finetune by train more rounds dtrain, _ = self._prepare_data(dataset, reweighter) # pylint: disable=W0632 if dtrain.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) verbose_eval_callback = lgb.log_evaluation(period=verbose_eval) self.model = lgb.train( self.params, diff --git a/qlib/contrib/model/highfreq_gdbt_model.py b/qlib/contrib/model/highfreq_gdbt_model.py index ad0641136f..6d76a531ed 100644 --- a/qlib/contrib/model/highfreq_gdbt_model.py +++ b/qlib/contrib/model/highfreq_gdbt_model.py @@ -31,7 +31,9 @@ def _cal_signal_metrics(self, y_test, l_cut, r_cut): for date in y_test.index.get_level_values(0).unique(): df_res = y_test.loc[date].sort_values("pred") if int(l_cut * len(df_res)) < 10: - warnings.warn("Warning: threhold is too low or instruments number is not enough") + warnings.warn( + "Warning: threhold is too low or instruments number is not enough" + ) continue top = df_res.iloc[: int(l_cut * len(df_res))] bottom = df_res.iloc[int(r_cut * len(df_res)) :] @@ -60,30 +62,44 @@ def hf_signal_test(self, dataset: DatasetH, threhold=0.2): """ if self.model is None: raise ValueError("Model hasn't been trained yet") - df_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) + df_test = dataset.prepare( + "test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I + ) df_test.dropna(inplace=True) x_test, y_test = df_test["feature"], df_test["label"] # Convert label into alpha - y_test[y_test.columns[0]] = y_test[y_test.columns[0]] - y_test[y_test.columns[0]].mean(level=0) + y_test[y_test.columns[0]] = y_test[y_test.columns[0]] - y_test[ + y_test.columns[0] + ].mean(level=0) res = pd.Series(self.model.predict(x_test.values), index=x_test.index) y_test["pred"] = res - up_p, down_p, up_a, down_a = self._cal_signal_metrics(y_test, threhold, 1 - threhold) + up_p, down_p, up_a, down_a = self._cal_signal_metrics( + y_test, threhold, 1 - threhold + ) print("===============================") print("High frequency signal test") print("===============================") print("Test set precision: ") print("Positive precision: {}, Negative precision: {}".format(up_p, down_p)) print("Test Alpha Average in test set: ") - print("Positive average alpha: {}, Negative average alpha: {}".format(up_a, down_a)) + print( + "Positive average alpha: {}, Negative average alpha: {}".format( + up_a, down_a + ) + ) def _prepare_data(self, dataset: DatasetH): df_train, df_valid = dataset.prepare( - ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ["train", "valid"], + col_set=["feature", "label"], + data_key=DataHandlerLP.DK_L, ) if df_train.empty or df_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) x_train, y_train = df_train["feature"], df_train["label"] x_valid, y_valid = df_valid["feature"], df_valid["label"] @@ -92,11 +108,15 @@ def _prepare_data(self, dataset: DatasetH): # Convert label into alpha df_train.loc[:, ("label", l_name)] = ( df_train.loc[:, ("label", l_name)] - - df_train.loc[:, ("label", l_name)].groupby(level=0, group_keys=False).mean() + - df_train.loc[:, ("label", l_name)] + .groupby(level=0, group_keys=False) + .mean() ) df_valid.loc[:, ("label", l_name)] = ( df_valid.loc[:, ("label", l_name)] - - df_valid.loc[:, ("label", l_name)].groupby(level=0, group_keys=False).mean() + - df_valid.loc[:, ("label", l_name)] + .groupby(level=0, group_keys=False) + .mean() ) def mapping_fn(x): @@ -133,7 +153,11 @@ def fit( num_boost_round=num_boost_round, valid_sets=[dtrain, dvalid], valid_names=["train", "valid"], - callbacks=[early_stopping_callback, verbose_eval_callback, evals_result_callback], + callbacks=[ + early_stopping_callback, + verbose_eval_callback, + evals_result_callback, + ], ) evals_result["train"] = list(evals_result["train"].values())[0] evals_result["valid"] = list(evals_result["valid"].values())[0] diff --git a/qlib/contrib/model/linear.py b/qlib/contrib/model/linear.py index 15cdb739e9..a3b4f65fa8 100644 --- a/qlib/contrib/model/linear.py +++ b/qlib/contrib/model/linear.py @@ -30,7 +30,13 @@ class LinearModel(Model): RIDGE = "ridge" LASSO = "lasso" - def __init__(self, estimator="ols", alpha=0.0, fit_intercept=False, include_valid: bool = False): + def __init__( + self, + estimator="ols", + alpha=0.0, + fit_intercept=False, + include_valid: bool = False, + ): """ Parameters ---------- @@ -44,10 +50,18 @@ def __init__(self, estimator="ols", alpha=0.0, fit_intercept=False, include_vali Should the validation data be included for training? The validation data should be included """ - assert estimator in [self.OLS, self.NNLS, self.RIDGE, self.LASSO], f"unsupported estimator `{estimator}`" + assert estimator in [ + self.OLS, + self.NNLS, + self.RIDGE, + self.LASSO, + ], f"unsupported estimator `{estimator}`" self.estimator = estimator - assert alpha == 0 or estimator in [self.RIDGE, self.LASSO], f"alpha is only supported in `ridge`&`lasso`" + assert alpha == 0 or estimator in [ + self.RIDGE, + self.LASSO, + ], f"alpha is only supported in `ridge`&`lasso`" self.alpha = alpha self.fit_intercept = fit_intercept @@ -56,16 +70,24 @@ def __init__(self, estimator="ols", alpha=0.0, fit_intercept=False, include_vali self.include_valid = include_valid def fit(self, dataset: DatasetH, reweighter: Reweighter = None): - df_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + df_train = dataset.prepare( + "train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) if self.include_valid: try: - df_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + df_valid = dataset.prepare( + "valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) df_train = pd.concat([df_train, df_valid]) except KeyError: - get_module_logger("LinearModel").info("include_valid=True, but valid does not exist") + get_module_logger("LinearModel").info( + "include_valid=True, but valid does not exist" + ) df_train = df_train.dropna() if df_train.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) if reweighter is not None: w: pd.Series = reweighter.reweight(df_train) w = w.values @@ -109,5 +131,9 @@ def _fit_nnls(self, X, y, w=None): def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if self.coef_ is None: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) - return pd.Series(x_test.values @ self.coef_ + self.intercept_, index=x_test.index) + x_test = dataset.prepare( + segment, col_set="feature", data_key=DataHandlerLP.DK_I + ) + return pd.Series( + x_test.values @ self.coef_ + self.intercept_, index=x_test.index + ) diff --git a/qlib/contrib/model/pytorch_adarnn.py b/qlib/contrib/model/pytorch_adarnn.py index c1585a6ac0..10e8d33343 100644 --- a/qlib/contrib/model/pytorch_adarnn.py +++ b/qlib/contrib/model/pytorch_adarnn.py @@ -81,7 +81,9 @@ def __init__( self.optimizer = optimizer.lower() self.loss = loss self.n_splits = n_splits - self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.seed = seed self.logger.info( @@ -141,7 +143,9 @@ def __init__( elif optimizer.lower() == "gd": self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) self.fitted = False self.model.to(self.device) @@ -162,7 +166,10 @@ def train_AdaRNN(self, train_loader_list, epoch, dist_old=None, weight_mat=None) list_label = [] for data in data_all: # feature :[36, 24, 6] - feature, label_reg = data[0].to(self.device).float(), data[1].to(self.device).float() + feature, label_reg = ( + data[0].to(self.device).float(), + data[1].to(self.device).float(), + ) list_feat.append(feature) list_label.append(label_reg) flag = False @@ -185,11 +192,13 @@ def train_AdaRNN(self, train_loader_list, epoch, dist_old=None, weight_mat=None) feature_all = torch.cat((feature_s, feature_t), 0) if epoch < self.pre_epoch: - pred_all, loss_transfer, out_weight_list = self.model.forward_pre_train( - feature_all, len_win=self.len_win + pred_all, loss_transfer, out_weight_list = ( + self.model.forward_pre_train(feature_all, len_win=self.len_win) ) else: - pred_all, loss_transfer, dist, weight_mat = self.model.forward_Boosting(feature_all, weight_mat) + pred_all, loss_transfer, dist, weight_mat = ( + self.model.forward_Boosting(feature_all, weight_mat) + ) dist_mat = dist_mat + dist pred_s = pred_all[0 : feature_s.size(0)] pred_t = pred_all[feature_s.size(0) :] @@ -204,7 +213,9 @@ def train_AdaRNN(self, train_loader_list, epoch, dist_old=None, weight_mat=None) self.train_optimizer.step() if epoch >= self.pre_epoch: if epoch > self.pre_epoch: - weight_mat = self.model.update_weight_Boosting(weight_mat, dist_old, dist_mat) + weight_mat = self.model.update_weight_Boosting( + weight_mat, dist_old, dist_mat + ) return weight_mat, dist_mat else: weight_mat = self.transform_type(out_weight_list) @@ -214,7 +225,9 @@ def train_AdaRNN(self, train_loader_list, epoch, dist_old=None, weight_mat=None) def calc_all_metrics(pred): """pred is a pandas dataframe that has two attributes: score (pred) and label (real)""" res = {} - ic = pred.groupby(level="datetime", group_keys=False).apply(lambda x: x.label.corr(x.score)) + ic = pred.groupby(level="datetime", group_keys=False).apply( + lambda x: x.label.corr(x.score) + ) rank_ic = pred.groupby(level="datetime", group_keys=False).apply( lambda x: x.label.corr(x.score, method="spearman") ) @@ -254,7 +267,9 @@ def fit( days = df_train.index.get_level_values(level=0).unique() train_splits = np.array_split(days, self.n_splits) train_splits = [df_train[s[0] : s[-1]] for s in train_splits] - train_loader_list = [get_stock_loader(df, self.batch_size) for df in train_splits] + train_loader_list = [ + get_stock_loader(df, self.batch_size) for df in train_splits + ] save_path = get_or_create_path(save_path) stop_steps = 0 @@ -271,7 +286,9 @@ def fit( for step in range(self.n_epochs): self.logger.info("Epoch%d:", step) self.logger.info("training...") - weight_mat, dist_mat = self.train_AdaRNN(train_loader_list, step, dist_mat, weight_mat) + weight_mat, dist_mat = self.train_AdaRNN( + train_loader_list, step, dist_mat, weight_mat + ) self.logger.info("evaluating...") train_metrics = self.test_epoch(df_train) valid_metrics = self.test_epoch(df_valid) @@ -304,7 +321,9 @@ def fit( def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare( + segment, col_set="feature", data_key=DataHandlerLP.DK_I + ) return self.infer(x_test) def infer(self, x_test): @@ -344,9 +363,12 @@ def __init__(self, df): self.df_label_reg = df["label"] self.df_index = df.index self.df_feature = torch.tensor( - self.df_feature.values.reshape(-1, 6, 60).transpose(0, 2, 1), dtype=torch.float32 + self.df_feature.values.reshape(-1, 6, 60).transpose(0, 2, 1), + dtype=torch.float32, + ) + self.df_label_reg = torch.tensor( + self.df_label_reg.values.reshape(-1), dtype=torch.float32 ) - self.df_label_reg = torch.tensor(self.df_label_reg.values.reshape(-1), dtype=torch.float32) def __getitem__(self, index): sample, label_reg = self.df_feature[index], self.df_label_reg[index] @@ -396,12 +418,20 @@ def __init__( self.model_type = model_type self.trans_loss = trans_loss self.len_seq = len_seq - self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) in_size = self.n_input features = nn.ModuleList() for hidden in n_hiddens: - rnn = nn.GRU(input_size=in_size, num_layers=1, hidden_size=hidden, batch_first=True, dropout=dropout) + rnn = nn.GRU( + input_size=in_size, + num_layers=1, + hidden_size=hidden, + batch_first=True, + dropout=dropout, + ) features.append(rnn) in_size = hidden self.features = nn.Sequential(*features) @@ -455,7 +485,9 @@ def forward_pre_train(self, x, len_win=0): out_list_s, out_list_t = self.get_features(out_list_all) loss_transfer = torch.zeros((1,)).to(self.device) for i, n in enumerate(out_list_s): - criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2]) + criterion_transder = TransferLoss( + loss_type=self.trans_loss, input_dim=n.shape[2] + ) h_start = 0 for j in range(h_start, self.len_seq, 1): i_start = j - len_win if j - len_win >= 0 else 0 @@ -517,14 +549,20 @@ def forward_Boosting(self, x, weight_mat=None): out_list_s, out_list_t = self.get_features(out_list_all) loss_transfer = torch.zeros((1,)).to(self.device) if weight_mat is None: - weight = (1.0 / self.len_seq * torch.ones(self.num_layers, self.len_seq)).to(self.device) + weight = ( + 1.0 / self.len_seq * torch.ones(self.num_layers, self.len_seq) + ).to(self.device) else: weight = weight_mat dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device) for i, n in enumerate(out_list_s): - criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2]) + criterion_transder = TransferLoss( + loss_type=self.trans_loss, input_dim=n.shape[2] + ) for j in range(self.len_seq): - loss_trans = criterion_transder.compute(n[:, j, :], out_list_t[i][:, j, :]) + loss_trans = criterion_transder.compute( + n[:, j, :], out_list_t[i][:, j, :] + ) loss_transfer = loss_transfer + weight[i, j] * loss_trans dist_mat[i, j] = loss_trans return fc_out, loss_transfer, dist_mat, weight @@ -535,7 +573,9 @@ def update_weight_Boosting(self, weight_mat, dist_old, dist_new): dist_old = dist_old.detach() dist_new = dist_new.detach() ind = dist_new > dist_old + epsilon - weight_mat[ind] = weight_mat[ind] * (1 + torch.sigmoid(dist_new[ind] - dist_old[ind])) + weight_mat[ind] = weight_mat[ind] * ( + 1 + torch.sigmoid(dist_new[ind] - dist_old[ind]) + ) weight_norm = torch.norm(weight_mat, dim=1, p=1) weight_mat = weight_mat / weight_norm.t().unsqueeze(1).repeat(1, self.len_seq) return weight_mat @@ -558,7 +598,9 @@ def __init__(self, loss_type="cosine", input_dim=512, GPU=0): """ self.loss_type = loss_type self.input_dim = input_dim - self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) def compute(self, X, Y): """Compute adaptation loss @@ -583,7 +625,9 @@ def compute(self, X, Y): elif self.loss_type == "js": loss = js(X, Y) elif self.loss_type == "mine": - mine_model = Mine_estimator(input_dim=self.input_dim, hidden_dim=60).to(self.device) + mine_model = Mine_estimator(input_dim=self.input_dim, hidden_dim=60).to( + self.device + ) loss = mine_model(X, Y) elif self.loss_type == "adv": loss = adv(X, Y, self.device, input_dim=self.input_dim, hidden_dim=32) @@ -637,12 +681,16 @@ def adv(source, target, device, input_dim=256, hidden_dim=512): adv_net = Discriminator(input_dim, hidden_dim).to(device) domain_src = torch.ones(len(source)).to(device) domain_tar = torch.zeros(len(target)).to(device) - domain_src, domain_tar = domain_src.view(domain_src.shape[0], 1), domain_tar.view(domain_tar.shape[0], 1) + domain_src, domain_tar = domain_src.view(domain_src.shape[0], 1), domain_tar.view( + domain_tar.shape[0], 1 + ) reverse_src = ReverseLayerF.apply(source, 1) reverse_tar = ReverseLayerF.apply(target, 1) pred_src = adv_net(reverse_src) pred_tar = adv_net(reverse_tar) - loss_s, loss_t = domain_loss(pred_src, domain_src), domain_loss(pred_tar, domain_tar) + loss_s, loss_t = domain_loss(pred_src, domain_src), domain_loss( + pred_tar, domain_tar + ) loss = loss_s + loss_t return loss @@ -678,8 +726,12 @@ def __init__(self, kernel_type="linear", kernel_mul=2.0, kernel_num=5): def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): n_samples = int(source.size()[0]) + int(target.size()[0]) total = torch.cat([source, target], dim=0) - total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) - total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) + total0 = total.unsqueeze(0).expand( + int(total.size(0)), int(total.size(0)), int(total.size(1)) + ) + total1 = total.unsqueeze(1).expand( + int(total.size(0)), int(total.size(0)), int(total.size(1)) + ) L2_distance = ((total0 - total1) ** 2).sum(2) if fix_sigma: bandwidth = fix_sigma @@ -687,7 +739,10 @@ def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None bandwidth = torch.sum(L2_distance.data) / (n_samples**2 - n_samples) bandwidth /= kernel_mul ** (kernel_num // 2) bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)] - kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list] + kernel_val = [ + torch.exp(-L2_distance / bandwidth_temp) + for bandwidth_temp in bandwidth_list + ] return sum(kernel_val) @staticmethod @@ -702,7 +757,11 @@ def forward(self, source, target): elif self.kernel_type == "rbf": batch_size = int(source.size()[0]) kernels = self.guassian_kernel( - source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma + source, + target, + kernel_mul=self.kernel_mul, + kernel_num=self.kernel_num, + fix_sigma=self.fix_sigma, ) with torch.no_grad(): XX = torch.mean(kernels[:batch_size, :batch_size]) diff --git a/qlib/contrib/model/pytorch_add.py b/qlib/contrib/model/pytorch_add.py index c94a03ecc3..645fda0906 100644 --- a/qlib/contrib/model/pytorch_add.py +++ b/qlib/contrib/model/pytorch_add.py @@ -83,7 +83,9 @@ def __init__( self.optimizer = optimizer.lower() self.base_model = base_model self.model_path = model_path - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.seed = seed self.gamma = gamma @@ -148,14 +150,18 @@ def __init__( gamma_clip=self.gamma_clip, ) self.logger.info("model:\n{:}".format(self.ADD_model)) - self.logger.info("model size: {:.4f} MB".format(count_parameters(self.ADD_model))) + self.logger.info( + "model size: {:.4f} MB".format(count_parameters(self.ADD_model)) + ) if optimizer.lower() == "adam": self.train_optimizer = optim.Adam(self.ADD_model.parameters(), lr=self.lr) elif optimizer.lower() == "gd": self.train_optimizer = optim.SGD(self.ADD_model.parameters(), lr=self.lr) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) self.fitted = False self.ADD_model.to(self.device) @@ -177,10 +183,12 @@ def loss_pre_market(self, pred_market, label_market, record=None): record["pre_market_loss"] = pre_market_loss.item() return pre_market_loss - def loss_pre(self, pred_excess, label_excess, pred_market, label_market, record=None): - pre_loss = self.loss_pre_excess(pred_excess, label_excess, record) + self.loss_pre_market( - pred_market, label_market, record - ) + def loss_pre( + self, pred_excess, label_excess, pred_market, label_market, record=None + ): + pre_loss = self.loss_pre_excess( + pred_excess, label_excess, record + ) + self.loss_pre_market(pred_market, label_market, record) if record is not None: record["pre_loss"] = pre_loss.item() return pre_loss @@ -199,17 +207,25 @@ def loss_adv_market(self, adv_market, label_market, record=None): return adv_market_loss def loss_adv(self, adv_excess, label_excess, adv_market, label_market, record=None): - adv_loss = self.loss_adv_excess(adv_excess, label_excess, record) + self.loss_adv_market( - adv_market, label_market, record - ) + adv_loss = self.loss_adv_excess( + adv_excess, label_excess, record + ) + self.loss_adv_market(adv_market, label_market, record) if record is not None: record["adv_loss"] = adv_loss.item() return adv_loss def loss_fn(self, x, preds, label_excess, label_market, record=None): loss = ( - self.loss_pre(preds["excess"], label_excess, preds["market"], label_market, record) - + self.loss_adv(preds["adv_excess"], label_excess, preds["adv_market"], label_market, record) + self.loss_pre( + preds["excess"], label_excess, preds["market"], label_market, record + ) + + self.loss_adv( + preds["adv_excess"], + label_excess, + preds["adv_market"], + label_market, + record, + ) + self.mu * self.loss_rec(x, preds["reconstructed_feature"], record) ) if record is not None: @@ -288,8 +304,12 @@ def train_epoch(self, x_train_values, y_train_values, m_train_values): break batch = indices[i : i + self.batch_size] feature = torch.from_numpy(x_train_values[batch]).float().to(self.device) - label_excess = torch.from_numpy(y_train_values[batch]).float().to(self.device) - label_market = torch.from_numpy(m_train_values[batch]).long().to(self.device) + label_excess = ( + torch.from_numpy(y_train_values[batch]).float().to(self.device) + ) + label_market = ( + torch.from_numpy(m_train_values[batch]).long().to(self.device) + ) preds = self.ADD_model(feature) @@ -344,7 +364,9 @@ def bootstrap_fit(self, x_train, y_train, m_train, x_valid, y_valid, m_valid): break self.ADD_model.before_adv_excess.step_alpha() self.ADD_model.before_adv_market.step_alpha() - self.logger.info("bootstrap_fit best score: {:.6f} @ {}".format(best_score, best_epoch)) + self.logger.info( + "bootstrap_fit best score: {:.6f} @ {}".format(best_score, best_epoch) + ) self.ADD_model.load_state_dict(best_param) return best_score @@ -357,7 +379,9 @@ def gen_market_label(self, df, raw_label): return df def fit_thresh(self, train_label): - market_label = train_label.groupby("datetime", group_keys=False).mean().squeeze() + market_label = ( + train_label.groupby("datetime", group_keys=False).mean().squeeze() + ) self.lo, self.hi = market_label.quantile([1 / 3, 2 / 3]) def fit( @@ -380,8 +404,16 @@ def fit( df_train = self.gen_market_label(df_train, label_train) df_valid = self.gen_market_label(df_valid, label_valid) - x_train, y_train, m_train = df_train["feature"], df_train["label"], df_train["market_return"] - x_valid, y_valid, m_valid = df_valid["feature"], df_valid["label"], df_valid["market_return"] + x_train, y_train, m_train = ( + df_train["feature"], + df_train["label"], + df_train["market_return"], + ) + x_valid, y_valid, m_valid = ( + df_valid["feature"], + df_valid["label"], + df_valid["market_return"], + ) evals_result["train"] = [] evals_result["valid"] = [] @@ -396,14 +428,24 @@ def fit( if self.model_path is not None: self.logger.info("Loading pretrained model...") - pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device)) + pretrained_model.load_state_dict( + torch.load(self.model_path, map_location=self.device) + ) model_dict = self.ADD_model.enc_excess.state_dict() - pretrained_dict = {k: v for k, v in pretrained_model.rnn.state_dict().items() if k in model_dict} + pretrained_dict = { + k: v + for k, v in pretrained_model.rnn.state_dict().items() + if k in model_dict + } model_dict.update(pretrained_dict) self.ADD_model.enc_excess.load_state_dict(model_dict) model_dict = self.ADD_model.enc_market.state_dict() - pretrained_dict = {k: v for k, v in pretrained_model.rnn.state_dict().items() if k in model_dict} + pretrained_dict = { + k: v + for k, v in pretrained_model.rnn.state_dict().items() + if k in model_dict + } model_dict.update(pretrained_dict) self.ADD_model.enc_market.load_state_dict(model_dict) self.logger.info("Loading pretrained model Done...") @@ -417,7 +459,9 @@ def fit( torch.cuda.empty_cache() def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): - x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare( + segment, col_set="feature", data_key=DataHandlerLP.DK_I + ) index = x_test.index self.ADD_model.eval() x_values = x_test.values @@ -482,14 +526,26 @@ def __init__( ctx_size = hidden_size * num_layers self.pred_excess, self.adv_excess = [ - nn.Sequential(nn.Linear(ctx_size, ctx_size), nn.BatchNorm1d(ctx_size), nn.Tanh(), nn.Linear(ctx_size, 1)) + nn.Sequential( + nn.Linear(ctx_size, ctx_size), + nn.BatchNorm1d(ctx_size), + nn.Tanh(), + nn.Linear(ctx_size, 1), + ) for _ in range(2) ] self.adv_market, self.pred_market = [ - nn.Sequential(nn.Linear(ctx_size, ctx_size), nn.BatchNorm1d(ctx_size), nn.Tanh(), nn.Linear(ctx_size, 3)) + nn.Sequential( + nn.Linear(ctx_size, ctx_size), + nn.BatchNorm1d(ctx_size), + nn.Tanh(), + nn.Linear(ctx_size, 3), + ) for _ in range(2) ] - self.before_adv_market, self.before_adv_excess = [RevGrad(gamma, gamma_clip) for _ in range(2)] + self.before_adv_market, self.before_adv_excess = [ + RevGrad(gamma, gamma_clip) for _ in range(2) + ] def forward(self, x): x = x.reshape(len(x), self.d_feat, -1) @@ -509,9 +565,13 @@ def forward(self, x): predicts["excess"] = self.pred_excess(feature_excess).squeeze(1) predicts["market"] = self.pred_market(feature_market) predicts["adv_market"] = self.adv_market(self.before_adv_market(feature_excess)) - predicts["adv_excess"] = self.adv_excess(self.before_adv_excess(feature_market).squeeze(1)) + predicts["adv_excess"] = self.adv_excess( + self.before_adv_excess(feature_market).squeeze(1) + ) if self.base_model == "LSTM": - hidden = [torch.cat([hidden_excess[i], hidden_market[i]], -1) for i in range(2)] + hidden = [ + torch.cat([hidden_excess[i], hidden_market[i]], -1) for i in range(2) + ] else: hidden = torch.cat([hidden_excess, hidden_market], -1) x = torch.zeros_like(x[:, 1, :]) @@ -525,7 +585,9 @@ def forward(self, x): class Decoder(nn.Module): - def __init__(self, d_feat=6, hidden_size=128, num_layers=1, dropout=0.5, base_model="GRU"): + def __init__( + self, d_feat=6, hidden_size=128, num_layers=1, dropout=0.5, base_model="GRU" + ): super().__init__() self.base_model = base_model if base_model == "GRU": @@ -590,7 +652,10 @@ def __init__(self, gamma=0.1, gamma_clip=0.4, *args, **kwargs): def step_alpha(self): self._p += 1 self._alpha = min( - self.gamma_clip, torch.tensor(2 / (1 + math.exp(-self.gamma * self._p)) - 1, requires_grad=False) + self.gamma_clip, + torch.tensor( + 2 / (1 + math.exp(-self.gamma * self._p)) - 1, requires_grad=False + ), ) def forward(self, input_): diff --git a/qlib/contrib/model/pytorch_alstm.py b/qlib/contrib/model/pytorch_alstm.py index d1c619ebf4..244f95a859 100644 --- a/qlib/contrib/model/pytorch_alstm.py +++ b/qlib/contrib/model/pytorch_alstm.py @@ -70,7 +70,9 @@ def __init__( self.early_stop = early_stop self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.seed = seed self.logger.info( @@ -117,14 +119,18 @@ def __init__( dropout=self.dropout, ) self.logger.info("model:\n{:}".format(self.ALSTM_model)) - self.logger.info("model size: {:.4f} MB".format(count_parameters(self.ALSTM_model))) + self.logger.info( + "model size: {:.4f} MB".format(count_parameters(self.ALSTM_model)) + ) if optimizer.lower() == "adam": self.train_optimizer = optim.Adam(self.ALSTM_model.parameters(), lr=self.lr) elif optimizer.lower() == "gd": self.train_optimizer = optim.SGD(self.ALSTM_model.parameters(), lr=self.lr) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) self.fitted = False self.ALSTM_model.to(self.device) @@ -166,8 +172,16 @@ def train_epoch(self, x_train, y_train): if len(indices) - i < self.batch_size: break - feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device) - label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device) + feature = ( + torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) + label = ( + torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) pred = self.ALSTM_model(feature) loss = self.loss_fn(pred, label) @@ -193,8 +207,16 @@ def test_epoch(self, data_x, data_y): if len(indices) - i < self.batch_size: break - feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device) - label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device) + feature = ( + torch.from_numpy(x_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) + label = ( + torch.from_numpy(y_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) with torch.no_grad(): pred = self.ALSTM_model(feature) @@ -218,7 +240,9 @@ def fit( data_key=DataHandlerLP.DK_L, ) if df_train.empty or df_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) x_train, y_train = df_train["feature"], df_train["label"] x_valid, y_valid = df_valid["feature"], df_valid["label"] @@ -268,7 +292,9 @@ def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare( + segment, col_set="feature", data_key=DataHandlerLP.DK_I + ) index = x_test.index self.ALSTM_model.eval() x_values = x_test.values @@ -292,7 +318,9 @@ def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): class ALSTMModel(nn.Module): - def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, rnn_type="GRU"): + def __init__( + self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, rnn_type="GRU" + ): super().__init__() self.hid_size = hidden_size self.input_size = d_feat @@ -307,7 +335,9 @@ def _build_model(self): except Exception as e: raise ValueError("unknown rnn_type `%s`" % self.rnn_type) from e self.net = nn.Sequential() - self.net.add_module("fc_in", nn.Linear(in_features=self.input_size, out_features=self.hid_size)) + self.net.add_module( + "fc_in", nn.Linear(in_features=self.input_size, out_features=self.hid_size) + ) self.net.add_module("act", nn.Tanh()) self.rnn = klass( input_size=self.hid_size, @@ -333,8 +363,12 @@ def _build_model(self): def forward(self, inputs): # inputs: [batch_size, input_size*input_day] inputs = inputs.view(len(inputs), self.input_size, -1) - inputs = inputs.permute(0, 2, 1) # [batch, input_size, seq_len] -> [batch, seq_len, input_size] - rnn_out, _ = self.rnn(self.net(inputs)) # [batch, seq_len, num_directions * hidden_size] + inputs = inputs.permute( + 0, 2, 1 + ) # [batch, input_size, seq_len] -> [batch, seq_len, input_size] + rnn_out, _ = self.rnn( + self.net(inputs) + ) # [batch, seq_len, num_directions * hidden_size] attention_score = self.att_net(rnn_out) # [batch, seq_len, 1] out_att = torch.mul(rnn_out, attention_score) out_att = torch.sum(out_att, dim=1) diff --git a/qlib/contrib/model/pytorch_alstm_ts.py b/qlib/contrib/model/pytorch_alstm_ts.py index 95b5cf95d8..afd0a33144 100644 --- a/qlib/contrib/model/pytorch_alstm_ts.py +++ b/qlib/contrib/model/pytorch_alstm_ts.py @@ -74,7 +74,9 @@ def __init__( self.early_stop = early_stop self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.n_jobs = n_jobs self.seed = seed @@ -124,14 +126,18 @@ def __init__( dropout=self.dropout, ) self.logger.info("model:\n{:}".format(self.ALSTM_model)) - self.logger.info("model size: {:.4f} MB".format(count_parameters(self.ALSTM_model))) + self.logger.info( + "model size: {:.4f} MB".format(count_parameters(self.ALSTM_model)) + ) if optimizer.lower() == "adam": self.train_optimizer = optim.Adam(self.ALSTM_model.parameters(), lr=self.lr) elif optimizer.lower() == "gd": self.train_optimizer = optim.SGD(self.ALSTM_model.parameters(), lr=self.lr) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) self.fitted = False self.ALSTM_model.to(self.device) @@ -210,10 +216,16 @@ def fit( save_path=None, reweighter=None, ): - dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) - dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + dl_train = dataset.prepare( + "train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) + dl_valid = dataset.prepare( + "valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) if dl_train.empty or dl_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader @@ -288,9 +300,13 @@ def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - dl_test = dataset.prepare(segment, col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) + dl_test = dataset.prepare( + segment, col_set=["feature", "label"], data_key=DataHandlerLP.DK_I + ) dl_test.config(fillna_type="ffill+bfill") - test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs) + test_loader = DataLoader( + dl_test, batch_size=self.batch_size, num_workers=self.n_jobs + ) self.ALSTM_model.eval() preds = [] @@ -306,7 +322,9 @@ def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): class ALSTMModel(nn.Module): - def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, rnn_type="GRU"): + def __init__( + self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, rnn_type="GRU" + ): super().__init__() self.hid_size = hidden_size self.input_size = d_feat @@ -321,7 +339,9 @@ def _build_model(self): except Exception as e: raise ValueError("unknown rnn_type `%s`" % self.rnn_type) from e self.net = nn.Sequential() - self.net.add_module("fc_in", nn.Linear(in_features=self.input_size, out_features=self.hid_size)) + self.net.add_module( + "fc_in", nn.Linear(in_features=self.input_size, out_features=self.hid_size) + ) self.net.add_module("act", nn.Tanh()) self.rnn = klass( input_size=self.hid_size, @@ -345,7 +365,9 @@ def _build_model(self): self.att_net.add_module("att_softmax", nn.Softmax(dim=1)) def forward(self, inputs): - rnn_out, _ = self.rnn(self.net(inputs)) # [batch, seq_len, num_directions * hidden_size] + rnn_out, _ = self.rnn( + self.net(inputs) + ) # [batch, seq_len, num_directions * hidden_size] attention_score = self.att_net(rnn_out) # [batch, seq_len, 1] out_att = torch.mul(rnn_out, attention_score) out_att = torch.sum(out_att, dim=1) diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py index 16439b3783..4aaf98cb7d 100644 --- a/qlib/contrib/model/pytorch_gats.py +++ b/qlib/contrib/model/pytorch_gats.py @@ -75,7 +75,9 @@ def __init__( self.loss = loss self.base_model = base_model self.model_path = model_path - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.seed = seed self.logger.info( @@ -125,14 +127,18 @@ def __init__( base_model=self.base_model, ) self.logger.info("model:\n{:}".format(self.GAT_model)) - self.logger.info("model size: {:.4f} MB".format(count_parameters(self.GAT_model))) + self.logger.info( + "model size: {:.4f} MB".format(count_parameters(self.GAT_model)) + ) if optimizer.lower() == "adam": self.train_optimizer = optim.Adam(self.GAT_model.parameters(), lr=self.lr) elif optimizer.lower() == "gd": self.train_optimizer = optim.SGD(self.GAT_model.parameters(), lr=self.lr) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) self.fitted = False self.GAT_model.to(self.device) @@ -233,7 +239,9 @@ def fit( data_key=DataHandlerLP.DK_L, ) if df_train.empty or df_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) x_train, y_train = df_train["feature"], df_train["label"] x_valid, y_valid = df_valid["feature"], df_valid["label"] @@ -255,11 +263,15 @@ def fit( if self.model_path is not None: self.logger.info("Loading pretrained model...") - pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device)) + pretrained_model.load_state_dict( + torch.load(self.model_path, map_location=self.device) + ) model_dict = self.GAT_model.state_dict() pretrained_dict = { - k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135 + k: v + for k, v in pretrained_model.state_dict().items() + if k in model_dict # pylint: disable=E1135 } model_dict.update(pretrained_dict) self.GAT_model.load_state_dict(model_dict) @@ -324,7 +336,9 @@ def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): class GATModel(nn.Module): - def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model="GRU"): + def __init__( + self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model="GRU" + ): super().__init__() if base_model == "GRU": diff --git a/qlib/contrib/model/pytorch_gats_ts.py b/qlib/contrib/model/pytorch_gats_ts.py index 09f0ac08b2..693fce449b 100644 --- a/qlib/contrib/model/pytorch_gats_ts.py +++ b/qlib/contrib/model/pytorch_gats_ts.py @@ -28,9 +28,14 @@ def __init__(self, data_source): self.data_source = data_source # calculate number of samples in each batch self.daily_count = ( - pd.Series(index=self.data_source.get_index()).groupby("datetime", group_keys=False).size().values + pd.Series(index=self.data_source.get_index()) + .groupby("datetime", group_keys=False) + .size() + .values ) - self.daily_index = np.roll(np.cumsum(self.daily_count), 1) # calculate begin index of each batch + self.daily_index = np.roll( + np.cumsum(self.daily_count), 1 + ) # calculate begin index of each batch self.daily_index[0] = 0 def __iter__(self): @@ -94,7 +99,9 @@ def __init__( self.loss = loss self.base_model = base_model self.model_path = model_path - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.n_jobs = n_jobs self.seed = seed @@ -145,14 +152,18 @@ def __init__( base_model=self.base_model, ) self.logger.info("model:\n{:}".format(self.GAT_model)) - self.logger.info("model size: {:.4f} MB".format(count_parameters(self.GAT_model))) + self.logger.info( + "model size: {:.4f} MB".format(count_parameters(self.GAT_model)) + ) if optimizer.lower() == "adam": self.train_optimizer = optim.Adam(self.GAT_model.parameters(), lr=self.lr) elif optimizer.lower() == "gd": self.train_optimizer = optim.SGD(self.GAT_model.parameters(), lr=self.lr) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) self.fitted = False self.GAT_model.to(self.device) @@ -236,10 +247,16 @@ def fit( evals_result=dict(), save_path=None, ): - dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) - dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + dl_train = dataset.prepare( + "train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) + dl_valid = dataset.prepare( + "valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) if dl_train.empty or dl_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader @@ -247,8 +264,12 @@ def fit( sampler_train = DailyBatchSampler(dl_train) sampler_valid = DailyBatchSampler(dl_valid) - train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs, drop_last=True) - valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs, drop_last=True) + train_loader = DataLoader( + dl_train, sampler=sampler_train, num_workers=self.n_jobs, drop_last=True + ) + valid_loader = DataLoader( + dl_valid, sampler=sampler_valid, num_workers=self.n_jobs, drop_last=True + ) save_path = get_or_create_path(save_path) @@ -261,19 +282,31 @@ def fit( # load pretrained base_model if self.base_model == "LSTM": - pretrained_model = LSTMModel(d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers) + pretrained_model = LSTMModel( + d_feat=self.d_feat, + hidden_size=self.hidden_size, + num_layers=self.num_layers, + ) elif self.base_model == "GRU": - pretrained_model = GRUModel(d_feat=self.d_feat, hidden_size=self.hidden_size, num_layers=self.num_layers) + pretrained_model = GRUModel( + d_feat=self.d_feat, + hidden_size=self.hidden_size, + num_layers=self.num_layers, + ) else: raise ValueError("unknown base model name `%s`" % self.base_model) if self.model_path is not None: self.logger.info("Loading pretrained model...") - pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device)) + pretrained_model.load_state_dict( + torch.load(self.model_path, map_location=self.device) + ) model_dict = self.GAT_model.state_dict() pretrained_dict = { - k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135 + k: v + for k, v in pretrained_model.state_dict().items() + if k in model_dict # pylint: disable=E1135 } model_dict.update(pretrained_dict) self.GAT_model.load_state_dict(model_dict) @@ -316,7 +349,9 @@ def predict(self, dataset): if not self.fitted: raise ValueError("model is not fitted yet!") - dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) + dl_test = dataset.prepare( + "test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I + ) dl_test.config(fillna_type="ffill+bfill") sampler_test = DailyBatchSampler(dl_test) test_loader = DataLoader(dl_test, sampler=sampler_test, num_workers=self.n_jobs) @@ -336,7 +371,9 @@ def predict(self, dataset): class GATModel(nn.Module): - def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model="GRU"): + def __init__( + self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model="GRU" + ): super().__init__() if base_model == "GRU": diff --git a/qlib/contrib/model/pytorch_general_nn.py b/qlib/contrib/model/pytorch_general_nn.py index 503c5a2a50..90a7bd44ff 100644 --- a/qlib/contrib/model/pytorch_general_nn.py +++ b/qlib/contrib/model/pytorch_general_nn.py @@ -83,12 +83,16 @@ def __init__( self.optimizer = optimizer.lower() self.loss = loss self.weight_decay = weight_decay - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.n_jobs = n_jobs self.seed = seed self.pt_model_uri, self.pt_model_kwargs = pt_model_uri, pt_model_kwargs - self.dnn_model = init_instance_by_config({"class": pt_model_uri, "kwargs": pt_model_kwargs}) + self.dnn_model = init_instance_by_config( + {"class": pt_model_uri, "kwargs": pt_model_kwargs} + ) self.logger.info( "GeneralPTNN parameters setting:" @@ -128,18 +132,31 @@ def __init__( torch.manual_seed(self.seed) self.logger.info("model:\n{:}".format(self.dnn_model)) - self.logger.info("model size: {:.4f} MB".format(count_parameters(self.dnn_model))) + self.logger.info( + "model size: {:.4f} MB".format(count_parameters(self.dnn_model)) + ) if optimizer.lower() == "adam": - self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=weight_decay) + self.train_optimizer = optim.Adam( + self.dnn_model.parameters(), lr=self.lr, weight_decay=weight_decay + ) elif optimizer.lower() == "gd": - self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr, weight_decay=weight_decay) + self.train_optimizer = optim.SGD( + self.dnn_model.parameters(), lr=self.lr, weight_decay=weight_decay + ) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) # === ReduceLROnPlateau learning rate scheduler === self.lr_scheduler = ReduceLROnPlateau( - self.train_optimizer, mode="min", factor=0.5, patience=5, min_lr=1e-6, threshold=1e-5 + self.train_optimizer, + mode="min", + factor=0.5, + patience=5, + min_lr=1e-6, + threshold=1e-5, ) self.fitted = False self.dnn_model.to(self.device) @@ -241,12 +258,18 @@ def fit( ): ists = isinstance(dataset, TSDatasetH) # is this time series dataset - dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) - dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + dl_train = dataset.prepare( + "train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) + dl_valid = dataset.prepare( + "valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) self.logger.info(f"Train samples: {len(dl_train)}") self.logger.info(f"Valid samples: {len(dl_valid)}") if dl_train.empty or dl_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) if reweighter is None: wl_train = np.ones(len(dl_train)) @@ -259,8 +282,12 @@ def fit( # Preprocess for data. To align to Dataset Interface for DataLoader if ists: - dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader - dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader + dl_train.config( + fillna_type="ffill+bfill" + ) # process nan brought by dataloader + dl_valid.config( + fillna_type="ffill+bfill" + ) # process nan brought by dataloader else: # If it is a tabular, we convert the dataframe to numpy to be indexable by DataLoader dl_train = dl_train.values @@ -302,7 +329,9 @@ def fit( self.logger.info("evaluating...") train_loss, train_score = self.test_epoch(train_loader) val_loss, val_score = self.test_epoch(valid_loader) - self.logger.info("Epoch%d: train %.6f, valid %.6f" % (step, train_score, val_score)) + self.logger.info( + "Epoch%d: train %.6f, valid %.6f" % (step, train_score, val_score) + ) evals_result["train"].append(train_score) evals_result["valid"].append(val_score) @@ -340,18 +369,24 @@ def predict( if not self.fitted: raise ValueError("model is not fitted yet!") - dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) + dl_test = dataset.prepare( + "test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I + ) self.logger.info(f"Test samples: {len(dl_test)}") if isinstance(dataset, TSDatasetH): - dl_test.config(fillna_type="ffill+bfill") # process nan brought by dataloader + dl_test.config( + fillna_type="ffill+bfill" + ) # process nan brought by dataloader index = dl_test.get_index() else: # If it is a tabular, we convert the dataframe to numpy to be indexable by DataLoader index = dl_test.index dl_test = dl_test.values - test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs) + test_loader = DataLoader( + dl_test, batch_size=self.batch_size, num_workers=self.n_jobs + ) self.dnn_model.eval() preds = [] diff --git a/qlib/contrib/model/pytorch_gru.py b/qlib/contrib/model/pytorch_gru.py index 06aa6810b8..54af7f742b 100755 --- a/qlib/contrib/model/pytorch_gru.py +++ b/qlib/contrib/model/pytorch_gru.py @@ -70,7 +70,9 @@ def __init__( self.early_stop = early_stop self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.seed = seed self.logger.info( @@ -117,14 +119,18 @@ def __init__( dropout=self.dropout, ) self.logger.info("model:\n{:}".format(self.gru_model)) - self.logger.info("model size: {:.4f} MB".format(count_parameters(self.gru_model))) + self.logger.info( + "model size: {:.4f} MB".format(count_parameters(self.gru_model)) + ) if optimizer.lower() == "adam": self.train_optimizer = optim.Adam(self.gru_model.parameters(), lr=self.lr) elif optimizer.lower() == "gd": self.train_optimizer = optim.SGD(self.gru_model.parameters(), lr=self.lr) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) self.fitted = False self.gru_model.to(self.device) @@ -166,8 +172,16 @@ def train_epoch(self, x_train, y_train): if len(indices) - i < self.batch_size: break - feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device) - label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device) + feature = ( + torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) + label = ( + torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) pred = self.gru_model(feature) loss = self.loss_fn(pred, label) @@ -193,8 +207,16 @@ def test_epoch(self, data_x, data_y): if len(indices) - i < self.batch_size: break - feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device) - label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device) + feature = ( + torch.from_numpy(x_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) + label = ( + torch.from_numpy(y_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) with torch.no_grad(): pred = self.gru_model(feature) @@ -222,11 +244,15 @@ def fit( for k in ["train", "valid"] if k in dataset.segments } - df_train, df_valid = dfs.get("train", pd.DataFrame()), dfs.get("valid", pd.DataFrame()) + df_train, df_valid = dfs.get("train", pd.DataFrame()), dfs.get( + "valid", pd.DataFrame() + ) # check if training data is empty if df_train.empty: - raise ValueError("Empty training data from dataset, please check your dataset config.") + raise ValueError( + "Empty training data from dataset, please check your dataset config." + ) df_train = df_train.dropna() x_train, y_train = df_train["feature"], df_train["label"] @@ -293,7 +319,9 @@ def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare( + segment, col_set="feature", data_key=DataHandlerLP.DK_I + ) index = x_test.index self.gru_model.eval() x_values = x_test.values diff --git a/qlib/contrib/model/pytorch_gru_ts.py b/qlib/contrib/model/pytorch_gru_ts.py index 65da5ac4b4..bca049f569 100755 --- a/qlib/contrib/model/pytorch_gru_ts.py +++ b/qlib/contrib/model/pytorch_gru_ts.py @@ -72,7 +72,9 @@ def __init__( self.early_stop = early_stop self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.n_jobs = n_jobs self.seed = seed @@ -122,14 +124,18 @@ def __init__( dropout=self.dropout, ) self.logger.info("model:\n{:}".format(self.GRU_model)) - self.logger.info("model size: {:.4f} MB".format(count_parameters(self.GRU_model))) + self.logger.info( + "model size: {:.4f} MB".format(count_parameters(self.GRU_model)) + ) if optimizer.lower() == "adam": self.train_optimizer = optim.Adam(self.GRU_model.parameters(), lr=self.lr) elif optimizer.lower() == "gd": self.train_optimizer = optim.SGD(self.GRU_model.parameters(), lr=self.lr) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) self.fitted = False self.GRU_model.to(self.device) @@ -204,10 +210,16 @@ def fit( save_path=None, reweighter=None, ): - dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) - dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + dl_train = dataset.prepare( + "train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) + dl_valid = dataset.prepare( + "valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) if dl_train.empty or dl_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader @@ -282,9 +294,13 @@ def predict(self, dataset): if not self.fitted: raise ValueError("model is not fitted yet!") - dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) + dl_test = dataset.prepare( + "test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I + ) dl_test.config(fillna_type="ffill+bfill") - test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs) + test_loader = DataLoader( + dl_test, batch_size=self.batch_size, num_workers=self.n_jobs + ) self.GRU_model.eval() preds = [] diff --git a/qlib/contrib/model/pytorch_hist.py b/qlib/contrib/model/pytorch_hist.py index 779cde9c85..30b0d81b74 100644 --- a/qlib/contrib/model/pytorch_hist.py +++ b/qlib/contrib/model/pytorch_hist.py @@ -80,7 +80,9 @@ def __init__( self.model_path = model_path self.stock2concept = stock2concept self.stock_index = stock_index - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.seed = seed self.logger.info( @@ -132,13 +134,17 @@ def __init__( base_model=self.base_model, ) self.logger.info("model:\n{:}".format(self.HIST_model)) - self.logger.info("model size: {:.4f} MB".format(count_parameters(self.HIST_model))) + self.logger.info( + "model size: {:.4f} MB".format(count_parameters(self.HIST_model)) + ) if optimizer.lower() == "adam": self.train_optimizer = optim.Adam(self.HIST_model.parameters(), lr=self.lr) elif optimizer.lower() == "gd": self.train_optimizer = optim.SGD(self.HIST_model.parameters(), lr=self.lr) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) self.fitted = False self.HIST_model.to(self.device) @@ -168,7 +174,9 @@ def metric_fn(self, pred, label): vx = x - torch.mean(x) vy = y - torch.mean(y) - return torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx**2)) * torch.sqrt(torch.sum(vy**2))) + return torch.sum(vx * vy) / ( + torch.sqrt(torch.sum(vx**2)) * torch.sqrt(torch.sum(vy**2)) + ) if self.metric == ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) @@ -201,7 +209,11 @@ def train_epoch(self, x_train, y_train, stock_index): for idx, count in zip(daily_index, daily_count): batch = slice(idx, idx + count) feature = torch.from_numpy(x_train_values[batch]).float().to(self.device) - concept_matrix = torch.from_numpy(stock2concept_matrix[stock_index[batch]]).float().to(self.device) + concept_matrix = ( + torch.from_numpy(stock2concept_matrix[stock_index[batch]]) + .float() + .to(self.device) + ) label = torch.from_numpy(y_train_values[batch]).float().to(self.device) pred = self.HIST_model(feature, concept_matrix) loss = self.loss_fn(pred, label) @@ -229,7 +241,11 @@ def test_epoch(self, data_x, data_y, stock_index): for idx, count in zip(daily_index, daily_count): batch = slice(idx, idx + count) feature = torch.from_numpy(x_values[batch]).float().to(self.device) - concept_matrix = torch.from_numpy(stock2concept_matrix[stock_index[batch]]).float().to(self.device) + concept_matrix = ( + torch.from_numpy(stock2concept_matrix[stock_index[batch]]) + .float() + .to(self.device) + ) label = torch.from_numpy(y_values[batch]).float().to(self.device) with torch.no_grad(): pred = self.HIST_model(feature, concept_matrix) @@ -253,7 +269,9 @@ def fit( data_key=DataHandlerLP.DK_L, ) if df_train.empty or df_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) if not os.path.exists(self.stock2concept): url = "https://github.com/SunsetWolf/qlib_dataset/releases/download/v0/qlib_csi300_stock2concept.npy" @@ -261,12 +279,24 @@ def fit( stock_index = np.load(self.stock_index, allow_pickle=True).item() df_train["stock_index"] = 733 - df_train["stock_index"] = df_train.index.get_level_values("instrument").map(stock_index) + df_train["stock_index"] = df_train.index.get_level_values("instrument").map( + stock_index + ) df_valid["stock_index"] = 733 - df_valid["stock_index"] = df_valid.index.get_level_values("instrument").map(stock_index) + df_valid["stock_index"] = df_valid.index.get_level_values("instrument").map( + stock_index + ) - x_train, y_train, stock_index_train = df_train["feature"], df_train["label"], df_train["stock_index"] - x_valid, y_valid, stock_index_valid = df_valid["feature"], df_valid["label"], df_valid["stock_index"] + x_train, y_train, stock_index_train = ( + df_train["feature"], + df_train["label"], + df_train["stock_index"], + ) + x_valid, y_valid, stock_index_valid = ( + df_valid["feature"], + df_valid["label"], + df_valid["stock_index"], + ) save_path = get_or_create_path(save_path) @@ -290,7 +320,9 @@ def fit( model_dict = self.HIST_model.state_dict() pretrained_dict = { - k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135 + k: v + for k, v in pretrained_model.state_dict().items() + if k in model_dict # pylint: disable=E1135 } model_dict.update(pretrained_dict) self.HIST_model.load_state_dict(model_dict) @@ -306,7 +338,9 @@ def fit( self.train_epoch(x_train, y_train, stock_index_train) self.logger.info("evaluating...") - train_loss, train_score = self.test_epoch(x_train, y_train, stock_index_train) + train_loss, train_score = self.test_epoch( + x_train, y_train, stock_index_train + ) val_loss, val_score = self.test_epoch(x_valid, y_valid, stock_index_valid) self.logger.info("train %.6f, valid %.6f" % (train_score, val_score)) evals_result["train"].append(train_score) @@ -333,9 +367,13 @@ def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): stock2concept_matrix = np.load(self.stock2concept) stock_index = np.load(self.stock_index, allow_pickle=True).item() - df_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + df_test = dataset.prepare( + segment, col_set="feature", data_key=DataHandlerLP.DK_I + ) df_test["stock_index"] = 733 - df_test["stock_index"] = df_test.index.get_level_values("instrument").map(stock_index) + df_test["stock_index"] = df_test.index.get_level_values("instrument").map( + stock_index + ) stock_index_test = df_test["stock_index"].values stock_index_test[np.isnan(stock_index_test)] = 733 stock_index_test = stock_index_test.astype("int") @@ -352,7 +390,11 @@ def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): for idx, count in zip(daily_index, daily_count): batch = slice(idx, idx + count) x_batch = torch.from_numpy(x_values[batch]).float().to(self.device) - concept_matrix = torch.from_numpy(stock2concept_matrix[stock_index_test[batch]]).float().to(self.device) + concept_matrix = ( + torch.from_numpy(stock2concept_matrix[stock_index_test[batch]]) + .float() + .to(self.device) + ) with torch.no_grad(): pred = self.HIST_model(x_batch, concept_matrix).detach().cpu().numpy() @@ -363,7 +405,9 @@ def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): class HISTModel(nn.Module): - def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model="GRU"): + def __init__( + self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model="GRU" + ): super().__init__() self.d_feat = d_feat @@ -440,7 +484,11 @@ def forward(self, x, concept_matrix): stock_to_concept = concept_matrix - stock_to_concept_sum = torch.sum(stock_to_concept, 0).reshape(1, -1).repeat(stock_to_concept.shape[0], 1) + stock_to_concept_sum = ( + torch.sum(stock_to_concept, 0) + .reshape(1, -1) + .repeat(stock_to_concept.shape[0], 1) + ) stock_to_concept_sum = stock_to_concept_sum.mul(concept_matrix) stock_to_concept_sum = stock_to_concept_sum + ( @@ -467,14 +515,18 @@ def forward(self, x, concept_matrix): i_stock_to_concept = self.cal_cos_similarity(i_shared_info, hidden) dim = i_stock_to_concept.shape[0] diag = i_stock_to_concept.diagonal(0) - i_stock_to_concept = i_stock_to_concept * (torch.ones(dim, dim) - torch.eye(dim)).to(device) + i_stock_to_concept = i_stock_to_concept * ( + torch.ones(dim, dim) - torch.eye(dim) + ).to(device) row = torch.linspace(0, dim - 1, dim).to(device).long() column = i_stock_to_concept.max(1)[1].long() value = i_stock_to_concept.max(1)[0] i_stock_to_concept[row, column] = 10 i_stock_to_concept[i_stock_to_concept != 10] = 0 i_stock_to_concept[row, column] = value - i_stock_to_concept = i_stock_to_concept + torch.diag_embed((i_stock_to_concept.sum(0) != 0).float() * diag) + i_stock_to_concept = i_stock_to_concept + torch.diag_embed( + (i_stock_to_concept.sum(0) != 0).float() * diag + ) hidden = torch.t(i_shared_info).mm(i_stock_to_concept).t() hidden = hidden[hidden.sum(1) != 0] diff --git a/qlib/contrib/model/pytorch_igmtf.py b/qlib/contrib/model/pytorch_igmtf.py index 0bddc5a0f5..7343fb9fa9 100644 --- a/qlib/contrib/model/pytorch_igmtf.py +++ b/qlib/contrib/model/pytorch_igmtf.py @@ -74,7 +74,9 @@ def __init__( self.loss = loss self.base_model = base_model self.model_path = model_path - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.seed = seed self.logger.info( @@ -124,14 +126,18 @@ def __init__( base_model=self.base_model, ) self.logger.info("model:\n{:}".format(self.igmtf_model)) - self.logger.info("model size: {:.4f} MB".format(count_parameters(self.igmtf_model))) + self.logger.info( + "model size: {:.4f} MB".format(count_parameters(self.igmtf_model)) + ) if optimizer.lower() == "adam": self.train_optimizer = optim.Adam(self.igmtf_model.parameters(), lr=self.lr) elif optimizer.lower() == "gd": self.train_optimizer = optim.SGD(self.igmtf_model.parameters(), lr=self.lr) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) self.fitted = False self.igmtf_model.to(self.device) @@ -161,7 +167,9 @@ def metric_fn(self, pred, label): vx = x - torch.mean(x) vy = y - torch.mean(y) - return torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx**2)) * torch.sqrt(torch.sum(vy**2))) + return torch.sum(vx * vy) / ( + torch.sqrt(torch.sum(vx**2)) * torch.sqrt(torch.sum(vy**2)) + ) if self.metric == ("", "loss"): return -self.loss_fn(pred[mask], label[mask]) @@ -211,7 +219,9 @@ def train_epoch(self, x_train, y_train, train_hidden, train_hidden_day): batch = slice(idx, idx + count) feature = torch.from_numpy(x_train_values[batch]).float().to(self.device) label = torch.from_numpy(y_train_values[batch]).float().to(self.device) - pred = self.igmtf_model(feature, train_hidden=train_hidden, train_hidden_day=train_hidden_day) + pred = self.igmtf_model( + feature, train_hidden=train_hidden, train_hidden_day=train_hidden_day + ) loss = self.loss_fn(pred, label) self.train_optimizer.zero_grad() @@ -236,7 +246,9 @@ def test_epoch(self, data_x, data_y, train_hidden, train_hidden_day): feature = torch.from_numpy(x_values[batch]).float().to(self.device) label = torch.from_numpy(y_values[batch]).float().to(self.device) - pred = self.igmtf_model(feature, train_hidden=train_hidden, train_hidden_day=train_hidden_day) + pred = self.igmtf_model( + feature, train_hidden=train_hidden, train_hidden_day=train_hidden_day + ) loss = self.loss_fn(pred, label) losses.append(loss.item()) @@ -257,7 +269,9 @@ def fit( data_key=DataHandlerLP.DK_L, ) if df_train.empty or df_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) x_train, y_train = df_train["feature"], df_train["label"] x_valid, y_valid = df_valid["feature"], df_valid["label"] @@ -280,11 +294,15 @@ def fit( if self.model_path is not None: self.logger.info("Loading pretrained model...") - pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device)) + pretrained_model.load_state_dict( + torch.load(self.model_path, map_location=self.device) + ) model_dict = self.igmtf_model.state_dict() pretrained_dict = { - k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135 + k: v + for k, v in pretrained_model.state_dict().items() + if k in model_dict # pylint: disable=E1135 } model_dict.update(pretrained_dict) self.igmtf_model.load_state_dict(model_dict) @@ -300,8 +318,12 @@ def fit( train_hidden, train_hidden_day = self.get_train_hidden(x_train) self.train_epoch(x_train, y_train, train_hidden, train_hidden_day) self.logger.info("evaluating...") - train_loss, train_score = self.test_epoch(x_train, y_train, train_hidden, train_hidden_day) - val_loss, val_score = self.test_epoch(x_valid, y_valid, train_hidden, train_hidden_day) + train_loss, train_score = self.test_epoch( + x_train, y_train, train_hidden, train_hidden_day + ) + val_loss, val_score = self.test_epoch( + x_valid, y_valid, train_hidden, train_hidden_day + ) self.logger.info("train %.6f, valid %.6f" % (train_score, val_score)) evals_result["train"].append(train_score) evals_result["valid"].append(val_score) @@ -327,9 +349,13 @@ def fit( def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - x_train = dataset.prepare("train", col_set="feature", data_key=DataHandlerLP.DK_L) + x_train = dataset.prepare( + "train", col_set="feature", data_key=DataHandlerLP.DK_L + ) train_hidden, train_hidden_day = self.get_train_hidden(x_train) - x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare( + segment, col_set="feature", data_key=DataHandlerLP.DK_I + ) index = x_test.index self.igmtf_model.eval() x_values = x_test.values @@ -343,7 +369,11 @@ def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): with torch.no_grad(): pred = ( - self.igmtf_model(x_batch, train_hidden=train_hidden, train_hidden_day=train_hidden_day) + self.igmtf_model( + x_batch, + train_hidden=train_hidden, + train_hidden_day=train_hidden_day, + ) .detach() .cpu() .numpy() @@ -355,7 +385,9 @@ def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): class IGMTFModel(nn.Module): - def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model="GRU"): + def __init__( + self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0, base_model="GRU" + ): super().__init__() if base_model == "GRU": @@ -401,7 +433,15 @@ def sparse_dense_mul(self, s, d): dv = d[i[0, :], i[1, :]] # get values from relevant entries of dense matrix return torch.sparse.FloatTensor(i, v * dv, s.size()) - def forward(self, x, get_hidden=False, train_hidden=None, train_hidden_day=None, k_day=10, n_neighbor=10): + def forward( + self, + x, + get_hidden=False, + train_hidden=None, + train_hidden_day=None, + k_day=10, + n_neighbor=10, + ): # x: [N, F*T] device = x.device x = x.reshape(len(x), self.d_feat, -1) # [N, F, T] @@ -414,12 +454,16 @@ def forward(self, x, get_hidden=False, train_hidden=None, train_hidden_day=None, return mini_batch_out mini_batch_out_day = torch.mean(mini_batch_out, dim=0).unsqueeze(0) - day_similarity = self.cal_cos_similarity(mini_batch_out_day, train_hidden_day.to(device)) + day_similarity = self.cal_cos_similarity( + mini_batch_out_day, train_hidden_day.to(device) + ) day_index = torch.topk(day_similarity, k_day, dim=1)[1] sample_train_hidden = train_hidden[day_index.long().cpu()].squeeze() sample_train_hidden = torch.cat(list(sample_train_hidden)).to(device) sample_train_hidden = self.lins(sample_train_hidden) - cos_similarity = self.cal_cos_similarity(self.project1(mini_batch_out), self.project2(sample_train_hidden)) + cos_similarity = self.cal_cos_similarity( + self.project1(mini_batch_out), self.project2(sample_train_hidden) + ) row = ( torch.linspace(0, x.shape[0] - 1, x.shape[0]) diff --git a/qlib/contrib/model/pytorch_krnn.py b/qlib/contrib/model/pytorch_krnn.py index d97920b4dc..517246e1c6 100644 --- a/qlib/contrib/model/pytorch_krnn.py +++ b/qlib/contrib/model/pytorch_krnn.py @@ -47,7 +47,9 @@ def __init__(self, input_dim, output_dim, kernel_size, device): # set padding to ensure the same length # it is correct only when kernel_size is odd, dilation is 1, stride is 1 - self.conv = nn.Conv1d(input_dim, output_dim, kernel_size, padding=(kernel_size - 1) // 2) + self.conv = nn.Conv1d( + input_dim, output_dim, kernel_size, padding=(kernel_size - 1) // 2 + ) def forward(self, x): """ @@ -97,7 +99,11 @@ def __init__(self, input_dim, output_dim, dup_num, rnn_layers, dropout, device): self.rnn_modules = nn.ModuleList() for _ in range(dup_num): - self.rnn_modules.append(nn.GRU(input_dim, output_dim, num_layers=self.rnn_layers, dropout=dropout)) + self.rnn_modules.append( + nn.GRU( + input_dim, output_dim, num_layers=self.rnn_layers, dropout=dropout + ) + ) def forward(self, x): """ @@ -135,7 +141,15 @@ def forward(self, x): class CNNKRNNEncoder(nn.Module): def __init__( - self, cnn_input_dim, cnn_output_dim, cnn_kernel_size, rnn_output_dim, rnn_dup_num, rnn_layers, dropout, device + self, + cnn_input_dim, + cnn_output_dim, + cnn_kernel_size, + rnn_output_dim, + rnn_dup_num, + rnn_layers, + dropout, + device, ): """Build an encoder composed of CNN and KRNN @@ -156,8 +170,12 @@ def __init__( """ super().__init__() - self.cnn_encoder = CNNEncoderBase(cnn_input_dim, cnn_output_dim, cnn_kernel_size, device) - self.krnn_encoder = KRNNEncoderBase(cnn_output_dim, rnn_output_dim, rnn_dup_num, rnn_layers, dropout, device) + self.cnn_encoder = CNNEncoderBase( + cnn_input_dim, cnn_output_dim, cnn_kernel_size, device + ) + self.krnn_encoder = KRNNEncoderBase( + cnn_output_dim, rnn_output_dim, rnn_dup_num, rnn_layers, dropout, device + ) def forward(self, x): """ @@ -180,7 +198,18 @@ def forward(self, x): class KRNNModel(nn.Module): - def __init__(self, fea_dim, cnn_dim, cnn_kernel_size, rnn_dim, rnn_dups, rnn_layers, dropout, device, **params): + def __init__( + self, + fea_dim, + cnn_dim, + cnn_kernel_size, + rnn_dim, + rnn_dups, + rnn_layers, + dropout, + device, + **params, + ): """Build a KRNN model Parameters @@ -276,7 +305,9 @@ def __init__( self.early_stop = early_stop self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.seed = seed self.logger.info( @@ -337,7 +368,9 @@ def __init__( elif optimizer.lower() == "gd": self.train_optimizer = optim.SGD(self.krnn_model.parameters(), lr=self.lr) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) self.fitted = False self.krnn_model.to(self.device) @@ -390,8 +423,16 @@ def train_epoch(self, x_train, y_train): if len(indices) - i < self.batch_size: break - feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device) - label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device) + feature = ( + torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) + label = ( + torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) pred = self.krnn_model(feature) loss = self.loss_fn(pred, label) @@ -417,8 +458,16 @@ def test_epoch(self, data_x, data_y): if len(indices) - i < self.batch_size: break - feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device) - label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device) + feature = ( + torch.from_numpy(x_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) + label = ( + torch.from_numpy(y_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) pred = self.krnn_model(feature) loss = self.loss_fn(pred, label) @@ -441,7 +490,9 @@ def fit( data_key=DataHandlerLP.DK_L, ) if df_train.empty or df_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) x_train, y_train = df_train["feature"], df_train["label"] x_valid, y_valid = df_valid["feature"], df_valid["label"] @@ -491,7 +542,9 @@ def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare( + segment, col_set="feature", data_key=DataHandlerLP.DK_I + ) index = x_test.index self.krnn_model.eval() x_values = x_test.values diff --git a/qlib/contrib/model/pytorch_localformer.py b/qlib/contrib/model/pytorch_localformer.py index 42851dd6a2..0908c72ccf 100644 --- a/qlib/contrib/model/pytorch_localformer.py +++ b/qlib/contrib/model/pytorch_localformer.py @@ -58,22 +58,36 @@ def __init__( self.optimizer = optimizer.lower() self.loss = loss self.n_jobs = n_jobs - self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.seed = seed self.logger = get_module_logger("TransformerModel") - self.logger.info("Naive Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device)) + self.logger.info( + "Naive Transformer:" + "\nbatch_size : {}" + "\ndevice : {}".format(self.batch_size, self.device) + ) if self.seed is not None: np.random.seed(self.seed) torch.manual_seed(self.seed) - self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device) + self.model = Transformer( + d_feat, d_model, nhead, num_layers, dropout, self.device + ) if optimizer.lower() == "adam": - self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg) + self.train_optimizer = optim.Adam( + self.model.parameters(), lr=self.lr, weight_decay=self.reg + ) elif optimizer.lower() == "gd": - self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg) + self.train_optimizer = optim.SGD( + self.model.parameters(), lr=self.lr, weight_decay=self.reg + ) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) self.fitted = False self.model.to(self.device) @@ -115,8 +129,16 @@ def train_epoch(self, x_train, y_train): if len(indices) - i < self.batch_size: break - feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device) - label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device) + feature = ( + torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) + label = ( + torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) pred = self.model(feature) loss = self.loss_fn(pred, label) @@ -142,8 +164,16 @@ def test_epoch(self, data_x, data_y): if len(indices) - i < self.batch_size: break - feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device) - label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device) + feature = ( + torch.from_numpy(x_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) + label = ( + torch.from_numpy(y_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) with torch.no_grad(): pred = self.model(feature) @@ -167,7 +197,9 @@ def fit( data_key=DataHandlerLP.DK_L, ) if df_train.empty or df_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) x_train, y_train = df_train["feature"], df_train["label"] x_valid, y_valid = df_valid["feature"], df_valid["label"] @@ -217,7 +249,9 @@ def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare( + segment, col_set="feature", data_key=DataHandlerLP.DK_I + ) index = x_test.index self.model.eval() x_values = x_test.values @@ -245,7 +279,9 @@ def __init__(self, d_model, max_len=1000): super(PositionalEncoding, self).__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) + ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) @@ -284,7 +320,9 @@ def forward(self, src, mask): class Transformer(nn.Module): - def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None): + def __init__( + self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None + ): super(Transformer, self).__init__() self.rnn = nn.GRU( input_size=d_model, @@ -295,8 +333,12 @@ def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, devi ) self.feature_layer = nn.Linear(d_feat, d_model) self.pos_encoder = PositionalEncoding(d_model) - self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout) - self.transformer_encoder = LocalformerEncoder(self.encoder_layer, num_layers=num_layers, d_model=d_model) + self.encoder_layer = nn.TransformerEncoderLayer( + d_model=d_model, nhead=nhead, dropout=dropout + ) + self.transformer_encoder = LocalformerEncoder( + self.encoder_layer, num_layers=num_layers, d_model=d_model + ) self.decoder_layer = nn.Linear(d_model, 1) self.device = device self.d_feat = d_feat diff --git a/qlib/contrib/model/pytorch_localformer_ts.py b/qlib/contrib/model/pytorch_localformer_ts.py index ae60a39968..b725a33784 100644 --- a/qlib/contrib/model/pytorch_localformer_ts.py +++ b/qlib/contrib/model/pytorch_localformer_ts.py @@ -56,24 +56,36 @@ def __init__( self.optimizer = optimizer.lower() self.loss = loss self.n_jobs = n_jobs - self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.seed = seed self.logger = get_module_logger("TransformerModel") self.logger.info( - "Improved Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device) + "Improved Transformer:" + "\nbatch_size : {}" + "\ndevice : {}".format(self.batch_size, self.device) ) if self.seed is not None: np.random.seed(self.seed) torch.manual_seed(self.seed) - self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device) + self.model = Transformer( + d_feat, d_model, nhead, num_layers, dropout, self.device + ) if optimizer.lower() == "adam": - self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg) + self.train_optimizer = optim.Adam( + self.model.parameters(), lr=self.lr, weight_decay=self.reg + ) elif optimizer.lower() == "gd": - self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg) + self.train_optimizer = optim.SGD( + self.model.parameters(), lr=self.lr, weight_decay=self.reg + ) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) self.fitted = False self.model.to(self.device) @@ -143,19 +155,33 @@ def fit( evals_result=dict(), save_path=None, ): - dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) - dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + dl_train = dataset.prepare( + "train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) + dl_valid = dataset.prepare( + "valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) if dl_train.empty or dl_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader train_loader = DataLoader( - dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True + dl_train, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.n_jobs, + drop_last=True, ) valid_loader = DataLoader( - dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True + dl_valid, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.n_jobs, + drop_last=True, ) save_path = get_or_create_path(save_path) @@ -204,9 +230,13 @@ def predict(self, dataset): if not self.fitted: raise ValueError("model is not fitted yet!") - dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) + dl_test = dataset.prepare( + "test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I + ) dl_test.config(fillna_type="ffill+bfill") - test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs) + test_loader = DataLoader( + dl_test, batch_size=self.batch_size, num_workers=self.n_jobs + ) self.model.eval() preds = [] @@ -226,7 +256,9 @@ def __init__(self, d_model, max_len=1000): super(PositionalEncoding, self).__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) + ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) @@ -265,7 +297,9 @@ def forward(self, src, mask): class Transformer(nn.Module): - def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None): + def __init__( + self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None + ): super(Transformer, self).__init__() self.rnn = nn.GRU( input_size=d_model, @@ -276,8 +310,12 @@ def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, devi ) self.feature_layer = nn.Linear(d_feat, d_model) self.pos_encoder = PositionalEncoding(d_model) - self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout) - self.transformer_encoder = LocalformerEncoder(self.encoder_layer, num_layers=num_layers, d_model=d_model) + self.encoder_layer = nn.TransformerEncoderLayer( + d_model=d_model, nhead=nhead, dropout=dropout + ) + self.transformer_encoder = LocalformerEncoder( + self.encoder_layer, num_layers=num_layers, d_model=d_model + ) self.decoder_layer = nn.Linear(d_model, 1) self.device = device self.d_feat = d_feat diff --git a/qlib/contrib/model/pytorch_lstm.py b/qlib/contrib/model/pytorch_lstm.py index 3ba09097ac..246fd4d8ee 100755 --- a/qlib/contrib/model/pytorch_lstm.py +++ b/qlib/contrib/model/pytorch_lstm.py @@ -69,7 +69,9 @@ def __init__( self.early_stop = early_stop self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.seed = seed self.logger.info( @@ -120,7 +122,9 @@ def __init__( elif optimizer.lower() == "gd": self.train_optimizer = optim.SGD(self.lstm_model.parameters(), lr=self.lr) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) self.fitted = False self.lstm_model.to(self.device) @@ -162,8 +166,16 @@ def train_epoch(self, x_train, y_train): if len(indices) - i < self.batch_size: break - feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device) - label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device) + feature = ( + torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) + label = ( + torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) pred = self.lstm_model(feature) loss = self.loss_fn(pred, label) @@ -189,8 +201,16 @@ def test_epoch(self, data_x, data_y): if len(indices) - i < self.batch_size: break - feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device) - label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device) + feature = ( + torch.from_numpy(x_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) + label = ( + torch.from_numpy(y_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) pred = self.lstm_model(feature) loss = self.loss_fn(pred, label) @@ -213,7 +233,9 @@ def fit( data_key=DataHandlerLP.DK_L, ) if df_train.empty or df_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) x_train, y_train = df_train["feature"], df_train["label"] x_valid, y_valid = df_valid["feature"], df_valid["label"] @@ -263,7 +285,9 @@ def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare( + segment, col_set="feature", data_key=DataHandlerLP.DK_I + ) index = x_test.index self.lstm_model.eval() x_values = x_test.values diff --git a/qlib/contrib/model/pytorch_lstm_ts.py b/qlib/contrib/model/pytorch_lstm_ts.py index a0fc34d583..7bcb597b47 100755 --- a/qlib/contrib/model/pytorch_lstm_ts.py +++ b/qlib/contrib/model/pytorch_lstm_ts.py @@ -71,7 +71,9 @@ def __init__( self.early_stop = early_stop self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.n_jobs = n_jobs self.seed = seed @@ -125,7 +127,9 @@ def __init__( elif optimizer.lower() == "gd": self.train_optimizer = optim.SGD(self.LSTM_model.parameters(), lr=self.lr) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) self.fitted = False self.LSTM_model.to(self.device) @@ -199,10 +203,16 @@ def fit( save_path=None, reweighter=None, ): - dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) - dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + dl_train = dataset.prepare( + "train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) + dl_valid = dataset.prepare( + "valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) if dl_train.empty or dl_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader @@ -277,9 +287,13 @@ def predict(self, dataset): if not self.fitted: raise ValueError("model is not fitted yet!") - dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) + dl_test = dataset.prepare( + "test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I + ) dl_test.config(fillna_type="ffill+bfill") - test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs) + test_loader = DataLoader( + dl_test, batch_size=self.batch_size, num_workers=self.n_jobs + ) self.LSTM_model.eval() preds = [] diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py index 190d1ba45a..d4a4d578a2 100644 --- a/qlib/contrib/model/pytorch_nn.py +++ b/qlib/contrib/model/pytorch_nn.py @@ -66,7 +66,9 @@ def __init__( seed=None, weight_decay=0.0, data_parall=False, - scheduler: Optional[Union[Callable]] = "default", # when it is Callable, it accept one argument named optimizer + scheduler: Optional[ + Union[Callable] + ] = "default", # when it is Callable, it accept one argument named optimizer init_model=None, eval_train_metric=False, pt_model_uri="qlib.contrib.model.pytorch_nn.Net", @@ -92,7 +94,9 @@ def __init__( if isinstance(GPU, str): self.device = torch.device(GPU) else: - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.seed = seed self.weight_decay = weight_decay self.data_parall = data_parall @@ -128,7 +132,9 @@ def __init__( self._scorer = mean_squared_error if loss == "mse" else roc_auc_score if init_model is None: - self.dnn_model = init_instance_by_config({"class": pt_model_uri, "kwargs": pt_model_kwargs}) + self.dnn_model = init_instance_by_config( + {"class": pt_model_uri, "kwargs": pt_model_kwargs} + ) if self.data_parall: self.dnn_model = DataParallel(self.dnn_model).to(self.device) @@ -136,31 +142,41 @@ def __init__( self.dnn_model = init_model self.logger.info("model:\n{:}".format(self.dnn_model)) - self.logger.info("model size: {:.4f} MB".format(count_parameters(self.dnn_model))) + self.logger.info( + "model size: {:.4f} MB".format(count_parameters(self.dnn_model)) + ) if optimizer.lower() == "adam": - self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=self.weight_decay) + self.train_optimizer = optim.Adam( + self.dnn_model.parameters(), lr=self.lr, weight_decay=self.weight_decay + ) elif optimizer.lower() == "gd": - self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr, weight_decay=self.weight_decay) + self.train_optimizer = optim.SGD( + self.dnn_model.parameters(), lr=self.lr, weight_decay=self.weight_decay + ) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) if scheduler == "default": # In torch version 2.7.0, the verbose parameter has been removed. Reference Link: # https://github.com/pytorch/pytorch/pull/147301/files#diff-036a7470d5307f13c9a6a51c3a65dd014f00ca02f476c545488cd856bea9bcf2L1313 if str(torch.__version__).split("+", maxsplit=1)[0] <= "2.6.0": # Reduce learning rate when loss has stopped decrease - self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( # pylint: disable=E1123 - self.train_optimizer, - mode="min", - factor=0.5, - patience=10, - verbose=True, - threshold=0.0001, - threshold_mode="rel", - cooldown=0, - min_lr=0.00001, - eps=1e-08, + self.scheduler = ( + torch.optim.lr_scheduler.ReduceLROnPlateau( # pylint: disable=E1123 + self.train_optimizer, + mode="min", + factor=0.5, + patience=10, + verbose=True, + threshold=0.0001, + threshold_mode="rel", + cooldown=0, + min_lr=0.00001, + eps=1e-08, + ) ) else: self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( @@ -203,12 +219,18 @@ def fit( if seg in dataset.segments: # df_train df_valid df = dataset.prepare( - seg, col_set=["feature", "label"], data_key=self.valid_key if seg == "valid" else DataHandlerLP.DK_L + seg, + col_set=["feature", "label"], + data_key=self.valid_key if seg == "valid" else DataHandlerLP.DK_L, ) all_df["x"][seg] = df["feature"] - all_df["y"][seg] = df["label"].copy() # We have to use copy to remove the reference to release mem + all_df["y"][seg] = df[ + "label" + ].copy() # We have to use copy to remove the reference to release mem if reweighter is None: - all_df["w"][seg] = pd.DataFrame(np.ones_like(all_df["y"][seg].values), index=df.index) + all_df["w"][seg] = pd.DataFrame( + np.ones_like(all_df["y"][seg].values), index=df.index + ) elif isinstance(reweighter, Reweighter): all_df["w"][seg] = pd.DataFrame(reweighter.reweight(df)) else: @@ -218,7 +240,9 @@ def fit( for v in vars: all_t[v][seg] = torch.from_numpy(all_df[v][seg].values).float() # if seg == "valid": # accelerate the eval of validation - all_t[v][seg] = all_t[v][seg].to(self.device) # This will consume a lot of memory !!!! + all_t[v][seg] = all_t[v][seg].to( + self.device + ) # This will consume a lot of memory !!!! evals_result[seg] = [] # free memory @@ -271,11 +295,18 @@ def fit( # forward preds = self._nn_predict(all_t["x"]["valid"], return_cpu=False) - cur_loss_val = self.get_loss(preds, all_t["w"]["valid"], all_t["y"]["valid"], self.loss_type) + cur_loss_val = self.get_loss( + preds, + all_t["w"]["valid"], + all_t["y"]["valid"], + self.loss_type, + ) loss_val = cur_loss_val.item() metric_val = ( self.get_metric( - preds.reshape(-1), all_t["y"]["valid"].reshape(-1), all_df["y"]["valid"].index + preds.reshape(-1), + all_t["y"]["valid"].reshape(-1), + all_df["y"]["valid"].index, ) .detach() .cpu() @@ -288,7 +319,9 @@ def fit( if self.eval_train_metric: metric_train = ( self.get_metric( - self._nn_predict(all_t["x"]["train"], return_cpu=False), + self._nn_predict( + all_t["x"]["train"], return_cpu=False + ), all_t["y"]["train"].reshape(-1), all_df["y"]["train"].index, ) @@ -321,7 +354,9 @@ def fit( train_loss = 0 # update learning rate if self.scheduler is not None: - auto_filter_kwargs(self.scheduler.step, warning=False)(metrics=cur_loss_val, epoch=step) + auto_filter_kwargs(self.scheduler.step, warning=False)( + metrics=cur_loss_val, epoch=step + ) R.log_metrics(lr=self.get_lr(), step=step) else: # retraining mode @@ -330,7 +365,9 @@ def fit( if has_valid: # restore the optimal parameters after training - self.dnn_model.load_state_dict(torch.load(save_path, map_location=self.device)) + self.dnn_model.load_state_dict( + torch.load(save_path, map_location=self.device) + ) if self.use_gpu: torch.cuda.empty_cache() @@ -381,7 +418,9 @@ def _nn_predict(self, data, return_cpu=True): def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - x_test_pd = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + x_test_pd = dataset.prepare( + segment, col_set="feature", data_key=DataHandlerLP.DK_I + ) preds = self._nn_predict(x_test_pd) return pd.Series(preds.reshape(-1), index=x_test_pd.index) @@ -394,12 +433,16 @@ def save(self, filename, **kwargs): def load(self, buffer, **kwargs): with unpack_archive_with_buffer(buffer) as model_dir: # Get model name - _model_name = os.path.splitext(list(filter(lambda x: x.startswith("model.bin"), os.listdir(model_dir)))[0])[ - 0 - ] + _model_name = os.path.splitext( + list( + filter(lambda x: x.startswith("model.bin"), os.listdir(model_dir)) + )[0] + )[0] _model_path = os.path.join(model_dir, _model_name) # Load model - self.dnn_model.load_state_dict(torch.load(_model_path, map_location=self.device)) + self.dnn_model.load_state_dict( + torch.load(_model_path, map_location=self.device) + ) self.fitted = True @@ -453,7 +496,9 @@ def __init__(self, input_dim, output_dim=1, layers=(256,), act="LeakyReLU"): def _weight_init(self): for m in self.modules(): if isinstance(m, nn.Linear): - nn.init.kaiming_normal_(m.weight, a=0.1, mode="fan_in", nonlinearity="leaky_relu") + nn.init.kaiming_normal_( + m.weight, a=0.1, mode="fan_in", nonlinearity="leaky_relu" + ) def forward(self, x): cur_output = x diff --git a/qlib/contrib/model/pytorch_sandwich.py b/qlib/contrib/model/pytorch_sandwich.py index 344368143f..15f468dc35 100644 --- a/qlib/contrib/model/pytorch_sandwich.py +++ b/qlib/contrib/model/pytorch_sandwich.py @@ -152,7 +152,9 @@ def __init__( self.early_stop = early_stop self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.seed = seed self.logger.info( @@ -215,11 +217,17 @@ def __init__( device=self.device, ) if optimizer.lower() == "adam": - self.train_optimizer = optim.Adam(self.sandwich_model.parameters(), lr=self.lr) + self.train_optimizer = optim.Adam( + self.sandwich_model.parameters(), lr=self.lr + ) elif optimizer.lower() == "gd": - self.train_optimizer = optim.SGD(self.sandwich_model.parameters(), lr=self.lr) + self.train_optimizer = optim.SGD( + self.sandwich_model.parameters(), lr=self.lr + ) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) self.fitted = False self.sandwich_model.to(self.device) @@ -260,8 +268,16 @@ def train_epoch(self, x_train, y_train): if len(indices) - i < self.batch_size: break - feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device) - label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device) + feature = ( + torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) + label = ( + torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) pred = self.sandwich_model(feature) loss = self.loss_fn(pred, label) @@ -287,8 +303,16 @@ def test_epoch(self, data_x, data_y): if len(indices) - i < self.batch_size: break - feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device) - label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device) + feature = ( + torch.from_numpy(x_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) + label = ( + torch.from_numpy(y_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) pred = self.sandwich_model(feature) loss = self.loss_fn(pred, label) @@ -311,7 +335,9 @@ def fit( data_key=DataHandlerLP.DK_L, ) if df_train.empty or df_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) x_train, y_train = df_train["feature"], df_train["label"] x_valid, y_valid = df_valid["feature"], df_valid["label"] @@ -361,7 +387,9 @@ def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare( + segment, col_set="feature", data_key=DataHandlerLP.DK_I + ) index = x_test.index self.sandwich_model.eval() x_values = x_test.values diff --git a/qlib/contrib/model/pytorch_sfm.py b/qlib/contrib/model/pytorch_sfm.py index c971f1a58c..58bb7e4bbc 100644 --- a/qlib/contrib/model/pytorch_sfm.py +++ b/qlib/contrib/model/pytorch_sfm.py @@ -41,30 +41,52 @@ def __init__( self.hidden_dim = hidden_size self.device = device - self.W_i = nn.Parameter(init.xavier_uniform_(torch.empty((self.input_dim, self.hidden_dim)))) - self.U_i = nn.Parameter(init.orthogonal_(torch.empty(self.hidden_dim, self.hidden_dim))) + self.W_i = nn.Parameter( + init.xavier_uniform_(torch.empty((self.input_dim, self.hidden_dim))) + ) + self.U_i = nn.Parameter( + init.orthogonal_(torch.empty(self.hidden_dim, self.hidden_dim)) + ) self.b_i = nn.Parameter(torch.zeros(self.hidden_dim)) - self.W_ste = nn.Parameter(init.xavier_uniform_(torch.empty(self.input_dim, self.hidden_dim))) - self.U_ste = nn.Parameter(init.orthogonal_(torch.empty(self.hidden_dim, self.hidden_dim))) + self.W_ste = nn.Parameter( + init.xavier_uniform_(torch.empty(self.input_dim, self.hidden_dim)) + ) + self.U_ste = nn.Parameter( + init.orthogonal_(torch.empty(self.hidden_dim, self.hidden_dim)) + ) self.b_ste = nn.Parameter(torch.ones(self.hidden_dim)) - self.W_fre = nn.Parameter(init.xavier_uniform_(torch.empty(self.input_dim, self.freq_dim))) - self.U_fre = nn.Parameter(init.orthogonal_(torch.empty(self.hidden_dim, self.freq_dim))) + self.W_fre = nn.Parameter( + init.xavier_uniform_(torch.empty(self.input_dim, self.freq_dim)) + ) + self.U_fre = nn.Parameter( + init.orthogonal_(torch.empty(self.hidden_dim, self.freq_dim)) + ) self.b_fre = nn.Parameter(torch.ones(self.freq_dim)) - self.W_c = nn.Parameter(init.xavier_uniform_(torch.empty(self.input_dim, self.hidden_dim))) - self.U_c = nn.Parameter(init.orthogonal_(torch.empty(self.hidden_dim, self.hidden_dim))) + self.W_c = nn.Parameter( + init.xavier_uniform_(torch.empty(self.input_dim, self.hidden_dim)) + ) + self.U_c = nn.Parameter( + init.orthogonal_(torch.empty(self.hidden_dim, self.hidden_dim)) + ) self.b_c = nn.Parameter(torch.zeros(self.hidden_dim)) - self.W_o = nn.Parameter(init.xavier_uniform_(torch.empty(self.input_dim, self.hidden_dim))) - self.U_o = nn.Parameter(init.orthogonal_(torch.empty(self.hidden_dim, self.hidden_dim))) + self.W_o = nn.Parameter( + init.xavier_uniform_(torch.empty(self.input_dim, self.hidden_dim)) + ) + self.U_o = nn.Parameter( + init.orthogonal_(torch.empty(self.hidden_dim, self.hidden_dim)) + ) self.b_o = nn.Parameter(torch.zeros(self.hidden_dim)) self.U_a = nn.Parameter(init.orthogonal_(torch.empty(self.freq_dim, 1))) self.b_a = nn.Parameter(torch.zeros(self.hidden_dim)) - self.W_p = nn.Parameter(init.xavier_uniform_(torch.empty(self.hidden_dim, self.output_dim))) + self.W_p = nn.Parameter( + init.xavier_uniform_(torch.empty(self.hidden_dim, self.output_dim)) + ) self.b_p = nn.Parameter(torch.zeros(self.output_dim)) self.activation = nn.Tanh() @@ -100,8 +122,12 @@ def forward(self, input): x_o = torch.matmul(x * B_W[0], self.W_o) + self.b_o i = self.inner_activation(x_i + torch.matmul(h_tm1 * B_U[0], self.U_i)) - ste = self.inner_activation(x_ste + torch.matmul(h_tm1 * B_U[0], self.U_ste)) - fre = self.inner_activation(x_fre + torch.matmul(h_tm1 * B_U[0], self.U_fre)) + ste = self.inner_activation( + x_ste + torch.matmul(h_tm1 * B_U[0], self.U_ste) + ) + fre = self.inner_activation( + x_fre + torch.matmul(h_tm1 * B_U[0], self.U_fre) + ) ste = torch.reshape(ste, (-1, self.hidden_dim, 1)) fre = torch.reshape(fre, (-1, 1, self.freq_dim)) @@ -233,7 +259,9 @@ def __init__( self.eval_steps = eval_steps self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.seed = seed self.logger.info( @@ -289,14 +317,18 @@ def __init__( device=self.device, ) self.logger.info("model:\n{:}".format(self.sfm_model)) - self.logger.info("model size: {:.4f} MB".format(count_parameters(self.sfm_model))) + self.logger.info( + "model size: {:.4f} MB".format(count_parameters(self.sfm_model)) + ) if optimizer.lower() == "adam": self.train_optimizer = optim.Adam(self.sfm_model.parameters(), lr=self.lr) elif optimizer.lower() == "gd": self.train_optimizer = optim.SGD(self.sfm_model.parameters(), lr=self.lr) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) self.fitted = False self.sfm_model.to(self.device) @@ -321,8 +353,16 @@ def test_epoch(self, data_x, data_y): if len(indices) - i < self.batch_size: break - feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device) - label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device) + feature = ( + torch.from_numpy(x_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) + label = ( + torch.from_numpy(y_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) pred = self.sfm_model(feature) loss = self.loss_fn(pred, label) @@ -346,8 +386,16 @@ def train_epoch(self, x_train, y_train): if len(indices) - i < self.batch_size: break - feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device) - label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device) + feature = ( + torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) + label = ( + torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) pred = self.sfm_model(feature) loss = self.loss_fn(pred, label) @@ -369,7 +417,9 @@ def fit( data_key=DataHandlerLP.DK_L, ) if df_train.empty or df_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) x_train, y_train = df_train["feature"], df_train["label"] x_valid, y_valid = df_valid["feature"], df_valid["label"] @@ -437,7 +487,9 @@ def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare( + segment, col_set="feature", data_key=DataHandlerLP.DK_I + ) index = x_test.index self.sfm_model.eval() x_values = x_test.values diff --git a/qlib/contrib/model/pytorch_tabnet.py b/qlib/contrib/model/pytorch_tabnet.py index 3c698edade..d17769e636 100644 --- a/qlib/contrib/model/pytorch_tabnet.py +++ b/qlib/contrib/model/pytorch_tabnet.py @@ -69,7 +69,9 @@ def __init__( self.n_epochs = n_epochs self.logger = get_module_logger("TabNet") self.pretrain_n_epochs = pretrain_n_epochs - self.device = "cuda:%s" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" + self.device = ( + "cuda:%s" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.loss = loss self.metric = metric self.early_stop = early_stop @@ -86,24 +88,42 @@ def __init__( np.random.seed(self.seed) torch.manual_seed(self.seed) - self.tabnet_model = TabNet(inp_dim=self.d_feat, out_dim=self.out_dim, vbs=vbs, relax=relax).to(self.device) - self.tabnet_decoder = TabNet_Decoder(self.out_dim, self.d_feat, n_shared, n_ind, vbs, n_steps).to(self.device) - self.logger.info("model:\n{:}\n{:}".format(self.tabnet_model, self.tabnet_decoder)) - self.logger.info("model size: {:.4f} MB".format(count_parameters([self.tabnet_model, self.tabnet_decoder]))) + self.tabnet_model = TabNet( + inp_dim=self.d_feat, out_dim=self.out_dim, vbs=vbs, relax=relax + ).to(self.device) + self.tabnet_decoder = TabNet_Decoder( + self.out_dim, self.d_feat, n_shared, n_ind, vbs, n_steps + ).to(self.device) + self.logger.info( + "model:\n{:}\n{:}".format(self.tabnet_model, self.tabnet_decoder) + ) + self.logger.info( + "model size: {:.4f} MB".format( + count_parameters([self.tabnet_model, self.tabnet_decoder]) + ) + ) if optimizer.lower() == "adam": self.pretrain_optimizer = optim.Adam( - list(self.tabnet_model.parameters()) + list(self.tabnet_decoder.parameters()), lr=self.lr + list(self.tabnet_model.parameters()) + + list(self.tabnet_decoder.parameters()), + lr=self.lr, + ) + self.train_optimizer = optim.Adam( + self.tabnet_model.parameters(), lr=self.lr ) - self.train_optimizer = optim.Adam(self.tabnet_model.parameters(), lr=self.lr) elif optimizer.lower() == "gd": self.pretrain_optimizer = optim.SGD( - list(self.tabnet_model.parameters()) + list(self.tabnet_decoder.parameters()), lr=self.lr + list(self.tabnet_model.parameters()) + + list(self.tabnet_decoder.parameters()), + lr=self.lr, ) self.train_optimizer = optim.SGD(self.tabnet_model.parameters(), lr=self.lr) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) @property def use_gpu(self): @@ -159,17 +179,23 @@ def fit( self.logger.info("Pretrain...") self.pretrain_fn(dataset, self.pretrain_file) self.logger.info("Load Pretrain model") - self.tabnet_model.load_state_dict(torch.load(self.pretrain_file, map_location=self.device)) + self.tabnet_model.load_state_dict( + torch.load(self.pretrain_file, map_location=self.device) + ) # adding one more linear layer to fit the final output dimension - self.tabnet_model = FinetuneModel(self.out_dim, self.final_out_dim, self.tabnet_model).to(self.device) + self.tabnet_model = FinetuneModel( + self.out_dim, self.final_out_dim, self.tabnet_model + ).to(self.device) df_train, df_valid = dataset.prepare( ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L, ) if df_train.empty or df_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) df_train.fillna(df_train.mean(), inplace=True) x_train, y_train = df_train["feature"], df_train["label"] x_valid, y_valid = df_valid["feature"], df_valid["label"] @@ -218,7 +244,9 @@ def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare( + segment, col_set="feature", data_key=DataHandlerLP.DK_I + ) index = x_test.index self.tabnet_model.eval() x_values = torch.from_numpy(x_test.values) @@ -285,8 +313,12 @@ def train_epoch(self, x_train, y_train): if len(indices) - i < self.batch_size: break - feature = x_train_values[indices[i : i + self.batch_size]].float().to(self.device) - label = y_train_values[indices[i : i + self.batch_size]].float().to(self.device) + feature = ( + x_train_values[indices[i : i + self.batch_size]].float().to(self.device) + ) + label = ( + y_train_values[indices[i : i + self.batch_size]].float().to(self.device) + ) priors = torch.ones(self.batch_size, self.d_feat).to(self.device) pred = self.tabnet_model(feature, priors) loss = self.loss_fn(pred, label) @@ -309,7 +341,9 @@ def pretrain_epoch(self, x_train): if len(indices) - i < self.batch_size: break - S_mask = torch.bernoulli(torch.empty(self.batch_size, self.d_feat).fill_(self.ps)) + S_mask = torch.bernoulli( + torch.empty(self.batch_size, self.d_feat).fill_(self.ps) + ) x_train_values = train_set[indices[i : i + self.batch_size]] * (1 - S_mask) y_train_values = train_set[indices[i : i + self.batch_size]] * (S_mask) @@ -339,7 +373,9 @@ def pretrain_test_epoch(self, x_train): if len(indices) - i < self.batch_size: break - S_mask = torch.bernoulli(torch.empty(self.batch_size, self.d_feat).fill_(self.ps)) + S_mask = torch.bernoulli( + torch.empty(self.batch_size, self.d_feat).fill_(self.ps) + ) x_train_values = train_set[indices[i : i + self.batch_size]] * (1 - S_mask) y_train_values = train_set[indices[i : i + self.batch_size]] * (S_mask) @@ -418,7 +454,9 @@ def __init__(self, inp_dim, out_dim, n_shared, n_ind, vbs, n_steps): self.shared = nn.ModuleList() self.shared.append(nn.Linear(inp_dim, 2 * out_dim)) for x in range(n_shared - 1): - self.shared.append(nn.Linear(out_dim, 2 * out_dim)) # preset the linear function we will use + self.shared.append( + nn.Linear(out_dim, 2 * out_dim) + ) # preset the linear function we will use else: self.shared = None self.n_steps = n_steps @@ -434,7 +472,18 @@ def forward(self, x): class TabNet(nn.Module): - def __init__(self, inp_dim=6, out_dim=6, n_d=64, n_a=64, n_shared=2, n_ind=2, n_steps=5, relax=1.2, vbs=1024): + def __init__( + self, + inp_dim=6, + out_dim=6, + n_d=64, + n_a=64, + n_shared=2, + n_ind=2, + n_steps=5, + relax=1.2, + vbs=1024, + ): """ TabNet AKA the original encoder @@ -454,14 +503,20 @@ def __init__(self, inp_dim=6, out_dim=6, n_d=64, n_a=64, n_shared=2, n_ind=2, n_ self.shared = nn.ModuleList() self.shared.append(nn.Linear(inp_dim, 2 * (n_d + n_a))) for x in range(n_shared - 1): - self.shared.append(nn.Linear(n_d + n_a, 2 * (n_d + n_a))) # preset the linear function we will use + self.shared.append( + nn.Linear(n_d + n_a, 2 * (n_d + n_a)) + ) # preset the linear function we will use else: self.shared = None - self.first_step = FeatureTransformer(inp_dim, n_d + n_a, self.shared, n_ind, vbs) + self.first_step = FeatureTransformer( + inp_dim, n_d + n_a, self.shared, n_ind, vbs + ) self.steps = nn.ModuleList() for x in range(n_steps - 1): - self.steps.append(DecisionStep(inp_dim, n_d, n_a, self.shared, n_ind, relax, vbs)) + self.steps.append( + DecisionStep(inp_dim, n_d, n_a, self.shared, n_ind, relax, vbs) + ) self.fc = nn.Linear(n_d, out_dim) self.bn = nn.BatchNorm1d(inp_dim, momentum=0.01) self.n_d = n_d @@ -474,7 +529,9 @@ def forward(self, x, priors): out = torch.zeros(x.size(0), self.n_d).to(x.device) for step in self.steps: x_te, loss = step(x, x_a, priors) - out += F.relu(x_te[:, : self.n_d]) # split the feature from feat_transformer + out += F.relu( + x_te[:, : self.n_d] + ) # split the feature from feat_transformer x_a = x_te[:, self.n_d :] sparse_loss.append(loss) return self.fc(out), sum(sparse_loss) diff --git a/qlib/contrib/model/pytorch_tcn.py b/qlib/contrib/model/pytorch_tcn.py index f6e7e953a0..bb0060c3c2 100755 --- a/qlib/contrib/model/pytorch_tcn.py +++ b/qlib/contrib/model/pytorch_tcn.py @@ -75,7 +75,9 @@ def __init__( self.early_stop = early_stop self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.seed = seed self.logger.info( @@ -125,14 +127,18 @@ def __init__( dropout=self.dropout, ) self.logger.info("model:\n{:}".format(self.tcn_model)) - self.logger.info("model size: {:.4f} MB".format(count_parameters(self.tcn_model))) + self.logger.info( + "model size: {:.4f} MB".format(count_parameters(self.tcn_model)) + ) if optimizer.lower() == "adam": self.train_optimizer = optim.Adam(self.tcn_model.parameters(), lr=self.lr) elif optimizer.lower() == "gd": self.train_optimizer = optim.SGD(self.tcn_model.parameters(), lr=self.lr) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) self.fitted = False self.tcn_model.to(self.device) @@ -174,8 +180,16 @@ def train_epoch(self, x_train, y_train): if len(indices) - i < self.batch_size: break - feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device) - label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device) + feature = ( + torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) + label = ( + torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) pred = self.tcn_model(feature) loss = self.loss_fn(pred, label) @@ -200,8 +214,16 @@ def test_epoch(self, data_x, data_y): if len(indices) - i < self.batch_size: break - feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device) - label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device) + feature = ( + torch.from_numpy(x_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) + label = ( + torch.from_numpy(y_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) with torch.no_grad(): pred = self.tcn_model(feature) @@ -273,7 +295,9 @@ def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare( + segment, col_set="feature", data_key=DataHandlerLP.DK_I + ) index = x_test.index self.tcn_model.eval() x_values = x_test.values @@ -300,7 +324,9 @@ class TCNModel(nn.Module): def __init__(self, num_input, output_size, num_channels, kernel_size, dropout): super().__init__() self.num_input = num_input - self.tcn = TemporalConvNet(num_input, num_channels, kernel_size, dropout=dropout) + self.tcn = TemporalConvNet( + num_input, num_channels, kernel_size, dropout=dropout + ) self.linear = nn.Linear(num_channels[-1], output_size) def forward(self, x): diff --git a/qlib/contrib/model/pytorch_tcn_ts.py b/qlib/contrib/model/pytorch_tcn_ts.py index a6cc38885c..6772c5ec14 100755 --- a/qlib/contrib/model/pytorch_tcn_ts.py +++ b/qlib/contrib/model/pytorch_tcn_ts.py @@ -73,7 +73,9 @@ def __init__( self.early_stop = early_stop self.optimizer = optimizer.lower() self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.n_jobs = n_jobs self.seed = seed @@ -126,14 +128,18 @@ def __init__( dropout=self.dropout, ) self.logger.info("model:\n{:}".format(self.TCN_model)) - self.logger.info("model size: {:.4f} MB".format(count_parameters(self.TCN_model))) + self.logger.info( + "model size: {:.4f} MB".format(count_parameters(self.TCN_model)) + ) if optimizer.lower() == "adam": self.train_optimizer = optim.Adam(self.TCN_model.parameters(), lr=self.lr) elif optimizer.lower() == "gd": self.train_optimizer = optim.SGD(self.TCN_model.parameters(), lr=self.lr) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) self.fitted = False self.TCN_model.to(self.device) @@ -206,8 +212,12 @@ def fit( evals_result=dict(), save_path=None, ): - dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) - dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + dl_train = dataset.prepare( + "train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) + dl_valid = dataset.prepare( + "valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) # process nan brought by dataloader dl_train.config(fillna_type="ffill+bfill") @@ -215,10 +225,18 @@ def fit( dl_valid.config(fillna_type="ffill+bfill") train_loader = DataLoader( - dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True + dl_train, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.n_jobs, + drop_last=True, ) valid_loader = DataLoader( - dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True + dl_valid, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.n_jobs, + drop_last=True, ) save_path = get_or_create_path(save_path) @@ -267,9 +285,13 @@ def predict(self, dataset): if not self.fitted: raise ValueError("model is not fitted yet!") - dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) + dl_test = dataset.prepare( + "test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I + ) dl_test.config(fillna_type="ffill+bfill") - test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs) + test_loader = DataLoader( + dl_test, batch_size=self.batch_size, num_workers=self.n_jobs + ) self.TCN_model.eval() preds = [] @@ -288,7 +310,9 @@ class TCNModel(nn.Module): def __init__(self, num_input, output_size, num_channels, kernel_size, dropout): super().__init__() self.num_input = num_input - self.tcn = TemporalConvNet(num_input, num_channels, kernel_size, dropout=dropout) + self.tcn = TemporalConvNet( + num_input, num_channels, kernel_size, dropout=dropout + ) self.linear = nn.Linear(num_channels[-1], output_size) def forward(self, x): diff --git a/qlib/contrib/model/pytorch_tcts.py b/qlib/contrib/model/pytorch_tcts.py index d8736627c2..d577204df9 100644 --- a/qlib/contrib/model/pytorch_tcts.py +++ b/qlib/contrib/model/pytorch_tcts.py @@ -73,7 +73,9 @@ def __init__( self.batch_size = batch_size self.early_stop = early_stop self.loss = loss - self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu") + self.device = torch.device( + "cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu" + ) self.use_gpu = torch.cuda.is_available() self.seed = seed self.input_dim = input_dim @@ -159,14 +161,29 @@ def train_epoch(self, x_train, y_train, x_valid, y_valid): if len(indices) - i < self.batch_size: break - feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device) - label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device) + feature = ( + torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) + label = ( + torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) init_pred = init_fore_model(feature) pred = self.fore_model(feature) dis = init_pred - label.transpose(0, 1) weight_feature = torch.cat( - (feature, dis.transpose(0, 1), label, init_pred.view(-1, 1), task_embedding), 1 + ( + feature, + dis.transpose(0, 1), + label, + init_pred.view(-1, 1), + task_embedding, + ), + 1, ) weight = self.weight_model(weight_feature) @@ -192,16 +209,29 @@ def train_epoch(self, x_train, y_train, x_valid, y_valid): if len(indices) - i < self.batch_size: break - feature = torch.from_numpy(x_valid_values[indices[i : i + self.batch_size]]).float().to(self.device) - label = torch.from_numpy(y_valid_values[indices[i : i + self.batch_size]]).float().to(self.device) + feature = ( + torch.from_numpy(x_valid_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) + label = ( + torch.from_numpy(y_valid_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) pred = self.fore_model(feature) dis = pred - label.transpose(0, 1) - weight_feature = torch.cat((feature, dis.transpose(0, 1), label, pred.view(-1, 1), task_embedding), 1) + weight_feature = torch.cat( + (feature, dis.transpose(0, 1), label, pred.view(-1, 1), task_embedding), + 1, + ) weight = self.weight_model(weight_feature) loc = torch.argmax(weight, 1) valid_loss = torch.mean((pred - label[:, abs(self.target_label)]) ** 2) - loss = torch.mean(valid_loss * torch.log(weight[np.arange(weight.shape[0]), loc])) + loss = torch.mean( + valid_loss * torch.log(weight[np.arange(weight.shape[0]), loc]) + ) self.weight_optimizer.zero_grad() loss.backward() @@ -223,8 +253,16 @@ def test_epoch(self, data_x, data_y): if len(indices) - i < self.batch_size: break - feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device) - label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device) + feature = ( + torch.from_numpy(x_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) + label = ( + torch.from_numpy(y_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) pred = self.fore_model(feature) loss = torch.mean((pred - label[:, abs(self.target_label)]) ** 2) @@ -244,7 +282,9 @@ def fit( data_key=DataHandlerLP.DK_L, ) if df_train.empty or df_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) x_train, y_train = df_train["feature"], df_train["label"] x_valid, y_valid = df_valid["feature"], df_valid["label"] @@ -263,7 +303,14 @@ def fit( torch.manual_seed(self.seed) best_loss = self.training( - x_train, y_train, x_valid, y_valid, x_test, y_test, verbose=verbose, save_path=save_path + x_train, + y_train, + x_valid, + y_valid, + x_test, + y_test, + verbose=verbose, + save_path=save_path, ) def training( @@ -291,17 +338,29 @@ def training( output_dim=self.output_dim, ) if self._fore_optimizer.lower() == "adam": - self.fore_optimizer = optim.Adam(self.fore_model.parameters(), lr=self.fore_lr) + self.fore_optimizer = optim.Adam( + self.fore_model.parameters(), lr=self.fore_lr + ) elif self._fore_optimizer.lower() == "gd": - self.fore_optimizer = optim.SGD(self.fore_model.parameters(), lr=self.fore_lr) + self.fore_optimizer = optim.SGD( + self.fore_model.parameters(), lr=self.fore_lr + ) else: - raise NotImplementedError("optimizer {} is not supported!".format(self._fore_optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(self._fore_optimizer) + ) if self._weight_optimizer.lower() == "adam": - self.weight_optimizer = optim.Adam(self.weight_model.parameters(), lr=self.weight_lr) + self.weight_optimizer = optim.Adam( + self.weight_model.parameters(), lr=self.weight_lr + ) elif self._weight_optimizer.lower() == "gd": - self.weight_optimizer = optim.SGD(self.weight_model.parameters(), lr=self.weight_lr) + self.weight_optimizer = optim.SGD( + self.weight_model.parameters(), lr=self.weight_lr + ) else: - raise NotImplementedError("optimizer {} is not supported!".format(self._weight_optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(self._weight_optimizer) + ) self.fitted = False self.fore_model.to(self.device) @@ -327,8 +386,14 @@ def training( best_loss = val_loss stop_round = 0 best_epoch = epoch - torch.save(copy.deepcopy(self.fore_model.state_dict()), save_path + "_fore_model.bin") - torch.save(copy.deepcopy(self.weight_model.state_dict()), save_path + "_weight_model.bin") + torch.save( + copy.deepcopy(self.fore_model.state_dict()), + save_path + "_fore_model.bin", + ) + torch.save( + copy.deepcopy(self.weight_model.state_dict()), + save_path + "_weight_model.bin", + ) else: stop_round += 1 @@ -339,7 +404,9 @@ def training( print("best loss:", best_loss, "@", best_epoch) best_param = torch.load(save_path + "_fore_model.bin", map_location=self.device) self.fore_model.load_state_dict(best_param) - best_param = torch.load(save_path + "_weight_model.bin", map_location=self.device) + best_param = torch.load( + save_path + "_weight_model.bin", map_location=self.device + ) self.weight_model.load_state_dict(best_param) self.fitted = True @@ -379,7 +446,9 @@ def predict(self, dataset): class MLPModel(nn.Module): - def __init__(self, d_feat, hidden_size=256, num_layers=3, dropout=0.0, output_dim=1): + def __init__( + self, d_feat, hidden_size=256, num_layers=3, dropout=0.0, output_dim=1 + ): super().__init__() self.mlp = nn.Sequential() @@ -388,7 +457,9 @@ def __init__(self, d_feat, hidden_size=256, num_layers=3, dropout=0.0, output_di for i in range(num_layers): if i > 0: self.mlp.add_module("drop_%d" % i, nn.Dropout(dropout)) - self.mlp.add_module("fc_%d" % i, nn.Linear(d_feat if i == 0 else hidden_size, hidden_size)) + self.mlp.add_module( + "fc_%d" % i, nn.Linear(d_feat if i == 0 else hidden_size, hidden_size) + ) self.mlp.add_module("relu_%d" % i, nn.ReLU()) self.mlp.add_module("fc_out", nn.Linear(hidden_size, output_dim)) diff --git a/qlib/contrib/model/pytorch_tra.py b/qlib/contrib/model/pytorch_tra.py index bc9a6aa977..fd0648eec0 100644 --- a/qlib/contrib/model/pytorch_tra.py +++ b/qlib/contrib/model/pytorch_tra.py @@ -87,8 +87,14 @@ def __init__( self.logger = get_module_logger("TRA") assert memory_mode in ["sample", "daily"], "invalid memory mode" - assert transport_method in ["none", "router", "oracle"], f"invalid transport method {transport_method}" - assert transport_method == "none" or tra_config["num_states"] > 1, "optimal transport requires `num_states` > 1" + assert transport_method in [ + "none", + "router", + "oracle", + ], f"invalid transport method {transport_method}" + assert ( + transport_method == "none" or tra_config["num_states"] > 1 + ), "optimal transport requires `num_states` > 1" assert ( memory_mode != "daily" or tra_config["src_info"] == "TPE" ), "daily transport can only support TPE as `src_info`" @@ -122,7 +128,9 @@ def __init__( self.freeze_predictors = freeze_predictors self.transport_method = transport_method self.use_daily_transport = memory_mode == "daily" - self.transport_fn = transport_daily if self.use_daily_transport else transport_sample + self.transport_fn = ( + transport_daily if self.use_daily_transport else transport_sample + ) self._writer = None if self.logdir is not None: @@ -165,10 +173,18 @@ def _init_model(self): for param in self.tra.predictors.parameters(): param.requires_grad_(False) - self.logger.info("# model params: %d" % sum(p.numel() for p in self.model.parameters() if p.requires_grad)) - self.logger.info("# tra params: %d" % sum(p.numel() for p in self.tra.parameters() if p.requires_grad)) + self.logger.info( + "# model params: %d" + % sum(p.numel() for p in self.model.parameters() if p.requires_grad) + ) + self.logger.info( + "# tra params: %d" + % sum(p.numel() for p in self.tra.parameters() if p.requires_grad) + ) - self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr) + self.optimizer = optim.Adam( + list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr + ) self.fitted = False self.global_step = -1 @@ -185,7 +201,9 @@ def train_epoch(self, epoch, data_set, is_pretrain=False): max_steps = len(data_set) if self.max_steps_per_epoch is not None: if epoch == 0 and self.max_steps_per_epoch < max_steps: - self.logger.info(f"max steps updated from {max_steps} to {self.max_steps_per_epoch}") + self.logger.info( + f"max steps updated from {max_steps} to {self.max_steps_per_epoch}" + ) max_steps = min(self.max_steps_per_epoch, max_steps) cur_step = 0 @@ -199,7 +217,12 @@ def train_epoch(self, epoch, data_set, is_pretrain=False): if not is_pretrain: self.global_step += 1 - data, state, label, count = batch["data"], batch["state"], batch["label"], batch["daily_count"] + data, state, label, count = ( + batch["data"], + batch["state"], + batch["label"], + batch["daily_count"], + ) index = batch["daily_index"] if self.use_daily_transport else batch["index"] with torch.set_grad_enabled(not self.freeze_model): @@ -223,18 +246,30 @@ def train_epoch(self, epoch, data_set, is_pretrain=False): data_set.assign_data(index, L) # save loss to memory if self.use_daily_transport: # only save for daily transport P_all.append(pd.DataFrame(P.detach().cpu().numpy(), index=index)) - prob_all.append(pd.DataFrame(prob.detach().cpu().numpy(), index=index)) - choice_all.append(pd.DataFrame(choice.detach().cpu().numpy(), index=index)) + prob_all.append( + pd.DataFrame(prob.detach().cpu().numpy(), index=index) + ) + choice_all.append( + pd.DataFrame(choice.detach().cpu().numpy(), index=index) + ) decay = self.rho ** (self.global_step // 100) # decay every 100 steps lamb = 0 if is_pretrain else self.lamb * decay - reg = prob.log().mul(P).sum(dim=1).mean() # train router to predict TO assignment + reg = ( + prob.log().mul(P).sum(dim=1).mean() + ) # train router to predict TO assignment if self._writer is not None and not is_pretrain: - self._writer.add_scalar("training/router_loss", -reg.item(), self.global_step) - self._writer.add_scalar("training/reg_loss", loss.item(), self.global_step) + self._writer.add_scalar( + "training/router_loss", -reg.item(), self.global_step + ) + self._writer.add_scalar( + "training/reg_loss", loss.item(), self.global_step + ) self._writer.add_scalar("training/lamb", lamb, self.global_step) if not self.use_daily_transport: P_mean = P.mean(axis=0).detach() - self._writer.add_scalar("training/P", P_mean.max() / P_mean.min(), self.global_step) + self._writer.add_scalar( + "training/P", P_mean.max() / P_mean.min(), self.global_step + ) loss = loss - lamb * reg else: pred = all_preds.mean(dim=1) @@ -246,7 +281,9 @@ def train_epoch(self, epoch, data_set, is_pretrain=False): self.optimizer.zero_grad() if self._writer is not None and not is_pretrain: - self._writer.add_scalar("training/total_loss", loss.item(), self.global_step) + self._writer.add_scalar( + "training/total_loss", loss.item(), self.global_step + ) total_loss += loss.item() total_count += 1 @@ -261,7 +298,9 @@ def train_epoch(self, epoch, data_set, is_pretrain=False): if not is_pretrain: self._writer.add_image("P", plot(P_all), epoch, dataformats="HWC") self._writer.add_image("prob", plot(prob_all), epoch, dataformats="HWC") - self._writer.add_image("choice", plot(choice_all), epoch, dataformats="HWC") + self._writer.add_image( + "choice", plot(choice_all), epoch, dataformats="HWC" + ) total_loss /= total_count @@ -270,7 +309,9 @@ def train_epoch(self, epoch, data_set, is_pretrain=False): return total_loss - def test_epoch(self, epoch, data_set, return_pred=False, prefix="test", is_pretrain=False): + def test_epoch( + self, epoch, data_set, return_pred=False, prefix="test", is_pretrain=False + ): self.model.eval() self.tra.eval() data_set.eval() @@ -280,7 +321,12 @@ def test_epoch(self, epoch, data_set, return_pred=False, prefix="test", is_pretr P_all = [] metrics = [] for batch in tqdm(data_set): - data, state, label, count = batch["data"], batch["state"], batch["label"], batch["daily_count"] + data, state, label, count = ( + batch["data"], + batch["state"], + batch["label"], + batch["daily_count"], + ) index = batch["daily_index"] if self.use_daily_transport else batch["index"] with torch.no_grad(): @@ -306,7 +352,9 @@ def test_epoch(self, epoch, data_set, return_pred=False, prefix="test", is_pretr pred = all_preds.mean(dim=1) X = np.c_[pred.cpu().numpy(), label.cpu().numpy(), all_preds.cpu().numpy()] - columns = ["score", "label"] + ["score_%d" % d for d in range(all_preds.shape[1])] + columns = ["score", "label"] + [ + "score_%d" % d for d in range(all_preds.shape[1]) + ] pred = pd.DataFrame(X, index=batch["index"], columns=columns) metrics.append(evaluate(pred)) @@ -315,7 +363,9 @@ def test_epoch(self, epoch, data_set, return_pred=False, prefix="test", is_pretr preds.append(pred) if prob is not None: columns = ["prob_%d" % d for d in range(all_preds.shape[1])] - probs.append(pd.DataFrame(prob.cpu().numpy(), index=index, columns=columns)) + probs.append( + pd.DataFrame(prob.cpu().numpy(), index=index, columns=columns) + ) metrics = pd.DataFrame(metrics) metrics = { @@ -376,18 +426,26 @@ def _fit(self, train_set, valid_set, test_set, evals_result, is_pretrain=True): self.logger.info("evaluating...") # NOTE: during evaluating, the whole memory will be refreshed - if not is_pretrain and (self.transport_method == "router" or self.eval_train): + if not is_pretrain and ( + self.transport_method == "router" or self.eval_train + ): train_set.clear_memory() # NOTE: clear the shared memory - train_metrics = self.test_epoch(epoch, train_set, is_pretrain=is_pretrain, prefix="train")[0] + train_metrics = self.test_epoch( + epoch, train_set, is_pretrain=is_pretrain, prefix="train" + )[0] evals_result["train"].append(train_metrics) self.logger.info("train metrics: %s" % train_metrics) - valid_metrics = self.test_epoch(epoch, valid_set, is_pretrain=is_pretrain, prefix="valid")[0] + valid_metrics = self.test_epoch( + epoch, valid_set, is_pretrain=is_pretrain, prefix="valid" + )[0] evals_result["valid"].append(valid_metrics) self.logger.info("valid metrics: %s" % valid_metrics) if self.eval_test: - test_metrics = self.test_epoch(epoch, test_set, is_pretrain=is_pretrain, prefix="test")[0] + test_metrics = self.test_epoch( + epoch, test_set, is_pretrain=is_pretrain, prefix="test" + )[0] evals_result["test"].append(test_metrics) self.logger.info("test metrics: %s" % test_metrics) @@ -414,7 +472,9 @@ def _fit(self, train_set, valid_set, test_set, evals_result, is_pretrain=True): return best_score def fit(self, dataset, evals_result=dict()): - assert isinstance(dataset, MTSDatasetH), "TRAModel only supports `qlib.contrib.data.dataset.MTSDatasetH`" + assert isinstance( + dataset, MTSDatasetH + ), "TRAModel only supports `qlib.contrib.data.dataset.MTSDatasetH`" train_set, valid_set, test_set = dataset.prepare(["train", "valid", "test"]) @@ -428,34 +488,49 @@ def fit(self, dataset, evals_result=dict()): if self.pretrain: self.logger.info("pretraining...") self.optimizer = optim.Adam( - list(self.model.parameters()) + list(self.tra.predictors.parameters()), lr=self.lr + list(self.model.parameters()) + list(self.tra.predictors.parameters()), + lr=self.lr, ) self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=True) # reset optimizer - self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr) + self.optimizer = optim.Adam( + list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr + ) self.logger.info("training...") - best_score = self._fit(train_set, valid_set, test_set, evals_result, is_pretrain=False) + best_score = self._fit( + train_set, valid_set, test_set, evals_result, is_pretrain=False + ) self.logger.info("inference") - train_metrics, train_preds, train_probs, train_P = self.test_epoch(-1, train_set, return_pred=True) + train_metrics, train_preds, train_probs, train_P = self.test_epoch( + -1, train_set, return_pred=True + ) self.logger.info("train metrics: %s" % train_metrics) - valid_metrics, valid_preds, valid_probs, valid_P = self.test_epoch(-1, valid_set, return_pred=True) + valid_metrics, valid_preds, valid_probs, valid_P = self.test_epoch( + -1, valid_set, return_pred=True + ) self.logger.info("valid metrics: %s" % valid_metrics) - test_metrics, test_preds, test_probs, test_P = self.test_epoch(-1, test_set, return_pred=True) + test_metrics, test_preds, test_probs, test_P = self.test_epoch( + -1, test_set, return_pred=True + ) self.logger.info("test metrics: %s" % test_metrics) if self.logdir: self.logger.info("save model & pred to local directory") - pd.concat({name: pd.DataFrame(evals_result[name]) for name in evals_result}, axis=1).to_csv( - self.logdir + "/logs.csv", index=False - ) + pd.concat( + {name: pd.DataFrame(evals_result[name]) for name in evals_result}, + axis=1, + ).to_csv(self.logdir + "/logs.csv", index=False) - torch.save({"model": self.model.state_dict(), "tra": self.tra.state_dict()}, self.logdir + "/model.bin") + torch.save( + {"model": self.model.state_dict(), "tra": self.tra.state_dict()}, + self.logdir + "/model.bin", + ) train_preds.to_pickle(self.logdir + "/train_pred.pkl") valid_preds.to_pickle(self.logdir + "/valid_pred.pkl") @@ -491,13 +566,19 @@ def fit(self, dataset, evals_result=dict()): "use_daily_transport": self.use_daily_transport, }, "best_eval_metric": -best_score, # NOTE: -1 for minimize - "metrics": {"train": train_metrics, "valid": valid_metrics, "test": test_metrics}, + "metrics": { + "train": train_metrics, + "valid": valid_metrics, + "test": test_metrics, + }, } with open(self.logdir + "/info.json", "w") as f: json.dump(info, f) def predict(self, dataset, segment="test"): - assert isinstance(dataset, MTSDatasetH), "TRAModel only supports `qlib.contrib.data.dataset.MTSDatasetH`" + assert isinstance( + dataset, MTSDatasetH + ), "TRAModel only supports `qlib.contrib.data.dataset.MTSDatasetH`" if not self.fitted: raise ValueError("model is not fitted yet!") @@ -588,7 +669,9 @@ def __init__(self, d_model, dropout=0.1, max_len=5000): pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) + ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) @@ -630,7 +713,10 @@ def __init__( self.pe = PositionalEncoding(input_size, dropout) layer = nn.TransformerEncoderLayer( - nhead=num_heads, dropout=dropout, d_model=hidden_size, dim_feedforward=hidden_size * 4 + nhead=num_heads, + dropout=dropout, + d_model=hidden_size, + dim_feedforward=hidden_size * 4, ) self.encoder = nn.TransformerEncoder(layer, num_layers=num_layers) @@ -692,7 +778,10 @@ def __init__( batch_first=True, dropout=dropout, ) - self.fc = nn.Linear(hidden_size + input_size if "LR" in src_info else hidden_size, num_states) + self.fc = nn.Linear( + hidden_size + input_size if "LR" in src_info else hidden_size, + num_states, + ) else: self.fc = nn.Linear(input_size, num_states) @@ -781,7 +870,17 @@ def minmax_norm(x): return x -def transport_sample(all_preds, label, choice, prob, hist_loss, count, transport_method, alpha, training=False): +def transport_sample( + all_preds, + label, + choice, + prob, + hist_loss, + count, + transport_method, + alpha, + training=False, +): """ sample-wise transport @@ -826,7 +925,17 @@ def transport_sample(all_preds, label, choice, prob, hist_loss, count, transport return loss, pred, L, P -def transport_daily(all_preds, label, choice, prob, hist_loss, count, transport_method, alpha, training=False): +def transport_daily( + all_preds, + label, + choice, + prob, + hist_loss, + count, + transport_method, + alpha, + training=False, +): """ daily transport @@ -901,7 +1010,13 @@ def load_state_dict_unsafe(model, state_dict): def load(module, prefix=""): local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) module._load_from_state_dict( - state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs + state_dict, + prefix, + local_metadata, + True, + missing_keys, + unexpected_keys, + error_msgs, ) for name, child in module._modules.items(): if child is not None: @@ -910,7 +1025,11 @@ def load(module, prefix=""): load(model) load = None # break load->load reference cycle - return {"unexpected_keys": unexpected_keys, "missing_keys": missing_keys, "error_msgs": error_msgs} + return { + "unexpected_keys": unexpected_keys, + "missing_keys": missing_keys, + "error_msgs": error_msgs, + } def plot(P): diff --git a/qlib/contrib/model/pytorch_transformer.py b/qlib/contrib/model/pytorch_transformer.py index d05b9f4cad..69e8ca6785 100644 --- a/qlib/contrib/model/pytorch_transformer.py +++ b/qlib/contrib/model/pytorch_transformer.py @@ -57,22 +57,36 @@ def __init__( self.optimizer = optimizer.lower() self.loss = loss self.n_jobs = n_jobs - self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.seed = seed self.logger = get_module_logger("TransformerModel") - self.logger.info("Naive Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device)) + self.logger.info( + "Naive Transformer:" + "\nbatch_size : {}" + "\ndevice : {}".format(self.batch_size, self.device) + ) if self.seed is not None: np.random.seed(self.seed) torch.manual_seed(self.seed) - self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device) + self.model = Transformer( + d_feat, d_model, nhead, num_layers, dropout, self.device + ) if optimizer.lower() == "adam": - self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg) + self.train_optimizer = optim.Adam( + self.model.parameters(), lr=self.lr, weight_decay=self.reg + ) elif optimizer.lower() == "gd": - self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg) + self.train_optimizer = optim.SGD( + self.model.parameters(), lr=self.lr, weight_decay=self.reg + ) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) self.fitted = False self.model.to(self.device) @@ -114,8 +128,16 @@ def train_epoch(self, x_train, y_train): if len(indices) - i < self.batch_size: break - feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device) - label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device) + feature = ( + torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) + label = ( + torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) pred = self.model(feature) loss = self.loss_fn(pred, label) @@ -141,8 +163,16 @@ def test_epoch(self, data_x, data_y): if len(indices) - i < self.batch_size: break - feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device) - label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device) + feature = ( + torch.from_numpy(x_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) + label = ( + torch.from_numpy(y_values[indices[i : i + self.batch_size]]) + .float() + .to(self.device) + ) with torch.no_grad(): pred = self.model(feature) @@ -166,7 +196,9 @@ def fit( data_key=DataHandlerLP.DK_L, ) if df_train.empty or df_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) x_train, y_train = df_train["feature"], df_train["label"] x_valid, y_valid = df_valid["feature"], df_valid["label"] @@ -216,7 +248,9 @@ def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if not self.fitted: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare( + segment, col_set="feature", data_key=DataHandlerLP.DK_I + ) index = x_test.index self.model.eval() x_values = x_test.values @@ -244,7 +278,9 @@ def __init__(self, d_model, max_len=1000): super(PositionalEncoding, self).__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) + ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) @@ -256,12 +292,18 @@ def forward(self, x): class Transformer(nn.Module): - def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None): + def __init__( + self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None + ): super(Transformer, self).__init__() self.feature_layer = nn.Linear(d_feat, d_model) self.pos_encoder = PositionalEncoding(d_model) - self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout) - self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers) + self.encoder_layer = nn.TransformerEncoderLayer( + d_model=d_model, nhead=nhead, dropout=dropout + ) + self.transformer_encoder = nn.TransformerEncoder( + self.encoder_layer, num_layers=num_layers + ) self.decoder_layer = nn.Linear(d_model, 1) self.device = device self.d_feat = d_feat diff --git a/qlib/contrib/model/pytorch_transformer_ts.py b/qlib/contrib/model/pytorch_transformer_ts.py index 70590e03e5..e6c6b5918d 100644 --- a/qlib/contrib/model/pytorch_transformer_ts.py +++ b/qlib/contrib/model/pytorch_transformer_ts.py @@ -55,22 +55,36 @@ def __init__( self.optimizer = optimizer.lower() self.loss = loss self.n_jobs = n_jobs - self.device = torch.device("cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.device = torch.device( + "cuda:%d" % GPU if torch.cuda.is_available() and GPU >= 0 else "cpu" + ) self.seed = seed self.logger = get_module_logger("TransformerModel") - self.logger.info("Naive Transformer:" "\nbatch_size : {}" "\ndevice : {}".format(self.batch_size, self.device)) + self.logger.info( + "Naive Transformer:" + "\nbatch_size : {}" + "\ndevice : {}".format(self.batch_size, self.device) + ) if self.seed is not None: np.random.seed(self.seed) torch.manual_seed(self.seed) - self.model = Transformer(d_feat, d_model, nhead, num_layers, dropout, self.device) + self.model = Transformer( + d_feat, d_model, nhead, num_layers, dropout, self.device + ) if optimizer.lower() == "adam": - self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.reg) + self.train_optimizer = optim.Adam( + self.model.parameters(), lr=self.lr, weight_decay=self.reg + ) elif optimizer.lower() == "gd": - self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr, weight_decay=self.reg) + self.train_optimizer = optim.SGD( + self.model.parameters(), lr=self.lr, weight_decay=self.reg + ) else: - raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + raise NotImplementedError( + "optimizer {} is not supported!".format(optimizer) + ) self.fitted = False self.model.to(self.device) @@ -140,20 +154,34 @@ def fit( evals_result=dict(), save_path=None, ): - dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) - dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + dl_train = dataset.prepare( + "train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) + dl_valid = dataset.prepare( + "valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) if dl_train.empty or dl_valid.empty: - raise ValueError("Empty data from dataset, please check your dataset config.") + raise ValueError( + "Empty data from dataset, please check your dataset config." + ) dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader train_loader = DataLoader( - dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True + dl_train, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.n_jobs, + drop_last=True, ) valid_loader = DataLoader( - dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True + dl_valid, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.n_jobs, + drop_last=True, ) save_path = get_or_create_path(save_path) @@ -202,9 +230,13 @@ def predict(self, dataset): if not self.fitted: raise ValueError("model is not fitted yet!") - dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) + dl_test = dataset.prepare( + "test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I + ) dl_test.config(fillna_type="ffill+bfill") - test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs) + test_loader = DataLoader( + dl_test, batch_size=self.batch_size, num_workers=self.n_jobs + ) self.model.eval() preds = [] @@ -224,7 +256,9 @@ def __init__(self, d_model, max_len=1000): super(PositionalEncoding, self).__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) + ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) @@ -236,12 +270,18 @@ def forward(self, x): class Transformer(nn.Module): - def __init__(self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None): + def __init__( + self, d_feat=6, d_model=8, nhead=4, num_layers=2, dropout=0.5, device=None + ): super(Transformer, self).__init__() self.feature_layer = nn.Linear(d_feat, d_model) self.pos_encoder = PositionalEncoding(d_model) - self.encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout) - self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers) + self.encoder_layer = nn.TransformerEncoderLayer( + d_model=d_model, nhead=nhead, dropout=dropout + ) + self.transformer_encoder = nn.TransformerEncoder( + self.encoder_layer, num_layers=num_layers + ) self.decoder_layer = nn.Linear(d_model, 1) self.device = device self.d_feat = d_feat diff --git a/qlib/contrib/model/tcn.py b/qlib/contrib/model/tcn.py index 173404b2b8..f89457f0b5 100644 --- a/qlib/contrib/model/tcn.py +++ b/qlib/contrib/model/tcn.py @@ -14,26 +14,51 @@ def forward(self, x): class TemporalBlock(nn.Module): - def __init__(self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2): + def __init__( + self, n_inputs, n_outputs, kernel_size, stride, dilation, padding, dropout=0.2 + ): super(TemporalBlock, self).__init__() self.conv1 = weight_norm( - nn.Conv1d(n_inputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation) + nn.Conv1d( + n_inputs, + n_outputs, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ) ) self.chomp1 = Chomp1d(padding) self.relu1 = nn.ReLU() self.dropout1 = nn.Dropout(dropout) self.conv2 = weight_norm( - nn.Conv1d(n_outputs, n_outputs, kernel_size, stride=stride, padding=padding, dilation=dilation) + nn.Conv1d( + n_outputs, + n_outputs, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ) ) self.chomp2 = Chomp1d(padding) self.relu2 = nn.ReLU() self.dropout2 = nn.Dropout(dropout) self.net = nn.Sequential( - self.conv1, self.chomp1, self.relu1, self.dropout1, self.conv2, self.chomp2, self.relu2, self.dropout2 + self.conv1, + self.chomp1, + self.relu1, + self.dropout1, + self.conv2, + self.chomp2, + self.relu2, + self.dropout2, + ) + self.downsample = ( + nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None ) - self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None self.relu = nn.ReLU() self.init_weights() diff --git a/qlib/contrib/model/xgboost.py b/qlib/contrib/model/xgboost.py index 634259aab1..aaf217755b 100755 --- a/qlib/contrib/model/xgboost.py +++ b/qlib/contrib/model/xgboost.py @@ -40,7 +40,9 @@ def fit( # Lightgbm need 1D array as its label if y_train.values.ndim == 2 and y_train.values.shape[1] == 1: - y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze(y_valid.values) + y_train_1d, y_valid_1d = np.squeeze(y_train.values), np.squeeze( + y_valid.values + ) else: raise ValueError("XGBoost doesn't support multi-label training") @@ -71,7 +73,9 @@ def fit( def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): if self.model is None: raise ValueError("model is not fitted yet!") - x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I) + x_test = dataset.prepare( + segment, col_set="feature", data_key=DataHandlerLP.DK_I + ) return pd.Series(self.model.predict(xgb.DMatrix(x_test)), index=x_test.index) def get_feature_importance(self, *args, **kwargs) -> pd.Series: @@ -82,4 +86,6 @@ def get_feature_importance(self, *args, **kwargs) -> pd.Series: parameters reference: https://xgboost.readthedocs.io/en/latest/python/python_api.html#xgboost.Booster.get_score """ - return pd.Series(self.model.get_score(*args, **kwargs)).sort_values(ascending=False) + return pd.Series(self.model.get_score(*args, **kwargs)).sort_values( + ascending=False + ) diff --git a/qlib/contrib/online/manager.py b/qlib/contrib/online/manager.py index 7475bb6fc5..58fba60f32 100644 --- a/qlib/contrib/online/manager.py +++ b/qlib/contrib/online/manager.py @@ -125,8 +125,12 @@ def add_user(self, user_id, config_file, add_date): # save user user_path.mkdir() - save_instance(model, self.data_path / user_id / "model_{}.pickle".format(user_id)) - save_instance(strategy, self.data_path / user_id / "strategy_{}.pickle".format(user_id)) + save_instance( + model, self.data_path / user_id / "model_{}.pickle".format(user_id) + ) + save_instance( + strategy, self.data_path / user_id / "strategy_{}.pickle".format(user_id) + ) trade_account.save_account(self.data_path / user_id) user_record = pd.read_csv(self.users_file, index_col=0) user_record.loc[user_id] = [add_date] diff --git a/qlib/contrib/online/online_model.py b/qlib/contrib/online/online_model.py index 2dc9533df8..3282e26592 100644 --- a/qlib/contrib/online/online_model.py +++ b/qlib/contrib/online/online_model.py @@ -16,11 +16,15 @@ class ScoreFileModel(Model): """ def __init__(self, score_path): - pred_test = pd.read_csv(score_path, index_col=[0, 1], parse_dates=True, infer_datetime_format=True) + pred_test = pd.read_csv( + score_path, index_col=[0, 1], parse_dates=True, infer_datetime_format=True + ) self.pred = pred_test def get_data_with_date(self, date, **kwargs): - score = self.pred.loc(axis=0)[:, date] # (stock_id, trade_date) multi_index, score in pdate + score = self.pred.loc(axis=0)[ + :, date + ] # (stock_id, trade_date) multi_index, score in pdate score_series = score.reset_index(level="datetime", drop=True)[ "score" ] # pd.Series ; index:stock_id, data: score @@ -32,7 +36,9 @@ def predict(self, x_test, **kwargs): def score(self, x_test, **kwargs): return - def fit(self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs): + def fit( + self, x_train, y_train, x_valid, y_valid, w_train=None, w_valid=None, **kwargs + ): return def save(self, fname, **kwargs): diff --git a/qlib/contrib/online/operator.py b/qlib/contrib/online/operator.py index d5c9edd621..a6953c16da 100644 --- a/qlib/contrib/online/operator.py +++ b/qlib/contrib/online/operator.py @@ -60,7 +60,9 @@ def init(client, path, date=None): else: trade_date = pd.Timestamp(date) if not is_tradable_date(trade_date): - raise ValueError("trade date is not tradable date".format(trade_date.date())) + raise ValueError( + "trade date is not tradable date".format(trade_date.date()) + ) pred_date = get_pre_trading_date(trade_date, future=True) return um, pred_date, trade_date @@ -132,7 +134,9 @@ def generate(self, date, path): user_path=(pathlib.Path(path) / user_id), trade_date=trade_date, ) - self.logger.info("Generate order list at {} for {}".format(trade_date, user_id)) + self.logger.info( + "Generate order list at {} for {}".format(trade_date, user_id) + ) um.save_user_data(user_id) def execute(self, date, exchange_config, path): @@ -160,14 +164,20 @@ def execute(self, date, exchange_config, path): # load and execute the order list # will not modify the trade_account after executing - order_list = load_order_list(user_path=(pathlib.Path(path) / user_id), trade_date=trade_date) - trade_info = executor.execute(order_list=order_list, trade_account=user.account, trade_date=trade_date) + order_list = load_order_list( + user_path=(pathlib.Path(path) / user_id), trade_date=trade_date + ) + trade_info = executor.execute( + order_list=order_list, trade_account=user.account, trade_date=trade_date + ) executor.save_executed_file_from_trade_info( trade_info=trade_info, user_path=(pathlib.Path(path) / user_id), trade_date=trade_date, ) - self.logger.info("execute order list at {} for {}".format(trade_date.date(), user_id)) + self.logger.info( + "execute order list at {} for {}".format(trade_date.date(), user_id) + ) def update(self, date, path, type="SIM"): """Update account at 'date'. @@ -205,10 +215,14 @@ def update(self, date, path, type="SIM"): score_series = load_score_series((pathlib.Path(path) / user_id), trade_date) update_account(user.account, trade_info, trade_exchange, trade_date) - portfolio_metrics = user.account.portfolio_metrics.generate_portfolio_metrics_dataframe() + portfolio_metrics = ( + user.account.portfolio_metrics.generate_portfolio_metrics_dataframe() + ) self.logger.info(portfolio_metrics) um.save_user_data(user_id) - self.logger.info("Update account state {} for {}".format(trade_date, user_id)) + self.logger.info( + "Update account state {} for {}".format(trade_date, user_id) + ) def simulate(self, id, config, exchange_config, start, end, path, bench="SH000905"): """Run the ( generate_trade_decision -> execute_order_list -> update_account) process everyday @@ -265,18 +279,26 @@ def simulate(self, id, config, exchange_config, start, end, path, bench="SH00090 trade_exchange=trade_exchange, trade_date=trade_date, ) - save_order_list(order_list=order_list, user_path=user_path, trade_date=trade_date) + save_order_list( + order_list=order_list, user_path=user_path, trade_date=trade_date + ) # 4. auto execute order list order_list = load_order_list(user_path=user_path, trade_date=trade_date) - trade_info = executor.execute(trade_account=user.account, order_list=order_list, trade_date=trade_date) + trade_info = executor.execute( + trade_account=user.account, order_list=order_list, trade_date=trade_date + ) executor.save_executed_file_from_trade_info( trade_info=trade_info, user_path=user_path, trade_date=trade_date ) # 5. update account state - trade_info = executor.load_trade_info_from_executed_file(user_path=user_path, trade_date=trade_date) + trade_info = executor.load_trade_info_from_executed_file( + user_path=user_path, trade_date=trade_date + ) update_account(user.account, trade_info, trade_exchange, trade_date) - portfolio_metrics = user.account.portfolio_metrics.generate_portfolio_metrics_dataframe() + portfolio_metrics = ( + user.account.portfolio_metrics.generate_portfolio_metrics_dataframe() + ) self.logger.info(portfolio_metrics) um.save_user_data(id) self.show(id, path, bench) @@ -298,12 +320,18 @@ def show(self, id, path, bench="SH000905"): if id not in um.users: raise ValueError("Cannot find user ".format(id)) bench = D.features([bench], ["$change"]).loc[bench, "$change"] - portfolio_metrics = um.users[id].account.portfolio_metrics.generate_portfolio_metrics_dataframe() + portfolio_metrics = um.users[ + id + ].account.portfolio_metrics.generate_portfolio_metrics_dataframe() portfolio_metrics["bench"] = bench analysis_result = {} r = (portfolio_metrics["return"] - portfolio_metrics["bench"]).dropna() analysis_result["excess_return_without_cost"] = risk_analysis(r) - r = (portfolio_metrics["return"] - portfolio_metrics["bench"] - portfolio_metrics["cost"]).dropna() + r = ( + portfolio_metrics["return"] + - portfolio_metrics["bench"] + - portfolio_metrics["cost"] + ).dropna() analysis_result["excess_return_with_cost"] = risk_analysis(r) print("Result:") print("excess_return_without_cost:") diff --git a/qlib/contrib/online/user.py b/qlib/contrib/online/user.py index fa74831eed..7f49717c74 100644 --- a/qlib/contrib/online/user.py +++ b/qlib/contrib/online/user.py @@ -39,7 +39,9 @@ def init_state(self, date): date : pd.Timestamp """ self.account.init_state(today=date) - self.strategy.init_state(trade_date=date, model=self.model, account=self.account) + self.strategy.init_state( + trade_date=date, model=self.model, account=self.account + ) return def get_latest_trading_date(self): @@ -61,13 +63,25 @@ def showReport(self, benchmark="SH000905"): benchmark : string bench that to be compared, 'SH000905' for csi500 """ - bench = D.features([benchmark], ["$change"], disk_cache=True).loc[benchmark, "$change"] - portfolio_metrics = self.account.portfolio_metrics.generate_portfolio_metrics_dataframe() + bench = D.features([benchmark], ["$change"], disk_cache=True).loc[ + benchmark, "$change" + ] + portfolio_metrics = ( + self.account.portfolio_metrics.generate_portfolio_metrics_dataframe() + ) portfolio_metrics["bench"] = bench - analysis_result = {"pred": {}, "excess_return_without_cost": {}, "excess_return_with_cost": {}} + analysis_result = { + "pred": {}, + "excess_return_without_cost": {}, + "excess_return_with_cost": {}, + } r = (portfolio_metrics["return"] - portfolio_metrics["bench"]).dropna() analysis_result["excess_return_without_cost"][0] = risk_analysis(r) - r = (portfolio_metrics["return"] - portfolio_metrics["bench"] - portfolio_metrics["cost"]).dropna() + r = ( + portfolio_metrics["return"] + - portfolio_metrics["bench"] + - portfolio_metrics["cost"] + ).dropna() analysis_result["excess_return_with_cost"][0] = risk_analysis(r) self.logger.info("Result of porfolio:") self.logger.info("excess_return_without_cost:") diff --git a/qlib/contrib/online/utils.py b/qlib/contrib/online/utils.py index dddf7f0d2a..e9ea51d000 100644 --- a/qlib/contrib/online/utils.py +++ b/qlib/contrib/online/utils.py @@ -79,7 +79,11 @@ def prepare(um, today, user_id, exchange_config=None): latest_trading_date = um.user_record.loc[user_id][0] if str(today.date()) < latest_trading_date: - log.warning("user_id:{}, last trading date {} after today {}".format(user_id, latest_trading_date, today)) + log.warning( + "user_id:{}, last trading date {} after today {}".format( + user_id, latest_trading_date, today + ) + ) return [pd.Timestamp(latest_trading_date)], None dates = D.calendar( diff --git a/qlib/contrib/ops/high_freq.py b/qlib/contrib/ops/high_freq.py index 51852b66cc..96798d5af0 100644 --- a/qlib/contrib/ops/high_freq.py +++ b/qlib/contrib/ops/high_freq.py @@ -31,7 +31,9 @@ def get_calendar_day(freq="1min", future=False): if flag in H["c"]: _calendar = H["c"][flag] else: - _calendar = np.array(list(map(lambda x: x.date(), Cal.load_calendar(freq, future)))) + _calendar = np.array( + list(map(lambda x: x.date(), Cal.load_calendar(freq, future))) + ) H["c"][flag] = _calendar return _calendar @@ -42,7 +44,9 @@ def get_calendar_minute(freq="day", future=False): if flag in H["c"]: _calendar = H["c"][flag] else: - _calendar = np.array(list(map(lambda x: x.minute // 30, Cal.load_calendar(freq, future)))) + _calendar = np.array( + list(map(lambda x: x.minute // 30, Cal.load_calendar(freq, future))) + ) H["c"][flag] = _calendar return _calendar @@ -70,7 +74,13 @@ class DayCumsum(ElemOperator): Otherwise, the value is zero. """ - def __init__(self, feature, start: str = "9:30", end: str = "14:59", data_granularity: int = 1): + def __init__( + self, + feature, + start: str = "9:30", + end: str = "14:59", + data_granularity: int = 1, + ): self.feature = feature self.start = datetime.strptime(start, "%H:%M") self.end = datetime.strptime(end, "%H:%M") @@ -96,7 +106,9 @@ def period_cusum(self, df): def _load_internal(self, instrument, start_index, end_index, freq): _calendar = get_calendar_day(freq=freq) series = self.feature.load(instrument, start_index, end_index, freq) - return series.groupby(_calendar[series.index], group_keys=False).transform(self.period_cusum) + return series.groupby(_calendar[series.index], group_keys=False).transform( + self.period_cusum + ) class DayLast(ElemOperator): @@ -116,7 +128,9 @@ class DayLast(ElemOperator): def _load_internal(self, instrument, start_index, end_index, freq): _calendar = get_calendar_day(freq=freq) series = self.feature.load(instrument, start_index, end_index, freq) - return series.groupby(_calendar[series.index], group_keys=False).transform("last") + return series.groupby(_calendar[series.index], group_keys=False).transform( + "last" + ) class FFillNan(ElemOperator): @@ -195,8 +209,12 @@ class Select(PairOperator): """ def _load_internal(self, instrument, start_index, end_index, freq): - series_condition = self.feature_left.load(instrument, start_index, end_index, freq) - series_feature = self.feature_right.load(instrument, start_index, end_index, freq) + series_condition = self.feature_left.load( + instrument, start_index, end_index, freq + ) + series_feature = self.feature_right.load( + instrument, start_index, end_index, freq + ) return series_feature.loc[series_condition] @@ -259,7 +277,9 @@ class Cut(ElemOperator): def __init__(self, feature, left=None, right=None): self.left = left self.right = right - if (self.left is not None and self.left <= 0) or (self.right is not None and self.right >= 0): + if (self.left is not None and self.left <= 0) or ( + self.right is not None and self.right >= 0 + ): raise ValueError("Cut operator l shoud > 0 and r should < 0") super(Cut, self).__init__(feature) diff --git a/qlib/contrib/report/analysis_model/analysis_model_performance.py b/qlib/contrib/report/analysis_model/analysis_model_performance.py index cac1f1b8ee..8a28ea9b3e 100644 --- a/qlib/contrib/report/analysis_model/analysis_model_performance.py +++ b/qlib/contrib/report/analysis_model/analysis_model_performance.py @@ -18,7 +18,9 @@ from ..utils import guess_plotly_rangebreaks -def _group_return(pred_label: pd.DataFrame = None, reverse: bool = False, N: int = 5, **kwargs) -> tuple: +def _group_return( + pred_label: pd.DataFrame = None, reverse: bool = False, N: int = 5, **kwargs +) -> tuple: """ :param pred_label: @@ -38,8 +40,12 @@ def _group_return(pred_label: pd.DataFrame = None, reverse: bool = False, N: int t_df = pd.DataFrame( { "Group%d" - % (i + 1): pred_label_drop.groupby(level="datetime", group_keys=False)["label"].apply( - lambda x: x[len(x) // N * i : len(x) // N * (i + 1)].mean() # pylint: disable=W0640 + % (i + 1): pred_label_drop.groupby(level="datetime", group_keys=False)[ + "label" + ].apply( + lambda x: x[ + len(x) // N * i : len(x) // N * (i + 1) + ].mean() # pylint: disable=W0640 ) for i in range(N) } @@ -50,7 +56,10 @@ def _group_return(pred_label: pd.DataFrame = None, reverse: bool = False, N: int t_df["long-short"] = t_df["Group1"] - t_df["Group%d" % N] # Long-Average - t_df["long-average"] = t_df["Group1"] - pred_label.groupby(level="datetime", group_keys=False)["label"].mean() + t_df["long-average"] = ( + t_df["Group1"] + - pred_label.groupby(level="datetime", group_keys=False)["label"].mean() + ) t_df = t_df.dropna(how="all") # for days which does not contain label # Cumulative Return By Group @@ -58,7 +67,12 @@ def _group_return(pred_label: pd.DataFrame = None, reverse: bool = False, N: int t_df.cumsum(), layout=dict( title="Cumulative Return", - xaxis=dict(tickangle=45, rangebreaks=kwargs.get("rangebreaks", guess_plotly_rangebreaks(t_df.index))), + xaxis=dict( + tickangle=45, + rangebreaks=kwargs.get( + "rangebreaks", guess_plotly_rangebreaks(t_df.index) + ), + ), ), ).figure @@ -117,7 +131,9 @@ def _plot_qq(data: pd.Series = None, dist=stats.norm) -> go.Figure: def _pred_ic( - pred_label: pd.DataFrame = None, methods: Sequence[Literal["IC", "Rank IC"]] = ("IC", "Rank IC"), **kwargs + pred_label: pd.DataFrame = None, + methods: Sequence[Literal["IC", "Rank IC"]] = ("IC", "Rank IC"), + **kwargs, ) -> tuple: """ @@ -146,7 +162,9 @@ def _corr_series(x, method): ) _ic = ic_df.iloc(axis=1)[0] - _index = _ic.index.get_level_values(0).astype("str").str.replace("-", "").str.slice(0, 6) + _index = ( + _ic.index.get_level_values(0).astype("str").str.replace("-", "").str.slice(0, 6) + ) _monthly_ic = _ic.groupby(_index, group_keys=False).mean() _monthly_ic.index = pd.MultiIndex.from_arrays( [_monthly_ic.index.str.slice(0, 4), _monthly_ic.index.str.slice(4, 6)], @@ -174,7 +192,11 @@ def _corr_series(x, method): ic_heatmap_figure = HeatmapGraph( _monthly_ic.unstack(), - layout=dict(title="Monthly IC", xaxis=dict(dtick=1), yaxis=dict(tickformat="04d", dtick=1)), + layout=dict( + title="Monthly IC", + xaxis=dict(dtick=1), + yaxis=dict(tickformat="04d", dtick=1), + ), graph_kwargs=dict(xtype="array", ytype="array"), ).figure @@ -222,7 +244,9 @@ def _corr_series(x, method): def _pred_autocorr(pred_label: pd.DataFrame, lag=1, **kwargs) -> tuple: pred = pred_label.copy() - pred["score_last"] = pred.groupby(level="instrument", group_keys=False)["score"].shift(lag) + pred["score_last"] = pred.groupby(level="instrument", group_keys=False)[ + "score" + ].shift(lag) ac = pred.groupby(level="datetime", group_keys=False).apply( lambda x: x["score"].rank(pct=True).corr(x["score_last"].rank(pct=True)) ) @@ -231,7 +255,12 @@ def _pred_autocorr(pred_label: pd.DataFrame, lag=1, **kwargs) -> tuple: _df, layout=dict( title="Auto Correlation", - xaxis=dict(tickangle=45, rangebreaks=kwargs.get("rangebreaks", guess_plotly_rangebreaks(_df.index))), + xaxis=dict( + tickangle=45, + rangebreaks=kwargs.get( + "rangebreaks", guess_plotly_rangebreaks(_df.index) + ), + ), ), ).figure return (ac_figure,) @@ -239,10 +268,14 @@ def _pred_autocorr(pred_label: pd.DataFrame, lag=1, **kwargs) -> tuple: def _pred_turnover(pred_label: pd.DataFrame, N=5, lag=1, **kwargs) -> tuple: pred = pred_label.copy() - pred["score_last"] = pred.groupby(level="instrument", group_keys=False)["score"].shift(lag) + pred["score_last"] = pred.groupby(level="instrument", group_keys=False)[ + "score" + ].shift(lag) top = pred.groupby(level="datetime", group_keys=False).apply( lambda x: 1 - - x.nlargest(len(x) // N, columns="score").index.isin(x.nlargest(len(x) // N, columns="score_last").index).sum() + - x.nlargest(len(x) // N, columns="score") + .index.isin(x.nlargest(len(x) // N, columns="score_last").index) + .sum() / (len(x) // N) ) bottom = pred.groupby(level="datetime", group_keys=False).apply( @@ -262,7 +295,12 @@ def _pred_turnover(pred_label: pd.DataFrame, N=5, lag=1, **kwargs) -> tuple: r_df, layout=dict( title="Top-Bottom Turnover", - xaxis=dict(tickangle=45, rangebreaks=kwargs.get("rangebreaks", guess_plotly_rangebreaks(r_df.index))), + xaxis=dict( + tickangle=45, + rangebreaks=kwargs.get( + "rangebreaks", guess_plotly_rangebreaks(r_df.index) + ), + ), ), ).figure return (turnover_figure,) @@ -284,7 +322,12 @@ def ic_figure(ic_df: pd.DataFrame, show_nature_day=True, **kwargs) -> go.Figure: ic_df, layout=dict( title="Information Coefficient (IC)", - xaxis=dict(tickangle=45, rangebreaks=kwargs.get("rangebreaks", guess_plotly_rangebreaks(ic_df.index))), + xaxis=dict( + tickangle=45, + rangebreaks=kwargs.get( + "rangebreaks", guess_plotly_rangebreaks(ic_df.index) + ), + ), ), ).figure return ic_bar_figure @@ -331,7 +374,13 @@ def model_performance_graph( figure_list = [] for graph_name in graph_names: fun_res = eval(f"_{graph_name}")( - pred_label=pred_label, lag=lag, N=N, reverse=reverse, rank=rank, show_nature_day=show_nature_day, **kwargs + pred_label=pred_label, + lag=lag, + N=N, + reverse=reverse, + rank=rank, + show_nature_day=show_nature_day, + **kwargs, ) figure_list += fun_res diff --git a/qlib/contrib/report/analysis_position/__init__.py b/qlib/contrib/report/analysis_position/__init__.py index cfe51a2249..345602717c 100644 --- a/qlib/contrib/report/analysis_position/__init__.py +++ b/qlib/contrib/report/analysis_position/__init__.py @@ -8,4 +8,10 @@ from .risk_analysis import risk_analysis_graph -__all__ = ["cumulative_return_graph", "score_ic_graph", "report_graph", "rank_label_graph", "risk_analysis_graph"] +__all__ = [ + "cumulative_return_graph", + "score_ic_graph", + "report_graph", + "rank_label_graph", + "risk_analysis_graph", +] diff --git a/qlib/contrib/report/analysis_position/cumulative_return.py b/qlib/contrib/report/analysis_position/cumulative_return.py index 4f325aa1a3..6c925efdc4 100644 --- a/qlib/contrib/report/analysis_position/cumulative_return.py +++ b/qlib/contrib/report/analysis_position/cumulative_return.py @@ -36,7 +36,9 @@ def _get_cum_return_data_with_position( end_date=end_date, ).copy() - _cumulative_return_df["label"] = _cumulative_return_df["label"] - _cumulative_return_df["bench"] + _cumulative_return_df["label"] = ( + _cumulative_return_df["label"] - _cumulative_return_df["bench"] + ) _cumulative_return_df = _cumulative_return_df.dropna() df_gp = _cumulative_return_df.groupby(level="datetime", group_keys=False) result_list = [] @@ -103,7 +105,9 @@ def _get_figure_with_position( :return: """ - cum_return_df = _get_cum_return_data_with_position(position, report_normal, label_data, start_date, end_date) + cum_return_df = _get_cum_return_data_with_position( + position, report_normal, label_data, start_date, end_date + ) cum_return_df = cum_return_df.set_index("date") # FIXME: support HIGH-FREQ cum_return_df.index = cum_return_df.index.strftime("%Y-%m-%d") @@ -113,10 +117,14 @@ def _get_figure_with_position( sub_graph_data = [ ( "cum_{}".format(_t_name), - dict(row=1, col=1, graph_kwargs={"mode": "lines+markers", "xaxis": "x3"}), + dict( + row=1, col=1, graph_kwargs={"mode": "lines+markers", "xaxis": "x3"} + ), ), ( - "{}_weight".format(_t_name.replace("minus", "plus") if "minus" in _t_name else _t_name), + "{}_weight".format( + _t_name.replace("minus", "plus") if "minus" in _t_name else _t_name + ), dict(row=2, col=1), ), ( @@ -266,7 +274,9 @@ def cumulative_return_graph( position = copy.deepcopy(position) report_normal = report_normal.copy() label_data.columns = ["label"] - _figures = _get_figure_with_position(position, report_normal, label_data, start_date, end_date) + _figures = _get_figure_with_position( + position, report_normal, label_data, start_date, end_date + ) if show_notebook: BaseGraph.show_graph_in_notebook(_figures) else: diff --git a/qlib/contrib/report/analysis_position/parse_position.py b/qlib/contrib/report/analysis_position/parse_position.py index 0f6510e818..d555c8c588 100644 --- a/qlib/contrib/report/analysis_position/parse_position.py +++ b/qlib/contrib/report/analysis_position/parse_position.py @@ -64,12 +64,15 @@ def parse_position(position: dict = None) -> pd.DataFrame: # Trading day sell if not result_df.empty: _trading_day_sell_df = result_df.loc[ - (result_df["date"] == previous_data["date"]) & (result_df.index.isin(_cur_day_sell)) + (result_df["date"] == previous_data["date"]) + & (result_df.index.isin(_cur_day_sell)) ].copy() if not _trading_day_sell_df.empty: _trading_day_sell_df["status"] = -1 _trading_day_sell_df["date"] = _trading_date - _trading_day_df = pd.concat([_trading_day_df, _trading_day_sell_df], sort=False) + _trading_day_df = pd.concat( + [_trading_day_df, _trading_day_sell_df], sort=False + ) result_df = pd.concat([result_df, _trading_day_df], sort=True) @@ -83,7 +86,9 @@ def parse_position(position: dict = None) -> pd.DataFrame: return result_df.set_index(["instrument", "datetime"]) -def _add_label_to_position(position_df: pd.DataFrame, label_data: pd.DataFrame) -> pd.DataFrame: +def _add_label_to_position( + position_df: pd.DataFrame, label_data: pd.DataFrame +) -> pd.DataFrame: """Concat position with custom label :param position_df: position DataFrame @@ -94,12 +99,16 @@ def _add_label_to_position(position_df: pd.DataFrame, label_data: pd.DataFrame) _start_time = position_df.index.get_level_values(level="datetime").min() _end_time = position_df.index.get_level_values(level="datetime").max() label_data = label_data.loc(axis=0)[:, pd.to_datetime(_start_time) :] - _result_df = pd.concat([position_df, label_data], axis=1, sort=True).reindex(label_data.index) + _result_df = pd.concat([position_df, label_data], axis=1, sort=True).reindex( + label_data.index + ) _result_df = _result_df.loc[_result_df.index.get_level_values(1) <= _end_time] return _result_df -def _add_bench_to_position(position_df: pd.DataFrame = None, bench: pd.Series = None) -> pd.DataFrame: +def _add_bench_to_position( + position_df: pd.DataFrame = None, bench: pd.Series = None +) -> pd.DataFrame: """Concat position with bench :param position_df: position DataFrame @@ -127,7 +136,9 @@ def _calculate_day_value(g_df: pd.DataFrame): # Sell: -1, Hold: 0, Buy: 1 for i in [-1, 0, 1]: - g_df.loc[g_df["status"] == i, "rank_label_mean"] = g_df[g_df["status"] == i]["rank_ratio"].mean() + g_df.loc[g_df["status"] == i, "rank_label_mean"] = g_df[ + g_df["status"] == i + ]["rank_ratio"].mean() g_df["excess_return"] = g_df[_label_name] - g_df[_label_name].mean() return g_df @@ -171,5 +182,7 @@ def get_position_data( _date_list = _position_df.index.get_level_values(level="datetime") start_date = _date_list.min() if start_date is None else start_date end_date = _date_list.max() if end_date is None else end_date - _position_df = _position_df.loc[(start_date <= _date_list) & (_date_list <= end_date)] + _position_df = _position_df.loc[ + (start_date <= _date_list) & (_date_list <= end_date) + ] return _position_df diff --git a/qlib/contrib/report/analysis_position/report.py b/qlib/contrib/report/analysis_position/report.py index daefb52955..c45f16981e 100644 --- a/qlib/contrib/report/analysis_position/report.py +++ b/qlib/contrib/report/analysis_position/report.py @@ -48,12 +48,20 @@ def _calculate_report_data(df: pd.DataFrame) -> pd.DataFrame: report_df["cum_return_w_cost"] = (df["return"] - df["cost"]).cumsum() # report_df['cum_return'] - report_df['cum_return'].cummax() report_df["return_wo_mdd"] = _calculate_mdd(report_df["cum_return_wo_cost"]) - report_df["return_w_cost_mdd"] = _calculate_mdd((df["return"] - df["cost"]).cumsum()) + report_df["return_w_cost_mdd"] = _calculate_mdd( + (df["return"] - df["cost"]).cumsum() + ) report_df["cum_ex_return_wo_cost"] = (df["return"] - df["bench"]).cumsum() - report_df["cum_ex_return_w_cost"] = (df["return"] - df["bench"] - df["cost"]).cumsum() - report_df["cum_ex_return_wo_cost_mdd"] = _calculate_mdd((df["return"] - df["bench"]).cumsum()) - report_df["cum_ex_return_w_cost_mdd"] = _calculate_mdd((df["return"] - df["cost"] - df["bench"]).cumsum()) + report_df["cum_ex_return_w_cost"] = ( + df["return"] - df["bench"] - df["cost"] + ).cumsum() + report_df["cum_ex_return_wo_cost_mdd"] = _calculate_mdd( + (df["return"] - df["bench"]).cumsum() + ) + report_df["cum_ex_return_w_cost_mdd"] = _calculate_mdd( + (df["return"] - df["cost"] - df["bench"]).cumsum() + ) # return_wo_mdd , return_w_cost_mdd, cum_ex_return_wo_cost_mdd, cum_ex_return_w report_df["turnover"] = df["turnover"] @@ -105,9 +113,21 @@ def _report_figure(df: pd.DataFrame) -> [list, tuple]: _subplot_layout = dict() for i in range(1, 8): # yaxis - _subplot_layout.update({"yaxis{}".format(i): dict(zeroline=True, showline=True, showticklabels=True)}) + _subplot_layout.update( + { + "yaxis{}".format(i): dict( + zeroline=True, showline=True, showticklabels=True + ) + } + ) _show_line = i == 7 - _subplot_layout.update({"xaxis{}".format(i): dict(showline=_show_line, type="category", tickangle=45)}) + _subplot_layout.update( + { + "xaxis{}".format(i): dict( + showline=_show_line, type="category", tickangle=45 + ) + } + ) _layout_style = dict( height=1200, diff --git a/qlib/contrib/report/analysis_position/risk_analysis.py b/qlib/contrib/report/analysis_position/risk_analysis.py index c7cb99c7a3..f7271ef6ae 100644 --- a/qlib/contrib/report/analysis_position/risk_analysis.py +++ b/qlib/contrib/report/analysis_position/risk_analysis.py @@ -32,9 +32,13 @@ def _get_risk_analysis_data_with_report( # analysis["pred_long_short"] = risk_analysis(report_long_short_df["long_short"]) if not report_normal_df.empty: - analysis["excess_return_without_cost"] = risk_analysis(report_normal_df["return"] - report_normal_df["bench"]) + analysis["excess_return_without_cost"] = risk_analysis( + report_normal_df["return"] - report_normal_df["bench"] + ) analysis["excess_return_with_cost"] = risk_analysis( - report_normal_df["return"] - report_normal_df["bench"] - report_normal_df["cost"] + report_normal_df["return"] + - report_normal_df["bench"] + - report_normal_df["cost"] ) analysis_df = pd.concat(analysis) # type: pd.DataFrame analysis_df["date"] = date @@ -54,7 +58,9 @@ def _get_all_risk_analysis(risk_df: pd.DataFrame) -> pd.DataFrame: return risk_df.drop("mean", axis=1) -def _get_monthly_risk_analysis_with_report(report_normal_df: pd.DataFrame) -> pd.DataFrame: +def _get_monthly_risk_analysis_with_report( + report_normal_df: pd.DataFrame, +) -> pd.DataFrame: """Get monthly analysis data :param report_normal_df: @@ -92,7 +98,9 @@ def _get_monthly_risk_analysis_with_report(report_normal_df: pd.DataFrame) -> pd return _monthly_df -def _get_monthly_analysis_with_feature(monthly_df: pd.DataFrame, feature: str = "annualized_return") -> pd.DataFrame: +def _get_monthly_analysis_with_feature( + monthly_df: pd.DataFrame, feature: str = "annualized_return" +) -> pd.DataFrame: """ :param monthly_df: @@ -102,7 +110,9 @@ def _get_monthly_analysis_with_feature(monthly_df: pd.DataFrame, feature: str = _monthly_df_gp = monthly_df.reset_index().groupby(["level_1"], group_keys=False) _name_df = _monthly_df_gp.get_group(feature).set_index(["level_0", "level_1"]) - _temp_df = _name_df.pivot_table(index="date", values=["risk"], columns=_name_df.index) + _temp_df = _name_df.pivot_table( + index="date", values=["risk"], columns=_name_df.index + ) _temp_df.columns = map(lambda x: "_".join(x[-1]), _temp_df.columns) _temp_df.index = _temp_df.index.strftime("%Y-%m") @@ -126,7 +136,9 @@ def _get_risk_analysis_figure(analysis_df: pd.DataFrame) -> Iterable[py.Figure]: return (_figure,) -def _get_monthly_risk_analysis_figure(report_normal_df: pd.DataFrame) -> Iterable[py.Figure]: +def _get_monthly_risk_analysis_figure( + report_normal_df: pd.DataFrame, +) -> Iterable[py.Figure]: """Get analysis monthly graph figure :param report_normal_df: diff --git a/qlib/contrib/report/analysis_position/score_ic.py b/qlib/contrib/report/analysis_position/score_ic.py index 52f45c9cba..f664c4b8c7 100644 --- a/qlib/contrib/report/analysis_position/score_ic.py +++ b/qlib/contrib/report/analysis_position/score_ic.py @@ -15,14 +15,18 @@ def _get_score_ic(pred_label: pd.DataFrame): """ concat_data = pred_label.copy() concat_data.dropna(axis=0, how="any", inplace=True) - _ic = concat_data.groupby(level="datetime", group_keys=False).apply(lambda x: x["label"].corr(x["score"])) + _ic = concat_data.groupby(level="datetime", group_keys=False).apply( + lambda x: x["label"].corr(x["score"]) + ) _rank_ic = concat_data.groupby(level="datetime", group_keys=False).apply( lambda x: x["label"].corr(x["score"], method="spearman") ) return pd.DataFrame({"ic": _ic, "rank_ic": _rank_ic}) -def score_ic_graph(pred_label: pd.DataFrame, show_notebook: bool = True, **kwargs) -> [list, tuple]: +def score_ic_graph( + pred_label: pd.DataFrame, show_notebook: bool = True, **kwargs +) -> [list, tuple]: """score IC Example: @@ -61,7 +65,12 @@ def score_ic_graph(pred_label: pd.DataFrame, show_notebook: bool = True, **kwarg _ic_df, layout=dict( title="Score IC", - xaxis=dict(tickangle=45, rangebreaks=kwargs.get("rangebreaks", guess_plotly_rangebreaks(_ic_df.index))), + xaxis=dict( + tickangle=45, + rangebreaks=kwargs.get( + "rangebreaks", guess_plotly_rangebreaks(_ic_df.index) + ), + ), ), graph_kwargs={"mode": "lines+markers"}, ).figure diff --git a/qlib/contrib/report/data/ana.py b/qlib/contrib/report/data/ana.py index e93b07612a..aea701a628 100644 --- a/qlib/contrib/report/data/ana.py +++ b/qlib/contrib/report/data/ana.py @@ -72,10 +72,14 @@ def calc_stat_values(self): self._val_cnt = {} for col, item in self._dataset.items(): if not super().skip(col): - self._val_cnt[col] = item.groupby(DT_COL_NAME, group_keys=False).apply(lambda s: len(s.unique())) + self._val_cnt[col] = item.groupby(DT_COL_NAME, group_keys=False).apply( + lambda s: len(s.unique()) + ) self._val_cnt = pd.DataFrame(self._val_cnt) if self.ratio: - self._val_cnt = self._val_cnt.div(self._dataset.groupby(DT_COL_NAME, group_keys=False).size(), axis=0) + self._val_cnt = self._val_cnt.div( + self._dataset.groupby(DT_COL_NAME, group_keys=False).size(), axis=0 + ) # TODO: transfer this feature to other analysers ymin, ymax = self._val_cnt.min().min(), self._val_cnt.max().max() @@ -98,7 +102,12 @@ def calc_stat_values(self): self._inf_cnt = {} for col, item in self._dataset.items(): if not super().skip(col): - self._inf_cnt[col] = item.apply(np.isinf).astype(np.int).groupby(DT_COL_NAME, group_keys=False).sum() + self._inf_cnt[col] = ( + item.apply(np.isinf) + .astype(np.int) + .groupby(DT_COL_NAME, group_keys=False) + .sum() + ) self._inf_cnt = pd.DataFrame(self._inf_cnt) def skip(self, col): @@ -111,7 +120,9 @@ def plot_single(self, col, ax): class FeaNanAna(FeaAnalyser): def calc_stat_values(self): - self._nan_cnt = self._dataset.isna().groupby(DT_COL_NAME, group_keys=False).sum() + self._nan_cnt = ( + self._dataset.isna().groupby(DT_COL_NAME, group_keys=False).sum() + ) def skip(self, col): return (col not in self._nan_cnt) or (self._nan_cnt[col].sum() == 0) @@ -123,7 +134,9 @@ def plot_single(self, col, ax): class FeaNanAnaRatio(FeaAnalyser): def calc_stat_values(self): - self._nan_cnt = self._dataset.isna().groupby(DT_COL_NAME, group_keys=False).sum() + self._nan_cnt = ( + self._dataset.isna().groupby(DT_COL_NAME, group_keys=False).sum() + ) self._total_cnt = self._dataset.groupby(DT_COL_NAME, group_keys=False).size() def skip(self, col): diff --git a/qlib/contrib/report/graph.py b/qlib/contrib/report/graph.py index 387a057a29..159ae6a7b5 100644 --- a/qlib/contrib/report/graph.py +++ b/qlib/contrib/report/graph.py @@ -18,7 +18,12 @@ class BaseGraph: _name = None def __init__( - self, df: pd.DataFrame = None, layout: dict = None, graph_kwargs: dict = None, name_dict: dict = None, **kwargs + self, + df: pd.DataFrame = None, + layout: dict = None, + graph_kwargs: dict = None, + name_dict: dict = None, + **kwargs, ): """ @@ -120,7 +125,11 @@ def _get_data(self) -> list: _data = [ self.get_instance_with_graph_parameters( - graph_type=self._graph_type, x=self._df.index, y=self._df[_col], name=_name, **self._graph_kwargs + graph_type=self._graph_type, + x=self._df.index, + y=self._df[_col], + name=_name, + **self._graph_kwargs, ) for _col, _name in self._name_dict.items() ] @@ -157,7 +166,9 @@ def _get_data(self): _t_df = self._df.dropna() _data_list = [_t_df[_col] for _col in self._name_dict] _label_list = list(self._name_dict.values()) - _fig = create_distplot(_data_list, _label_list, show_rug=False, **self._graph_kwargs) + _fig = create_distplot( + _data_list, _label_list, show_rug=False, **self._graph_kwargs + ) return _fig["data"] @@ -192,7 +203,10 @@ def _get_data(self): """ _data = [ self.get_instance_with_graph_parameters( - graph_type=self._graph_type, x=self._df[_col], name=_name, **self._graph_kwargs + graph_type=self._graph_type, + x=self._df[_col], + name=_name, + **self._graph_kwargs, ) for _col, _name in self._name_dict.items() ] @@ -347,8 +361,12 @@ def _init_figure(self): _graph_obj = column_name elif isinstance(column_name, str): temp_name = column_map.get("name", column_name.replace("_", " ")) - kind = column_map.get("kind", self._kind_map.get("kind", "ScatterGraph")) - _graph_kwargs = column_map.get("graph_kwargs", self._kind_map.get("kwargs", {})) + kind = column_map.get( + "kind", self._kind_map.get("kind", "ScatterGraph") + ) + _graph_kwargs = column_map.get( + "graph_kwargs", self._kind_map.get("kwargs", {}) + ) _graph_obj = BaseGraph.get_instance_with_graph_parameters( kind, **dict( diff --git a/qlib/contrib/report/utils.py b/qlib/contrib/report/utils.py index 8d3d3fac9a..ad6d450fcb 100644 --- a/qlib/contrib/report/utils.py +++ b/qlib/contrib/report/utils.py @@ -4,7 +4,15 @@ import pandas as pd -def sub_fig_generator(sub_figsize=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None, sharex=False, sharey=False): +def sub_fig_generator( + sub_figsize=(3, 3), + col_n=10, + row_n=1, + wspace=None, + hspace=None, + sharex=False, + sharey=False, +): """sub_fig_generator. it will return a generator, each row contains sub graph @@ -33,7 +41,11 @@ def sub_fig_generator(sub_figsize=(3, 3), col_n=10, row_n=1, wspace=None, hspace while True: fig, axes = plt.subplots( - row_n, col_n, figsize=(sub_figsize[0] * col_n, sub_figsize[1] * row_n), sharex=sharex, sharey=sharey + row_n, + col_n, + figsize=(sub_figsize[0] * col_n, sub_figsize[1] * row_n), + sharex=sharex, + sharey=sharey, ) plt.subplots_adjust(wspace=wspace, hspace=hspace) axes = axes.reshape(row_n, col_n) @@ -71,4 +83,7 @@ def guess_plotly_rangebreaks(dt_index: pd.DatetimeIndex): for gap, d in zip(gaps, dt_idx[:-1]): if gap > min_gap: gaps_to_break.setdefault(gap - min_gap, []).append(d + min_gap) - return [dict(values=v, dvalue=int(k.total_seconds() * 1000)) for k, v in gaps_to_break.items()] + return [ + dict(values=v, dvalue=int(k.total_seconds() * 1000)) + for k, v in gaps_to_break.items() + ] diff --git a/qlib/contrib/rolling/base.py b/qlib/contrib/rolling/base.py index 5f17c05623..ea790fd14d 100644 --- a/qlib/contrib/rolling/base.py +++ b/qlib/contrib/rolling/base.py @@ -94,7 +94,9 @@ def __init__( self._rid = None # the final combined recorder id in `exp_name` self.step = step - assert horizon is not None, "Current version does not support extracting horizon from the underlying dataset" + assert ( + horizon is not None + ), "Current version does not support extracting horizon from the underlying dataset" self.horizon = horizon if rolling_exp is None: datetime_suffix = pd.Timestamp.now().strftime("%Y%m%d%H%M%S") @@ -135,11 +137,16 @@ def _replace_handler_with_cache(self, task: dict): def _update_start_end_time(self, task: dict): if self.train_start is not None: seg = task["dataset"]["kwargs"]["segments"]["train"] - task["dataset"]["kwargs"]["segments"]["train"] = pd.Timestamp(self.train_start), seg[1] + task["dataset"]["kwargs"]["segments"]["train"] = ( + pd.Timestamp(self.train_start), + seg[1], + ) if self.test_end is not None: seg = task["dataset"]["kwargs"]["segments"]["test"] - task["dataset"]["kwargs"]["segments"]["test"] = seg[0], pd.Timestamp(self.test_end) + task["dataset"]["kwargs"]["segments"]["test"] = seg[0], pd.Timestamp( + self.test_end + ) return task def basic_task(self, enable_handler_cache: Optional[bool] = True): @@ -161,15 +168,21 @@ def basic_task(self, enable_handler_cache: Optional[bool] = True): raise NotImplementedError(f"This type of input is not supported") else: if enable_handler_cache and self.h_path is not None: - self.logger.info("Fail to override the horizon due to data handler cache") + self.logger.info( + "Fail to override the horizon due to data handler cache" + ) else: self.logger.info("The prediction horizon is overrided") if isinstance(task["dataset"]["kwargs"]["handler"], dict): task["dataset"]["kwargs"]["handler"]["kwargs"]["label"] = [ - "Ref($close, -{}) / Ref($close, -1) - 1".format(self.horizon + 1) + "Ref($close, -{}) / Ref($close, -1) - 1".format( + self.horizon + 1 + ) ] else: - self.logger.warning("Try to automatically configure the lablel but failed.") + self.logger.warning( + "Try to automatically configure the lablel but failed." + ) if self.h_path is not None or enable_handler_cache: # if we already have provided data source or we want to create one @@ -209,7 +222,9 @@ def _train_rolling_tasks(self): try: # TODO: mlflow does not support permanently delete experiment # it will be moved to .trash and prevents creating the experiments with the same name - R.delete_exp(experiment_name=self.rolling_exp) # We should remove the rolling experiments. + R.delete_exp( + experiment_name=self.rolling_exp + ) # We should remove the rolling experiments. except ValueError: self.logger.info("No previous rolling results") trainer = TrainerR(experiment_name=self.rolling_exp) @@ -248,7 +263,9 @@ def _update_rolling_rec(self): default_module="qlib.workflow.record_temp", ) r.generate() - print(f"Your evaluation results can be found in the experiment named `{self.exp_name}`.") + print( + f"Your evaluation results can be found in the experiment named `{self.exp_name}`." + ) def run(self): # the results will be save in mlruns. diff --git a/qlib/contrib/rolling/ddgda.py b/qlib/contrib/rolling/ddgda.py index b62820ccea..dcf5d8ffce 100644 --- a/qlib/contrib/rolling/ddgda.py +++ b/qlib/contrib/rolling/ddgda.py @@ -112,11 +112,15 @@ def __init__( # NOTE: # the horizon must match the meaning in the base task template self.meta_exp_name = "DDG-DA" - self.sim_task_model: UTIL_MODEL_TYPE = sim_task_model # The model to capture the distribution of data. + self.sim_task_model: UTIL_MODEL_TYPE = ( + sim_task_model # The model to capture the distribution of data. + ) self.alpha = alpha self.meta_1st_train_end = meta_1st_train_end super().__init__(**kwargs) - self.working_dir = self.conf_path.parent if working_dir is None else Path(working_dir) + self.working_dir = ( + self.conf_path.parent if working_dir is None else Path(working_dir) + ) self.proxy_hd = self.working_dir / "handler_proxy.pkl" self.fea_imp_n = fea_imp_n self.meta_data_proc = meta_data_proc @@ -169,7 +173,9 @@ def _get_feature_importance(self): fi = model.get_feature_importance() # Because the model use numpy instead of dataframe for training lightgbm # So the we must use following extra steps to get the right feature importance - df = dataset.prepare(segments=slice(None), col_set="feature", data_key=DataHandlerLP.DK_R) + df = dataset.prepare( + segments=slice(None), col_set="feature", data_key=DataHandlerLP.DK_R + ) cols = df.columns fi_named = {cols[int(k.split("_")[1])]: imp for k, imp in fi.to_dict().items()} @@ -184,7 +190,9 @@ def _dump_data_for_proxy_model(self): # NOTE: adjusting to `self.sim_task_model` just for aligning with previous implementation. # In previous version. The data for proxy model is using sim_task_model's way for processing - task = self._adjust_task(self.basic_task(enable_handler_cache=False), self.sim_task_model) + task = self._adjust_task( + self.basic_task(enable_handler_cache=False), self.sim_task_model + ) task = replace_task_handler_with_cache(task, self.working_dir) # if self.meta_data_proc is not None: # else: @@ -192,7 +200,9 @@ def _dump_data_for_proxy_model(self): # task = self.basic_task() dataset = init_instance_by_config(task["dataset"]) - prep_ds = dataset.prepare(slice(None), col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + prep_ds = dataset.prepare( + slice(None), col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) feature_df = prep_ds["feature"] label_df = prep_ds["label"] @@ -205,9 +215,9 @@ def _dump_data_for_proxy_model(self): feature_selected = feature_df if self.meta_data_proc == "V01": - feature_selected = feature_selected.groupby("datetime", group_keys=False).apply( - lambda df: (df - df.mean()).div(df.std()) - ) + feature_selected = feature_selected.groupby( + "datetime", group_keys=False + ).apply(lambda df: (df - df.mean()).div(df.std())) feature_selected = feature_selected.fillna(0.0) df_all = { @@ -236,11 +246,15 @@ def _dump_meta_ipt(self): This function will dump the input data for meta model """ # According to the experiments, the choice of the model type is very important for achieving good results - sim_task = self._adjust_task(self.basic_task(enable_handler_cache=False), astype=self.sim_task_model) + sim_task = self._adjust_task( + self.basic_task(enable_handler_cache=False), astype=self.sim_task_model + ) sim_task = replace_task_handler_with_cache(sim_task, self.working_dir) if self.sim_task_model == "gbdt": - sim_task["model"].setdefault("kwargs", {}).update({"early_stopping_rounds": None, "num_boost_round": 150}) + sim_task["model"].setdefault("kwargs", {}).update( + {"early_stopping_rounds": None, "num_boost_round": 150} + ) exp_name_sim = f"data_sim_s{self.step}" @@ -263,8 +277,12 @@ def _train_meta_model(self, fill_method="max"): # But please select a right time to make sure the finnal rolling tasks are not leaked in the training data. # - The test_start is automatically aligned to the next day of test_end. Validation is ignored. train_start = "2008-01-01" if self.train_start is None else self.train_start - train_end = "2010-12-31" if self.meta_1st_train_end is None else self.meta_1st_train_end - test_start = (pd.Timestamp(train_end) + pd.Timedelta(days=1)).strftime("%Y-%m-%d") + train_end = ( + "2010-12-31" if self.meta_1st_train_end is None else self.meta_1st_train_end + ) + test_start = (pd.Timestamp(train_end) + pd.Timedelta(days=1)).strftime( + "%Y-%m-%d" + ) proxy_forecast_model_task = { # "model": "qlib.contrib.model.linear.LinearModel", "dataset": { @@ -273,7 +291,12 @@ def _train_meta_model(self, fill_method="max"): "handler": f"file://{(self.working_dir / self.proxy_hd).absolute()}", "segments": { "train": (train_start, train_end), - "test": (test_start, self.basic_task()["dataset"]["kwargs"]["segments"]["test"][1]), + "test": ( + test_start, + self.basic_task()["dataset"]["kwargs"]["segments"]["test"][ + 1 + ], + ), }, }, }, diff --git a/qlib/contrib/strategy/cost_control.py b/qlib/contrib/strategy/cost_control.py index ff51f484f5..fd6f48a2e6 100644 --- a/qlib/contrib/strategy/cost_control.py +++ b/qlib/contrib/strategy/cost_control.py @@ -37,7 +37,13 @@ def __init__( average_fill: assign the weight to the stocks rank high averagely. """ super(SoftTopkStrategy, self).__init__( - model, dataset, order_generator_cls_or_obj, trade_exchange, level_infra, common_infra, **kwargs + model, + dataset, + order_generator_cls_or_obj, + trade_exchange, + level_infra, + common_infra, + **kwargs, ) self.topk = topk self.max_sold_weight = max_sold_weight @@ -52,7 +58,9 @@ def get_risk_degree(self, trade_step=None): # It will use 95% amount of your total value by default return self.risk_degree - def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time): + def generate_target_weight_position( + self, score, current, trade_start_time, trade_end_time + ): """ Parameters ---------- @@ -70,7 +78,9 @@ def generate_target_weight_position(self, score, current, trade_start_time, trad # TODO: # If the current stock list is more than topk(eg. The weights are modified # by risk control), the weight will not be handled correctly. - buy_signal_stocks = set(score.sort_values(ascending=False).iloc[: self.topk].index) + buy_signal_stocks = set( + score.sort_values(ascending=False).iloc[: self.topk].index + ) cur_stock_weight = current.get_stock_weight_dict(only_stock=True) if len(cur_stock_weight) == 0: @@ -89,13 +99,15 @@ def generate_target_weight_position(self, score, current, trade_start_time, trad max(1 / self.topk - final_stock_weight.get(stock_id, 0), 0.0), sold_stock_weight, ) - final_stock_weight[stock_id] = final_stock_weight.get(stock_id, 0.0) + add_weight + final_stock_weight[stock_id] = ( + final_stock_weight.get(stock_id, 0.0) + add_weight + ) sold_stock_weight -= add_weight elif self.buy_method == "average_fill": for stock_id in buy_signal_stocks: - final_stock_weight[stock_id] = final_stock_weight.get(stock_id, 0.0) + sold_stock_weight / len( - buy_signal_stocks - ) + final_stock_weight[stock_id] = final_stock_weight.get( + stock_id, 0.0 + ) + sold_stock_weight / len(buy_signal_stocks) else: raise ValueError("Buy method not found") return final_stock_weight diff --git a/qlib/contrib/strategy/optimizer/enhanced_indexing.py b/qlib/contrib/strategy/optimizer/enhanced_indexing.py index 7e42856a2e..ceca2e2f9c 100644 --- a/qlib/contrib/strategy/optimizer/enhanced_indexing.py +++ b/qlib/contrib/strategy/optimizer/enhanced_indexing.py @@ -71,7 +71,9 @@ def __init__( assert delta >= 0, "turnover limit `delta` should be positive" self.delta = delta - assert b_dev is None or b_dev >= 0, "benchmark deviation limit `b_dev` should be positive" + assert ( + b_dev is None or b_dev >= 0 + ), "benchmark deviation limit `b_dev` should be positive" self.b_dev = b_dev if isinstance(f_dev, float): @@ -176,7 +178,9 @@ def __call__( # trial 2: remove turnover constraint if not success and len(t_cons): - logger.info("try removing turnover constraint as the last optimization failed") + logger.info( + "try removing turnover constraint as the last optimization failed" + ) try: w.value = wb prob = cp.Problem(obj, cons) diff --git a/qlib/contrib/strategy/optimizer/optimizer.py b/qlib/contrib/strategy/optimizer/optimizer.py index a5fb763127..9fc03bb1d4 100644 --- a/qlib/contrib/strategy/optimizer/optimizer.py +++ b/qlib/contrib/strategy/optimizer/optimizer.py @@ -47,7 +47,12 @@ def __init__( scale_return (bool): if to scale alpha to match the volatility of the covariance matrix tol (float): tolerance for optimization termination """ - assert method in [self.OPT_GMV, self.OPT_MVO, self.OPT_RP, self.OPT_INV], f"method `{method}` is not supported" + assert method in [ + self.OPT_GMV, + self.OPT_MVO, + self.OPT_RP, + self.OPT_INV, + ], f"method `{method}` is not supported" self.method = method assert lamb >= 0, f"risk aversion parameter `lamb` should be positive" @@ -111,7 +116,12 @@ def __call__( return w - def _optimize(self, S: np.ndarray, r: Optional[np.ndarray] = None, w0: Optional[np.ndarray] = None) -> np.ndarray: + def _optimize( + self, + S: np.ndarray, + r: Optional[np.ndarray] = None, + w0: Optional[np.ndarray] = None, + ) -> np.ndarray: # inverse volatility if self.method == self.OPT_INV: if r is not None: @@ -143,7 +153,9 @@ def _optimize_inv(self, S: np.ndarray) -> np.ndarray: w /= w.sum() return w - def _optimize_gmv(self, S: np.ndarray, w0: Optional[np.ndarray] = None) -> np.ndarray: + def _optimize_gmv( + self, S: np.ndarray, w0: Optional[np.ndarray] = None + ) -> np.ndarray: """optimize global minimum variance portfolio This method solves the following optimization problem @@ -151,10 +163,15 @@ def _optimize_gmv(self, S: np.ndarray, w0: Optional[np.ndarray] = None) -> np.nd s.t. w >= 0, sum(w) == 1 where `S` is the covariance matrix. """ - return self._solve(len(S), self._get_objective_gmv(S), *self._get_constrains(w0)) + return self._solve( + len(S), self._get_objective_gmv(S), *self._get_constrains(w0) + ) def _optimize_mvo( - self, S: np.ndarray, r: Optional[np.ndarray] = None, w0: Optional[np.ndarray] = None + self, + S: np.ndarray, + r: Optional[np.ndarray] = None, + w0: Optional[np.ndarray] = None, ) -> np.ndarray: """optimize mean-variance portfolio @@ -164,9 +181,13 @@ def _optimize_mvo( where `S` is the covariance matrix, `u` is the expected returns, and `lamb` is the risk aversion parameter. """ - return self._solve(len(S), self._get_objective_mvo(S, r), *self._get_constrains(w0)) + return self._solve( + len(S), self._get_objective_mvo(S, r), *self._get_constrains(w0) + ) - def _optimize_rp(self, S: np.ndarray, w0: Optional[np.ndarray] = None) -> np.ndarray: + def _optimize_rp( + self, S: np.ndarray, w0: Optional[np.ndarray] = None + ) -> np.ndarray: """optimize risk parity portfolio This method solves the following optimization problem @@ -234,11 +255,15 @@ def _get_constrains(self, w0: Optional[np.ndarray] = None): # turnover constraint if w0 is not None: - cons.append({"type": "ineq", "fun": lambda x: self.delta - np.sum(np.abs(x - w0))}) # >= 0 + cons.append( + {"type": "ineq", "fun": lambda x: self.delta - np.sum(np.abs(x - w0))} + ) # >= 0 return bounds, cons - def _solve(self, n: int, obj: Callable, bounds: so.Bounds, cons: List) -> np.ndarray: + def _solve( + self, n: int, obj: Callable, bounds: so.Bounds, cons: List + ) -> np.ndarray: """solve optimization Args: @@ -258,7 +283,9 @@ def opt_obj(x): # solve x0 = np.ones(n) / n # init results - sol = so.minimize(wrapped_obj, x0, bounds=bounds, constraints=cons, tol=self.tol) + sol = so.minimize( + wrapped_obj, x0, bounds=bounds, constraints=cons, tol=self.tol + ) if not sol.success: warnings.warn(f"optimization not success ({sol.status})") diff --git a/qlib/contrib/strategy/order_generator.py b/qlib/contrib/strategy/order_generator.py index fe0d048bfd..d0aa57f234 100644 --- a/qlib/contrib/strategy/order_generator.py +++ b/qlib/contrib/strategy/order_generator.py @@ -116,19 +116,25 @@ def generate_order_list_from_target_weight_position( # value. Then just sell all the stocks target_amount_dict = copy.deepcopy(current_amount_dict.copy()) for stock_id in list(target_amount_dict.keys()): - if trade_exchange.is_stock_tradable(stock_id, start_time=trade_start_time, end_time=trade_end_time): + if trade_exchange.is_stock_tradable( + stock_id, start_time=trade_start_time, end_time=trade_end_time + ): del target_amount_dict[stock_id] else: # consider cost rate - current_tradable_value /= 1 + max(trade_exchange.close_cost, trade_exchange.open_cost) + current_tradable_value /= 1 + max( + trade_exchange.close_cost, trade_exchange.open_cost + ) # strategy 1 : generate amount_position by weight_position # Use API in Exchange() - target_amount_dict = trade_exchange.generate_amount_position_from_weight_position( - weight_position=target_weight_position, - cash=current_tradable_value, - start_time=trade_start_time, - end_time=trade_end_time, + target_amount_dict = ( + trade_exchange.generate_amount_position_from_weight_position( + weight_position=target_weight_position, + cash=current_tradable_value, + start_time=trade_start_time, + end_time=trade_end_time, + ) ) order_list = trade_exchange.generate_order_for_target_amount_position( target_position=target_amount_dict, @@ -198,14 +204,18 @@ def generate_order_list_from_target_weight_position( amount_dict[stock_id] = ( risk_total_value * target_weight_position[stock_id] - / trade_exchange.get_close(stock_id, start_time=pred_start_time, end_time=pred_end_time) + / trade_exchange.get_close( + stock_id, start_time=pred_start_time, end_time=pred_end_time + ) ) # TODO: Qlib use None to represent trading suspension. # So last close price can't be the estimated trading price. # Maybe a close price with forward fill will be a better solution. elif stock_id in current_stock: amount_dict[stock_id] = ( - risk_total_value * target_weight_position[stock_id] / current.get_stock_price(stock_id) + risk_total_value + * target_weight_position[stock_id] + / current.get_stock_price(stock_id) ) else: continue diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 2cac662f76..eda71ba0dc 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -34,7 +34,9 @@ def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs): outer_trade_decision : BaseTradeDecision, optional """ - super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs) + super(TWAPStrategy, self).reset( + outer_trade_decision=outer_trade_decision, **kwargs + ) if outer_trade_decision is not None: self.trade_amount_remain = {} for order in outer_trade_decision.get_decision(): @@ -53,14 +55,18 @@ def generate_trade_decision(self, execute_result=None): # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1] trade_step = self.trade_calendar.get_trade_step() # get the total count of trading step - start_idx, end_idx = get_start_end_idx(self.trade_calendar, self.outer_trade_decision) + start_idx, end_idx = get_start_end_idx( + self.trade_calendar, self.outer_trade_decision + ) trade_len = end_idx - start_idx + 1 if trade_step < start_idx or trade_step > end_idx: # It is not time to start trading or trading has ended. return TradeDecisionWO(order_list=[], strategy=self) - rel_trade_step = trade_step - start_idx # trade_step relative to start_idx (number of steps has already passed) + rel_trade_step = ( + trade_step - start_idx + ) # trade_step relative to start_idx (number of steps has already passed) # update the order amount if execute_result is not None: @@ -75,7 +81,9 @@ def generate_trade_decision(self, execute_result=None): # - if stock is suspended, the quote values of stocks is NaN. The following code will raise error when # encountering NaN factor if self.trade_exchange.check_stock_suspended( - stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time + stock_id=order.stock_id, + start_time=trade_start_time, + end_time=trade_end_time, ): continue @@ -92,7 +100,9 @@ def generate_trade_decision(self, execute_result=None): amount_delta = amount_expect - amount_finished _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit( - stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time + stock_id=order.stock_id, + start_time=order.start_time, + end_time=order.end_time, ) # round the amount_delta by trade_unit and clip by remain @@ -102,7 +112,8 @@ def generate_trade_decision(self, execute_result=None): amount_delta_target = amount_delta else: amount_delta_target = min( - np.round(amount_delta / _amount_trade_unit) * _amount_trade_unit, amount_remain + np.round(amount_delta / _amount_trade_unit) * _amount_trade_unit, + amount_remain, ) # handle last step to make sure all positions have gone @@ -142,7 +153,9 @@ def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs): ---------- outer_trade_decision : BaseTradeDecision, optional """ - super(SBBStrategyBase, self).reset(outer_trade_decision=outer_trade_decision, **kwargs) + super(SBBStrategyBase, self).reset( + outer_trade_decision=outer_trade_decision, **kwargs + ) if outer_trade_decision is not None: self.trade_trend = {} self.trade_amount = {} @@ -166,43 +179,57 @@ def generate_trade_decision(self, execute_result=None): self.trade_amount[order.stock_id] -= order.deal_amount trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) - pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) + pred_start_time, pred_end_time = self.trade_calendar.get_step_time( + trade_step, shift=1 + ) order_list = [] # for each order in in self.outer_trade_decision for order in self.outer_trade_decision.get_decision(): # get the price trend if trade_step % 2 == 0: # in the first of two adjacent bars, predict the price trend - _pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time) + _pred_trend = self._pred_price_trend( + order.stock_id, pred_start_time, pred_end_time + ) else: # in the second of two adjacent bars, use the trend predicted in the first one _pred_trend = self.trade_trend[order.stock_id] # if not tradable, continue if not self.trade_exchange.is_stock_tradable( - stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time + stock_id=order.stock_id, + start_time=trade_start_time, + end_time=trade_end_time, ): if trade_step % 2 == 0: self.trade_trend[order.stock_id] = _pred_trend continue # get amount of one trade unit _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit( - stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time + stock_id=order.stock_id, + start_time=order.start_time, + end_time=order.end_time, ) if _pred_trend == self.TREND_MID: _order_amount = None # considering trade unit if _amount_trade_unit is None: # divide the order into equal parts, and trade one part - _order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step) + _order_amount = self.trade_amount[order.stock_id] / ( + trade_len - trade_step + ) # without considering trade unit else: # divide the order into equal parts, and trade one part # calculate the total count of trade units to trade - trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit) + trade_unit_cnt = int( + self.trade_amount[order.stock_id] // _amount_trade_unit + ) # calculate the amount of one part, ceil the amount # floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step)) _order_amount = ( - (trade_unit_cnt + trade_len - trade_step - 1) // (trade_len - trade_step) * _amount_trade_unit + (trade_unit_cnt + trade_len - trade_step - 1) + // (trade_len - trade_step) + * _amount_trade_unit ) if order.direction == order.SELL: # sell all amount at last @@ -228,11 +255,17 @@ def generate_trade_decision(self, execute_result=None): # considering trade unit if _amount_trade_unit is None: # N trade day left, divide the order into N + 1 parts, and trade 2 parts - _order_amount = 2 * self.trade_amount[order.stock_id] / (trade_len - trade_step + 1) + _order_amount = ( + 2 + * self.trade_amount[order.stock_id] + / (trade_len - trade_step + 1) + ) # without considering trade unit else: # cal how many trade unit - trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit) + trade_unit_cnt = int( + self.trade_amount[order.stock_id] // _amount_trade_unit + ) # N trade day left, divide the order into N + 1 parts, and trade 2 parts _order_amount = ( (trade_unit_cnt + trade_len - trade_step) @@ -332,23 +365,37 @@ def __init__( self.instruments = instruments self.freq = freq super(SBBStrategyEMA, self).__init__( - outer_trade_decision, level_infra, common_infra, trade_exchange=trade_exchange, **kwargs + outer_trade_decision, + level_infra, + common_infra, + trade_exchange=trade_exchange, + **kwargs, ) def _reset_signal(self): trade_len = self.trade_calendar.get_trade_len() fields = ["EMA($close, 10)-EMA($close, 20)"] signal_start_time, _ = self.trade_calendar.get_step_time(trade_step=0, shift=1) - _, signal_end_time = self.trade_calendar.get_step_time(trade_step=trade_len - 1, shift=1) + _, signal_end_time = self.trade_calendar.get_step_time( + trade_step=trade_len - 1, shift=1 + ) signal_df = D.features( - self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq + self.instruments, + fields, + start_time=signal_start_time, + end_time=signal_end_time, + freq=self.freq, ) signal_df.columns = ["signal"] self.signal = {} if not signal_df.empty: - for stock_id, stock_val in signal_df.groupby(level="instrument", group_keys=False): - self.signal[stock_id] = stock_val["signal"].droplevel(level="instrument") + for stock_id, stock_val in signal_df.groupby( + level="instrument", group_keys=False + ): + self.signal[stock_id] = stock_val["signal"].droplevel( + level="instrument" + ) def reset_level_infra(self, level_infra): """ @@ -370,7 +417,11 @@ def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None): method=ts_data_last, ) # if EMA signal == 0 or None, return mid trend - if _sample_signal is None or np.isnan(_sample_signal) or _sample_signal == 0: + if ( + _sample_signal is None + or np.isnan(_sample_signal) + or _sample_signal == 0 + ): return self.TREND_MID # if EMA signal > 0, return long trend elif _sample_signal > 0: @@ -417,7 +468,11 @@ def __init__( self.instruments = D.instruments(instruments) self.freq = freq super(ACStrategy, self).__init__( - outer_trade_decision, level_infra, common_infra, trade_exchange=trade_exchange, **kwargs + outer_trade_decision, + level_infra, + common_infra, + trade_exchange=trade_exchange, + **kwargs, ) def _reset_signal(self): @@ -426,16 +481,26 @@ def _reset_signal(self): f"Power(Sum(Power(Log($close/Ref($close, 1)), 2), {self.window_size})/{self.window_size - 1}-Power(Sum(Log($close/Ref($close, 1)), {self.window_size}), 2)/({self.window_size}*{self.window_size - 1}), 0.5)" ] signal_start_time, _ = self.trade_calendar.get_step_time(trade_step=0, shift=1) - _, signal_end_time = self.trade_calendar.get_step_time(trade_step=trade_len - 1, shift=1) + _, signal_end_time = self.trade_calendar.get_step_time( + trade_step=trade_len - 1, shift=1 + ) signal_df = D.features( - self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq + self.instruments, + fields, + start_time=signal_start_time, + end_time=signal_end_time, + freq=self.freq, ) signal_df.columns = ["volatility"] self.signal = {} if not signal_df.empty: - for stock_id, stock_val in signal_df.groupby(level="instrument", group_keys=False): - self.signal[stock_id] = stock_val["volatility"].droplevel(level="instrument") + for stock_id, stock_val in signal_df.groupby( + level="instrument", group_keys=False + ): + self.signal[stock_id] = stock_val["volatility"].droplevel( + level="instrument" + ) def reset_level_infra(self, level_infra): """ @@ -451,7 +516,9 @@ def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs): ---------- outer_trade_decision : BaseTradeDecision, optional """ - super(ACStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs) + super(ACStrategy, self).reset( + outer_trade_decision=outer_trade_decision, **kwargs + ) if outer_trade_decision is not None: self.trade_amount = {} # init the trade amount of order and predicted trade trend @@ -470,19 +537,28 @@ def generate_trade_decision(self, execute_result=None): self.trade_amount[order.stock_id] -= order.deal_amount trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) - pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) + pred_start_time, pred_end_time = self.trade_calendar.get_step_time( + trade_step, shift=1 + ) order_list = [] for order in self.outer_trade_decision.get_decision(): # if not tradable, continue if not self.trade_exchange.is_stock_tradable( - stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time + stock_id=order.stock_id, + start_time=trade_start_time, + end_time=trade_end_time, ): continue _order_amount = None # considering trade unit sig_sam = ( - resam_ts_data(self.signal[order.stock_id], pred_start_time, pred_end_time, method=ts_data_last) + resam_ts_data( + self.signal[order.stock_id], + pred_start_time, + pred_end_time, + method=ts_data_last, + ) if order.stock_id in self.signal else None ) @@ -490,35 +566,49 @@ def generate_trade_decision(self, execute_result=None): if sig_sam is None or np.isnan(sig_sam): # no signal, TWAP _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit( - stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time + stock_id=order.stock_id, + start_time=order.start_time, + end_time=order.end_time, ) if _amount_trade_unit is None: # divide the order into equal parts, and trade one part - _order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step) + _order_amount = self.trade_amount[order.stock_id] / ( + trade_len - trade_step + ) else: # divide the order into equal parts, and trade one part # calculate the total count of trade units to trade - trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit) + trade_unit_cnt = int( + self.trade_amount[order.stock_id] // _amount_trade_unit + ) # calculate the amount of one part, ceil the amount # floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step)) _order_amount = ( - (trade_unit_cnt + trade_len - trade_step - 1) // (trade_len - trade_step) * _amount_trade_unit + (trade_unit_cnt + trade_len - trade_step - 1) + // (trade_len - trade_step) + * _amount_trade_unit ) else: # VA strategy kappa_tild = self.lamb / self.eta * sig_sam * sig_sam kappa = np.arccosh(kappa_tild / 2 + 1) amount_ratio = ( - np.sinh(kappa * (trade_len - trade_step)) - np.sinh(kappa * (trade_len - trade_step - 1)) + np.sinh(kappa * (trade_len - trade_step)) + - np.sinh(kappa * (trade_len - trade_step - 1)) ) / np.sinh(kappa * trade_len) _order_amount = order.amount * amount_ratio _order_amount = self.trade_exchange.round_amount_by_trade_unit( - _order_amount, stock_id=order.stock_id, start_time=order.start_time, end_time=order.end_time + _order_amount, + stock_id=order.stock_id, + start_time=order.start_time, + end_time=order.end_time, ) if order.direction == order.SELL: # sell all amount at last - if self.trade_amount[order.stock_id] > 1e-5 and (_order_amount < 1e-5 or trade_step == trade_len - 1): + if self.trade_amount[order.stock_id] > 1e-5 and ( + _order_amount < 1e-5 or trade_step == trade_len - 1 + ): _order_amount = self.trade_amount[order.stock_id] _order_amount = min(_order_amount, self.trade_amount[order.stock_id]) @@ -539,7 +629,9 @@ def generate_trade_decision(self, execute_result=None): class RandomOrderStrategy(BaseStrategy): def __init__( self, - trade_range: Union[Tuple[int, int], TradeRange], # The range is closed on both left and right. + trade_range: Union[ + Tuple[int, int], TradeRange + ], # The range is closed on both left and right. sample_ratio: float = 1.0, volume_ratio: float = 0.01, market: str = "all", @@ -569,7 +661,10 @@ def __init__( exch: Exchange = self.common_infra.get("trade_exchange") # TODO: this can't be online self.volume = D.features( - D.instruments(market), ["Mean(Ref($volume, 1), 10)"], start_time=exch.start_time, end_time=exch.end_time + D.instruments(market), + ["Mean(Ref($volume, 1), 10)"], + start_time=exch.start_time, + end_time=exch.end_time, ) self.volume_df = self.volume.iloc[:, 0].unstack() self.trade_range = trade_range @@ -580,7 +675,12 @@ def generate_trade_decision(self, execute_result=None): order_list = [] if step_time_start in self.volume_df: - for stock_id, volume in self.volume_df[step_time_start].dropna().sample(frac=self.sample_ratio).items(): + for stock_id, volume in ( + self.volume_df[step_time_start] + .dropna() + .sample(frac=self.sample_ratio) + .items() + ): order_list.append( self.common_infra.get("trade_exchange") .get_order_helper() @@ -642,7 +742,9 @@ def __init__( self.order_df = self.order_df.set_index(["datetime", "instrument"]) # make sure the datetime is the first level for fast indexing - self.order_df = lazy_sort_index(convert_index_format(self.order_df, level="datetime")) + self.order_df = lazy_sort_index( + convert_index_format(self.order_df, level="datetime") + ) self.trade_range = trade_range def generate_trade_decision(self, execute_result=None) -> TradeDecisionWO: diff --git a/qlib/contrib/strategy/signal_strategy.py b/qlib/contrib/strategy/signal_strategy.py index bad19ddfdc..e8cbb9391a 100644 --- a/qlib/contrib/strategy/signal_strategy.py +++ b/qlib/contrib/strategy/signal_strategy.py @@ -26,7 +26,9 @@ class BaseSignalStrategy(BaseStrategy, ABC): def __init__( self, *, - signal: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame] = None, + signal: Union[ + Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame + ] = None, model=None, dataset=None, risk_degree: float = 0.95, @@ -52,13 +54,20 @@ def __init__( - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended. """ - super().__init__(level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs) + super().__init__( + level_infra=level_infra, + common_infra=common_infra, + trade_exchange=trade_exchange, + **kwargs, + ) self.risk_degree = risk_degree # This is trying to be compatible with previous version of qlib task config if model is not None and dataset is not None: - warnings.warn("`model` `dataset` is deprecated; use `signal`.", DeprecationWarning) + warnings.warn( + "`model` `dataset` is deprecated; use `signal`.", DeprecationWarning + ) signal = model, dataset self.signal: Signal = create_signal_from(signal) @@ -139,8 +148,12 @@ def generate_trade_decision(self, execute_result=None): # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1] trade_step = self.trade_calendar.get_trade_step() trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) - pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) - pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time) + pred_start_time, pred_end_time = self.trade_calendar.get_step_time( + trade_step, shift=1 + ) + pred_score = self.signal.get_signal( + start_time=pred_start_time, end_time=pred_end_time + ) # NOTE: the current version of topk dropout strategy can't handle pd.DataFrame(multiple signal) # So it only leverage the first col of signal if isinstance(pred_score, pd.DataFrame): @@ -155,7 +168,9 @@ def get_first_n(li, n, reverse=False): res = [] for si in reversed(li) if reverse else li: if self.trade_exchange.is_stock_tradable( - stock_id=si, start_time=trade_start_time, end_time=trade_end_time + stock_id=si, + start_time=trade_start_time, + end_time=trade_end_time, ): res.append(si) cur_n += 1 @@ -171,7 +186,9 @@ def filter_stock(li): si for si in li if self.trade_exchange.is_stock_tradable( - stock_id=si, start_time=trade_start_time, end_time=trade_end_time + stock_id=si, + start_time=trade_start_time, + end_time=trade_end_time, ) ] @@ -198,11 +215,15 @@ def filter_stock(li): # The new stocks today want to buy **at most** if self.method_buy == "top": today = get_first_n( - pred_score[~pred_score.index.isin(last)].sort_values(ascending=False).index, + pred_score[~pred_score.index.isin(last)] + .sort_values(ascending=False) + .index, self.n_drop + self.topk - len(last), ) elif self.method_buy == "random": - topk_candi = get_first_n(pred_score.sort_values(ascending=False).index, self.topk) + topk_candi = get_first_n( + pred_score.sort_values(ascending=False).index, self.topk + ) candi = list(filter(lambda x: x not in last, topk_candi)) n = self.n_drop + self.topk - len(last) try: @@ -213,7 +234,11 @@ def filter_stock(li): raise NotImplementedError(f"This type of input is not supported") # combine(new stocks + last stocks), we will drop stocks from this list # In case of dropping higher score stock and buying lower score stock. - comb = pred_score.reindex(last.union(pd.Index(today))).sort_values(ascending=False).index + comb = ( + pred_score.reindex(last.union(pd.Index(today))) + .sort_values(ascending=False) + .index + ) # Get the stock list we really want to sell (After filtering the case that we sell high and buy low) if self.method_sell == "bottom": @@ -221,7 +246,11 @@ def filter_stock(li): elif self.method_sell == "random": candi = filter_stock(last) try: - sell = pd.Index(np.random.choice(candi, self.n_drop, replace=False) if len(last) else []) + sell = pd.Index( + np.random.choice(candi, self.n_drop, replace=False) + if len(last) + else [] + ) except ValueError: # No enough candidates sell = candi else: @@ -240,7 +269,10 @@ def filter_stock(li): if code in sell: # check hold limit time_per_step = self.trade_calendar.get_freq() - if current_temp.get_stock_count(code, bar=time_per_step) < self.hold_thresh: + if ( + current_temp.get_stock_count(code, bar=time_per_step) + < self.hold_thresh + ): continue # sell order sell_amount = current_temp.get_stock_amount(code=code) @@ -279,11 +311,18 @@ def filter_stock(li): continue # buy order buy_price = self.trade_exchange.get_deal_price( - stock_id=code, start_time=trade_start_time, end_time=trade_end_time, direction=OrderDir.BUY + stock_id=code, + start_time=trade_start_time, + end_time=trade_end_time, + direction=OrderDir.BUY, ) buy_amount = value / buy_price - factor = self.trade_exchange.get_factor(stock_id=code, start_time=trade_start_time, end_time=trade_end_time) - buy_amount = self.trade_exchange.round_amount_by_trade_unit(buy_amount, factor) + factor = self.trade_exchange.get_factor( + stock_id=code, start_time=trade_start_time, end_time=trade_end_time + ) + buy_amount = self.trade_exchange.round_amount_by_trade_unit( + buy_amount, factor + ) buy_order = Order( stock_id=code, amount=buy_amount, @@ -327,7 +366,9 @@ def __init__( else: self.order_generator: OrderGenerator = order_generator_cls_or_obj - def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time): + def generate_target_weight_position( + self, score, current, trade_start_time, trade_end_time + ): """ Generate target position from score for this date and the current position.The cash is not considered in the position @@ -349,26 +390,35 @@ def generate_trade_decision(self, execute_result=None): # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1] trade_step = self.trade_calendar.get_trade_step() trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) - pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) - pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time) + pred_start_time, pred_end_time = self.trade_calendar.get_step_time( + trade_step, shift=1 + ) + pred_score = self.signal.get_signal( + start_time=pred_start_time, end_time=pred_end_time + ) if pred_score is None: return TradeDecisionWO([], self) current_temp = copy.deepcopy(self.trade_position) assert isinstance(current_temp, Position) # Avoid InfPosition target_weight_position = self.generate_target_weight_position( - score=pred_score, current=current_temp, trade_start_time=trade_start_time, trade_end_time=trade_end_time - ) - order_list = self.order_generator.generate_order_list_from_target_weight_position( + score=pred_score, current=current_temp, - trade_exchange=self.trade_exchange, - risk_degree=self.get_risk_degree(trade_step), - target_weight_position=target_weight_position, - pred_start_time=pred_start_time, - pred_end_time=pred_end_time, trade_start_time=trade_start_time, trade_end_time=trade_end_time, ) + order_list = ( + self.order_generator.generate_order_list_from_target_weight_position( + current=current_temp, + trade_exchange=self.trade_exchange, + risk_degree=self.get_risk_degree(trade_step), + target_weight_position=target_weight_position, + pred_start_time=pred_start_time, + pred_end_time=pred_end_time, + trade_start_time=trade_start_time, + trade_end_time=trade_end_time, + ) + ) return TradeDecisionWO(order_list, self) @@ -424,7 +474,9 @@ def __init__( self.factor_exp_path = name_mapping.get("factor_exp", self.FACTOR_EXP_NAME) self.factor_cov_path = name_mapping.get("factor_cov", self.FACTOR_COV_NAME) - self.specific_risk_path = name_mapping.get("specific_risk", self.SPECIFIC_RISK_NAME) + self.specific_risk_path = name_mapping.get( + "specific_risk", self.SPECIFIC_RISK_NAME + ) self.blacklist_path = name_mapping.get("blacklist", self.BLACKLIST_NAME) self.optimizer = EnhancedIndexingOptimizer(**optimizer_kwargs) @@ -443,11 +495,15 @@ def get_risk_data(self, date): factor_exp = load_dataset(root + "/" + self.factor_exp_path, index_col=[0]) factor_cov = load_dataset(root + "/" + self.factor_cov_path, index_col=[0]) - specific_risk = load_dataset(root + "/" + self.specific_risk_path, index_col=[0]) + specific_risk = load_dataset( + root + "/" + self.specific_risk_path, index_col=[0] + ) if not factor_exp.index.equals(specific_risk.index): # NOTE: for stocks missing specific_risk, we always assume it has the highest volatility - specific_risk = specific_risk.reindex(factor_exp.index, fill_value=specific_risk.max()) + specific_risk = specific_risk.reindex( + factor_exp.index, fill_value=specific_risk.max() + ) universe = factor_exp.index.tolist() @@ -455,18 +511,28 @@ def get_risk_data(self, date): if os.path.exists(root + "/" + self.blacklist_path): blacklist = load_dataset(root + "/" + self.blacklist_path).index.tolist() - self._riskdata_cache[date] = factor_exp.values, factor_cov.values, specific_risk.values, universe, blacklist + self._riskdata_cache[date] = ( + factor_exp.values, + factor_cov.values, + specific_risk.values, + universe, + blacklist, + ) return self._riskdata_cache[date] - def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time): + def generate_target_weight_position( + self, score, current, trade_start_time, trade_end_time + ): trade_date = trade_start_time pre_date = get_pre_trading_date(trade_date, future=True) # previous trade date # load risk data outs = self.get_risk_data(pre_date) if outs is None: - self.logger.warning(f"no risk data for {pre_date:%Y-%m-%d}, skip optimization") + self.logger.warning( + f"no risk data for {pre_date:%Y-%m-%d}, skip optimization" + ) return None factor_exp, factor_cov, specific_risk, universe, blacklist = outs @@ -479,26 +545,37 @@ def generate_target_weight_position(self, score, current, trade_start_time, trad cur_weight = current.get_stock_weight_dict(only_stock=False) cur_weight = np.array([cur_weight.get(stock, 0) for stock in universe]) assert all(cur_weight >= 0), "current weight has negative values" - cur_weight = cur_weight / self.get_risk_degree(trade_date) # sum of weight should be risk_degree + cur_weight = cur_weight / self.get_risk_degree( + trade_date + ) # sum of weight should be risk_degree if cur_weight.sum() > 1 and self.verbose: - self.logger.warning(f"previous total holdings excess risk degree (current: {cur_weight.sum()})") + self.logger.warning( + f"previous total holdings excess risk degree (current: {cur_weight.sum()})" + ) # load bench weight bench_weight = D.features( - D.instruments("all"), [f"${self.market}_weight"], start_time=pre_date, end_time=pre_date + D.instruments("all"), + [f"${self.market}_weight"], + start_time=pre_date, + end_time=pre_date, ).squeeze() bench_weight.index = bench_weight.index.droplevel(level="datetime") bench_weight = bench_weight.reindex(universe).fillna(0).values # whether stock tradable # NOTE: currently we use last day volume to check whether tradable - tradable = D.features(D.instruments("all"), ["$volume"], start_time=pre_date, end_time=pre_date).squeeze() + tradable = D.features( + D.instruments("all"), ["$volume"], start_time=pre_date, end_time=pre_date + ).squeeze() tradable.index = tradable.index.droplevel(level="datetime") tradable = tradable.reindex(universe).gt(0).values mask_force_hold = ~tradable # mask force sell - mask_force_sell = np.array([stock in blacklist for stock in universe], dtype=bool) + mask_force_sell = np.array( + [stock in blacklist for stock in universe], dtype=bool + ) # optimize weight = self.optimizer( @@ -512,11 +589,15 @@ def generate_target_weight_position(self, score, current, trade_start_time, trad mfs=mask_force_sell, ) - target_weight_position = {stock: weight for stock, weight in zip(universe, weight) if weight > 0} + target_weight_position = { + stock: weight for stock, weight in zip(universe, weight) if weight > 0 + } if self.verbose: self.logger.info("trade date: {:%Y-%m-%d}".format(trade_date)) - self.logger.info("number of holding stocks: {}".format(len(target_weight_position))) + self.logger.info( + "number of holding stocks: {}".format(len(target_weight_position)) + ) self.logger.info("total holding weight: {:.6f}".format(weight.sum())) return target_weight_position diff --git a/qlib/contrib/tuner/config.py b/qlib/contrib/tuner/config.py index 4cedd3642b..f6fb11c4f5 100644 --- a/qlib/contrib/tuner/config.py +++ b/qlib/contrib/tuner/config.py @@ -20,9 +20,13 @@ def __init__(self, config_path): config = yaml.load(fp) self.config = copy.deepcopy(config) - self.pipeline_ex_config = PipelineExperimentConfig(config.get("experiment", dict()), self) + self.pipeline_ex_config = PipelineExperimentConfig( + config.get("experiment", dict()), self + ) self.pipeline_config = config.get("tuner_pipeline", list()) - self.optim_config = OptimizationConfig(config.get("optimization_criteria", dict()), self) + self.optim_config = OptimizationConfig( + config.get("optimization_criteria", dict()), self + ) self.time_config = config.get("time_period", dict()) self.data_config = config.get("data", dict()) @@ -38,17 +42,25 @@ def __init__(self, config, TUNER_CONFIG_MANAGER): """ self.name = config.get("name", "tuner_experiment") # The dir of the config - self.global_dir = config.get("dir", os.path.dirname(TUNER_CONFIG_MANAGER.config_path)) + self.global_dir = config.get( + "dir", os.path.dirname(TUNER_CONFIG_MANAGER.config_path) + ) # The dir of the result of tuner experiment - self.tuner_ex_dir = config.get("tuner_ex_dir", os.path.join(self.global_dir, self.name)) + self.tuner_ex_dir = config.get( + "tuner_ex_dir", os.path.join(self.global_dir, self.name) + ) if not os.path.exists(self.tuner_ex_dir): os.makedirs(self.tuner_ex_dir) # The dir of the results of all estimator experiments - self.estimator_ex_dir = config.get("estimator_ex_dir", os.path.join(self.tuner_ex_dir, "estimator_experiment")) + self.estimator_ex_dir = config.get( + "estimator_ex_dir", os.path.join(self.tuner_ex_dir, "estimator_experiment") + ) if not os.path.exists(self.estimator_ex_dir): os.makedirs(self.estimator_ex_dir) # Get the tuner type - self.tuner_module_path = config.get("tuner_module_path", "qlib.contrib.tuner.tuner") + self.tuner_module_path = config.get( + "tuner_module_path", "qlib.contrib.tuner.tuner" + ) self.tuner_class = config.get("tuner_class", "QLibTuner") # Save the tuner experiment for further view tuner_ex_config_path = os.path.join(self.tuner_ex_dir, "tuner_config.yaml") diff --git a/qlib/contrib/tuner/launcher.py b/qlib/contrib/tuner/launcher.py index 352d2ca48d..7048f80ad7 100644 --- a/qlib/contrib/tuner/launcher.py +++ b/qlib/contrib/tuner/launcher.py @@ -30,7 +30,9 @@ def run(): # 1. Get pipeline class. - tuner_pipeline_class = getattr(importlib.import_module(".pipeline", package="qlib.contrib.tuner"), "Pipeline") + tuner_pipeline_class = getattr( + importlib.import_module(".pipeline", package="qlib.contrib.tuner"), "Pipeline" + ) # 2. Init tuner pipeline. tuner_pipeline = tuner_pipeline_class(TUNER_CONFIG_MANAGER) # 3. Begin to tune diff --git a/qlib/contrib/tuner/pipeline.py b/qlib/contrib/tuner/pipeline.py index 34977fa55f..86864872af 100644 --- a/qlib/contrib/tuner/pipeline.py +++ b/qlib/contrib/tuner/pipeline.py @@ -68,14 +68,18 @@ def init_tuner(self, tuner_index, tuner_config): tuner_config["trainer"].update({"args": self.time_config}) # 5. Import Tuner class - tuner_module = get_module_by_module_path(self.pipeline_ex_config.tuner_module_path) + tuner_module = get_module_by_module_path( + self.pipeline_ex_config.tuner_module_path + ) tuner_class = getattr(tuner_module, self.pipeline_ex_config.tuner_class) # 6. Return the specific tuner return tuner_class(tuner_config, self.optim_config) def save_tuner_exp_info(self): TimeInspector.set_time_mark() - save_path = os.path.join(self.pipeline_ex_config.tuner_ex_dir, Pipeline.GLOBAL_BEST_PARAMS_NAME) + save_path = os.path.join( + self.pipeline_ex_config.tuner_ex_dir, Pipeline.GLOBAL_BEST_PARAMS_NAME + ) with open(save_path, "w") as fp: json.dump(self.global_best_params, fp) TimeInspector.log_cost_time("Finished save global best tuner parameters.") diff --git a/qlib/contrib/tuner/tuner.py b/qlib/contrib/tuner/tuner.py index 7705ce8b73..2c88714002 100644 --- a/qlib/contrib/tuner/tuner.py +++ b/qlib/contrib/tuner/tuner.py @@ -51,7 +51,9 @@ def tune(self): ) self.logger.info("Local best params: {} ".format(self.best_params)) TimeInspector.log_cost_time( - "Finished searching best parameters in Tuner {}.".format(self.tuner_config["experiment"]["id"]) + "Finished searching best parameters in Tuner {}.".format( + self.tuner_config["experiment"]["id"] + ) ) self.save_local_best_params() @@ -94,10 +96,14 @@ def objective(self, params): self.logger.info("Searching params: {} ".format(params)) # 2. Use subprocess to do the estimator program, this process will wait until subprocess finish - sub_fails = subprocess.call("estimator -c {}".format(estimator_path), shell=True) + sub_fails = subprocess.call( + "estimator -c {}".format(estimator_path), shell=True + ) if sub_fails: # If this subprocess failed, ignore this evaluation step - self.logger.info("Estimator experiment failed when using this searching parameters") + self.logger.info( + "Estimator experiment failed when using this searching parameters" + ) return {"loss": np.nan, "status": STATUS_FAIL} # 3. Fetch the result of subprocess, and check whether the result is Nan @@ -133,13 +139,17 @@ def fetch_result(self): return np.abs(exp_info["performance"]["model_pearsonr"] - 1) # 3. Get backtest results - exp_result_dir = os.path.join(self.ex_dir, QLibTuner.EXP_RESULT_DIR.format(estimator_ex_id)) + exp_result_dir = os.path.join( + self.ex_dir, QLibTuner.EXP_RESULT_DIR.format(estimator_ex_id) + ) exp_result_path = os.path.join(exp_result_dir, QLibTuner.EXP_RESULT_NAME) with open(exp_result_path, "rb") as fp: analysis_df = pickle.load(fp) # 4. Get the backtest factor which user want to optimize, if user want to maximize the factor, then reverse the result - res = analysis_df.loc[self.optim_config.report_type].loc[self.optim_config.report_factor] + res = analysis_df.loc[self.optim_config.report_type].loc[ + self.optim_config.report_factor + ] # res = res.values[0] if self.optim_config.optim_type == 'min' else -res.values[0] if self.optim_config == "min": return res.values[0] @@ -207,9 +217,13 @@ def setup_space(self): def save_local_best_params(self): TimeInspector.set_time_mark() - local_best_params_path = os.path.join(self.ex_dir, QLibTuner.LOCAL_BEST_PARAMS_NAME) + local_best_params_path = os.path.join( + self.ex_dir, QLibTuner.LOCAL_BEST_PARAMS_NAME + ) with open(local_best_params_path, "w") as fp: json.dump(self.best_params, fp) TimeInspector.log_cost_time( - "Finished saving local best tuner parameters to: {} .".format(local_best_params_path) + "Finished saving local best tuner parameters to: {} .".format( + local_best_params_path + ) ) diff --git a/qlib/contrib/workflow/record_temp.py b/qlib/contrib/workflow/record_temp.py index 8d10b2ab48..c8313d5cde 100644 --- a/qlib/contrib/workflow/record_temp.py +++ b/qlib/contrib/workflow/record_temp.py @@ -25,7 +25,11 @@ class MultiSegRecord(RecordTemp): def __init__(self, model, dataset, recorder=None): super().__init__(recorder=recorder) if not isinstance(dataset, qlib_dataset.DatasetH): - raise ValueError("The type of dataset is not DatasetH instead of {:}".format(type(dataset))) + raise ValueError( + "The type of dataset is not DatasetH instead of {:}".format( + type(dataset) + ) + ) self.model = model self.dataset = dataset @@ -35,11 +39,18 @@ def generate(self, segments: Dict[Text, Any], save: bool = False): if isinstance(predics, pd.Series): predics = predics.to_frame("score") labels = self.dataset.prepare( - segments=segment, col_set="label", data_key=qlib_dataset.handler.DataHandlerLP.DK_R + segments=segment, + col_set="label", + data_key=qlib_dataset.handler.DataHandlerLP.DK_R, ) # Compute the IC and Rank IC ic, ric = calc_ic(predics.iloc[:, 0], labels.iloc[:, 0]) - results = {"all-IC": ic, "mean-IC": ic.mean(), "all-Rank-IC": ric, "mean-Rank-IC": ric.mean()} + results = { + "all-IC": ic, + "mean-IC": ic.mean(), + "all-Rank-IC": ric, + "mean-Rank-IC": ric.mean(), + } logger.info("--- Results for {:} ({:}) ---".format(key, segment)) ic_x100, ric_x100 = ic * 100, ric * 100 logger.info("IC: {:.4f}%".format(ic_x100.mean())) diff --git a/qlib/data/base.py b/qlib/data/base.py index 496ae38ee2..d9cf0e567f 100644 --- a/qlib/data/base.py +++ b/qlib/data/base.py @@ -187,8 +187,14 @@ def load(self, instrument, start_index, end_index, *args): cache_key = str(self), instrument, start_index, end_index, *args if cache_key in H["f"]: return H["f"][cache_key] - if start_index is not None and end_index is not None and start_index > end_index: - raise ValueError("Invalid index range: {} {}".format(start_index, end_index)) + if ( + start_index is not None + and end_index is not None + and start_index > end_index + ): + raise ValueError( + "Invalid index range: {} {}".format(start_index, end_index) + ) try: series = self._load_internal(instrument, start_index, end_index, *args) except Exception as e: @@ -204,7 +210,9 @@ def load(self, instrument, start_index, end_index, *args): @abc.abstractmethod def _load_internal(self, instrument, start_index, end_index, *args) -> pd.Series: - raise NotImplementedError("This function must be implemented in your newly defined feature") + raise NotImplementedError( + "This function must be implemented in your newly defined feature" + ) @abc.abstractmethod def get_longest_back_rolling(self): @@ -217,7 +225,9 @@ def get_longest_back_rolling(self): So this will only used for detecting the length of historical data needed. """ # TODO: forward operator like Ref($close, -1) is not supported yet. - raise NotImplementedError("This function must be implemented in your newly defined feature") + raise NotImplementedError( + "This function must be implemented in your newly defined feature" + ) @abc.abstractmethod def get_extended_window_size(self): @@ -232,7 +242,9 @@ def get_extended_window_size(self): (int, int) lft_etd, rght_etd """ - raise NotImplementedError("This function must be implemented in your newly defined feature") + raise NotImplementedError( + "This function must be implemented in your newly defined feature" + ) class Feature(Expression): @@ -270,7 +282,9 @@ def __str__(self): def _load_internal(self, instrument, start_index, end_index, cur_time, period=None): from .data import PITD # pylint: disable=C0415 - return PITD.period_feature(instrument, str(self), start_index, end_index, cur_time, period) + return PITD.period_feature( + instrument, str(self), start_index, end_index, cur_time, period + ) class ExpressionOps(Expression): diff --git a/qlib/data/cache.py b/qlib/data/cache.py index 9ba87f3d26..e6671e5c61 100644 --- a/qlib/data/cache.py +++ b/qlib/data/cache.py @@ -147,7 +147,11 @@ def __init__(self, mem_cache_size_limit=None, limit_type="length"): length or sizeof; length(call fun: len), size(call fun: sys.getsizeof). """ - size_limit = C.mem_cache_size_limit if mem_cache_size_limit is None else mem_cache_size_limit + size_limit = ( + C.mem_cache_size_limit + if mem_cache_size_limit is None + else mem_cache_size_limit + ) limit_type = C.mem_cache_limit_type if limit_type is None else limit_type if limit_type == "length": @@ -155,7 +159,9 @@ def __init__(self, mem_cache_size_limit=None, limit_type="length"): elif limit_type == "sizeof": klass = MemCacheSizeofUnit else: - raise ValueError(f"limit_type must be length or sizeof, your limit_type is {limit_type}") + raise ValueError( + f"limit_type must be length or sizeof, your limit_type is {limit_type}" + ) self.__calendar_mem_cache = klass(size_limit) self.__instrument_mem_cache = klass(size_limit) @@ -234,7 +240,9 @@ def visit(cache_path: Union[str, Path]): raise KeyError("Unknown meta keyword") from key_e pickle.dump(d, f, protocol=C.dump_protocol_version) except Exception as e: - get_module_logger("CacheUtils").warning(f"visit {cache_path} cache error: {e}") + get_module_logger("CacheUtils").warning( + f"visit {cache_path} cache error: {e}" + ) @staticmethod def acquire(lock, lock_name): @@ -283,7 +291,9 @@ def reader_lock(redis_t, lock_name: str): @staticmethod @contextlib.contextmanager def writer_lock(redis_t, lock_name): - current_cache_wlock = redis_lock.Lock(redis_t, f"{lock_name}-wlock", id=CacheUtils.LOCK_ID) + current_cache_wlock = redis_lock.Lock( + redis_t, f"{lock_name}-wlock", id=CacheUtils.LOCK_ID + ) CacheUtils.acquire(current_cache_wlock, lock_name) try: yield @@ -302,7 +312,9 @@ def __getattr__(self, attr): return getattr(self.provider, attr) @staticmethod - def check_cache_exists(cache_path: Union[str, Path], suffix_list: Iterable = (".index", ".meta")) -> bool: + def check_cache_exists( + cache_path: Union[str, Path], suffix_list: Iterable = (".index", ".meta") + ) -> bool: cache_path = Path(cache_path) for p in [cache_path] + [cache_path.with_suffix(_s) for _s in suffix_list]: if not p.exists(): @@ -342,21 +354,27 @@ def expression(self, instrument, field, start_time, end_time, freq): try: return self._expression(instrument, field, start_time, end_time, freq) except NotImplementedError: - return self.provider.expression(instrument, field, start_time, end_time, freq) + return self.provider.expression( + instrument, field, start_time, end_time, freq + ) def _uri(self, instrument, field, start_time, end_time, freq): """Get expression cache file uri. Override this method to define how to get expression cache file uri corresponding to users' own cache mechanism. """ - raise NotImplementedError("Implement this function to match your own cache mechanism") + raise NotImplementedError( + "Implement this function to match your own cache mechanism" + ) def _expression(self, instrument, field, start_time, end_time, freq): """Get expression data using cache. Override this method to define how to get expression data corresponding to users' own cache mechanism. """ - raise NotImplementedError("Implement this method if you want to use expression cache") + raise NotImplementedError( + "Implement this method if you want to use expression cache" + ) def update(self, cache_uri: Union[str, Path], freq: str = "day"): """Update expression cache to latest calendar. @@ -374,7 +392,9 @@ def update(self, cache_uri: Union[str, Path], freq: str = "day"): int 0(successful update)/ 1(no need to update)/ 2(update failure). """ - raise NotImplementedError("Implement this method if you want to make expression cache up to date") + raise NotImplementedError( + "Implement this method if you want to make expression cache up to date" + ) class DatasetCache(BaseProviderCache): @@ -388,7 +408,14 @@ class DatasetCache(BaseProviderCache): HDF_KEY = "df" def dataset( - self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[] + self, + instruments, + fields, + start_time=None, + end_time=None, + freq="day", + disk_cache=1, + inst_processors=[], ): """Get feature dataset. @@ -401,17 +428,33 @@ def dataset( if disk_cache == 0: # skip cache return self.provider.dataset( - instruments, fields, start_time, end_time, freq, inst_processors=inst_processors + instruments, + fields, + start_time, + end_time, + freq, + inst_processors=inst_processors, ) else: # use and replace cache try: return self._dataset( - instruments, fields, start_time, end_time, freq, disk_cache, inst_processors=inst_processors + instruments, + fields, + start_time, + end_time, + freq, + disk_cache, + inst_processors=inst_processors, ) except NotImplementedError: return self.provider.dataset( - instruments, fields, start_time, end_time, freq, inst_processors=inst_processors + instruments, + fields, + start_time, + end_time, + freq, + inst_processors=inst_processors, ) def _uri(self, instruments, fields, start_time, end_time, freq, **kwargs): @@ -419,19 +462,37 @@ def _uri(self, instruments, fields, start_time, end_time, freq, **kwargs): Override this method to define how to get dataset cache file uri corresponding to users' own cache mechanism. """ - raise NotImplementedError("Implement this function to match your own cache mechanism") + raise NotImplementedError( + "Implement this function to match your own cache mechanism" + ) def _dataset( - self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[] + self, + instruments, + fields, + start_time=None, + end_time=None, + freq="day", + disk_cache=1, + inst_processors=[], ): """Get feature dataset using cache. Override this method to define how to get feature dataset corresponding to users' own cache mechanism. """ - raise NotImplementedError("Implement this method if you want to use dataset feature cache") + raise NotImplementedError( + "Implement this method if you want to use dataset feature cache" + ) def _dataset_uri( - self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[] + self, + instruments, + fields, + start_time=None, + end_time=None, + freq="day", + disk_cache=1, + inst_processors=[], ): """Get a uri of feature dataset using cache. specially: @@ -460,7 +521,9 @@ def update(self, cache_uri: Union[str, Path], freq: str = "day"): int 0(successful update)/ 1(no need to update)/ 2(update failure) """ - raise NotImplementedError("Implement this method if you want to make expression cache up to date") + raise NotImplementedError( + "Implement this method if you want to make expression cache up to date" + ) @staticmethod def cache_to_origin_data(data, fields): @@ -496,15 +559,25 @@ def __init__(self, provider, **kwargs): self.remote = kwargs.get("remote", False) def get_cache_dir(self, freq: str = None) -> Path: - return super(DiskExpressionCache, self).get_cache_dir(C.features_cache_dir_name, freq) + return super(DiskExpressionCache, self).get_cache_dir( + C.features_cache_dir_name, freq + ) def _uri(self, instrument, field, start_time, end_time, freq): field = remove_fields_space(field) instrument = str(instrument).lower() return hash_args(instrument, field, freq) - def _expression(self, instrument, field, start_time=None, end_time=None, freq="day"): - _cache_uri = self._uri(instrument=instrument, field=field, start_time=None, end_time=None, freq=freq) + def _expression( + self, instrument, field, start_time=None, end_time=None, freq="day" + ): + _cache_uri = self._uri( + instrument=instrument, + field=field, + start_time=None, + end_time=None, + freq=freq, + ) _instrument_dir = self.get_cache_dir(freq).joinpath(instrument.lower()) cache_path = _instrument_dir.joinpath(_cache_uri) # get calendar @@ -512,7 +585,9 @@ def _expression(self, instrument, field, start_time=None, end_time=None, freq="d _calendar = Cal.calendar(freq=freq) - _, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq, future=False) + _, _, start_index, end_index = Cal.locate_index( + start_time, end_time, freq, future=False + ) if self.check_cache_exists(cache_path, suffix_list=[".meta"]): """ @@ -532,7 +607,9 @@ def _expression(self, instrument, field, start_time=None, end_time=None, freq="d return series except Exception: series = None - self.logger.error("reading %s file error : %s" % (cache_path, traceback.format_exc())) + self.logger.error( + "reading %s file error : %s" % (cache_path, traceback.format_exc()) + ) return series else: # normalize field @@ -543,10 +620,15 @@ def _expression(self, instrument, field, start_time=None, end_time=None, freq="d # When the expression is not a raw feature # generate expression cache if the feature is not a Feature # instance - series = self.provider.expression(instrument, field, _calendar[0], _calendar[-1], freq) + series = self.provider.expression( + instrument, field, _calendar[0], _calendar[-1], freq + ) if not series.empty: # This expression is empty, we don't generate any cache for it. - with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_uri(freq))}:expression-{_cache_uri}"): + with CacheUtils.writer_lock( + self.r, + f"{str(C.dpm.get_data_uri(freq))}:expression-{_cache_uri}", + ): self.gen_expression_cache( expression_data=series, cache_path=cache_path, @@ -560,14 +642,23 @@ def _expression(self, instrument, field, start_time=None, end_time=None, freq="d return series else: # If the expression is a raw feature(such as $close, $open) - return self.provider.expression(instrument, field, start_time, end_time, freq) + return self.provider.expression( + instrument, field, start_time, end_time, freq + ) - def gen_expression_cache(self, expression_data, cache_path, instrument, field, freq, last_update): + def gen_expression_cache( + self, expression_data, cache_path, instrument, field, freq, last_update + ): """use bin file to save like feature-data.""" # Make sure the cache runs right when the directory is deleted # while running meta = { - "info": {"instrument": instrument, "field": field, "freq": freq, "last_update": last_update}, + "info": { + "instrument": instrument, + "field": field, + "freq": freq, + "last_update": last_update, + }, "meta": {"last_visit": time.time(), "visits": 1}, } self.logger.debug(f"generating expression cache: {meta}") @@ -586,11 +677,15 @@ def update(self, sid, cache_uri, freq: str = "day"): cp_cache_uri = self.get_cache_dir(freq).joinpath(sid).joinpath(cache_uri) meta_path = cp_cache_uri.with_suffix(".meta") if not self.check_cache_exists(cp_cache_uri, suffix_list=[".meta"]): - self.logger.info(f"The cache {cp_cache_uri} has corrupted. It will be removed") + self.logger.info( + f"The cache {cp_cache_uri} has corrupted. It will be removed" + ) self.clear_cache(cp_cache_uri) return 2 - with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_uri())}:expression-{cache_uri}"): + with CacheUtils.writer_lock( + self.r, f"{str(C.dpm.get_data_uri())}:expression-{cache_uri}" + ): with meta_path.open("rb") as f: d = pickle.load(f) instrument = d["info"]["instrument"] @@ -603,7 +698,9 @@ def update(self, sid, cache_uri, freq: str = "day"): whole_calendar = Cal.calendar(start_time=None, end_time=None, freq=freq) # calendar since last updated. - new_calendar = Cal.calendar(start_time=last_update_time, end_time=None, freq=freq) + new_calendar = Cal.calendar( + start_time=last_update_time, end_time=None, freq=freq + ) # get append data if len(new_calendar) <= 1: @@ -629,7 +726,11 @@ def update(self, sid, cache_uri, freq: str = "day"): remove_n = min(rght_etd, ele_n) assert new_calendar[1] == whole_calendar[current_index] data = self.provider.expression( - instrument, field, whole_calendar[current_index - remove_n], new_calendar[-1], freq + instrument, + field, + whole_calendar[current_index - remove_n], + new_calendar[-1], + freq, ) with open(cp_cache_uri, "ab") as f: data = np.array(data).astype(" Path: - return super(DiskDatasetCache, self).get_cache_dir(C.dataset_cache_dir_name, freq) + return super(DiskDatasetCache, self).get_cache_dir( + C.dataset_cache_dir_name, freq + ) @classmethod - def read_data_from_cache(cls, cache_path: Union[str, Path], start_time, end_time, fields): + def read_data_from_cache( + cls, cache_path: Union[str, Path], start_time, end_time, fields + ): """read_cache_from This function can read data from the disk cache dataset @@ -693,12 +811,24 @@ def read_data_from_cache(cls, cache_path: Union[str, Path], start_time, end_time return df def _dataset( - self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[] + self, + instruments, + fields, + start_time=None, + end_time=None, + freq="day", + disk_cache=0, + inst_processors=[], ): if disk_cache == 0: # In this case, data_set cache is configured but will not be used. return self.provider.dataset( - instruments, fields, start_time, end_time, freq, inst_processors=inst_processors + instruments, + fields, + start_time, + end_time, + freq, + inst_processors=inst_processors, ) # FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date if inst_processors: @@ -724,9 +854,13 @@ def _dataset( if self.check_cache_exists(cache_path): if disk_cache == 1: # use cache - with CacheUtils.reader_lock(self.r, f"{str(C.dpm.get_data_uri(freq))}:dataset-{_cache_uri}"): + with CacheUtils.reader_lock( + self.r, f"{str(C.dpm.get_data_uri(freq))}:dataset-{_cache_uri}" + ): CacheUtils.visit(cache_path) - features = self.read_data_from_cache(cache_path, start_time, end_time, fields) + features = self.read_data_from_cache( + cache_path, start_time, end_time, fields + ) elif disk_cache == 2: gen_flag = True else: @@ -734,7 +868,9 @@ def _dataset( if gen_flag: # cache unavailable, generate the cache - with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_uri(freq))}:dataset-{_cache_uri}"): + with CacheUtils.writer_lock( + self.r, f"{str(C.dpm.get_data_uri(freq))}:dataset-{_cache_uri}" + ): features = self.gen_dataset_cache( cache_path=cache_path, instruments=instruments, @@ -747,14 +883,23 @@ def _dataset( return features def _dataset_uri( - self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[] + self, + instruments, + fields, + start_time=None, + end_time=None, + freq="day", + disk_cache=0, + inst_processors=[], ): if disk_cache == 0: # In this case, server only checks the expression cache. # The client will load the cache data by itself. from .data import LocalDatasetProvider # pylint: disable=C0415 - LocalDatasetProvider.multi_cache_walker(instruments, fields, start_time, end_time, freq) + LocalDatasetProvider.multi_cache_walker( + instruments, fields, start_time, end_time, freq + ) return "" # FIXME: The cache after resample, when read again and intercepted with end_time, results in incomplete data date if inst_processors: @@ -774,13 +919,19 @@ def _dataset_uri( cache_path = self.get_cache_dir(freq).joinpath(_cache_uri) if self.check_cache_exists(cache_path): - self.logger.debug(f"The cache dataset has already existed {cache_path}. Return the uri directly") - with CacheUtils.reader_lock(self.r, f"{str(C.dpm.get_data_uri(freq))}:dataset-{_cache_uri}"): + self.logger.debug( + f"The cache dataset has already existed {cache_path}. Return the uri directly" + ) + with CacheUtils.reader_lock( + self.r, f"{str(C.dpm.get_data_uri(freq))}:dataset-{_cache_uri}" + ): CacheUtils.visit(cache_path) return _cache_uri else: # cache unavailable, generate the cache - with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_uri(freq))}:dataset-{_cache_uri}"): + with CacheUtils.writer_lock( + self.r, f"{str(C.dpm.get_data_uri(freq))}:dataset-{_cache_uri}" + ): self.gen_dataset_cache( cache_path=cache_path, instruments=instruments, @@ -853,7 +1004,14 @@ def build_index_from_data(data, start_index=0): index_data += start_index return index_data - def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, freq, inst_processors=[]): + def gen_dataset_cache( + self, + cache_path: Union[str, Path], + instruments, + fields, + freq, + inst_processors=[], + ): """gen_dataset_cache .. note:: This function does not consider the cache read write lock. Please @@ -903,7 +1061,12 @@ def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, f self.clear_cache(cache_path) features = self.provider.dataset( - instruments, fields, _calendar[0], _calendar[-1], freq, inst_processors=inst_processors + instruments, + fields, + _calendar[0], + _calendar[-1], + freq, + inst_processors=inst_processors, ) if features.empty: @@ -914,9 +1077,15 @@ def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, f # write cache data with pd.HDFStore(str(cache_path.with_suffix(".data"))) as store: - cache_to_orig_map = dict(zip(remove_fields_space(features.columns), features.columns)) - orig_to_cache_map = dict(zip(features.columns, remove_fields_space(features.columns))) - cache_features = features[list(cache_to_orig_map.values())].rename(columns=orig_to_cache_map) + cache_to_orig_map = dict( + zip(remove_fields_space(features.columns), features.columns) + ) + orig_to_cache_map = dict( + zip(features.columns, remove_fields_space(features.columns)) + ) + cache_features = features[list(cache_to_orig_map.values())].rename( + columns=orig_to_cache_map + ) # cache columns cache_columns = sorted(cache_features.columns) cache_features = cache_features.loc[:, cache_columns] @@ -935,7 +1104,9 @@ def gen_dataset_cache(self, cache_path: Union[str, Path], instruments, fields, f } with cache_path.with_suffix(".meta").open("wb") as f: pickle.dump(meta, f, protocol=C.dump_protocol_version) - cache_path.with_suffix(".meta").chmod(stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH) + cache_path.with_suffix(".meta").chmod( + stat.S_IRWXU | stat.S_IRGRP | stat.S_IROTH + ) # write index file im = DiskDatasetCache.IndexManager(cache_path) index_data = im.build_index_from_data(features) @@ -952,12 +1123,16 @@ def update(self, cache_uri, freq: str = "day"): cp_cache_uri = self.get_cache_dir(freq).joinpath(cache_uri) meta_path = cp_cache_uri.with_suffix(".meta") if not self.check_cache_exists(cp_cache_uri): - self.logger.info(f"The cache {cp_cache_uri} has corrupted. It will be removed") + self.logger.info( + f"The cache {cp_cache_uri} has corrupted. It will be removed" + ) self.clear_cache(cp_cache_uri) return 2 im = DiskDatasetCache.IndexManager(cp_cache_uri) - with CacheUtils.writer_lock(self.r, f"{str(C.dpm.get_data_uri())}:dataset-{cache_uri}"): + with CacheUtils.writer_lock( + self.r, f"{str(C.dpm.get_data_uri())}:dataset-{cache_uri}" + ): with meta_path.open("rb") as f: d = pickle.load(f) instruments = d["info"]["instruments"] @@ -979,7 +1154,9 @@ def update(self, cache_uri, freq: str = "day"): whole_calendar = Cal.calendar(start_time=None, end_time=None, freq=freq) # The calendar since last updated - new_calendar = Cal.calendar(start_time=last_update_time, end_time=None, freq=freq) + new_calendar = Cal.calendar( + start_time=last_update_time, end_time=None, freq=freq + ) # get append data if len(new_calendar) <= 1: @@ -1066,31 +1243,67 @@ class SimpleDatasetCache(DatasetCache): def __init__(self, provider): super(SimpleDatasetCache, self).__init__(provider) try: - self.local_cache_path: Path = Path(C["local_cache_path"]).expanduser().resolve() + self.local_cache_path: Path = ( + Path(C["local_cache_path"]).expanduser().resolve() + ) except (KeyError, TypeError): - self.logger.error("Assign a local_cache_path in config if you want to use this cache mechanism") + self.logger.error( + "Assign a local_cache_path in config if you want to use this cache mechanism" + ) raise self.logger.info( f"DatasetCache directory: {self.local_cache_path}, " f"modify the cache directory via the local_cache_path in the config" ) - def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs): + def _uri( + self, + instruments, + fields, + start_time, + end_time, + freq, + disk_cache=1, + inst_processors=[], + **kwargs, + ): instruments, fields, freq = self.normalize_uri_args(instruments, fields, freq) return hash_args( - instruments, fields, start_time, end_time, freq, disk_cache, str(self.local_cache_path), inst_processors + instruments, + fields, + start_time, + end_time, + freq, + disk_cache, + str(self.local_cache_path), + inst_processors, ) def _dataset( - self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1, inst_processors=[] + self, + instruments, + fields, + start_time=None, + end_time=None, + freq="day", + disk_cache=1, + inst_processors=[], ): if disk_cache == 0: # In this case, data_set cache is configured but will not be used. - return self.provider.dataset(instruments, fields, start_time, end_time, freq) + return self.provider.dataset( + instruments, fields, start_time, end_time, freq + ) self.local_cache_path.mkdir(exist_ok=True, parents=True) cache_file = self.local_cache_path.joinpath( self._uri( - instruments, fields, start_time, end_time, freq, disk_cache=disk_cache, inst_processors=inst_processors + instruments, + fields, + start_time, + end_time, + freq, + disk_cache=disk_cache, + inst_processors=inst_processors, ) ) gen_flag = False @@ -1108,7 +1321,12 @@ def _dataset( if gen_flag: data = self.provider.dataset( - instruments, normalize_cache_fields(fields), start_time, end_time, freq, inst_processors=inst_processors + instruments, + normalize_cache_fields(fields), + start_time, + end_time, + freq, + inst_processors=inst_processors, ) data.to_pickle(cache_file) return self.cache_to_origin_data(data, fields) @@ -1117,16 +1335,42 @@ def _dataset( class DatasetURICache(DatasetCache): """Prepared cache mechanism for server.""" - def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, inst_processors=[], **kwargs): - return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache, inst_processors) + def _uri( + self, + instruments, + fields, + start_time, + end_time, + freq, + disk_cache=1, + inst_processors=[], + **kwargs, + ): + return hash_args( + *self.normalize_uri_args(instruments, fields, freq), + disk_cache, + inst_processors, + ) def dataset( - self, instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=0, inst_processors=[] + self, + instruments, + fields, + start_time=None, + end_time=None, + freq="day", + disk_cache=0, + inst_processors=[], ): if "local" in C.dataset_provider.lower(): # use LocalDatasetProvider return self.provider.dataset( - instruments, fields, start_time, end_time, freq, inst_processors=inst_processors + instruments, + fields, + start_time, + end_time, + freq, + inst_processors=inst_processors, ) if disk_cache == 0: @@ -1149,10 +1393,20 @@ def dataset( ) # use ClientDatasetProvider feature_uri = self._uri( - instruments, fields, None, None, freq, disk_cache=disk_cache, inst_processors=inst_processors + instruments, + fields, + None, + None, + freq, + disk_cache=disk_cache, + inst_processors=inst_processors, ) value, expire = MemCacheExpire.get_cache(H["f"], feature_uri) - mnt_feature_uri = C.dpm.get_data_uri(freq).joinpath(C.dataset_cache_dir_name).joinpath(feature_uri) + mnt_feature_uri = ( + C.dpm.get_data_uri(freq) + .joinpath(C.dataset_cache_dir_name) + .joinpath(feature_uri) + ) if value is None or expire or not mnt_feature_uri.exists(): df, uri = self.provider.dataset( instruments, @@ -1170,7 +1424,9 @@ def dataset( # HZ['f'][uri] = df.copy() get_module_logger("cache").debug(f"get feature from {C.dataset_provider}") else: - df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields) + df = DiskDatasetCache.read_data_from_cache( + mnt_feature_uri, start_time, end_time, fields + ) get_module_logger("cache").debug("get feature from uri cache") return df diff --git a/qlib/data/client.py b/qlib/data/client.py index a9b4b2edf7..1847a1a392 100644 --- a/qlib/data/client.py +++ b/qlib/data/client.py @@ -28,7 +28,9 @@ def __init__(self, host, port): # bind connect/disconnect callbacks self.sio.on( "connect", - lambda: self.logger.debug("Connect to server {}".format(self.sio.connection_url)), + lambda: self.logger.debug( + "Connect to server {}".format(self.sio.connection_url) + ), ) self.sio.on("disconnect", lambda: self.logger.debug("Disconnect from server!")) @@ -37,7 +39,9 @@ def connect_server(self): try: self.sio.connect(f"ws://{self.server_host}:{self.server_port}") except socketio.exceptions.ConnectionError: - self.logger.error("Cannot connect to server - check your network or server status") + self.logger.error( + "Cannot connect to server - check your network or server status" + ) def disconnect(self): """Disconnect from server.""" @@ -46,7 +50,9 @@ def disconnect(self): except Exception as e: self.logger.error("Cannot disconnect from server : %s" % e) - def send_request(self, request_type, request_content, msg_queue, msg_proc_func=None): + def send_request( + self, request_type, request_content, msg_queue, msg_proc_func=None + ): """Send a certain request to server. Parameters @@ -76,7 +82,9 @@ def request_callback(*args): else: self.logger.info(msg["detailed_info"]) if msg["status"] != 0: - ex = ValueError(f"Bad response(status=={msg['status']}), detailed info: {msg['detailed_info']}") + ex = ValueError( + f"Bad response(status=={msg['status']}), detailed info: {msg['detailed_info']}" + ) msg_queue.put(ex) else: if msg_proc_func is not None: @@ -96,7 +104,10 @@ def request_callback(*args): self.logger.debug("connected") # The pickle is for passing some parameters with special type(such as # pd.Timestamp) - request_content = {"head": head_info, "body": pickle.dumps(request_content, protocol=C.dump_protocol_version)} + request_content = { + "head": head_info, + "body": pickle.dumps(request_content, protocol=C.dump_protocol_version), + } self.sio.on(request_type + "_response", request_callback) self.logger.debug("try sending") self.sio.emit(request_type + "_request", request_content) diff --git a/qlib/data/data.py b/qlib/data/data.py index aba75c0b1a..f18bd37705 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -109,7 +109,11 @@ def calendar(self, start_time=None, end_time=None, freq="day", future=False): return _calendar[si : ei + 1] def locate_index( - self, start_time: Union[pd.Timestamp, str], end_time: Union[pd.Timestamp, str], freq: str, future: bool = False + self, + start_time: Union[pd.Timestamp, str], + end_time: Union[pd.Timestamp, str], + freq: str, + future: bool = False, ): """Locate the start time index and end time index in a calendar under certain frequency. @@ -193,7 +197,9 @@ def load_calendar(self, freq, future): list list of timestamps """ - raise NotImplementedError("Subclass of CalendarProvider must implement `load_calendar` method") + raise NotImplementedError( + "Subclass of CalendarProvider must implement `load_calendar` method" + ) class InstrumentProvider(abc.ABC): @@ -203,7 +209,9 @@ class InstrumentProvider(abc.ABC): """ @staticmethod - def instruments(market: Union[List, str] = "all", filter_pipe: Union[List, None] = None): + def instruments( + market: Union[List, str] = "all", filter_pipe: Union[List, None] = None + ): """Get the general config dictionary for a base market adding several dynamic filters. Parameters @@ -264,7 +272,9 @@ def instruments(market: Union[List, str] = "all", filter_pipe: Union[List, None] return config @abc.abstractmethod - def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): + def list_instruments( + self, instruments, start_time=None, end_time=None, freq="day", as_list=False + ): """List the instruments based on a certain stockpool config. Parameters @@ -283,9 +293,13 @@ def list_instruments(self, instruments, start_time=None, end_time=None, freq="da dict or list instruments list or dictionary with time spans """ - raise NotImplementedError("Subclass of InstrumentProvider must implement `list_instruments` method") + raise NotImplementedError( + "Subclass of InstrumentProvider must implement `list_instruments` method" + ) - def _uri(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): + def _uri( + self, instruments, start_time=None, end_time=None, freq="day", as_list=False + ): return hash_args(instruments, start_time, end_time, freq, as_list) # instruments type @@ -332,7 +346,9 @@ def feature(self, instrument, field, start_time, end_time, freq): pd.Series data of a certain feature """ - raise NotImplementedError("Subclass of FeatureProvider must implement `feature` method") + raise NotImplementedError( + "Subclass of FeatureProvider must implement `feature` method" + ) class PITProvider(abc.ABC): @@ -398,16 +414,21 @@ def get_expression_instance(self, field): self.expression_instance_cache[field] = expression except NameError as e: get_module_logger("data").exception( - "ERROR: field [%s] contains invalid operator/variable [%s]" % (str(field), str(e).split()[1]) + "ERROR: field [%s] contains invalid operator/variable [%s]" + % (str(field), str(e).split()[1]) ) raise except SyntaxError: - get_module_logger("data").exception("ERROR: field [%s] contains invalid syntax" % str(field)) + get_module_logger("data").exception( + "ERROR: field [%s] contains invalid syntax" % str(field) + ) raise return expression @abc.abstractmethod - def expression(self, instrument, field, start_time=None, end_time=None, freq="day") -> pd.Series: + def expression( + self, instrument, field, start_time=None, end_time=None, freq="day" + ) -> pd.Series: """Get Expression data. The responsibility of `expression` @@ -440,7 +461,9 @@ def expression(self, instrument, field, start_time=None, end_time=None, freq="da - because the datetime is not as good as """ - raise NotImplementedError("Subclass of ExpressionProvider must implement `Expression` method") + raise NotImplementedError( + "Subclass of ExpressionProvider must implement `Expression` method" + ) class DatasetProvider(abc.ABC): @@ -450,7 +473,15 @@ class DatasetProvider(abc.ABC): """ @abc.abstractmethod - def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day", inst_processors=[]): + def dataset( + self, + instruments, + fields, + start_time=None, + end_time=None, + freq="day", + inst_processors=[], + ): """Get dataset data. Parameters @@ -473,7 +504,9 @@ def dataset(self, instruments, fields, start_time=None, end_time=None, freq="day pd.DataFrame a pandas dataframe with index. """ - raise NotImplementedError("Subclass of DatasetProvider must implement `Dataset` method") + raise NotImplementedError( + "Subclass of DatasetProvider must implement `Dataset` method" + ) def _uri( self, @@ -505,7 +538,9 @@ def _uri( """ # TODO: qlib-server support inst_processors - return DiskDatasetCache._uri(instruments, fields, start_time, end_time, freq, disk_cache, inst_processors) + return DiskDatasetCache._uri( + instruments, fields, start_time, end_time, freq, disk_cache, inst_processors + ) @staticmethod def get_instruments_d(instruments, freq): @@ -517,7 +552,9 @@ def get_instruments_d(instruments, freq): if isinstance(instruments, dict): if "market" in instruments: # dict of stockpool config - instruments_d = Inst.list_instruments(instruments=instruments, freq=freq, as_list=False) + instruments_d = Inst.list_instruments( + instruments=instruments, freq=freq, as_list=False + ) else: # dict of instruments and timestamp instruments_d = instruments @@ -545,7 +582,9 @@ def parse_fields(fields): return [ExpressionD.get_expression_instance(f) for f in fields] @staticmethod - def dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors=[]): + def dataset_processor( + instruments_d, column_names, start_time, end_time, freq, inst_processors=[] + ): """ Load and process the data, return the data set. - default using multi-kernel method. @@ -567,14 +606,25 @@ def dataset_processor(instruments_d, column_names, start_time, end_time, freq, i inst_l.append(inst) task_l.append( delayed(DatasetProvider.inst_calculator)( - inst, start_time, end_time, freq, normalize_column_names, spans, C, inst_processors + inst, + start_time, + end_time, + freq, + normalize_column_names, + spans, + C, + inst_processors, ) ) data = dict( zip( inst_l, - ParallelExt(n_jobs=workers, backend=C.joblib_backend, maxtasksperchild=C.maxtasksperchild)(task_l), + ParallelExt( + n_jobs=workers, + backend=C.joblib_backend, + maxtasksperchild=C.maxtasksperchild, + )(task_l), ) ) @@ -589,7 +639,9 @@ def dataset_processor(instruments_d, column_names, start_time, end_time, freq, i data = DiskDatasetCache.cache_to_origin_data(data, column_names) else: data = pd.DataFrame( - index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), + index=pd.MultiIndex.from_arrays( + [[], []], names=("instrument", "datetime") + ), columns=column_names, dtype=np.float32, ) @@ -597,7 +649,16 @@ def dataset_processor(instruments_d, column_names, start_time, end_time, freq, i return data @staticmethod - def inst_calculator(inst, start_time, end_time, freq, column_names, spans=None, g_config=None, inst_processors=[]): + def inst_calculator( + inst, + start_time, + end_time, + freq, + column_names, + spans=None, + g_config=None, + inst_processors=[], + ): """ Calculate the expressions for **one** instrument, return a df result. If the expression has been calculated before, load from cache. @@ -629,7 +690,9 @@ def inst_calculator(inst, start_time, end_time, freq, column_names, spans=None, for _processor in inst_processors: if _processor: - _processor_obj = init_instance_by_config(_processor, accept_types=InstProcessor) + _processor_obj = init_instance_by_config( + _processor, accept_types=InstProcessor + ) data = _processor_obj(data, instrument=inst) return data @@ -688,7 +751,9 @@ def __init__(self, backend={}) -> None: def _load_instruments(self, market, freq): return self.backend_obj(market=market, freq=freq).data - def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): + def list_instruments( + self, instruments, start_time=None, end_time=None, freq="day", as_list=False + ): market = instruments["market"] if market in H["i"]: _instruments = H["i"][market] @@ -704,19 +769,31 @@ def list_instruments(self, instruments, start_time=None, end_time=None, freq="da inst: list( filter( lambda x: x[0] <= x[1], - [(max(start_time, pd.Timestamp(x[0])), min(end_time, pd.Timestamp(x[1]))) for x in spans], + [ + ( + max(start_time, pd.Timestamp(x[0])), + min(end_time, pd.Timestamp(x[1])), + ) + for x in spans + ], ) ) for inst, spans in _instruments.items() } - _instruments_filtered = {key: value for key, value in _instruments_filtered.items() if value} + _instruments_filtered = { + key: value for key, value in _instruments_filtered.items() if value + } # filter filter_pipe = instruments["filter_pipe"] for filter_config in filter_pipe: from . import filter as F # pylint: disable=C0415 - filter_t = getattr(F, filter_config["filter_type"]).from_config(filter_config) - _instruments_filtered = filter_t(_instruments_filtered, start_time, end_time, freq) + filter_t = getattr(F, filter_config["filter_type"]).from_config( + filter_config + ) + _instruments_filtered = filter_t( + _instruments_filtered, start_time, end_time, freq + ) # as list if as_list: return list(_instruments_filtered) @@ -738,14 +815,18 @@ def feature(self, instrument, field, start_index, end_index, freq): # validate field = str(field)[1:] instrument = code_to_fname(instrument) - return self.backend_obj(instrument=instrument, field=field, freq=freq)[start_index : end_index + 1] + return self.backend_obj(instrument=instrument, field=field, freq=freq)[ + start_index : end_index + 1 + ] class LocalPITProvider(PITProvider): # TODO: Add PIT backend file storage # NOTE: This class is not multi-threading-safe!!!! - def period_feature(self, instrument, field, start_index, end_index, cur_time, period=None): + def period_feature( + self, instrument, field, start_index, end_index, cur_time, period=None + ): if not isinstance(cur_time, pd.Timestamp): raise ValueError( f"Expected pd.Timestamp for `cur_time`, got '{cur_time}'. Advices: you can't query PIT data directly(e.g. '$$roewa_q'), you must use `P` operator to convert data to each day (e.g. 'P($$roewa_q)')" @@ -779,8 +860,12 @@ def period_feature(self, instrument, field, start_index, end_index, cur_time, pe if not field.endswith("_q") and not field.endswith("_a"): raise ValueError("period field must ends with '_q' or '_a'") quarterly = field.endswith("_q") - index_path = C.dpm.get_data_uri() / "financial" / instrument.lower() / f"{field}.index" - data_path = C.dpm.get_data_uri() / "financial" / instrument.lower() / f"{field}.data" + index_path = ( + C.dpm.get_data_uri() / "financial" / instrument.lower() / f"{field}.index" + ) + data_path = ( + C.dpm.get_data_uri() / "financial" / instrument.lower() / f"{field}.data" + ) if not (index_path.exists() and data_path.exists()): raise FileNotFoundError("No file is found.") # NOTE: The most significant performance loss is here. @@ -793,7 +878,9 @@ def period_feature(self, instrument, field, start_index, end_index, cur_time, pe data = np.fromfile(data_path, dtype=DATA_RECORDS) # find all revision periods before `cur_time` - cur_time_int = int(cur_time.year) * 10000 + int(cur_time.month) * 100 + int(cur_time.day) + cur_time_int = ( + int(cur_time.year) * 10000 + int(cur_time.month) * 100 + int(cur_time.day) + ) loc = np.searchsorted(data["date"], cur_time_int, side="right") if loc <= 0: return pd.Series(dtype=C.pit_record_type["value"]) @@ -807,12 +894,19 @@ def period_feature(self, instrument, field, start_index, end_index, cur_time, pe else: period_list = [period] else: - period_list = period_list[max(0, len(period_list) + start_index - 1) : len(period_list) + end_index] + period_list = period_list[ + max(0, len(period_list) + start_index - 1) : len(period_list) + + end_index + ] value = np.full((len(period_list),), np.nan, dtype=VALUE_DTYPE) for i, p in enumerate(period_list): # last_period_index = self.period_index[field].get(period) # For acceleration value[i], now_period_index = read_period_data( - index_path, data_path, p, cur_time_int, quarterly # , last_period_index # For acceleration + index_path, + data_path, + p, + cur_time_int, + quarterly, # , last_period_index # For acceleration ) # self.period_index[field].update({period: now_period_index}) # For acceleration # NOTE: the index is period_list; So it may result in unexpected values(e.g. nan) @@ -849,7 +943,9 @@ def expression(self, instrument, field, start_time=None, end_time=None, freq="da # - Index-based expression: this may save a lot of memory because the datetime index is not saved on the disk # - Data with datetime index expression: this will make it more convenient to integrating with some existing databases if self.time2idx: - _, _, start_index, end_index = Cal.locate_index(start_time, end_time, freq=freq, future=False) + _, _, start_index, end_index = Cal.locate_index( + start_time, end_time, freq=freq, future=False + ) lft_etd, rght_etd = expression.get_extended_window_size() query_start, query_end = max(0, start_index - lft_etd), end_index + rght_etd else: @@ -916,18 +1012,28 @@ def dataset( cal = Cal.calendar(start_time, end_time, freq) if len(cal) == 0: return pd.DataFrame( - index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), columns=column_names + index=pd.MultiIndex.from_arrays( + [[], []], names=("instrument", "datetime") + ), + columns=column_names, ) start_time = cal[0] end_time = cal[-1] data = self.dataset_processor( - instruments_d, column_names, start_time, end_time, freq, inst_processors=inst_processors + instruments_d, + column_names, + start_time, + end_time, + freq, + inst_processors=inst_processors, ) return data @staticmethod - def multi_cache_walker(instruments, fields, start_time=None, end_time=None, freq="day"): + def multi_cache_walker( + instruments, fields, start_time=None, end_time=None, freq="day" + ): """ This method is used to prepare the expression cache for the client. Then the client will load the data from expression cache by itself. @@ -942,8 +1048,14 @@ def multi_cache_walker(instruments, fields, start_time=None, end_time=None, freq end_time = cal[-1] workers = max(min(C.kernels, len(instruments_d)), 1) - ParallelExt(n_jobs=workers, backend=C.joblib_backend, maxtasksperchild=C.maxtasksperchild)( - delayed(LocalDatasetProvider.cache_walker)(inst, start_time, end_time, freq, column_names) + ParallelExt( + n_jobs=workers, + backend=C.joblib_backend, + maxtasksperchild=C.maxtasksperchild, + )( + delayed(LocalDatasetProvider.cache_walker)( + inst, start_time, end_time, freq, column_names + ) for inst in instruments_d ) @@ -974,9 +1086,16 @@ def set_conn(self, conn): def calendar(self, start_time=None, end_time=None, freq="day", future=False): self.conn.send_request( request_type="calendar", - request_content={"start_time": str(start_time), "end_time": str(end_time), "freq": freq, "future": future}, + request_content={ + "start_time": str(start_time), + "end_time": str(end_time), + "freq": freq, + "future": future, + }, msg_queue=self.queue, - msg_proc_func=lambda response_content: [pd.Timestamp(c) for c in response_content], + msg_proc_func=lambda response_content: [ + pd.Timestamp(c) for c in response_content + ], ) result = self.queue.get(timeout=C["timeout"]) return result @@ -995,11 +1114,14 @@ def __init__(self): def set_conn(self, conn): self.conn = conn - def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): + def list_instruments( + self, instruments, start_time=None, end_time=None, freq="day", as_list=False + ): def inst_msg_proc_func(response_content): if isinstance(response_content, dict): instrument = { - i: [(pd.Timestamp(s), pd.Timestamp(e)) for s, e in t] for i, t in response_content.items() + i: [(pd.Timestamp(s), pd.Timestamp(e)) for s, e in t] + for i, t in response_content.items() } else: instrument = response_content @@ -1083,13 +1205,22 @@ def dataset( cal = Cal.calendar(start_time, end_time, freq) if len(cal) == 0: return pd.DataFrame( - index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), + index=pd.MultiIndex.from_arrays( + [[], []], names=("instrument", "datetime") + ), columns=column_names, ) start_time = cal[0] end_time = cal[-1] - data = self.dataset_processor(instruments_d, column_names, start_time, end_time, freq, inst_processors) + data = self.dataset_processor( + instruments_d, + column_names, + start_time, + end_time, + freq, + inst_processors, + ) if return_uri: return data, feature_uri else: @@ -1127,14 +1258,20 @@ def dataset( get_module_logger("data").debug("get result") try: # pre-mound nfs, used for demo - mnt_feature_uri = C.dpm.get_data_uri(freq).joinpath(C.dataset_cache_dir_name, feature_uri) - df = DiskDatasetCache.read_data_from_cache(mnt_feature_uri, start_time, end_time, fields) + mnt_feature_uri = C.dpm.get_data_uri(freq).joinpath( + C.dataset_cache_dir_name, feature_uri + ) + df = DiskDatasetCache.read_data_from_cache( + mnt_feature_uri, start_time, end_time, fields + ) get_module_logger("data").debug("finish slicing data") if return_uri: return df, feature_uri return df except AttributeError as attribute_e: - raise IOError("Unable to fetch instruments from remote server!") from attribute_e + raise IOError( + "Unable to fetch instruments from remote server!" + ) from attribute_e class BaseProvider: @@ -1148,7 +1285,9 @@ class BaseProvider: def calendar(self, start_time=None, end_time=None, freq="day", future=False): return Cal.calendar(start_time, end_time, freq, future=future) - def instruments(self, market="all", filter_pipe=None, start_time=None, end_time=None): + def instruments( + self, market="all", filter_pipe=None, start_time=None, end_time=None + ): if start_time is not None or end_time is not None: get_module_logger("Provider").warning( "The instruments corresponds to a stock pool. " @@ -1156,7 +1295,9 @@ def instruments(self, market="all", filter_pipe=None, start_time=None, end_time= ) return InstrumentProvider.instruments(market, filter_pipe) - def list_instruments(self, instruments, start_time=None, end_time=None, freq="day", as_list=False): + def list_instruments( + self, instruments, start_time=None, end_time=None, freq="day", as_list=False + ): return Inst.list_instruments(instruments, start_time, end_time, freq, as_list) def features( @@ -1184,10 +1325,23 @@ def features( fields = list(fields) # In case of tuple. try: return DatasetD.dataset( - instruments, fields, start_time, end_time, freq, disk_cache, inst_processors=inst_processors + instruments, + fields, + start_time, + end_time, + freq, + disk_cache, + inst_processors=inst_processors, ) except TypeError: - return DatasetD.dataset(instruments, fields, start_time, end_time, freq, inst_processors=inst_processors) + return DatasetD.dataset( + instruments, + fields, + start_time, + end_time, + freq, + inst_processors=inst_processors, + ) class LocalProvider(BaseProvider): @@ -1206,7 +1360,9 @@ def _uri(self, type, **kwargs): elif type == "feature": return DatasetD._uri(**kwargs) - def features_uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1): + def features_uri( + self, instruments, fields, start_time, end_time, freq, disk_cache=1 + ): """features_uri Return the uri of the generated cache of features/dataset @@ -1218,7 +1374,9 @@ def features_uri(self, instruments, fields, start_time, end_time, freq, disk_cac :param end_time: :param freq: """ - return DatasetD._dataset_uri(instruments, fields, start_time, end_time, freq, disk_cache) + return DatasetD._dataset_uri( + instruments, fields, start_time, end_time, freq, disk_cache + ) class ClientProvider(BaseProvider): @@ -1296,7 +1454,9 @@ def register_all_wrappers(C): _calendar_provider = init_instance_by_config(C.calendar_provider, module) if getattr(C, "calendar_cache", None) is not None: - _calendar_provider = init_instance_by_config(C.calendar_cache, module, provide=_calendar_provider) + _calendar_provider = init_instance_by_config( + C.calendar_cache, module, provide=_calendar_provider + ) register_wrapper(Cal, _calendar_provider, "qlib.data") logger.debug(f"registering Cal {C.calendar_provider}-{C.calendar_cache}") @@ -1318,13 +1478,19 @@ def register_all_wrappers(C): # This provider is unnecessary in client provider _eprovider = init_instance_by_config(C.expression_provider, module) if getattr(C, "expression_cache", None) is not None: - _eprovider = init_instance_by_config(C.expression_cache, module, provider=_eprovider) + _eprovider = init_instance_by_config( + C.expression_cache, module, provider=_eprovider + ) register_wrapper(ExpressionD, _eprovider, "qlib.data") - logger.debug(f"registering ExpressionD {C.expression_provider}-{C.expression_cache}") + logger.debug( + f"registering ExpressionD {C.expression_provider}-{C.expression_cache}" + ) _dprovider = init_instance_by_config(C.dataset_provider, module) if getattr(C, "dataset_cache", None) is not None: - _dprovider = init_instance_by_config(C.dataset_cache, module, provider=_dprovider) + _dprovider = init_instance_by_config( + C.dataset_cache, module, provider=_dprovider + ) register_wrapper(DatasetD, _dprovider, "qlib.data") logger.debug(f"registering DatasetD {C.dataset_provider}-{C.dataset_cache}") diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index a6cace3730..293899a370 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -116,7 +116,9 @@ def __init__( 'outsample': ("2017-01-01", "2020-08-01",), } """ - self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler) + self.handler: DataHandler = init_instance_by_config( + handler, accept_types=DataHandler + ) self.segments = segments.copy() self.fetch_kwargs = copy(fetch_kwargs) super().__init__(**kwargs) @@ -240,8 +242,12 @@ def prepare( return self._prepare_seg(self.segments[segments], **seg_kwargs) # 1.2) fetch multiple splits like ["train", "valid"] ["train", "valid", "test"] - if isinstance(segments, (list, tuple)) and all(seg in self.segments for seg in segments): - return [self._prepare_seg(self.segments[seg], **seg_kwargs) for seg in segments] + if isinstance(segments, (list, tuple)) and all( + seg in self.segments for seg in segments + ): + return [ + self._prepare_seg(self.segments[seg], **seg_kwargs) for seg in segments + ] # 2) Use pass it directly to prepare a single seg return self._prepare_seg(segments, **seg_kwargs) @@ -393,7 +399,9 @@ def __init__( if dtype is not None: kwargs["dtype"] = dtype - self.data_arr = np.array(**kwargs) # Get index from numpy.array will much faster than DataFrame.values! + self.data_arr = np.array( + **kwargs + ) # Get index from numpy.array will much faster than DataFrame.values! # NOTE: # - append last line with full NaN for better performance in `__getitem__` # - Keep the same dtype will result in a better performance @@ -402,7 +410,9 @@ def __init__( np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), axis=0, ) - self.nan_idx = len(self.data_arr) - 1 # The last line is all NaN; setting it to -1 can cause bug #1716 + self.nan_idx = ( + len(self.data_arr) - 1 + ) # The last line is all NaN; setting it to -1 can cause bug #1716 # the data type will be changed # The index of usable data is between start_idx and end_idx @@ -425,7 +435,9 @@ def __init__( self.idx_map, self.idx_df, self.data_index, start, end ) - self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance + self.idx_arr = np.array( + self.idx_df.values, dtype=np.float64 + ) # for better performance del self.data # save memory @staticmethod @@ -440,9 +452,13 @@ def slice_idx_map_and_data_index( len(idx_map) == data_index.shape[0] ) # make sure idx_map and data_index is same so index of idx_map can be used on data_index - start_row_idx, end_row_idx = idx_df.index.slice_locs(start=time_to_slc_point(start), end=time_to_slc_point(end)) + start_row_idx, end_row_idx = idx_df.index.slice_locs( + start=time_to_slc_point(start), end=time_to_slc_point(end) + ) - time_flter_idx = (idx_map[:, 0] < end_row_idx) & (idx_map[:, 0] >= start_row_idx) + time_flter_idx = (idx_map[:, 0] < end_row_idx) & ( + idx_map[:, 0] >= start_row_idx + ) return idx_map[time_flter_idx], data_index[time_flter_idx] @staticmethod @@ -479,7 +495,9 @@ def get_index(self): Get the pandas index of the data, it will be useful in following scenarios - Special sampler will be used (e.g. user want to sample day by day) """ - return self.data_index.swaplevel() # to align the order of multiple index of original data received by __init__ + return ( + self.data_index.swaplevel() + ) # to align the order of multiple index of original data received by __init__ def config(self, **kwargs): # Config the attributes @@ -552,7 +570,9 @@ def _get_indices(self, row: int, col: int) -> np.array: indices = self.idx_arr[max(row - self.step_len + 1, 0) : row + 1, col] if len(indices) < self.step_len: - indices = np.concatenate([np.full((self.step_len - len(indices),), np.nan), indices]) + indices = np.concatenate( + [np.full((self.step_len - len(indices),), np.nan), indices] + ) if self.fillna_type == "ffill": indices = np_ffill(indices) @@ -580,7 +600,9 @@ def _get_row_col(self, idx) -> Tuple[int]: if isinstance(idx, (int, np.integer)): real_idx = idx if 0 <= real_idx < len(self.idx_map): - i, j = self.idx_map[real_idx] # TODO: The performance of this line is not good + i, j = self.idx_map[ + real_idx + ] # TODO: The performance of this line is not good else: raise KeyError(f"{real_idx} is out of [0, {len(self.idx_map)})") elif isinstance(idx, tuple): @@ -623,9 +645,13 @@ def __getitem__(self, idx: Union[int, Tuple[object, str], List[int]]): # 1) for better performance, use the last nan line for padding the lost date # 2) In case of precision problems. We use np.float64. # TODO: I'm not sure if whether np.float64 will result in # precision problems. It will not cause any problems in my tests at least - indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(int) + indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype( + int + ) - if (np.diff(indices) == 1).all(): # slicing instead of indexing for speeding up. + if ( + np.diff(indices) == 1 + ).all(): # slicing instead of indexing for speeding up. data = self.data_arr[indices[0] : indices[-1] + 1] else: data = self.data_arr[indices] @@ -660,7 +686,9 @@ class TSDatasetH(DatasetH): DEFAULT_STEP_LEN = 30 - def __init__(self, step_len=DEFAULT_STEP_LEN, flt_col: Optional[str] = None, **kwargs): + def __init__( + self, step_len=DEFAULT_STEP_LEN, flt_col: Optional[str] = None, **kwargs + ): self.step_len = step_len self.flt_col = flt_col super().__init__(**kwargs) @@ -673,7 +701,11 @@ def config(self, **kwargs): def setup_data(self, **kwargs): super().setup_data(**kwargs) # make sure the calendar is updated to latest when loading data from new config - cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique() + cal = ( + self.handler.fetch(col_set=self.handler.CS_RAW) + .index.get_level_values("datetime") + .unique() + ) self.cal = sorted(cal) @staticmethod diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index b6ee957947..0dac1e7077 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -130,12 +130,18 @@ def __init__( """ # Setup data loader - assert data_loader is not None # to make start_time end_time could have None default value + assert ( + data_loader is not None + ) # to make start_time end_time could have None default value # what data source to load data self.data_loader = init_instance_by_config( data_loader, - None if (isinstance(data_loader, dict) and "module_path" in data_loader) else data_loader_module, + ( + None + if (isinstance(data_loader, dict) and "module_path" in data_loader) + else data_loader_module + ), accept_types=DataLoader, ) @@ -192,7 +198,9 @@ def setup_data(self, enable_cache: bool = False): # _data may be with multiple column index level. The outer level indicates the feature set name with TimeInspector.logt("Loading data"): # make sure the fetch method is based on an index-sorted pd.DataFrame - self._data = lazy_sort_index(self.data_loader.load(self.instruments, self.start_time, self.end_time)) + self._data = lazy_sort_index( + self.data_loader.load(self.instruments, self.start_time, self.end_time) + ) # TODO: cache def fetch( @@ -298,25 +306,42 @@ def _fetch_data( try: selector = slice(*selector) except ValueError: - get_module_logger("DataHandlerLP").info(f"Fail to converting to query to slice. It will used directly") + get_module_logger("DataHandlerLP").info( + f"Fail to converting to query to slice. It will used directly" + ) if isinstance(data_storage, pd.DataFrame): data_df = data_storage if proc_func is not None: # FIXME: fetching by time first will be more friendly to `proc_func` # Copy in case of `proc_func` changing the data inplace.... - data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy()) + data_df = proc_func( + fetch_df_by_index( + data_df, selector, level, fetch_orig=self.fetch_orig + ).copy() + ) data_df = fetch_df_by_col(data_df, col_set) else: # Fetch column first will be more friendly to SepDataFrame data_df = fetch_df_by_col(data_df, col_set) - data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig) + data_df = fetch_df_by_index( + data_df, selector, level, fetch_orig=self.fetch_orig + ) elif isinstance(data_storage, BaseHandlerStorage): if proc_func is not None: - raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}") - data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig) + raise ValueError( + f"proc_func is not supported by the storage {type(data_storage)}" + ) + data_df = data_storage.fetch( + selector=selector, + level=level, + col_set=col_set, + fetch_orig=self.fetch_orig, + ) else: - raise TypeError(f"data_storage should be pd.DataFrame|HashingStockStorage, not {type(data_storage)}") + raise TypeError( + f"data_storage should be pd.DataFrame|HashingStockStorage, not {type(data_storage)}" + ) if squeeze: # squeeze columns @@ -344,7 +369,9 @@ def get_cols(self, col_set=DataHandlerABC.CS_ALL) -> list: df = fetch_df_by_col(df, col_set) return df.columns.to_list() - def get_range_selector(self, cur_date: Union[pd.Timestamp, str], periods: int) -> slice: + def get_range_selector( + self, cur_date: Union[pd.Timestamp, str], periods: int + ) -> slice: """ get range selector by number of periods @@ -420,7 +447,11 @@ class DataHandlerLP(DataHandler): _learn: pd.DataFrame # data for learning models # map data_key to attribute name - ATTR_MAP = {DataHandler.DK_R: "_data", DataHandler.DK_I: "_infer", DataHandler.DK_L: "_learn"} + ATTR_MAP = { + DataHandler.DK_R: "_data", + DataHandler.DK_I: "_infer", + DataHandler.DK_L: "_learn", + } # process type PTYPE_I = "independent" @@ -499,7 +530,11 @@ def __init__( getattr(self, pname).append( init_instance_by_config( proc, - None if (isinstance(proc, dict) and "module_path" in proc) else processor_module, + ( + None + if (isinstance(proc, dict) and "module_path" in proc) + else processor_module + ), accept_types=processor_module.Processor, ) ) @@ -529,11 +564,16 @@ def fit_process_data(self): @staticmethod def _run_proc_l( - df: pd.DataFrame, proc_l: List[processor_module.Processor], with_fit: bool, check_for_infer: bool + df: pd.DataFrame, + proc_l: List[processor_module.Processor], + with_fit: bool, + check_for_infer: bool, ) -> pd.DataFrame: for proc in proc_l: if check_for_infer and not proc.is_for_infer(): - raise TypeError("Only processors usable for inference can be used in `infer_processors` ") + raise TypeError( + "Only processors usable for inference can be used in `infer_processors` " + ) with TimeInspector.logt(f"{proc.__class__.__name__}"): if with_fit: proc.fit(df) @@ -578,18 +618,26 @@ def process_data(self, with_fit: bool = False): # shared data processors # 1) assign _shared_df = self._data - if not self._is_proc_readonly(self.shared_processors): # avoid modifying the original data + if not self._is_proc_readonly( + self.shared_processors + ): # avoid modifying the original data _shared_df = _shared_df.copy() # 2) process - _shared_df = self._run_proc_l(_shared_df, self.shared_processors, with_fit=with_fit, check_for_infer=True) + _shared_df = self._run_proc_l( + _shared_df, self.shared_processors, with_fit=with_fit, check_for_infer=True + ) # data for inference # 1) assign _infer_df = _shared_df - if not self._is_proc_readonly(self.infer_processors): # avoid modifying the original data + if not self._is_proc_readonly( + self.infer_processors + ): # avoid modifying the original data _infer_df = _infer_df.copy() # 2) process - _infer_df = self._run_proc_l(_infer_df, self.infer_processors, with_fit=with_fit, check_for_infer=True) + _infer_df = self._run_proc_l( + _infer_df, self.infer_processors, with_fit=with_fit, check_for_infer=True + ) self._infer = _infer_df @@ -602,10 +650,14 @@ def process_data(self, with_fit: bool = False): _learn_df = _infer_df else: raise NotImplementedError(f"This type of input is not supported") - if not self._is_proc_readonly(self.learn_processors): # avoid modifying the original data + if not self._is_proc_readonly( + self.learn_processors + ): # avoid modifying the original data _learn_df = _learn_df.copy() # 2) process - _learn_df = self._run_proc_l(_learn_df, self.learn_processors, with_fit=with_fit, check_for_infer=False) + _learn_df = self._run_proc_l( + _learn_df, self.learn_processors, with_fit=with_fit, check_for_infer=False + ) self._learn = _learn_df @@ -627,7 +679,9 @@ def config(self, processor_kwargs: dict = None, **kwargs): processor.config(**processor_kwargs) # init type - IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor + IT_FIT_SEQ = ( + "fit_seq" # the input of `fit` will be the output of the previous processor + ) IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df IT_LS = "load_state" # The state of the object has been load by pickle @@ -663,7 +717,9 @@ def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs): # TODO: Be able to cache handler data. Save the memory for data processing - def _get_df_by_key(self, data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I) -> pd.DataFrame: + def _get_df_by_key( + self, data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I + ) -> pd.DataFrame: if data_key == self.DK_R and self.drop_raw: raise AttributeError( "DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data" @@ -710,7 +766,9 @@ def fetch( proc_func=proc_func, ) - def get_cols(self, col_set=DataHandler.CS_ALL, data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I) -> list: + def get_cols( + self, col_set=DataHandler.CS_ALL, data_key: DATA_KEY_TYPE = DataHandlerABC.DK_I + ) -> list: """ get the column names diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index d283cb4f67..926b6e4144 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -88,7 +88,10 @@ def __init__(self, config: Union[list, tuple, dict]): self.is_group = isinstance(config, dict) if self.is_group: - self.fields = {grp: self._parse_fields_info(fields_info) for grp, fields_info in config.items()} + self.fields = { + grp: self._parse_fields_info(fields_info) + for grp, fields_info in config.items() + } else: self.fields = self._parse_fields_info(config) @@ -139,7 +142,9 @@ def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame if self.is_group: df = pd.concat( { - grp: self.load_group_df(instruments, exprs, names, start_time, end_time, grp) + grp: self.load_group_df( + instruments, exprs, names, start_time, end_time, grp + ) for grp, (exprs, names) in self.fields.items() }, axis=1, @@ -214,16 +219,29 @@ def load_group_df( if isinstance(instruments, str): instruments = D.instruments(instruments, filter_pipe=self.filter_pipe) elif self.filter_pipe is not None: - warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list") + warnings.warn( + "`filter_pipe` is not None, but it will not be used with `instruments` as list" + ) freq = self.freq[gp_name] if isinstance(self.freq, dict) else self.freq inst_processors = ( - self.inst_processors if isinstance(self.inst_processors, list) else self.inst_processors.get(gp_name, []) + self.inst_processors + if isinstance(self.inst_processors, list) + else self.inst_processors.get(gp_name, []) + ) + df = D.features( + instruments, + exprs, + start_time, + end_time, + freq=freq, + inst_processors=inst_processors, ) - df = D.features(instruments, exprs, start_time, end_time, freq=freq, inst_processors=inst_processors) df.columns = names if self.swap_level: - df = df.swaplevel().sort_index() # NOTE: if swaplevel, return + df = ( + df.swaplevel().sort_index() + ) # NOTE: if swaplevel, return return df @@ -273,7 +291,10 @@ def _maybe_load_raw_data(self): return if isinstance(self._config, dict): self._data = pd.concat( - {fields_group: load_dataset(path_or_obj) for fields_group, path_or_obj in self._config.items()}, + { + fields_group: load_dataset(path_or_obj) + for fields_group, path_or_obj in self._config.items() + }, axis=1, join=self.join, ) @@ -322,7 +343,8 @@ def __init__(self, dataloader_l: List[Dict], join="left") -> None: """ super().__init__() self.data_loader_l = [ - (dl if isinstance(dl, DataLoader) else init_instance_by_config(dl)) for dl in dataloader_l + (dl if isinstance(dl, DataLoader) else init_instance_by_config(dl)) + for dl in dataloader_l ] self.join = join @@ -335,15 +357,25 @@ def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame warnings.warn( "If the value of `instruments` cannot be processed, it will set instruments to None to get all the data." ) - df_current = dl.load(instruments=None, start_time=start_time, end_time=end_time) + df_current = dl.load( + instruments=None, start_time=start_time, end_time=end_time + ) if df_full is None: df_full = df_current else: current_columns = df_current.columns.tolist() full_columns = df_full.columns.tolist() - columns_to_drop = [col for col in current_columns if col in full_columns] + columns_to_drop = [ + col for col in current_columns if col in full_columns + ] df_full.drop(columns=columns_to_drop, inplace=True) - df_full = pd.merge(df_full, df_current, left_index=True, right_index=True, how=self.join) + df_full = pd.merge( + df_full, + df_current, + left_index=True, + right_index=True, + how=self.join, + ) return df_full.sort_index(axis=1) @@ -388,10 +420,13 @@ def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False if is_group: self.handlers = { - grp: init_instance_by_config(config, accept_types=DataHandler) for grp, config in handler_config.items() + grp: init_instance_by_config(config, accept_types=DataHandler) + for grp, config in handler_config.items() } else: - self.handlers = init_instance_by_config(handler_config, accept_types=DataHandler) + self.handlers = init_instance_by_config( + handler_config, accept_types=DataHandler + ) self.is_group = is_group self.fetch_kwargs = {"col_set": DataHandler.CS_RAW} @@ -399,16 +434,26 @@ def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame: if instruments is not None: - get_module_logger(self.__class__.__name__).warning(f"instruments[{instruments}] is ignored") + get_module_logger(self.__class__.__name__).warning( + f"instruments[{instruments}] is ignored" + ) if self.is_group: df = pd.concat( { - grp: dh.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs) + grp: dh.fetch( + selector=slice(start_time, end_time), + level="datetime", + **self.fetch_kwargs, + ) for grp, dh in self.handlers.items() }, axis=1, ) else: - df = self.handlers.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs) + df = self.handlers.fetch( + selector=slice(start_time, end_time), + level="datetime", + **self.fetch_kwargs, + ) return df diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index d05dbe381c..01c43e8117 100644 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -134,7 +134,9 @@ def __init__(self, fields_group="feature", col_list=[]): def __call__(self, df): cols = get_group_columns(df, self.fields_group) all_cols = df.columns - diff_cols = np.setdiff1d(all_cols.get_level_values(-1), cols.get_level_values(-1)) + diff_cols = np.setdiff1d( + all_cols.get_level_values(-1), cols.get_level_values(-1) + ) self.col_list = np.union1d(diff_cols, self.col_list) mask = df.columns.get_level_values(-1).isin(self.col_list) return df.loc[:, mask] @@ -166,7 +168,9 @@ def replace_inf(data): def process_inf(df): for col in df.columns: # FIXME: Such behavior is very weird - df[col] = df[col].replace([np.inf, -np.inf], df[col][~np.isinf(df[col])].mean()) + df[col] = df[col].replace( + [np.inf, -np.inf], df[col][~np.isinf(df[col])].mean() + ) return df data = datetime_groupby_apply(data, process_inf) @@ -202,7 +206,9 @@ def __init__(self, fit_start_time, fit_end_time, fields_group=None): self.fields_group = fields_group def fit(self, df: pd.DataFrame = None): - df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime") + df = fetch_df_by_index( + df, slice(self.fit_start_time, self.fit_end_time), level="datetime" + ) cols = get_group_columns(df, self.fields_group) self.min_val = np.nanmin(df[cols].values, axis=0) self.max_val = np.nanmax(df[cols].values, axis=0) @@ -236,7 +242,9 @@ def __init__(self, fit_start_time, fit_end_time, fields_group=None): self.fields_group = fields_group def fit(self, df: pd.DataFrame = None): - df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime") + df = fetch_df_by_index( + df, slice(self.fit_start_time, self.fit_end_time), level="datetime" + ) cols = get_group_columns(df, self.fields_group) self.mean_train = np.nanmean(df[cols].values, axis=0) self.std_train = np.nanstd(df[cols].values, axis=0) @@ -270,7 +278,9 @@ class RobustZScoreNorm(Processor): https://en.wikipedia.org/wiki/Median_absolute_deviation. """ - def __init__(self, fit_start_time, fit_end_time, fields_group=None, clip_outlier=True): + def __init__( + self, fit_start_time, fit_end_time, fields_group=None, clip_outlier=True + ): # NOTE: correctly set the `fit_start_time` and `fit_end_time` is very important !!! # `fit_end_time` **must not** include any information from the test data!!! self.fit_start_time = fit_start_time @@ -279,7 +289,9 @@ def __init__(self, fit_start_time, fit_end_time, fields_group=None, clip_outlier self.clip_outlier = clip_outlier def fit(self, df: pd.DataFrame = None): - df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime") + df = fetch_df_by_index( + df, slice(self.fit_start_time, self.fit_end_time), level="datetime" + ) self.cols = get_group_columns(df, self.fields_group) X = df[self.cols].values self.mean_train = np.nanmedian(X, axis=0) @@ -319,7 +331,11 @@ def __call__(self, df): with pd.option_context("mode.chained_assignment", None): for g in self.fields_group: cols = get_group_columns(df, g) - df[cols] = df[cols].groupby("datetime", group_keys=False).apply(self.zscore_func) + df[cols] = ( + df[cols] + .groupby("datetime", group_keys=False) + .apply(self.zscore_func) + ) return df @@ -367,7 +383,11 @@ def __init__(self, fields_group=None): def __call__(self, df): cols = get_group_columns(df, self.fields_group) - df[cols] = df[cols].groupby("datetime", group_keys=False).apply(lambda x: x.fillna(x.mean())) + df[cols] = ( + df[cols] + .groupby("datetime", group_keys=False) + .apply(lambda x: x.fillna(x.mean())) + ) return df diff --git a/qlib/data/dataset/storage.py b/qlib/data/dataset/storage.py index ca3325a28c..10ccff23d9 100644 --- a/qlib/data/dataset/storage.py +++ b/qlib/data/dataset/storage.py @@ -77,7 +77,9 @@ def fetch( try: selector = slice(*selector) except ValueError: - get_module_logger("DataHandlerLP").info(f"Fail to converting to query to slice. It will used directly") + get_module_logger("DataHandlerLP").info( + f"Fail to converting to query to slice. It will used directly" + ) data_df = self.df data_df = fetch_df_by_col(data_df, col_set) @@ -150,8 +152,12 @@ def _fetch_hash_df_by_stock(self, selector, level): elif isinstance(selector, (list, str)): stock_selector = selector - if not isinstance(stock_selector, (list, str)) and stock_selector != slice(None): - raise TypeError(f"stock selector must be type str|list, or slice(None), rather than {stock_selector}") + if not isinstance(stock_selector, (list, str)) and stock_selector != slice( + None + ): + raise TypeError( + f"stock selector must be type str|list, or slice(None), rather than {stock_selector}" + ) if stock_selector == slice(None): return self.hash_df, time_selector @@ -172,18 +178,29 @@ def fetch( col_set: Union[str, List[str]] = DataHandler.CS_ALL, fetch_orig: bool = True, ) -> pd.DataFrame: - fetch_stock_df_list, time_selector = self._fetch_hash_df_by_stock(selector=selector, level=level) + fetch_stock_df_list, time_selector = self._fetch_hash_df_by_stock( + selector=selector, level=level + ) fetch_stock_df_list = list(fetch_stock_df_list.values()) for _index, stock_df in enumerate(fetch_stock_df_list): fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set) fetch_index_df = fetch_df_by_index( - df=fetch_col_df, selector=time_selector, level="datetime", fetch_orig=fetch_orig + df=fetch_col_df, + selector=time_selector, + level="datetime", + fetch_orig=fetch_orig, ) fetch_stock_df_list[_index] = fetch_index_df if len(fetch_stock_df_list) == 0: - index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument") + index_names = ( + ("instrument", "datetime") + if self.stock_level == 0 + else ("datetime", "instrument") + ) return pd.DataFrame( - index=pd.MultiIndex.from_arrays([[], []], names=index_names), columns=self.columns, dtype=np.float32 + index=pd.MultiIndex.from_arrays([[], []], names=index_names), + columns=self.columns, + dtype=np.float32, ) elif len(fetch_stock_df_list) == 1: return fetch_stock_df_list[0] diff --git a/qlib/data/dataset/utils.py b/qlib/data/dataset/utils.py index 688cde99af..750329281c 100644 --- a/qlib/data/dataset/utils.py +++ b/qlib/data/dataset/utils.py @@ -89,7 +89,9 @@ def fetch_df_by_col(df: pd.DataFrame, col_set: Union[str, List[str]]) -> pd.Data return df.loc(axis=1)[col_set] -def convert_index_format(df: Union[pd.DataFrame, pd.Series], level: str = "datetime") -> Union[pd.DataFrame, pd.Series]: +def convert_index_format( + df: Union[pd.DataFrame, pd.Series], level: str = "datetime" +) -> Union[pd.DataFrame, pd.Series]: """ Convert the format of df.MultiIndex according to the following rules: - If `level` is the first level of df.MultiIndex, do nothing diff --git a/qlib/data/filter.py b/qlib/data/filter.py index 5057e20a4b..b7e5978165 100644 --- a/qlib/data/filter.py +++ b/qlib/data/filter.py @@ -34,7 +34,9 @@ def from_config(config): config : dict dict of config parameters. """ - raise NotImplementedError("Subclass of BaseDFilter must reimplement `from_config` method") + raise NotImplementedError( + "Subclass of BaseDFilter must reimplement `from_config` method" + ) @abstractmethod def to_config(self): @@ -45,7 +47,9 @@ def to_config(self): dict return the dict of config parameters. """ - raise NotImplementedError("Subclass of BaseDFilter must reimplement `to_config` method") + raise NotImplementedError( + "Subclass of BaseDFilter must reimplement `to_config` method" + ) class SeriesDFilter(BaseDFilter): @@ -123,7 +127,9 @@ def _toSeries(self, time_range, target_timestamp): timestamp_series = pd.Series(timestamp_series) # Fill the date within target_timestamp with TRUE for start, end in target_timestamp: - timestamp_series[Cal.calendar(start_time=start, end_time=end, freq=self.filter_freq)] = True + timestamp_series[ + Cal.calendar(start_time=start, end_time=end, freq=self.filter_freq) + ] = True return timestamp_series def _filterSeries(self, timestamp_series, filter_series): @@ -142,7 +148,9 @@ def _filterSeries(self, timestamp_series, filter_series): the series of bool value indicating whether the date satisfies the filter condition and exists in target timestamp. """ fstart, fend = list(filter_series.keys())[0], list(filter_series.keys())[-1] - filter_series = filter_series.astype("bool") # Make sure the filter_series is boolean + filter_series = filter_series.astype( + "bool" + ) # Make sure the filter_series is boolean timestamp_series[fstart:fend] = timestamp_series[fstart:fend] & filter_series return timestamp_series @@ -211,7 +219,9 @@ def _getFilterSeries(self, instruments, fstart, fend): pd.Dataframe a series of {pd.Timestamp => bool}. """ - raise NotImplementedError("Subclass of SeriesDFilter must reimplement `getFilterSeries` method") + raise NotImplementedError( + "Subclass of SeriesDFilter must reimplement `getFilterSeries` method" + ) def filter_main(self, instruments, start_time=None, end_time=None): """Implement this method to filter the instruments. @@ -234,13 +244,21 @@ def filter_main(self, instruments, start_time=None, end_time=None): start_time = pd.Timestamp(start_time or lbound) end_time = pd.Timestamp(end_time or ubound) _instruments_filtered = {} - _all_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq=self.filter_freq) + _all_calendar = Cal.calendar( + start_time=start_time, end_time=end_time, freq=self.filter_freq + ) _filter_calendar = Cal.calendar( - start_time=self.filter_start_time and max(self.filter_start_time, _all_calendar[0]) or _all_calendar[0], - end_time=self.filter_end_time and min(self.filter_end_time, _all_calendar[-1]) or _all_calendar[-1], + start_time=self.filter_start_time + and max(self.filter_start_time, _all_calendar[0]) + or _all_calendar[0], + end_time=self.filter_end_time + and min(self.filter_end_time, _all_calendar[-1]) + or _all_calendar[-1], freq=self.filter_freq, ) - _all_filter_series = self._getFilterSeries(instruments, _filter_calendar[0], _filter_calendar[-1]) + _all_filter_series = self._getFilterSeries( + instruments, _filter_calendar[0], _filter_calendar[-1] + ) for inst, timestamp in instruments.items(): # Construct a whole map of date _timestamp_series = self._toSeries(_all_calendar, timestamp) @@ -249,9 +267,13 @@ def filter_main(self, instruments, start_time=None, end_time=None): _filter_series = _all_filter_series[inst] else: if self.keep: - _filter_series = pd.Series({timestamp: True for timestamp in _filter_calendar}) + _filter_series = pd.Series( + {timestamp: True for timestamp in _filter_calendar} + ) else: - _filter_series = pd.Series({timestamp: False for timestamp in _filter_calendar}) + _filter_series = pd.Series( + {timestamp: False for timestamp in _filter_calendar} + ) # Calculate bool value within the range of filter _timestamp_series = self._filterSeries(_timestamp_series, _filter_series) # Reform the map to (start_timestamp, end_timestamp) format @@ -283,12 +305,18 @@ def __init__(self, name_rule_re, fstart_time=None, fend_time=None): def _getFilterSeries(self, instruments, fstart, fend): all_filter_series = {} - filter_calendar = Cal.calendar(start_time=fstart, end_time=fend, freq=self.filter_freq) + filter_calendar = Cal.calendar( + start_time=fstart, end_time=fend, freq=self.filter_freq + ) for inst, timestamp in instruments.items(): if re.match(self.name_rule_re, inst): - _filter_series = pd.Series({timestamp: True for timestamp in filter_calendar}) + _filter_series = pd.Series( + {timestamp: True for timestamp in filter_calendar} + ) else: - _filter_series = pd.Series({timestamp: False for timestamp in filter_calendar}) + _filter_series = pd.Series( + {timestamp: False for timestamp in filter_calendar} + ) all_filter_series[inst] = _filter_series return all_filter_series @@ -304,8 +332,16 @@ def to_config(self): return { "filter_type": "NameDFilter", "name_rule_re": self.name_rule_re, - "filter_start_time": str(self.filter_start_time) if self.filter_start_time else self.filter_start_time, - "filter_end_time": str(self.filter_end_time) if self.filter_end_time else self.filter_end_time, + "filter_start_time": ( + str(self.filter_start_time) + if self.filter_start_time + else self.filter_start_time + ), + "filter_end_time": ( + str(self.filter_end_time) + if self.filter_end_time + else self.filter_end_time + ), } @@ -351,7 +387,9 @@ def _getFilterSeries(self, instruments, fstart, fend): ) except TypeError: # use LocalDatasetProvider - _features = DatasetD.dataset(instruments, [self.rule_expression], fstart, fend, freq=self.filter_freq) + _features = DatasetD.dataset( + instruments, [self.rule_expression], fstart, fend, freq=self.filter_freq + ) rule_expression_field_name = list(_features.keys())[0] all_filter_series = _features[rule_expression_field_name] return all_filter_series @@ -369,7 +407,15 @@ def to_config(self): return { "filter_type": "ExpressionDFilter", "rule_expression": self.rule_expression, - "filter_start_time": str(self.filter_start_time) if self.filter_start_time else self.filter_start_time, - "filter_end_time": str(self.filter_end_time) if self.filter_end_time else self.filter_end_time, + "filter_start_time": ( + str(self.filter_start_time) + if self.filter_start_time + else self.filter_start_time + ), + "filter_end_time": ( + str(self.filter_end_time) + if self.filter_end_time + else self.filter_end_time + ), "keep": self.keep, } diff --git a/qlib/data/ops.py b/qlib/data/ops.py index d9a2ffbb3e..d8ab648dc7 100644 --- a/qlib/data/ops.py +++ b/qlib/data/ops.py @@ -23,8 +23,12 @@ ) raise except ValueError: - print("!!!!!!!! A error occurs when importing operators implemented based on Cython.!!!!!!!!") - print("!!!!!!!! They will be disabled. Please Upgrade your numpy to enable them !!!!!!!!") + print( + "!!!!!!!! A error occurs when importing operators implemented based on Cython.!!!!!!!!" + ) + print( + "!!!!!!!! They will be disabled. Please Upgrade your numpy to enable them !!!!!!!!" + ) # We catch this error because some platform can't upgrade there package (e.g. Kaggle) # https://www.kaggle.com/general/293387 # https://www.kaggle.com/product-feedback/98562 @@ -203,7 +207,9 @@ def __init__(self, feature, instrument): self.instrument = instrument def __str__(self): - return "{}({},{})".format(type(self).__name__, self.feature, self.instrument.lower()) + return "{}({},{})".format( + type(self).__name__, self.feature, self.instrument.lower() + ) def _load_internal(self, instrument, start_index, end_index, *args): return self.feature.load(self.instrument, start_index, end_index, *args) @@ -249,7 +255,9 @@ def __init__(self, feature_left, feature_right): self.feature_right = feature_right def __str__(self): - return "{}({},{})".format(type(self).__name__, self.feature_left, self.feature_right) + return "{}({},{})".format( + type(self).__name__, self.feature_left, self.feature_right + ) def get_longest_back_rolling(self): if isinstance(self.feature_left, (Expression,)): @@ -300,14 +308,22 @@ def __init__(self, feature_left, feature_right, func): def _load_internal(self, instrument, start_index, end_index, *args): assert any( - [isinstance(self.feature_left, (Expression,)), self.feature_right, Expression] + [ + isinstance(self.feature_left, (Expression,)), + self.feature_right, + Expression, + ] ), "at least one of two inputs is Expression instance" if isinstance(self.feature_left, (Expression,)): - series_left = self.feature_left.load(instrument, start_index, end_index, *args) + series_left = self.feature_left.load( + instrument, start_index, end_index, *args + ) else: series_left = self.feature_left # numeric value if isinstance(self.feature_right, (Expression,)): - series_right = self.feature_right.load(instrument, start_index, end_index, *args) + series_right = self.feature_right.load( + instrument, start_index, end_index, *args + ) else: series_right = self.feature_right check_length = isinstance(series_left, (np.ndarray, pd.Series)) and isinstance( @@ -655,19 +671,27 @@ def __init__(self, condition, feature_left, feature_right): self.feature_right = feature_right def __str__(self): - return "If({},{},{})".format(self.condition, self.feature_left, self.feature_right) + return "If({},{},{})".format( + self.condition, self.feature_left, self.feature_right + ) def _load_internal(self, instrument, start_index, end_index, *args): series_cond = self.condition.load(instrument, start_index, end_index, *args) if isinstance(self.feature_left, (Expression,)): - series_left = self.feature_left.load(instrument, start_index, end_index, *args) + series_left = self.feature_left.load( + instrument, start_index, end_index, *args + ) else: series_left = self.feature_left if isinstance(self.feature_right, (Expression,)): - series_right = self.feature_right.load(instrument, start_index, end_index, *args) + series_right = self.feature_right.load( + instrument, start_index, end_index, *args + ) else: series_right = self.feature_right - series = pd.Series(np.where(series_cond, series_left, series_right), index=series_cond.index) + series = pd.Series( + np.where(series_cond, series_left, series_right), index=series_cond.index + ) return series def get_longest_back_rolling(self): @@ -765,7 +789,9 @@ def get_extended_window_size(self): if self.N == 0: # FIXME: How to make this accurate and efficiently? Or should we # remove such support for N == 0? - get_module_logger(self.__class__.__name__).warning("The Rolling(ATTR, 0) will not be accurately calculated") + get_module_logger(self.__class__.__name__).warning( + "The Rolling(ATTR, 0) will not be accurately calculated" + ) return self.feature.get_extended_window_size() elif 0 < self.N < 1: lft_etd, rght_etd = self.feature.get_extended_window_size() @@ -815,7 +841,9 @@ def get_longest_back_rolling(self): def get_extended_window_size(self): if self.N == 0: - get_module_logger(self.__class__.__name__).warning("The Ref(ATTR, 0) will not be accurately calculated") + get_module_logger(self.__class__.__name__).warning( + "The Ref(ATTR, 0) will not be accurately calculated" + ) return self.feature.get_extended_window_size() else: lft_etd, rght_etd = self.feature.get_extended_window_size() @@ -922,7 +950,9 @@ class Skew(Rolling): def __init__(self, feature, N): if N != 0 and N < 3: - raise ValueError("The rolling window size of Skewness operation should >= 3") + raise ValueError( + "The rolling window size of Skewness operation should >= 3" + ) super(Skew, self).__init__(feature, N, "skew") @@ -944,7 +974,9 @@ class Kurt(Rolling): def __init__(self, feature, N): if N != 0 and N < 4: - raise ValueError("The rolling window size of Kurtosis operation should >= 5") + raise ValueError( + "The rolling window size of Kurtosis operation should >= 5" + ) super(Kurt, self).__init__(feature, N, "kurt") @@ -990,9 +1022,13 @@ def __init__(self, feature, N): def _load_internal(self, instrument, start_index, end_index, *args): series = self.feature.load(instrument, start_index, end_index, *args) if self.N == 0: - series = series.expanding(min_periods=1).apply(lambda x: x.argmax() + 1, raw=True) + series = series.expanding(min_periods=1).apply( + lambda x: x.argmax() + 1, raw=True + ) else: - series = series.rolling(self.N, min_periods=1).apply(lambda x: x.argmax() + 1, raw=True) + series = series.rolling(self.N, min_periods=1).apply( + lambda x: x.argmax() + 1, raw=True + ) return series @@ -1038,9 +1074,13 @@ def __init__(self, feature, N): def _load_internal(self, instrument, start_index, end_index, *args): series = self.feature.load(instrument, start_index, end_index, *args) if self.N == 0: - series = series.expanding(min_periods=1).apply(lambda x: x.argmin() + 1, raw=True) + series = series.expanding(min_periods=1).apply( + lambda x: x.argmin() + 1, raw=True + ) else: - series = series.rolling(self.N, min_periods=1).apply(lambda x: x.argmin() + 1, raw=True) + series = series.rolling(self.N, min_periods=1).apply( + lambda x: x.argmin() + 1, raw=True + ) return series @@ -1065,7 +1105,9 @@ def __init__(self, feature, N, qscore): self.qscore = qscore def __str__(self): - return "{}({},{},{})".format(type(self).__name__, self.feature, self.N, self.qscore) + return "{}({},{},{})".format( + type(self).__name__, self.feature, self.N, self.qscore + ) def _load_internal(self, instrument, start_index, end_index, *args): series = self.feature.load(instrument, start_index, end_index, *args) @@ -1153,7 +1195,11 @@ def __init__(self, feature, N): def _load_internal(self, instrument, start_index, end_index, *args): series = self.feature.load(instrument, start_index, end_index, *args) - rolling_or_expending = series.expanding(min_periods=1) if self.N == 0 else series.rolling(self.N, min_periods=1) + rolling_or_expending = ( + series.expanding(min_periods=1) + if self.N == 0 + else series.rolling(self.N, min_periods=1) + ) if hasattr(rolling_or_expending, "rank"): return rolling_or_expending.rank(pct=True) @@ -1278,8 +1324,12 @@ def _load_internal(self, instrument, start_index, end_index, *args): if self.N == 0: series = pd.Series(expanding_rsquare(_series.values), index=_series.index) else: - series = pd.Series(rolling_rsquare(_series.values, self.N), index=_series.index) - series.loc[np.isclose(_series.rolling(self.N, min_periods=1).std(), 0, atol=2e-05)] = np.nan + series = pd.Series( + rolling_rsquare(_series.values, self.N), index=_series.index + ) + series.loc[ + np.isclose(_series.rolling(self.N, min_periods=1).std(), 0, atol=2e-05) + ] = np.nan return series @@ -1342,7 +1392,9 @@ def weighted_mean(x): if self.N == 0: series = series.expanding(min_periods=1).apply(weighted_mean, raw=True) else: - series = series.rolling(self.N, min_periods=1).apply(weighted_mean, raw=True) + series = series.rolling(self.N, min_periods=1).apply( + weighted_mean, raw=True + ) return series @@ -1410,7 +1462,9 @@ def __init__(self, feature_left, feature_right, N, func): self.func = func def __str__(self): - return "{}({},{},{})".format(type(self).__name__, self.feature_left, self.feature_right, self.N) + return "{}({},{},{})".format( + type(self).__name__, self.feature_left, self.feature_right, self.N + ) def _load_internal(self, instrument, start_index, end_index, *args): assert any( @@ -1418,18 +1472,26 @@ def _load_internal(self, instrument, start_index, end_index, *args): ), "at least one of two inputs is Expression instance" if isinstance(self.feature_left, Expression): - series_left = self.feature_left.load(instrument, start_index, end_index, *args) + series_left = self.feature_left.load( + instrument, start_index, end_index, *args + ) else: series_left = self.feature_left # numeric value if isinstance(self.feature_right, Expression): - series_right = self.feature_right.load(instrument, start_index, end_index, *args) + series_right = self.feature_right.load( + instrument, start_index, end_index, *args + ) else: series_right = self.feature_right if self.N == 0: - series = getattr(series_left.expanding(min_periods=1), self.func)(series_right) + series = getattr(series_left.expanding(min_periods=1), self.func)( + series_right + ) else: - series = getattr(series_left.rolling(self.N, min_periods=1), self.func)(series_right) + series = getattr(series_left.rolling(self.N, min_periods=1), self.func)( + series_right + ) return series def get_longest_back_rolling(self): @@ -1486,14 +1548,20 @@ def __init__(self, feature_left, feature_right, N): super(Corr, self).__init__(feature_left, feature_right, N, "corr") def _load_internal(self, instrument, start_index, end_index, *args): - res: pd.Series = super(Corr, self)._load_internal(instrument, start_index, end_index, *args) + res: pd.Series = super(Corr, self)._load_internal( + instrument, start_index, end_index, *args + ) # NOTE: Load uses MemCache, so calling load again will not cause performance degradation series_left = self.feature_left.load(instrument, start_index, end_index, *args) - series_right = self.feature_right.load(instrument, start_index, end_index, *args) + series_right = self.feature_right.load( + instrument, start_index, end_index, *args + ) res.loc[ np.isclose(series_left.rolling(self.N, min_periods=1).std(), 0, atol=2e-05) - | np.isclose(series_right.rolling(self.N, min_periods=1).std(), 0, atol=2e-05) + | np.isclose( + series_right.rolling(self.N, min_periods=1).std(), 0, atol=2e-05 + ) ] = np.nan return res @@ -1650,11 +1718,17 @@ def register(self, ops_list: List[Union[Type[ExpressionOps], dict]]): _ops_class = _operator if not issubclass(_ops_class, (Expression,)): - raise TypeError("operator must be subclass of ExpressionOps, not {}".format(_ops_class)) + raise TypeError( + "operator must be subclass of ExpressionOps, not {}".format( + _ops_class + ) + ) if _ops_class.__name__ in self._ops: get_module_logger(self.__class__.__name__).warning( - "The custom operator [{}] will override the qlib default definition".format(_ops_class.__name__) + "The custom operator [{}] will override the qlib default definition".format( + _ops_class.__name__ + ) ) self._ops[_ops_class.__name__] = _ops_class diff --git a/qlib/data/pit.py b/qlib/data/pit.py index 33d5e0c5cc..3c775eb029 100644 --- a/qlib/data/pit.py +++ b/qlib/data/pit.py @@ -37,13 +37,20 @@ def _load_internal(self, instrument, start_index, end_index, freq): # The calculated value will always the last element, so the end_offset is zero. try: s = self._load_feature(instrument, -start_ws, 0, cur_time) - resample_data[cur_index - start_index] = s.iloc[-1] if len(s) > 0 else np.nan + resample_data[cur_index - start_index] = ( + s.iloc[-1] if len(s) > 0 else np.nan + ) except FileNotFoundError: - get_module_logger("base").warning(f"WARN: period data not found for {str(self)}") + get_module_logger("base").warning( + f"WARN: period data not found for {str(self)}" + ) return pd.Series(dtype="float32", name=str(self)) resample_series = pd.Series( - resample_data, index=pd.RangeIndex(start_index, end_index + 1), dtype="float32", name=str(self) + resample_data, + index=pd.RangeIndex(start_index, end_index + 1), + dtype="float32", + name=str(self), ) return resample_series @@ -68,4 +75,6 @@ def __str__(self): return f"{super().__str__()}[{self.period}]" def _load_feature(self, instrument, start_index, end_index, cur_time): - return self.feature.load(instrument, start_index, end_index, cur_time, self.period) + return self.feature.load( + instrument, start_index, end_index, cur_time, self.period + ) diff --git a/qlib/data/storage/__init__.py b/qlib/data/storage/__init__.py index a77fb0e3e4..5bbfe749f8 100644 --- a/qlib/data/storage/__init__.py +++ b/qlib/data/storage/__init__.py @@ -1,7 +1,21 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstVT, InstKT +from .storage import ( + CalendarStorage, + InstrumentStorage, + FeatureStorage, + CalVT, + InstVT, + InstKT, +) -__all__ = ["CalendarStorage", "InstrumentStorage", "FeatureStorage", "CalVT", "InstVT", "InstKT"] +__all__ = [ + "CalendarStorage", + "InstrumentStorage", + "FeatureStorage", + "CalVT", + "InstVT", + "InstKT", +] diff --git a/qlib/data/storage/file_storage.py b/qlib/data/storage/file_storage.py index 8a100a2d19..08c4fa0486 100644 --- a/qlib/data/storage/file_storage.py +++ b/qlib/data/storage/file_storage.py @@ -13,7 +13,14 @@ from qlib.config import C from qlib.data.cache import H from qlib.log import get_module_logger -from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT +from qlib.data.storage import ( + CalendarStorage, + InstrumentStorage, + FeatureStorage, + CalVT, + InstKT, + InstVT, +) logger = get_module_logger("file_storage") @@ -30,7 +37,11 @@ class FileStorageMixin: @property def provider_uri(self): - return C["provider_uri"] if getattr(self, "_provider_uri", None) is None else self._provider_uri + return ( + C["provider_uri"] + if getattr(self, "_provider_uri", None) is None + else self._provider_uri + ) @property def dpm(self): @@ -48,7 +59,12 @@ def support_freq(self) -> List[str]: if len(self.provider_uri) == 1 and C.DEFAULT_FREQ in self.provider_uri: freq_l = filter( lambda _freq: not _freq.endswith("_future"), - map(lambda x: x.stem, self.dpm.get_data_uri(C.DEFAULT_FREQ).joinpath("calendars").glob("*.txt")), + map( + lambda x: x.stem, + self.dpm.get_data_uri(C.DEFAULT_FREQ) + .joinpath("calendars") + .glob("*.txt"), + ), ) else: freq_l = self.provider_uri.keys() @@ -59,8 +75,12 @@ def support_freq(self) -> List[str]: @property def uri(self) -> Path: if self.freq not in self.support_freq: - raise ValueError(f"{self.storage_name}: {self.provider_uri} does not contain data for {self.freq}") - return self.dpm.get_data_uri(self.freq).joinpath(f"{self.storage_name}s", self.file_name) + raise ValueError( + f"{self.storage_name}: {self.provider_uri} does not contain data for {self.freq}" + ) + return self.dpm.get_data_uri(self.freq).joinpath( + f"{self.storage_name}s", self.file_name + ) def check(self): """check self.uri @@ -77,13 +97,21 @@ class FileCalendarStorage(FileStorageMixin, CalendarStorage): def __init__(self, freq: str, future: bool, provider_uri: dict = None, **kwargs): super(FileCalendarStorage, self).__init__(freq, future, **kwargs) self.future = future - self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri) + self._provider_uri = ( + None + if provider_uri is None + else C.DataPathManager.format_provider_uri(provider_uri) + ) self.enable_read_cache = True # TODO: make it configurable self.region = C["region"] @property def file_name(self) -> str: - return f"{self._freq_file}_future.txt" if self.future else f"{self._freq_file}.txt".lower() + return ( + f"{self._freq_file}_future.txt" + if self.future + else f"{self._freq_file}.txt".lower() + ) @property def _freq_file(self) -> str: @@ -98,7 +126,9 @@ def _freq_file(self) -> str: freq = Freq.get_recent_freq(freq, self.support_freq) if freq is None: - raise ValueError(f"can't find a freq from {self.support_freq} that can resample to {self.freq}!") + raise ValueError( + f"can't find a freq from {self.support_freq} that can resample to {self.freq}!" + ) self._freq_file_cache = freq return self._freq_file_cache @@ -125,7 +155,9 @@ def _write_calendar(self, values: Iterable[CalVT], mode: str = "wb"): @property def uri(self) -> Path: - return self.dpm.get_data_uri(self._freq_file).joinpath(f"{self.storage_name}s", self.file_name) + return self.dpm.get_data_uri(self._freq_file).joinpath( + f"{self.storage_name}s", self.file_name + ) @property def data(self) -> List[CalVT]: @@ -140,12 +172,17 @@ def data(self) -> List[CalVT]: _calendar = self._read_calendar() if Freq(self._freq_file) != Freq(self.freq): _calendar = resam_calendar( - np.array(list(map(pd.Timestamp, _calendar))), self._freq_file, self.freq, self.region + np.array(list(map(pd.Timestamp, _calendar))), + self._freq_file, + self.freq, + self.region, ) return _calendar def _get_storage_freq(self) -> List[str]: - return sorted(set(map(lambda x: x.stem.split("_")[0], self.uri.parent.glob("*.txt")))) + return sorted( + set(map(lambda x: x.stem.split("_")[0], self.uri.parent.glob("*.txt"))) + ) def extend(self, values: Iterable[CalVT]) -> None: self._write_calendar(values, mode="ab") @@ -170,7 +207,9 @@ def remove(self, value: CalVT) -> None: calendar = np.delete(calendar, index) self._write_calendar(values=calendar) - def __setitem__(self, i: Union[int, slice], values: Union[CalVT, Iterable[CalVT]]) -> None: + def __setitem__( + self, i: Union[int, slice], values: Union[CalVT, Iterable[CalVT]] + ) -> None: calendar = self._read_calendar() calendar[i] = values self._write_calendar(values=calendar) @@ -197,7 +236,11 @@ class FileInstrumentStorage(FileStorageMixin, InstrumentStorage): def __init__(self, market: str, freq: str, provider_uri: dict = None, **kwargs): super(FileInstrumentStorage, self).__init__(market, freq, **kwargs) - self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri) + self._provider_uri = ( + None + if provider_uri is None + else C.DataPathManager.format_provider_uri(provider_uri) + ) self.file_name = f"{market.lower()}.txt" def _read_instrument(self) -> Dict[InstKT, InstVT]: @@ -209,7 +252,11 @@ def _read_instrument(self) -> Dict[InstKT, InstVT]: self.uri, sep="\t", usecols=[0, 1, 2], - names=[self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD], + names=[ + self.SYMBOL_FIELD_NAME, + self.INSTRUMENT_START_FIELD, + self.INSTRUMENT_END_FIELD, + ], dtype={self.SYMBOL_FIELD_NAME: str}, parse_dates=[self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD], ) @@ -225,14 +272,21 @@ def _write_instrument(self, data: Dict[InstKT, InstVT] = None) -> None: res = [] for inst, v_list in data.items(): - _df = pd.DataFrame(v_list, columns=[self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD]) + _df = pd.DataFrame( + v_list, columns=[self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD] + ) _df[self.SYMBOL_FIELD_NAME] = inst res.append(_df) df = pd.concat(res, sort=False) - df.loc[:, [self.SYMBOL_FIELD_NAME, self.INSTRUMENT_START_FIELD, self.INSTRUMENT_END_FIELD]].to_csv( - self.uri, header=False, sep=self.INSTRUMENT_SEP, index=False - ) + df.loc[ + :, + [ + self.SYMBOL_FIELD_NAME, + self.INSTRUMENT_START_FIELD, + self.INSTRUMENT_END_FIELD, + ], + ].to_csv(self.uri, header=False, sep=self.INSTRUMENT_SEP, index=False) df.to_csv(self.uri, sep="\t", encoding="utf-8", header=False, index=False) def clear(self) -> None: @@ -283,9 +337,20 @@ def __len__(self) -> int: class FileFeatureStorage(FileStorageMixin, FeatureStorage): - def __init__(self, instrument: str, field: str, freq: str, provider_uri: dict = None, **kwargs): + def __init__( + self, + instrument: str, + field: str, + freq: str, + provider_uri: dict = None, + **kwargs, + ): super(FileFeatureStorage, self).__init__(instrument, field, freq, **kwargs) - self._provider_uri = None if provider_uri is None else C.DataPathManager.format_provider_uri(provider_uri) + self._provider_uri = ( + None + if provider_uri is None + else C.DataPathManager.format_provider_uri(provider_uri) + ) self.file_name = f"{instrument.lower()}/{field.lower()}.{freq.lower()}.bin" def clear(self): @@ -313,17 +378,25 @@ def write(self, data_array: Union[List, np.ndarray], index: int = None) -> None: # append index = 0 if index is None else index with self.uri.open("ab+") as fp: - np.hstack([[np.nan] * (index - self.end_index - 1), data_array]).astype(" Iterable[CalVT]: ValueError If the data(storage) does not exist, raise ValueError """ - raise NotImplementedError("Subclass of CalendarStorage must implement `data` method") + raise NotImplementedError( + "Subclass of CalendarStorage must implement `data` method" + ) def clear(self) -> None: - raise NotImplementedError("Subclass of CalendarStorage must implement `clear` method") + raise NotImplementedError( + "Subclass of CalendarStorage must implement `clear` method" + ) def extend(self, iterable: Iterable[CalVT]) -> None: - raise NotImplementedError("Subclass of CalendarStorage must implement `extend` method") + raise NotImplementedError( + "Subclass of CalendarStorage must implement `extend` method" + ) def index(self, value: CalVT) -> int: """ @@ -115,13 +121,19 @@ def index(self, value: CalVT) -> int: ValueError If the data(storage) does not exist, raise ValueError """ - raise NotImplementedError("Subclass of CalendarStorage must implement `index` method") + raise NotImplementedError( + "Subclass of CalendarStorage must implement `index` method" + ) def insert(self, index: int, value: CalVT) -> None: - raise NotImplementedError("Subclass of CalendarStorage must implement `insert` method") + raise NotImplementedError( + "Subclass of CalendarStorage must implement `insert` method" + ) def remove(self, value: CalVT) -> None: - raise NotImplementedError("Subclass of CalendarStorage must implement `remove` method") + raise NotImplementedError( + "Subclass of CalendarStorage must implement `remove` method" + ) @overload def __setitem__(self, i: int, value: CalVT) -> None: @@ -185,7 +197,9 @@ def __len__(self) -> int: If the data(storage) does not exist, raise ValueError """ - raise NotImplementedError("Subclass of CalendarStorage must implement `__len__` method") + raise NotImplementedError( + "Subclass of CalendarStorage must implement `__len__` method" + ) class InstrumentStorage(BaseStorage): @@ -203,10 +217,14 @@ def data(self) -> Dict[InstKT, InstVT]: ValueError If the data(storage) does not exist, raise ValueError """ - raise NotImplementedError("Subclass of InstrumentStorage must implement `data` method") + raise NotImplementedError( + "Subclass of InstrumentStorage must implement `data` method" + ) def clear(self) -> None: - raise NotImplementedError("Subclass of InstrumentStorage must implement `clear` method") + raise NotImplementedError( + "Subclass of InstrumentStorage must implement `clear` method" + ) def update(self, *args, **kwargs) -> None: """D.update([E, ]**F) -> None. Update D from mapping/iterable E and F. @@ -220,11 +238,15 @@ def update(self, *args, **kwargs) -> None: In either case, this is followed by: for k, v in F.items(): D[k] = v """ - raise NotImplementedError("Subclass of InstrumentStorage must implement `update` method") + raise NotImplementedError( + "Subclass of InstrumentStorage must implement `update` method" + ) def __setitem__(self, k: InstKT, v: InstVT) -> None: """Set self[key] to value.""" - raise NotImplementedError("Subclass of InstrumentStorage must implement `__setitem__` method") + raise NotImplementedError( + "Subclass of InstrumentStorage must implement `__setitem__` method" + ) def __delitem__(self, k: InstKT) -> None: """Delete self[key]. @@ -234,11 +256,15 @@ def __delitem__(self, k: InstKT) -> None: ValueError If the data(storage) does not exist, raise ValueError """ - raise NotImplementedError("Subclass of InstrumentStorage must implement `__delitem__` method") + raise NotImplementedError( + "Subclass of InstrumentStorage must implement `__delitem__` method" + ) def __getitem__(self, k: InstKT) -> InstVT: """x.__getitem__(k) <==> x[k]""" - raise NotImplementedError("Subclass of InstrumentStorage must implement `__getitem__` method") + raise NotImplementedError( + "Subclass of InstrumentStorage must implement `__getitem__` method" + ) def __len__(self) -> int: """ @@ -249,7 +275,9 @@ def __len__(self) -> int: If the data(storage) does not exist, raise ValueError """ - raise NotImplementedError("Subclass of InstrumentStorage must implement `__len__` method") + raise NotImplementedError( + "Subclass of InstrumentStorage must implement `__len__` method" + ) class FeatureStorage(BaseStorage): @@ -267,7 +295,9 @@ def data(self) -> pd.Series: ------ if data(storage) does not exist, return empty pd.Series: `return pd.Series(dtype=np.float32)` """ - raise NotImplementedError("Subclass of FeatureStorage must implement `data` method") + raise NotImplementedError( + "Subclass of FeatureStorage must implement `data` method" + ) @property def start_index(self) -> Union[int, None]: @@ -277,7 +307,9 @@ def start_index(self) -> Union[int, None]: ----- If the data(storage) does not exist, return None """ - raise NotImplementedError("Subclass of FeatureStorage must implement `start_index` method") + raise NotImplementedError( + "Subclass of FeatureStorage must implement `start_index` method" + ) @property def end_index(self) -> Union[int, None]: @@ -291,10 +323,14 @@ def end_index(self) -> Union[int, None]: If the data(storage) does not exist, return None """ - raise NotImplementedError("Subclass of FeatureStorage must implement `end_index` method") + raise NotImplementedError( + "Subclass of FeatureStorage must implement `end_index` method" + ) def clear(self) -> None: - raise NotImplementedError("Subclass of FeatureStorage must implement `clear` method") + raise NotImplementedError( + "Subclass of FeatureStorage must implement `clear` method" + ) def write(self, data_array: Union[List, np.ndarray, Tuple], index: int = None): """Write data_array to FeatureStorage starting from index. @@ -349,7 +385,9 @@ def write(self, data_array: Union[List, np.ndarray, Tuple], index: int = None): 9 8 """ - raise NotImplementedError("Subclass of FeatureStorage must implement `write` method") + raise NotImplementedError( + "Subclass of FeatureStorage must implement `write` method" + ) def rebase(self, start_index: int = None, end_index: int = None): """Rebase the start_index and end_index of the FeatureStorage. @@ -411,13 +449,17 @@ def rebase(self, start_index: int = None, end_index: int = None): storage_si = self.start_index storage_ei = self.end_index if storage_si is None or storage_ei is None: - raise ValueError("storage.start_index or storage.end_index is None, storage may not exist") + raise ValueError( + "storage.start_index or storage.end_index is None, storage may not exist" + ) start_index = storage_si if start_index is None else start_index end_index = storage_ei if end_index is None else end_index if start_index is None or end_index is None: - logger.warning("both start_index and end_index are None, or storage does not exist; rebase is ignored") + logger.warning( + "both start_index and end_index are None, or storage does not exist; rebase is ignored" + ) return if start_index < 0 or end_index < 0: @@ -491,4 +533,6 @@ def __len__(self) -> int: If the data(storage) does not exist, raise ValueError """ - raise NotImplementedError("Subclass of FeatureStorage must implement `__len__` method") + raise NotImplementedError( + "Subclass of FeatureStorage must implement `__len__` method" + ) diff --git a/qlib/log.py b/qlib/log.py index f7683d5116..e49a1c4196 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -215,7 +215,9 @@ def set_global_logger_level(level: int, return_orig_handler_level: bool = False) """ _handler_level_map = {} - qlib_logger = logging.root.manager.loggerDict.get("qlib", None) # pylint: disable=E1101 + qlib_logger = logging.root.manager.loggerDict.get( + "qlib", None + ) # pylint: disable=E1101 if qlib_logger is not None: for _handler in qlib_logger.handlers: _handler_level_map[_handler] = _handler.level diff --git a/qlib/model/ens/ensemble.py b/qlib/model/ens/ensemble.py index 1670a6538e..1e63e7f070 100644 --- a/qlib/model/ens/ensemble.py +++ b/qlib/model/ens/ensemble.py @@ -48,7 +48,9 @@ class SingleKeyEnsemble(Ensemble): dict: the readable dict. """ - def __call__(self, ensemble_dict: Union[dict, object], recursion: bool = True) -> object: + def __call__( + self, ensemble_dict: Union[dict, object], recursion: bool = True + ) -> object: if not isinstance(ensemble_dict, dict): return ensemble_dict if recursion: @@ -78,7 +80,9 @@ class RollingEnsemble(Ensemble): """ def __call__(self, ensemble_dict: dict) -> pd.DataFrame: - get_module_logger("RollingEnsemble").info(f"keys in group: {list(ensemble_dict.keys())}") + get_module_logger("RollingEnsemble").info( + f"keys in group: {list(ensemble_dict.keys())}" + ) artifact_list = list(ensemble_dict.values()) artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min()) artifact = pd.concat(artifact_list) @@ -121,12 +125,16 @@ def __call__(self, ensemble_dict: dict) -> pd.DataFrame: """ # need to flatten the nested dict ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE) - get_module_logger("AverageEnsemble").info(f"keys in group: {list(ensemble_dict.keys())}") + get_module_logger("AverageEnsemble").info( + f"keys in group: {list(ensemble_dict.keys())}" + ) values = list(ensemble_dict.values()) # NOTE: this may change the style underlying data!!!! # from pd.DataFrame to pd.Series results = pd.concat(values, axis=1) - results = results.groupby("datetime", group_keys=False).apply(lambda df: (df - df.mean()) / df.std()) + results = results.groupby("datetime", group_keys=False).apply( + lambda df: (df - df.mean()) / df.std() + ) results = results.mean(axis=1) results = results.sort_index() return results diff --git a/qlib/model/ens/group.py b/qlib/model/ens/group.py index ba6f9f8071..21ae7cfe5e 100644 --- a/qlib/model/ens/group.py +++ b/qlib/model/ens/group.py @@ -64,7 +64,9 @@ def reduce(self, *args, **kwargs) -> dict: else: raise NotImplementedError(f"Please specify valid `_ens_func`.") - def __call__(self, ungrouped_dict: dict, n_jobs: int = 1, verbose: int = 0, *args, **kwargs) -> dict: + def __call__( + self, ungrouped_dict: dict, n_jobs: int = 1, verbose: int = 0, *args, **kwargs + ) -> dict: """ Group the ungrouped_dict into different groups. diff --git a/qlib/model/interpret/base.py b/qlib/model/interpret/base.py index a490d77442..6d80f95460 100644 --- a/qlib/model/interpret/base.py +++ b/qlib/model/interpret/base.py @@ -39,7 +39,8 @@ def get_feature_importance(self, *args, **kwargs) -> pd.Series: https://lightgbm.readthedocs.io/en/latest/pythonapi/lightgbm.Booster.html?highlight=feature_importance#lightgbm.Booster.feature_importance """ return pd.Series( - self.model.feature_importance(*args, **kwargs), index=self.model.feature_name() + self.model.feature_importance(*args, **kwargs), + index=self.model.feature_name(), ).sort_values( # pylint: disable=E1101 ascending=False ) diff --git a/qlib/model/meta/dataset.py b/qlib/model/meta/dataset.py index 34a9b949b3..c8cbc3479f 100644 --- a/qlib/model/meta/dataset.py +++ b/qlib/model/meta/dataset.py @@ -36,7 +36,9 @@ def __init__(self, segments: Union[Dict[Text, Tuple], float], *args, **kwargs): super().__init__(*args, **kwargs) self.segments = segments - def prepare_tasks(self, segments: Union[List[Text], Text], *args, **kwargs) -> List[MetaTask]: + def prepare_tasks( + self, segments: Union[List[Text], Text], *args, **kwargs + ) -> List[MetaTask]: """ Prepare the data in each meta-task and ready for training. diff --git a/qlib/model/meta/task.py b/qlib/model/meta/task.py index a051acf146..82b9a77271 100644 --- a/qlib/model/meta/task.py +++ b/qlib/model/meta/task.py @@ -40,7 +40,9 @@ def __init__(self, task: dict, meta_info: object, mode: str = PROC_MODE_FULL): the input for meta model """ self.task = task - self.meta_info = meta_info # the original meta input information, it will be processed later + self.meta_info = ( + meta_info # the original meta input information, it will be processed later + ) self.mode = mode def get_dataset(self) -> Dataset: diff --git a/qlib/model/riskmodel/base.py b/qlib/model/riskmodel/base.py index 7afacfe8ff..b65e4e6857 100644 --- a/qlib/model/riskmodel/base.py +++ b/qlib/model/riskmodel/base.py @@ -19,7 +19,12 @@ class RiskModel(BaseModel): FILL_NAN = "fill" IGNORE_NAN = "ignore" - def __init__(self, nan_option: str = "ignore", assume_centered: bool = False, scale_return: bool = True): + def __init__( + self, + nan_option: str = "ignore", + assume_centered: bool = False, + scale_return: bool = True, + ): """ Args: nan_option (str): nan handling option (`ignore`/`mask`/`fill`). @@ -65,7 +70,9 @@ def predict( else: if isinstance(X.index, pd.MultiIndex): if isinstance(X, pd.DataFrame): - X = X.iloc[:, 0].unstack(level="instrument") # always use the first column + X = X.iloc[:, 0].unstack( + level="instrument" + ) # always use the first column else: X = X.unstack(level="instrument") else: @@ -88,10 +95,13 @@ def predict( # return decomposed components if needed if return_decomposed_components: assert ( - "return_decomposed_components" in inspect.getfullargspec(self._predict).args + "return_decomposed_components" + in inspect.getfullargspec(self._predict).args ), "This risk model does not support return decomposed components of the covariance matrix " - F, cov_b, var_u = self._predict(X, return_decomposed_components=True) # pylint: disable=E1123 + F, cov_b, var_u = self._predict( + X, return_decomposed_components=True + ) # pylint: disable=E1123 return F, cov_b, var_u # estimate covariance diff --git a/qlib/model/riskmodel/poet.py b/qlib/model/riskmodel/poet.py index 42388d84cb..0f2ab53fec 100644 --- a/qlib/model/riskmodel/poet.py +++ b/qlib/model/riskmodel/poet.py @@ -16,7 +16,13 @@ class POETCovEstimator(RiskModel): THRESH_HARD = "hard" THRESH_SCAD = "scad" - def __init__(self, num_factors: int = 0, thresh: float = 1.0, thresh_method: str = "soft", **kwargs): + def __init__( + self, + num_factors: int = 0, + thresh: float = 1.0, + thresh_method: str = "soft", + **kwargs, + ): """ Args: num_factors (int): number of factors (if set to zero, no factor model will be used). @@ -71,8 +77,18 @@ def _predict(self, X: np.ndarray) -> np.ndarray: res = (res + np.abs(res)) / 2 M = np.sign(R) * res else: - M1 = (np.abs(R) < 2 * lamb) * np.sign(R) * (np.abs(R) - lamb) * (np.abs(R) > lamb) - M2 = (np.abs(R) < 3.7 * lamb) * (np.abs(R) >= 2 * lamb) * (2.7 * R - 3.7 * np.sign(R) * lamb) / 1.7 + M1 = ( + (np.abs(R) < 2 * lamb) + * np.sign(R) + * (np.abs(R) - lamb) + * (np.abs(R) > lamb) + ) + M2 = ( + (np.abs(R) < 3.7 * lamb) + * (np.abs(R) >= 2 * lamb) + * (2.7 * R - 3.7 * np.sign(R) * lamb) + / 1.7 + ) M3 = (np.abs(R) >= 3.7 * lamb) * R M = M1 + M2 + M3 diff --git a/qlib/model/riskmodel/shrink.py b/qlib/model/riskmodel/shrink.py index c3c0e48ef8..c305bd555a 100644 --- a/qlib/model/riskmodel/shrink.py +++ b/qlib/model/riskmodel/shrink.py @@ -51,7 +51,12 @@ class ShrinkCovEstimator(RiskModel): TGT_CONST_CORR = "const_corr" TGT_SINGLE_FACTOR = "single_factor" - def __init__(self, alpha: Union[str, float] = 0.0, target: Union[str, np.ndarray] = "const_var", **kwargs): + def __init__( + self, + alpha: Union[str, float] = 0.0, + target: Union[str, np.ndarray] = "const_var", + **kwargs, + ): """ Args: alpha (str or float): shrinking parameter or estimator (`lw`/`oas`) @@ -62,7 +67,10 @@ def __init__(self, alpha: Union[str, float] = 0.0, target: Union[str, np.ndarray # alpha if isinstance(alpha, str): - assert alpha in [self.SHR_LW, self.SHR_OAS], f"shrinking method `{alpha}` is not supported" + assert alpha in [ + self.SHR_LW, + self.SHR_OAS, + ], f"shrinking method `{alpha}` is not supported" elif isinstance(alpha, (float, np.floating)): assert 0 <= alpha <= 1, "alpha should be between [0, 1]" else: @@ -81,7 +89,9 @@ def __init__(self, alpha: Union[str, float] = 0.0, target: Union[str, np.ndarray else: raise TypeError("invalid argument type for `target`") if alpha == self.SHR_OAS and target != self.TGT_CONST_VAR: - raise NotImplementedError("currently `oas` can only support `const_var` as target") + raise NotImplementedError( + "currently `oas` can only support `const_var` as target" + ) self.target = target def _predict(self, X: np.ndarray) -> np.ndarray: @@ -138,7 +148,9 @@ def _get_shrink_target_const_corr(self, X: np.ndarray, S: np.ndarray) -> np.ndar np.fill_diagonal(F, var) return F - def _get_shrink_target_single_factor(self, X: np.ndarray, S: np.ndarray) -> np.ndarray: + def _get_shrink_target_single_factor( + self, X: np.ndarray, S: np.ndarray + ) -> np.ndarray: """get shrinking target with single factor model""" X_mkt = np.nanmean(X, axis=1) cov_mkt = np.asarray(X.T.dot(X_mkt) / len(X)) @@ -164,7 +176,9 @@ def _get_shrink_param(self, X: np.ndarray, S: np.ndarray, F: np.ndarray) -> floa return self._get_shrink_param_lw_single_factor(X, S, F) return self.alpha - def _get_shrink_param_oas(self, X: np.ndarray, S: np.ndarray, F: np.ndarray) -> float: + def _get_shrink_param_oas( + self, X: np.ndarray, S: np.ndarray, F: np.ndarray + ) -> float: """Oracle Approximating Shrinkage Estimator This method uses the following formula to estimate the `alpha` @@ -185,7 +199,9 @@ def _get_shrink_param_oas(self, X: np.ndarray, S: np.ndarray, F: np.ndarray) -> return alpha - def _get_shrink_param_lw_const_var(self, X: np.ndarray, S: np.ndarray, F: np.ndarray) -> float: + def _get_shrink_param_lw_const_var( + self, X: np.ndarray, S: np.ndarray, F: np.ndarray + ) -> float: """Ledoit-Wolf Shrinkage Estimator (Constant Variance) This method shrinks the covariance matrix towards the constand variance target. @@ -202,7 +218,9 @@ def _get_shrink_param_lw_const_var(self, X: np.ndarray, S: np.ndarray, F: np.nda return alpha - def _get_shrink_param_lw_const_corr(self, X: np.ndarray, S: np.ndarray, F: np.ndarray) -> float: + def _get_shrink_param_lw_const_corr( + self, X: np.ndarray, S: np.ndarray, F: np.ndarray + ) -> float: """Ledoit-Wolf Shrinkage Estimator (Constant Correlation) This method shrinks the covariance matrix towards the constand correlation target. @@ -219,7 +237,9 @@ def _get_shrink_param_lw_const_corr(self, X: np.ndarray, S: np.ndarray, F: np.nd theta_mat = (X**3).T.dot(X) / t - var[:, None] * S np.fill_diagonal(theta_mat, 0) - rho = np.sum(np.diag(phi_mat)) + r_bar * np.sum(np.outer(1 / sqrt_var, sqrt_var) * theta_mat) + rho = np.sum(np.diag(phi_mat)) + r_bar * np.sum( + np.outer(1 / sqrt_var, sqrt_var) * theta_mat + ) gamma = np.linalg.norm(S - F, "fro") ** 2 @@ -228,7 +248,9 @@ def _get_shrink_param_lw_const_corr(self, X: np.ndarray, S: np.ndarray, F: np.nd return alpha - def _get_shrink_param_lw_single_factor(self, X: np.ndarray, S: np.ndarray, F: np.ndarray) -> float: + def _get_shrink_param_lw_single_factor( + self, X: np.ndarray, S: np.ndarray, F: np.ndarray + ) -> float: """Ledoit-Wolf Shrinkage Estimator (Single Factor Model) This method shrinks the covariance matrix towards the single factor model target. @@ -245,9 +267,15 @@ def _get_shrink_param_lw_single_factor(self, X: np.ndarray, S: np.ndarray, F: np rdiag = np.sum(y**2) / t - np.sum(np.diag(S) ** 2) z = X * X_mkt[:, None] v1 = y.T.dot(z) / t - cov_mkt[:, None] * S - roff1 = np.sum(v1 * cov_mkt[:, None].T) / var_mkt - np.sum(np.diag(v1) * cov_mkt) / var_mkt + roff1 = ( + np.sum(v1 * cov_mkt[:, None].T) / var_mkt + - np.sum(np.diag(v1) * cov_mkt) / var_mkt + ) v3 = z.T.dot(z) / t - var_mkt * S - roff3 = np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2 + roff3 = ( + np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 + - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2 + ) roff = 2 * roff1 - roff3 rho = rdiag + roff diff --git a/qlib/model/riskmodel/structured.py b/qlib/model/riskmodel/structured.py index 71e442536d..9bc067366e 100644 --- a/qlib/model/riskmodel/structured.py +++ b/qlib/model/riskmodel/structured.py @@ -50,9 +50,9 @@ def __init__(self, factor_model: str = "pca", num_factors: int = 10, **kwargs): kwargs: see `RiskModel` for more information """ if "nan_option" in kwargs: - assert kwargs["nan_option"] in [self.DEFAULT_NAN_OPTION], "nan_option={} is not supported".format( - kwargs["nan_option"] - ) + assert kwargs["nan_option"] in [ + self.DEFAULT_NAN_OPTION + ], "nan_option={} is not supported".format(kwargs["nan_option"]) else: kwargs["nan_option"] = self.DEFAULT_NAN_OPTION @@ -66,7 +66,9 @@ def __init__(self, factor_model: str = "pca", num_factors: int = 10, **kwargs): self.num_factors = num_factors - def _predict(self, X: np.ndarray, return_decomposed_components=False) -> Union[np.ndarray, tuple]: + def _predict( + self, X: np.ndarray, return_decomposed_components=False + ) -> Union[np.ndarray, tuple]: """ covariance estimation implementation diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index ce204420f8..444ee40ee4 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -43,7 +43,9 @@ def _exe_task(task_config: dict): rec = R.get_recorder() # model & dataset initialization model: Model = init_instance_by_config(task_config["model"], accept_types=Model) - dataset: Dataset = init_instance_by_config(task_config["dataset"], accept_types=Dataset) + dataset: Dataset = init_instance_by_config( + task_config["dataset"], accept_types=Dataset + ) reweighter: Reweighter = task_config.get("reweighter", None) # model training auto_filter_kwargs(model.fit)(dataset, reweighter=reweighter) @@ -71,7 +73,9 @@ def _exe_task(task_config: dict): r.generate() -def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder: +def begin_task_train( + task_config: dict, experiment_name: str, recorder_name: str = None +) -> Recorder: """ Begin task training to start a recorder and save the task config. @@ -99,13 +103,17 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder: Returns: Recorder: the model recorder """ - with R.start(experiment_name=experiment_name, recorder_id=rec.info["id"], resume=True): + with R.start( + experiment_name=experiment_name, recorder_id=rec.info["id"], resume=True + ): task_config = R.load_object("task") _exe_task(task_config) return rec -def task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder: +def task_train( + task_config: dict, experiment_name: str, recorder_name: str = None +) -> Recorder: """ Task based training, will be divided into two steps. @@ -241,7 +249,11 @@ def __init__( self._call_in_subproc = call_in_subproc def train( - self, tasks: list, train_func: Optional[Callable] = None, experiment_name: Optional[str] = None, **kwargs + self, + tasks: list, + train_func: Optional[Callable] = None, + experiment_name: Optional[str] = None, + **kwargs, ) -> List[Recorder]: """ Given a list of `tasks` and return a list of trained Recorder. The order can be guaranteed. @@ -266,9 +278,13 @@ def train( recs = [] for task in tqdm(tasks, desc="train tasks"): if self._call_in_subproc: - get_module_logger("TrainerR").info("running models in sub process (for forcing release memroy).") + get_module_logger("TrainerR").info( + "running models in sub process (for forcing release memroy)." + ) train_func = call_in_subproc(train_func, C) - rec = train_func(task, experiment_name, recorder_name=self.default_rec_name, **kwargs) + rec = train_func( + task, experiment_name, recorder_name=self.default_rec_name, **kwargs + ) rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN}) recs.append(rec) return recs @@ -296,7 +312,11 @@ class DelayTrainerR(TrainerR): """ def __init__( - self, experiment_name: str = None, train_func=begin_task_train, end_train_func=end_task_train, **kwargs + self, + experiment_name: str = None, + train_func=begin_task_train, + end_train_func=end_task_train, + **kwargs, ): """ Init TrainerRM. @@ -310,7 +330,9 @@ def __init__( self.end_train_func = end_train_func self.delay = True - def end_train(self, models, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]: + def end_train( + self, models, end_train_func=None, experiment_name: str = None, **kwargs + ) -> List[Recorder]: """ Given a list of Recorder and return a list of trained Recorder. This class will finish real data loading and model fitting. @@ -521,7 +543,9 @@ def __init__( self.delay = True self.skip_run_task = skip_run_task - def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]: + def train( + self, tasks: list, train_func=None, experiment_name: str = None, **kwargs + ) -> List[Recorder]: """ Same as `train` of TrainerRM, after_status will be STATUS_PART_DONE. @@ -549,7 +573,9 @@ def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwa self.skip_run_task = _skip_run_task return res - def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]: + def end_train( + self, recs, end_train_func=None, experiment_name: str = None, **kwargs + ) -> List[Recorder]: """ Given a list of Recorder and return a list of trained Recorder. This class will finish real data loading and model fitting. diff --git a/qlib/rl/__init__.py b/qlib/rl/__init__.py index a12afc3996..9684cc9d52 100644 --- a/qlib/rl/__init__.py +++ b/qlib/rl/__init__.py @@ -5,4 +5,11 @@ from .reward import Reward, RewardCombination from .simulator import Simulator -__all__ = ["Interpreter", "StateInterpreter", "ActionInterpreter", "Reward", "RewardCombination", "Simulator"] +__all__ = [ + "Interpreter", + "StateInterpreter", + "ActionInterpreter", + "Reward", + "RewardCombination", + "Simulator", +] diff --git a/qlib/rl/contrib/backtest.py b/qlib/rl/contrib/backtest.py index 60602c10d3..af70026381 100644 --- a/qlib/rl/contrib/backtest.py +++ b/qlib/rl/contrib/backtest.py @@ -38,7 +38,11 @@ def _get_multi_level_executor_config( "kwargs": { "time_per_step": data_granularity, "verbose": False, - "trade_type": SimulatorExecutor.TT_PARAL if cash_limit is not None else SimulatorExecutor.TT_SERIAL, + "trade_type": ( + SimulatorExecutor.TT_PARAL + if cash_limit is not None + else SimulatorExecutor.TT_SERIAL + ), "generate_report": generate_report, "track_data": True, }, @@ -80,7 +84,9 @@ def _convert_indicator_to_dataframe(indicator: dict) -> Optional[pd.DataFrame]: if not record_list: return None - records: pd.DataFrame = pd.concat(record_list, 0).reset_index().rename(columns={"index": "instrument"}) + records: pd.DataFrame = ( + pd.concat(record_list, 0).reset_index().rename(columns={"index": "instrument"}) + ) records = records.set_index(["instrument", "datetime"]) return records @@ -110,11 +116,17 @@ def _generate_report( indicator_his[key].append(indicator_obj.order_indicator_his) report = {} - decision_details = pd.concat([getattr(d, "details") for d in decisions if hasattr(d, "details")]) + decision_details = pd.concat( + [getattr(d, "details") for d in decisions if hasattr(d, "details")] + ) for key in indicator_dict: cur_dict = pd.concat(indicator_dict[key]) - cur_his = pd.concat([_convert_indicator_to_dataframe(his) for his in indicator_his[key]]) - cur_details = decision_details[decision_details.freq == key].set_index(["instrument", "datetime"]) + cur_his = pd.concat( + [_convert_indicator_to_dataframe(his) for his in indicator_his[key]] + ) + cur_details = decision_details[decision_details.freq == key].set_index( + ["instrument", "datetime"] + ) if len(cur_details) > 0: cur_details.pop("freq") cur_his = cur_his.join(cur_details, how="outer") @@ -163,8 +175,12 @@ def single_with_simulator( decisions = [] for _, row in orders.iterrows(): date = pd.Timestamp(row["datetime"]) - start_time = pd.Timestamp(backtest_config["start_time"]).replace(year=date.year, month=date.month, day=date.day) - end_time = pd.Timestamp(backtest_config["end_time"]).replace(year=date.year, month=date.month, day=date.day) + start_time = pd.Timestamp(backtest_config["start_time"]).replace( + year=date.year, month=date.month, day=date.day + ) + end_time = pd.Timestamp(backtest_config["end_time"]).replace( + year=date.year, month=date.month, day=date.day + ) order = Order( stock_id=row["instrument"], amount=row["amount"], @@ -200,12 +216,16 @@ def single_with_simulator( decisions += simulator.decisions indicator_1day_objs = [report["indicator_dict"]["1day"][1] for report in reports] - indicator_info = {k: v for obj in indicator_1day_objs for k, v in obj.order_indicator_his.items()} + indicator_info = { + k: v for obj in indicator_1day_objs for k, v in obj.order_indicator_his.items() + } records = _convert_indicator_to_dataframe(indicator_info) assert records is None or not np.isnan(records["ffr"]).any() if generate_report: - _report = _generate_report(decisions, [report["indicator"] for report in reports]) + _report = _generate_report( + decisions, [report["indicator"] for report in reports] + ) if split == "stock": stock_id = orders.iloc[0].instrument @@ -295,10 +315,16 @@ def single_with_collect_data_loop( ) report_dict: dict = {} - decisions = list(collect_data_loop(trade_start_time, trade_end_time, strategy, executor, report_dict)) + decisions = list( + collect_data_loop( + trade_start_time, trade_end_time, strategy, executor, report_dict + ) + ) indicator_dict = cast(INDICATOR_METRIC, report_dict.get("indicator_dict")) - records = _convert_indicator_to_dataframe(indicator_dict["1day"][1].order_indicator_his) + records = _convert_indicator_to_dataframe( + indicator_dict["1day"][1].order_indicator_his + ) assert records is None or not np.isnan(records["ffr"]).any() if generate_report: @@ -324,7 +350,11 @@ def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFram stock_pool.sort() single = single_with_simulator if with_simulator else single_with_collect_data_loop - mp_config = {"n_jobs": backtest_config["concurrency"], "verbose": 10, "backend": "multiprocessing"} + mp_config = { + "n_jobs": backtest_config["concurrency"], + "verbose": 10, + "backend": "multiprocessing", + } torch.set_num_threads(1) # https://github.com/pytorch/pytorch/issues/17199 res = Parallel(**mp_config)( delayed(single)( @@ -364,8 +394,14 @@ def backtest(backtest_config: dict, with_simulator: bool = False) -> pd.DataFram warnings.filterwarnings("ignore", category=RuntimeWarning) parser = argparse.ArgumentParser() - parser.add_argument("--config_path", type=str, required=True, help="Path to the config file") - parser.add_argument("--use_simulator", action="store_true", help="Whether to use simulator as the backend") + parser.add_argument( + "--config_path", type=str, required=True, help="Path to the config file" + ) + parser.add_argument( + "--use_simulator", + action="store_true", + help="Whether to use simulator as the backend", + ) parser.add_argument( "--n_jobs", type=int, diff --git a/qlib/rl/contrib/naive_config_parser.py b/qlib/rl/contrib/naive_config_parser.py index 5608cbd1ef..d81febcaff 100644 --- a/qlib/rl/contrib/naive_config_parser.py +++ b/qlib/rl/contrib/naive_config_parser.py @@ -38,7 +38,9 @@ def parse_backtest_config(path: str) -> dict: raise IOError("Only py/yml/yaml/json type are supported now!") with tempfile.TemporaryDirectory() as tmp_config_dir: - with tempfile.NamedTemporaryFile(dir=tmp_config_dir, suffix=file_ext_name) as tmp_config_file: + with tempfile.NamedTemporaryFile( + dir=tmp_config_dir, suffix=file_ext_name + ) as tmp_config_file: if platform.system() == "Windows": tmp_config_file.close() @@ -51,7 +53,9 @@ def parse_backtest_config(path: str) -> dict: module = import_module(tmp_module_name) sys.path.pop(0) - config = {k: v for k, v in module.__dict__.items() if not k.startswith("__")} + config = { + k: v for k, v in module.__dict__.items() if not k.startswith("__") + } del sys.modules[tmp_module_name] else: @@ -65,7 +69,9 @@ def parse_backtest_config(path: str) -> dict: base_file_name = [base_file_name] for f in base_file_name: - base_config = parse_backtest_config(os.path.join(os.path.dirname(abs_path), f)) + base_config = parse_backtest_config( + os.path.join(os.path.dirname(abs_path), f) + ) config = merge_a_into_b(a=config, b=base_config) return config @@ -90,8 +96,12 @@ def get_backtest_config_fromfile(path: str) -> dict: "trade_unit": 100.0, "cash_limit": None, } - backtest_config["exchange"] = merge_a_into_b(a=backtest_config["exchange"], b=exchange_config_default) - backtest_config["exchange"] = _convert_all_list_to_tuple(backtest_config["exchange"]) + backtest_config["exchange"] = merge_a_into_b( + a=backtest_config["exchange"], b=exchange_config_default + ) + backtest_config["exchange"] = _convert_all_list_to_tuple( + backtest_config["exchange"] + ) backtest_config_default = { "debug_single_stock": None, diff --git a/qlib/rl/contrib/train_onpolicy.py b/qlib/rl/contrib/train_onpolicy.py index 83dd924103..0e32caef70 100644 --- a/qlib/rl/contrib/train_onpolicy.py +++ b/qlib/rl/contrib/train_onpolicy.py @@ -91,7 +91,9 @@ def __getitem__(self, index: int) -> Order: amount=row["amount"], direction=OrderDir(int(row["order_type"])), start_time=date + self._ticks_index[self._default_start_time_index], - end_time=date + self._ticks_index[self._default_end_time_index - 1] + ONE_MIN, + end_time=date + + self._ticks_index[self._default_end_time_index - 1] + + ONE_MIN, ) return order @@ -118,7 +120,9 @@ def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple: order=order, data_dir=data_config["source"]["feature_root_dir"], feature_columns_today=data_config["source"]["feature_columns_today"], - feature_columns_yesterday=data_config["source"]["feature_columns_yesterday"], + feature_columns_yesterday=data_config["source"][ + "feature_columns_yesterday" + ], data_granularity=data_granularity, ticks_per_step=simulator_config["time_per_step"], vol_threshold=simulator_config["vol_limit"], @@ -132,15 +136,21 @@ def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple: LazyLoadDataset( data_dir=data_config["source"]["feature_root_dir"], order_file_path=order_root_path / tag, - default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity, - default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity, + default_start_time_index=data_config["source"][ + "default_start_time_index" + ] + // data_granularity, + default_end_time_index=data_config["source"]["default_end_time_index"] + // data_granularity, ) for tag in ("train", "valid") ] callbacks: List[Callback] = [] if "checkpoint_path" in trainer_config: - callbacks.append(MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"]))) + callbacks.append( + MetricsWriter(dirpath=Path(trainer_config["checkpoint_path"])) + ) callbacks.append( Checkpoint( dirpath=Path(trainer_config["checkpoint_path"]) / "checkpoints", @@ -184,8 +194,10 @@ def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple: test_dataset = LazyLoadDataset( data_dir=data_config["source"]["feature_root_dir"], order_file_path=order_root_path / "test", - default_start_time_index=data_config["source"]["default_start_time_index"] // data_granularity, - default_end_time_index=data_config["source"]["default_end_time_index"] // data_granularity, + default_start_time_index=data_config["source"]["default_start_time_index"] + // data_granularity, + default_end_time_index=data_config["source"]["default_end_time_index"] + // data_granularity, ) backtest( @@ -203,7 +215,9 @@ def _simulator_factory_simple(order: Order) -> SingleAssetOrderExecutionSimple: def main(config: dict, run_training: bool, run_backtest: bool) -> None: if not run_training and not run_backtest: - warnings.warn("Skip the entire job since training and backtest are both skipped.") + warnings.warn( + "Skip the entire job since training and backtest are both skipped." + ) return if "seed" in config["runtime"]: @@ -212,8 +226,12 @@ def main(config: dict, run_training: bool, run_backtest: bool) -> None: for extra_module_path in config["env"].get("extra_module_paths", []): sys.path.append(extra_module_path) - state_interpreter: StateInterpreter = init_instance_by_config(config["state_interpreter"]) - action_interpreter: ActionInterpreter = init_instance_by_config(config["action_interpreter"]) + state_interpreter: StateInterpreter = init_instance_by_config( + config["state_interpreter"] + ) + action_interpreter: ActionInterpreter = init_instance_by_config( + config["action_interpreter"] + ) reward: Reward = init_instance_by_config(config["reward"]) additional_policy_kwargs = { @@ -225,7 +243,9 @@ def main(config: dict, run_training: bool, run_backtest: bool) -> None: if "network" in config: if "kwargs" not in config["network"]: config["network"]["kwargs"] = {} - config["network"]["kwargs"].update({"obs_space": state_interpreter.observation_space}) + config["network"]["kwargs"].update( + {"obs_space": state_interpreter.observation_space} + ) additional_policy_kwargs["network"] = init_instance_by_config(config["network"]) # Create policy @@ -257,9 +277,15 @@ def main(config: dict, run_training: bool, run_backtest: bool) -> None: warnings.filterwarnings("ignore", category=RuntimeWarning) parser = argparse.ArgumentParser() - parser.add_argument("--config_path", type=str, required=True, help="Path to the config file") - parser.add_argument("--no_training", action="store_true", help="Skip training workflow.") - parser.add_argument("--run_backtest", action="store_true", help="Run backtest workflow.") + parser.add_argument( + "--config_path", type=str, required=True, help="Path to the config file" + ) + parser.add_argument( + "--no_training", action="store_true", help="Skip training workflow." + ) + parser.add_argument( + "--run_backtest", action="store_true", help="Run backtest workflow." + ) args = parser.parse_args() with open(args.config_path, "r") as input_stream: diff --git a/qlib/rl/contrib/utils.py b/qlib/rl/contrib/utils.py index cad25e0dba..a73f77f7a0 100644 --- a/qlib/rl/contrib/utils.py +++ b/qlib/rl/contrib/utils.py @@ -23,7 +23,9 @@ def read_order_file(order_file: Path | pd.DataFrame) -> pd.DataFrame: if "date" in order_df.columns: # legacy dataframe columns - order_df = order_df.rename(columns={"date": "datetime", "order_type": "direction"}) + order_df = order_df.rename( + columns={"date": "datetime", "order_type": "direction"} + ) order_df["datetime"] = order_df["datetime"].astype(str) return order_df diff --git a/qlib/rl/data/integration.py b/qlib/rl/data/integration.py index e123b6c8cf..a03184b703 100644 --- a/qlib/rl/data/integration.py +++ b/qlib/rl/data/integration.py @@ -12,7 +12,17 @@ import qlib from qlib.constant import REG_CN -from qlib.contrib.ops.high_freq import BFillNan, Cut, Date, DayCumsum, DayLast, FFillNan, IsInf, IsNull, Select +from qlib.contrib.ops.high_freq import ( + BFillNan, + Cut, + Date, + DayCumsum, + DayLast, + FFillNan, + IsInf, + IsNull, + Select, +) def init_qlib(qlib_config: dict) -> None: @@ -46,12 +56,24 @@ def _convert_to_path(path: str | Path) -> Path: provider_uri_map = {} for granularity in ["1min", "5min", "day"]: if f"provider_uri_{granularity}" in qlib_config: - provider_uri_map[f"{granularity}"] = _convert_to_path(qlib_config[f"provider_uri_{granularity}"]).as_posix() + provider_uri_map[f"{granularity}"] = _convert_to_path( + qlib_config[f"provider_uri_{granularity}"] + ).as_posix() qlib.init( region=REG_CN, auto_mount=False, - custom_ops=[DayLast, FFillNan, BFillNan, Date, Select, IsNull, IsInf, Cut, DayCumsum], + custom_ops=[ + DayLast, + FFillNan, + BFillNan, + Date, + Select, + IsNull, + IsInf, + Cut, + DayCumsum, + ], expression_cache=None, calendar_provider={ "class": "LocalCalendarProvider", diff --git a/qlib/rl/data/native.py b/qlib/rl/data/native.py index ceb5408829..531956006b 100644 --- a/qlib/rl/data/native.py +++ b/qlib/rl/data/native.py @@ -13,7 +13,11 @@ from qlib.backtest import Exchange, Order from qlib.backtest.decision import TradeRange, TradeRangeByTime from qlib.constant import EPS_T -from .base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider +from .base import ( + BaseIntradayBacktestData, + BaseIntradayProcessedData, + ProcessedDataProvider, +) def get_ticks_slice( @@ -86,13 +90,25 @@ def get_time_index(self) -> pd.DatetimeIndex: class DataframeIntradayBacktestData(BaseIntradayBacktestData): """Backtest data from dataframe""" - def __init__(self, df: pd.DataFrame, price_column: str = "$close0", volume_column: str = "$volume0") -> None: + def __init__( + self, + df: pd.DataFrame, + price_column: str = "$close0", + volume_column: str = "$volume0", + ) -> None: self.df = df self.price_column = price_column self.volume_column = volume_column def __repr__(self) -> str: - with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"): + with pd.option_context( + "memory_usage", + False, + "display.max_info_columns", + 1, + "display.large_repr", + "info", + ): return f"{self.__class__.__name__}({self.df})" def __len__(self) -> int: @@ -159,11 +175,17 @@ def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame: df = df.drop(columns=["instrument"]) return df.set_index(["datetime"]) - path = os.path.join(data_dir, "backtest" if backtest else "feature", f"{stock_id}.pkl") - start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace(hour=23, minute=59, second=59) + path = os.path.join( + data_dir, "backtest" if backtest else "feature", f"{stock_id}.pkl" + ) + start_time, end_time = date.replace(hour=0, minute=0, second=0), date.replace( + hour=23, minute=59, second=59 + ) with open(path, "rb") as fstream: dataset = pickle.load(fstream) - data = dataset.handler.fetch(pd.IndexSlice[stock_id, start_time:end_time], level=None) + data = dataset.handler.fetch( + pd.IndexSlice[stock_id, start_time:end_time], level=None + ) if index_only: self.today = _drop_stock_id(data[[]]) @@ -173,7 +195,14 @@ def _drop_stock_id(df: pd.DataFrame) -> pd.DataFrame: self.yesterday = _drop_stock_id(data[feature_columns_yesterday]) def __repr__(self) -> str: - with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"): + with pd.option_context( + "memory_usage", + False, + "display.max_info_columns", + 1, + "display.large_repr", + "info", + ): return f"{self.__class__.__name__}({self.today}, {self.yesterday})" @@ -196,7 +225,13 @@ def load_handler_intraday_processed_data( index_only: bool = False, ) -> HandlerIntradayProcessedData: return HandlerIntradayProcessedData( - data_dir, stock_id, date, feature_columns_today, feature_columns_yesterday, backtest, index_only + data_dir, + stock_id, + date, + feature_columns_today, + feature_columns_yesterday, + backtest, + index_only, ) diff --git a/qlib/rl/data/pickle_styled.py b/qlib/rl/data/pickle_styled.py index 4905b026a2..7e97053239 100644 --- a/qlib/rl/data/pickle_styled.py +++ b/qlib/rl/data/pickle_styled.py @@ -29,7 +29,11 @@ from cachetools.keys import hashkey from qlib.backtest.decision import Order, OrderDir -from qlib.rl.data.base import BaseIntradayBacktestData, BaseIntradayProcessedData, ProcessedDataProvider +from qlib.rl.data.base import ( + BaseIntradayBacktestData, + BaseIntradayProcessedData, + ProcessedDataProvider, +) from qlib.typehint import Literal DealPriceType = Literal["bid_or_ask", "bid_or_ask_fill", "close"] @@ -75,9 +79,13 @@ def _find_pickle(filename_without_suffix: Path) -> Path: if path.exists(): paths.append(path) if not paths: - raise FileNotFoundError(f"No file starting with '{filename_without_suffix}' found") + raise FileNotFoundError( + f"No file starting with '{filename_without_suffix}' found" + ) if len(paths) > 1: - raise ValueError(f"Multiple paths are found with prefix '{filename_without_suffix}': {paths}") + raise ValueError( + f"Multiple paths are found with prefix '{filename_without_suffix}': {paths}" + ) return paths[0] @@ -108,7 +116,9 @@ def __init__( ) -> None: super(SimpleIntradayBacktestData, self).__init__() - backtest = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id) + backtest = _read_pickle( + (data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id + ) backtest = backtest.loc[pd.IndexSlice[stock_id, :, date]] # No longer need for pandas >= 1.4 @@ -119,7 +129,14 @@ def __init__( self.order_dir = order_dir def __repr__(self) -> str: - with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"): + with pd.option_context( + "memory_usage", + False, + "display.max_info_columns", + 1, + "display.large_repr", + "info", + ): return f"{self.__class__.__name__}({self.data})" def __len__(self) -> int: @@ -130,7 +147,9 @@ def get_deal_price(self) -> pd.Series: See :attribute:`DealPriceType` for details.""" if self.deal_price_type in ("bid_or_ask", "bid_or_ask_fill"): if self.order_dir is None: - raise ValueError("Order direction cannot be none when deal_price_type is not close.") + raise ValueError( + "Order direction cannot be none when deal_price_type is not close." + ) if self.order_dir == OrderDir.SELL: col = "$bid0" else: # BUY @@ -169,7 +188,9 @@ def __init__( feature_dim: int, time_index: pd.Index, ) -> None: - proc = _read_pickle((data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id) + proc = _read_pickle( + (data_dir if isinstance(data_dir, Path) else Path(data_dir)) / stock_id + ) # We have to infer the names here because, # unfortunately they are not included in the original data. @@ -182,15 +203,23 @@ def __init__( proc = proc.loc[pd.IndexSlice[stock_id, :, date]] assert len(proc) == time_length and len(proc.columns) == feature_dim * 2 proc_today = proc[cnames] - proc_yesterday = proc[[f"{c}_1" for c in cnames]].rename(columns=lambda c: c[:-2]) + proc_yesterday = proc[[f"{c}_1" for c in cnames]].rename( + columns=lambda c: c[:-2] + ) except (IndexError, KeyError): # legacy data proc = proc.loc[pd.IndexSlice[stock_id, date]] assert time_length * feature_dim * 2 == len(proc) - proc_today = proc.to_numpy()[: time_length * feature_dim].reshape((time_length, feature_dim)) - proc_yesterday = proc.to_numpy()[time_length * feature_dim :].reshape((time_length, feature_dim)) + proc_today = proc.to_numpy()[: time_length * feature_dim].reshape( + (time_length, feature_dim) + ) + proc_yesterday = proc.to_numpy()[time_length * feature_dim :].reshape( + (time_length, feature_dim) + ) proc_today = pd.DataFrame(proc_today, index=time_index, columns=cnames) - proc_yesterday = pd.DataFrame(proc_yesterday, index=time_index, columns=cnames) + proc_yesterday = pd.DataFrame( + proc_yesterday, index=time_index, columns=cnames + ) self.today: pd.DataFrame = proc_today self.yesterday: pd.DataFrame = proc_yesterday @@ -198,7 +227,14 @@ def __init__( assert len(self.today) == len(self.yesterday) == time_length def __repr__(self) -> str: - with pd.option_context("memory_usage", False, "display.max_info_columns", 1, "display.large_repr", "info"): + with pd.option_context( + "memory_usage", + False, + "display.max_info_columns", + 1, + "display.large_repr", + "info", + ): return f"{self.__class__.__name__}({self.today}, {self.yesterday})" @@ -215,7 +251,9 @@ def load_simple_intraday_backtest_data( @cachetools.cached( # type: ignore cache=cachetools.LRUCache(100), # 100 * 50K = 5MB - key=lambda data_dir, stock_id, date, feature_dim, time_index: hashkey(data_dir, stock_id, date), + key=lambda data_dir, stock_id, date, feature_dim, time_index: hashkey( + data_dir, stock_id, date + ), ) def load_pickle_intraday_processed_data( data_dir: Path, @@ -224,7 +262,9 @@ def load_pickle_intraday_processed_data( feature_dim: int, time_index: pd.Index, ) -> BaseIntradayProcessedData: - return PickleIntradayProcessedData(data_dir, stock_id, date, feature_dim, time_index) + return PickleIntradayProcessedData( + data_dir, stock_id, date, feature_dim, time_index + ) class PickleProcessedDataProvider(ProcessedDataProvider): @@ -288,8 +328,14 @@ def load_orders( row["instrument"], row["amount"], OrderDir(int(row["order_type"])), - row["datetime"].replace(hour=start_time.hour, minute=start_time.minute, second=start_time.second), - row["datetime"].replace(hour=end_time.hour, minute=end_time.minute, second=end_time.second), + row["datetime"].replace( + hour=start_time.hour, + minute=start_time.minute, + second=start_time.second, + ), + row["datetime"].replace( + hour=end_time.hour, minute=end_time.minute, second=end_time.second + ), ), ) diff --git a/qlib/rl/interpreter.py b/qlib/rl/interpreter.py index 5c9cc26c4e..8bb320e4e3 100644 --- a/qlib/rl/interpreter.py +++ b/qlib/rl/interpreter.py @@ -5,9 +5,9 @@ from typing import Any, Generic, TypeVar -import gym +import gymnasium as gymnasium import numpy as np -from gym import spaces +from gymnasium import spaces from qlib.typehint import final from .simulator import ActType, StateType @@ -36,7 +36,7 @@ class StateInterpreter(Generic[StateType, ObsType], Interpreter): """State Interpreter that interpret execution result of qlib executor into rl env state""" @property - def observation_space(self) -> gym.Space: + def observation_space(self) -> gymnasium.Space: raise NotImplementedError() @final # no overridden @@ -47,7 +47,7 @@ def __call__(self, simulator_state: StateType) -> ObsType: def validate(self, obs: ObsType) -> None: """Validate whether an observation belongs to the pre-defined observation space.""" - _gym_space_contains(self.observation_space, obs) + _gymnasium_space_contains(self.observation_space, obs) def interpret(self, simulator_state: StateType) -> ObsType: """Interpret the state of simulator. @@ -68,7 +68,7 @@ class ActionInterpreter(Generic[StateType, PolicyActType, ActType], Interpreter) """Action Interpreter that interpret rl agent action into qlib orders""" @property - def action_space(self) -> gym.Space: + def action_space(self) -> gymnasium.Space: raise NotImplementedError() @final # no overridden @@ -79,7 +79,7 @@ def __call__(self, simulator_state: StateType, action: PolicyActType) -> ActType def validate(self, action: PolicyActType) -> None: """Validate whether an action belongs to the pre-defined action space.""" - _gym_space_contains(self.action_space, action) + _gymnasium_space_contains(self.action_space, action) def interpret(self, simulator_state: StateType, action: PolicyActType) -> ActType: """Convert the policy action to simulator action. @@ -98,41 +98,76 @@ def interpret(self, simulator_state: StateType, action: PolicyActType) -> ActTyp raise NotImplementedError("interpret is not implemented!") -def _gym_space_contains(space: gym.Space, x: Any) -> None: - """Strengthened version of gym.Space.contains. +def _gymnasium_space_contains(space: gymnasium.Space, x: Any) -> None: + """Strengthened version of gymnasium.Space.contains. Giving more diagnostic information on why validation fails. Throw exception rather than returning true or false. """ if isinstance(space, spaces.Dict): if not isinstance(x, dict) or len(x) != len(space): - raise GymSpaceValidationError("Sample must be a dict with same length as space.", space, x) +<<<<<<< HEAD + raise gymnasiumSpaceValidationError("Sample must be a dict with same length as space.", space, x) for k, subspace in space.spaces.items(): if k not in x: - raise GymSpaceValidationError(f"Key {k} not found in sample.", space, x) + raise gymnasiumSpaceValidationError(f"Key {k} not found in sample.", space, x) try: - _gym_space_contains(subspace, x[k]) - except GymSpaceValidationError as e: - raise GymSpaceValidationError(f"Subspace of key {k} validation error.", space, x) from e + _gymnasium_space_contains(subspace, x[k]) + except gymnasiumSpaceValidationError as e: + raise gymnasiumSpaceValidationError(f"Subspace of key {k} validation error.", space, x) from e +======= + raise gymnasiumSpaceValidationError( + "Sample must be a dict with same length as space.", space, x + ) + for k, subspace in space.spaces.items(): + if k not in x: + raise gymnasiumSpaceValidationError( + f"Key {k} not found in sample.", space, x + ) + try: + _gymnasium_space_contains(subspace, x[k]) + except gymnasiumSpaceValidationError as e: + raise gymnasiumSpaceValidationError( + f"Subspace of key {k} validation error.", space, x + ) from e +>>>>>>> f180e36a (fix: migrate from gym to gymnasium for NumPy 2.0+ compatibility) elif isinstance(space, spaces.Tuple): if isinstance(x, (list, np.ndarray)): x = tuple(x) # Promote list and ndarray to tuple for contains check if not isinstance(x, tuple) or len(x) != len(space): - raise GymSpaceValidationError("Sample must be a tuple with same length as space.", space, x) +<<<<<<< HEAD + raise gymnasiumSpaceValidationError("Sample must be a tuple with same length as space.", space, x) +======= + raise gymnasiumSpaceValidationError( + "Sample must be a tuple with same length as space.", space, x + ) +>>>>>>> f180e36a (fix: migrate from gym to gymnasium for NumPy 2.0+ compatibility) for i, (subspace, part) in enumerate(zip(space, x)): try: - _gym_space_contains(subspace, part) - except GymSpaceValidationError as e: - raise GymSpaceValidationError(f"Subspace of index {i} validation error.", space, x) from e + _gymnasium_space_contains(subspace, part) + except gymnasiumSpaceValidationError as e: +<<<<<<< HEAD + raise gymnasiumSpaceValidationError(f"Subspace of index {i} validation error.", space, x) from e + + else: + if not space.contains(x): + raise gymnasiumSpaceValidationError("Validation error reported by gymnasium.", space, x) +======= + raise gymnasiumSpaceValidationError( + f"Subspace of index {i} validation error.", space, x + ) from e else: if not space.contains(x): - raise GymSpaceValidationError("Validation error reported by gym.", space, x) + raise gymnasiumSpaceValidationError( + "Validation error reported by gymnasium.", space, x + ) +>>>>>>> f180e36a (fix: migrate from gym to gymnasium for NumPy 2.0+ compatibility) -class GymSpaceValidationError(Exception): - def __init__(self, message: str, space: gym.Space, x: Any) -> None: +class gymnasiumSpaceValidationError(Exception): + def __init__(self, message: str, space: gymnasium.Space, x: Any) -> None: self.message = message self.space = space self.x = x diff --git a/qlib/rl/order_execution/interpreter.py b/qlib/rl/order_execution/interpreter.py index 01b0811530..34289c9f5e 100644 --- a/qlib/rl/order_execution/interpreter.py +++ b/qlib/rl/order_execution/interpreter.py @@ -8,7 +8,7 @@ import numpy as np import pandas as pd -from gym import spaces +from gymnasium import spaces from qlib.constant import EPS from qlib.rl.data.base import ProcessedDataProvider @@ -27,13 +27,19 @@ from qlib.utils import init_instance_by_config -def canonicalize(value: int | float | np.ndarray | pd.DataFrame | dict) -> np.ndarray | dict: +def canonicalize( + value: int | float | np.ndarray | pd.DataFrame | dict, +) -> np.ndarray | dict: """To 32-bit numeric types. Recursively.""" if isinstance(value, pd.DataFrame): return value.to_numpy() - if isinstance(value, (float, np.floating)) or (isinstance(value, np.ndarray) and value.dtype.kind == "f"): + if isinstance(value, (float, np.floating)) or ( + isinstance(value, np.ndarray) and value.dtype.kind == "f" + ): return np.array(value, dtype=np.float32) - elif isinstance(value, (int, bool, np.integer)) or (isinstance(value, np.ndarray) and value.dtype.kind == "i"): + elif isinstance(value, (int, bool, np.integer)) or ( + isinstance(value, np.ndarray) and value.dtype.kind == "i" + ): return np.array(value, dtype=np.int32) elif isinstance(value, dict): return {k: canonicalize(v) for k, v in value.items()} @@ -62,7 +68,9 @@ def interpret(self, state: SAOEState) -> dict: @property def observation_space(self) -> spaces.Dict: - return spaces.Dict({"DUMMY": spaces.Box(-np.inf, np.inf, shape=(), dtype=np.int32)}) + return spaces.Dict( + {"DUMMY": spaces.Box(-np.inf, np.inf, shape=(), dtype=np.int32)} + ) class FullHistoryStateInterpreter(StateInterpreter[SAOEState, FullHistoryObs]): @@ -108,7 +116,9 @@ def interpret(self, state: SAOEState) -> FullHistoryObs: position_history = np.full(self.max_step + 1, 0.0, dtype=np.float32) position_history[0] = state.order.amount - position_history[1 : len(state.history_steps) + 1] = state.history_steps["position"].to_numpy() + position_history[1 : len(state.history_steps) + 1] = state.history_steps[ + "position" + ].to_numpy() # The min, slice here are to make sure that indices fit into the range, # even after the final step of the simulator (in the done step), @@ -117,10 +127,17 @@ def interpret(self, state: SAOEState) -> FullHistoryObs: FullHistoryObs, canonicalize( { - "data_processed": np.array(self._mask_future_info(processed.today, state.cur_time)), + "data_processed": np.array( + self._mask_future_info(processed.today, state.cur_time) + ), "data_processed_prev": np.array(processed.yesterday), "acquiring": _to_int32(state.order.direction == state.order.BUY), - "cur_tick": _to_int32(min(int(np.sum(state.ticks_index < state.cur_time)), self.data_ticks - 1)), + "cur_tick": _to_int32( + min( + int(np.sum(state.ticks_index < state.cur_time)), + self.data_ticks - 1, + ) + ), "cur_step": _to_int32(min(state.cur_step, self.max_step - 1)), "num_step": _to_int32(self.max_step), "target": _to_float32(state.order.amount), @@ -133,13 +150,19 @@ def interpret(self, state: SAOEState) -> FullHistoryObs: @property def observation_space(self) -> spaces.Dict: space = { - "data_processed": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)), - "data_processed_prev": spaces.Box(-np.inf, np.inf, shape=(self.data_ticks, self.data_dim)), + "data_processed": spaces.Box( + -np.inf, np.inf, shape=(self.data_ticks, self.data_dim) + ), + "data_processed_prev": spaces.Box( + -np.inf, np.inf, shape=(self.data_ticks, self.data_dim) + ), "acquiring": spaces.Discrete(2), "cur_tick": spaces.Box(0, self.data_ticks - 1, shape=(), dtype=np.int32), "cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32), # TODO: support arbitrary length index - "num_step": spaces.Box(self.max_step, self.max_step, shape=(), dtype=np.int32), + "num_step": spaces.Box( + self.max_step, self.max_step, shape=(), dtype=np.int32 + ), "target": spaces.Box(-EPS, np.inf, shape=()), "position": spaces.Box(-EPS, np.inf, shape=()), "position_history": spaces.Box(-EPS, np.inf, shape=(self.max_step,)), @@ -178,7 +201,9 @@ def observation_space(self) -> spaces.Dict: space = { "acquiring": spaces.Discrete(2), "cur_step": spaces.Box(0, self.max_step - 1, shape=(), dtype=np.int32), - "num_step": spaces.Box(self.max_step, self.max_step, shape=(), dtype=np.int32), + "num_step": spaces.Box( + self.max_step, self.max_step, shape=(), dtype=np.int32 + ), "target": spaces.Box(-EPS, np.inf, shape=()), "position": spaces.Box(-EPS, np.inf, shape=()), } @@ -210,7 +235,9 @@ class CategoricalActionInterpreter(ActionInterpreter[SAOEState, int, float]): Total number of steps (an upper-bound estimation). For example, 390min / 30min-per-step = 13 steps. """ - def __init__(self, values: int | List[float], max_step: Optional[int] = None) -> None: + def __init__( + self, values: int | List[float], max_step: Optional[int] = None + ) -> None: super().__init__() if isinstance(values, int): @@ -244,7 +271,9 @@ def action_space(self) -> spaces.Box: return spaces.Box(0, np.inf, shape=(), dtype=np.float32) def interpret(self, state: SAOEState, action: float) -> float: - estimated_total_steps = math.ceil(len(state.ticks_for_order) / state.ticks_per_step) + estimated_total_steps = math.ceil( + len(state.ticks_for_order) / state.ticks_per_step + ) twap_volume = state.position / (estimated_total_steps - state.cur_step) return min(state.position, twap_volume * action) diff --git a/qlib/rl/order_execution/network.py b/qlib/rl/order_execution/network.py index d6a11189cf..03c0d4e7de 100644 --- a/qlib/rl/order_execution/network.py +++ b/qlib/rl/order_execution/network.py @@ -46,13 +46,26 @@ def __init__( self.rnn_class = rnn_classes[rnn_type] self.rnn_layers = rnn_num_layers - self.raw_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers) - self.prev_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers) - self.pri_rnn = self.rnn_class(hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers) + self.raw_rnn = self.rnn_class( + hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers + ) + self.prev_rnn = self.rnn_class( + hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers + ) + self.pri_rnn = self.rnn_class( + hidden_dim, hidden_dim, batch_first=True, num_layers=self.rnn_layers + ) - self.raw_fc = nn.Sequential(nn.Linear(obs_space["data_processed"].shape[-1], hidden_dim), nn.ReLU()) + self.raw_fc = nn.Sequential( + nn.Linear(obs_space["data_processed"].shape[-1], hidden_dim), nn.ReLU() + ) self.pri_fc = nn.Sequential(nn.Linear(2, hidden_dim), nn.ReLU()) - self.dire_fc = nn.Sequential(nn.Linear(2, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU()) + self.dire_fc = nn.Sequential( + nn.Linear(2, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + ) self._init_extra_branches() @@ -66,16 +79,25 @@ def __init__( def _init_extra_branches(self) -> None: pass - def _source_features(self, obs: FullHistoryObs, device: torch.device) -> Tuple[List[torch.Tensor], torch.Tensor]: + def _source_features( + self, obs: FullHistoryObs, device: torch.device + ) -> Tuple[List[torch.Tensor], torch.Tensor]: bs, _, data_dim = obs["data_processed"].size() - data = torch.cat((torch.zeros(bs, 1, data_dim, device=device), obs["data_processed"]), 1) + data = torch.cat( + (torch.zeros(bs, 1, data_dim, device=device), obs["data_processed"]), 1 + ) cur_step = obs["cur_step"].long() cur_tick = obs["cur_tick"].long() bs_indices = torch.arange(bs, device=device) - position = obs["position_history"] / obs["target"].unsqueeze(-1) # [bs, num_step] + position = obs["position_history"] / obs["target"].unsqueeze( + -1 + ) # [bs, num_step] steps = ( - torch.arange(position.size(-1), device=device).unsqueeze(0).repeat(bs, 1).float() + torch.arange(position.size(-1), device=device) + .unsqueeze(0) + .repeat(bs, 1) + .float() / obs["num_step"].unsqueeze(-1).float() ) # [bs, num_step] priv = torch.stack((position.float(), steps), -1) @@ -91,7 +113,9 @@ def _source_features(self, obs: FullHistoryObs, device: torch.device) -> Tuple[L sources = [data_out_slice, priv_out] - dir_out = self.dire_fc(torch.stack((obs["acquiring"], 1 - obs["acquiring"]), -1).float()) + dir_out = self.dire_fc( + torch.stack((obs["acquiring"], 1 - obs["acquiring"]), -1).float() + ) sources.append(dir_out) return sources, data_out diff --git a/qlib/rl/order_execution/policy.py b/qlib/rl/order_execution/policy.py index a46b587aa1..c99a6256f3 100644 --- a/qlib/rl/order_execution/policy.py +++ b/qlib/rl/order_execution/policy.py @@ -6,11 +6,11 @@ from pathlib import Path from typing import Any, Dict, Generator, Iterable, Optional, OrderedDict, Tuple, cast -import gym +import gymnasium as gymnasium import numpy as np import torch import torch.nn as nn -from gym.spaces import Discrete +from gymnasium.spaces import Discrete from tianshou.data import Batch, ReplayBuffer, to_torch from tianshou.policy import BasePolicy, PPOPolicy, DQNPolicy @@ -28,7 +28,13 @@ class NonLearnablePolicy(BasePolicy): This could be moved outside in future. """ - def __init__(self, obs_space: gym.Space, action_space: gym.Space) -> None: +<<<<<<< HEAD + def __init__(self, obs_space: gymnasium.Space, action_space: gymnasium.Space) -> None: +======= + def __init__( + self, obs_space: gymnasium.Space, action_space: gymnasium.Space + ) -> None: +>>>>>>> f180e36a (fix: migrate from gym to gymnasium for NumPy 2.0+ compatibility) super().__init__() def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, Any]: @@ -49,7 +55,16 @@ class AllOne(NonLearnablePolicy): Useful when implementing some baselines (e.g., TWAP). """ - def __init__(self, obs_space: gym.Space, action_space: gym.Space, fill_value: float | int = 1.0) -> None: +<<<<<<< HEAD + def __init__(self, obs_space: gymnasium.Space, action_space: gymnasium.Space, fill_value: float | int = 1.0) -> None: +======= + def __init__( + self, + obs_space: gymnasium.Space, + action_space: gymnasium.Space, + fill_value: float | int = 1.0, + ) -> None: +>>>>>>> f180e36a (fix: migrate from gym to gymnasium for NumPy 2.0+ compatibility) super().__init__(obs_space, action_space) self.fill_value = fill_value @@ -70,7 +85,9 @@ class PPOActor(nn.Module): def __init__(self, extractor: nn.Module, action_dim: int) -> None: super().__init__() self.extractor = extractor - self.layer_out = nn.Sequential(nn.Linear(cast(int, extractor.output_dim), action_dim), nn.Softmax(dim=-1)) + self.layer_out = nn.Sequential( + nn.Linear(cast(int, extractor.output_dim), action_dim), nn.Softmax(dim=-1) + ) def forward( self, @@ -114,8 +131,8 @@ class PPO(PPOPolicy): def __init__( self, network: nn.Module, - obs_space: gym.Space, - action_space: gym.Space, + obs_space: gymnasium.Space, + action_space: gymnasium.Space, lr: float, weight_decay: float = 0.0, discount_factor: float = 1.0, @@ -173,8 +190,8 @@ class DQN(DQNPolicy): def __init__( self, network: nn.Module, - obs_space: gym.Space, - action_space: gym.Space, + obs_space: gymnasium.Space, + action_space: gymnasium.Space, lr: float, weight_decay: float = 0.0, discount_factor: float = 0.99, diff --git a/qlib/rl/order_execution/reward.py b/qlib/rl/order_execution/reward.py index 0dcfd24bb3..dc41365fd3 100644 --- a/qlib/rl/order_execution/reward.py +++ b/qlib/rl/order_execution/reward.py @@ -33,17 +33,23 @@ def __init__(self, penalty: float = 100.0, scale: float = 1.0) -> None: def reward(self, simulator_state: SAOEState) -> float: whole_order = simulator_state.order.amount assert whole_order > 0 - last_step = cast(SAOEMetrics, simulator_state.history_steps.reset_index().iloc[-1].to_dict()) + last_step = cast( + SAOEMetrics, simulator_state.history_steps.reset_index().iloc[-1].to_dict() + ) pa = last_step["pa"] * last_step["amount"] / whole_order # Inspect the "break-down" of the latest step: trading amount at every tick last_step_breakdown = simulator_state.history_exec.loc[last_step["datetime"] :] - penalty = -self.penalty * ((last_step_breakdown["amount"] / whole_order) ** 2).sum() + penalty = ( + -self.penalty * ((last_step_breakdown["amount"] / whole_order) ** 2).sum() + ) reward = pa + penalty # Throw error in case of NaN - assert not (np.isnan(reward) or np.isinf(reward)), f"Invalid reward for simulator state: {simulator_state}" + assert not ( + np.isnan(reward) or np.isinf(reward) + ), f"Invalid reward for simulator state: {simulator_state}" self.log("reward/pa", pa) self.log("reward/penalty", penalty) @@ -63,13 +69,18 @@ class PPOReward(Reward[SAOEState]): Last time index that allowed to trade. """ - def __init__(self, max_step: int, start_time_index: int = 0, end_time_index: int = 239) -> None: + def __init__( + self, max_step: int, start_time_index: int = 0, end_time_index: int = 239 + ) -> None: self.max_step = max_step self.start_time_index = start_time_index self.end_time_index = end_time_index def reward(self, simulator_state: SAOEState) -> float: - if simulator_state.cur_step == self.max_step - 1 or simulator_state.position < 1e-6: + if ( + simulator_state.cur_step == self.max_step - 1 + or simulator_state.position < 1e-6 + ): if simulator_state.history_exec["deal_amount"].sum() == 0.0: vwap_price = cast( float, diff --git a/qlib/rl/order_execution/simulator_qlib.py b/qlib/rl/order_execution/simulator_qlib.py index 1417e2ab4a..dde3fcf2c6 100644 --- a/qlib/rl/order_execution/simulator_qlib.py +++ b/qlib/rl/order_execution/simulator_qlib.py @@ -43,19 +43,30 @@ def __init__( ) -> None: super().__init__(initial=order) - assert order.start_time.date() == order.end_time.date(), "Start date and end date must be the same." + assert ( + order.start_time.date() == order.end_time.date() + ), "Start date and end date must be the same." strategy_config = { "class": "SingleOrderStrategy", "module_path": "qlib.rl.strategy.single_order", "kwargs": { "order": order, - "trade_range": TradeRangeByTime(order.start_time.time(), order.end_time.time()), + "trade_range": TradeRangeByTime( + order.start_time.time(), order.end_time.time() + ), }, } self._collect_data_loop: Optional[Generator] = None - self.reset(order, strategy_config, executor_config, exchange_config, qlib_config, cash_limit) + self.reset( + order, + strategy_config, + executor_config, + exchange_config, + qlib_config, + cash_limit, + ) def reset( self, @@ -108,11 +119,19 @@ def _iter_strategy(self, action: Optional[float] = None) -> SAOEStrategy: """Iterate the _collect_data_loop until we get the next yield SAOEStrategy.""" assert self._collect_data_loop is not None - obj = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action) + obj = ( + next(self._collect_data_loop) + if action is None + else self._collect_data_loop.send(action) + ) while not isinstance(obj, SAOEStrategy): if isinstance(obj, BaseTradeDecision): self.decisions.append(obj) - obj = next(self._collect_data_loop) if action is None else self._collect_data_loop.send(action) + obj = ( + next(self._collect_data_loop) + if action is None + else self._collect_data_loop.send(action) + ) assert isinstance(obj, SAOEStrategy) return obj diff --git a/qlib/rl/order_execution/simulator_simple.py b/qlib/rl/order_execution/simulator_simple.py index 48aa03a170..edb337f23e 100644 --- a/qlib/rl/order_execution/simulator_simple.py +++ b/qlib/rl/order_execution/simulator_simple.py @@ -12,7 +12,10 @@ from qlib.backtest.decision import Order, OrderDir from qlib.constant import EPS, EPS_T, float_or_ndarray from qlib.rl.data.base import BaseIntradayBacktestData -from qlib.rl.data.native import DataframeIntradayBacktestData, load_handler_intraday_processed_data +from qlib.rl.data.native import ( + DataframeIntradayBacktestData, + load_handler_intraday_processed_data, +) from qlib.rl.data.pickle_styled import load_simple_intraday_backtest_data from qlib.rl.simulator import Simulator from qlib.rl.utils import LogLevel @@ -100,17 +103,26 @@ def __init__( self.ticks_index = self.backtest_data.get_time_index() # Get time index available for trading - self.ticks_for_order = self._get_ticks_slice(self.order.start_time, self.order.end_time) + self.ticks_for_order = self._get_ticks_slice( + self.order.start_time, self.order.end_time + ) self.cur_time = self.ticks_for_order[0] self.cur_step = 0 # NOTE: astype(float) is necessary in some systems. # this will align the precision with `.to_numpy()` in `_split_exec_vol` - self.twap_price = float(self.backtest_data.get_deal_price().loc[self.ticks_for_order].astype(float).mean()) + self.twap_price = float( + self.backtest_data.get_deal_price() + .loc[self.ticks_for_order] + .astype(float) + .mean() + ) self.position = order.amount - metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member + metric_keys = list( + SAOEMetrics.__annotations__.keys() + ) # pylint: disable=no-member # NOTE: can empty dataframe contain index? self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime") self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime") @@ -166,7 +178,9 @@ def step(self, amount: float) -> None: if abs(self.position) < 1e-6: self.position = 0.0 if self.position < -EPS or (exec_vol < -EPS).any(): - raise ValueError(f"Execution volume is invalid: {exec_vol} (position = {self.position})") + raise ValueError( + f"Execution volume is invalid: {exec_vol} (position = {self.position})" + ) # Get time index available for this step time_index = self._get_ticks_slice(self.cur_time, self._next_time()) @@ -189,19 +203,29 @@ def step(self, amount: float) -> None: trade_value=self.market_price * exec_vol, position=ticks_position, ffr=exec_vol / self.order.amount, - pa=price_advantage(self.market_price, self.twap_price, self.order.direction), + pa=price_advantage( + self.market_price, self.twap_price, self.order.direction + ), ), ) self.history_steps = self._dataframe_append( self.history_steps, - [self._metrics_collect(self.cur_time, self.market_vol, self.market_price, amount, exec_vol)], + [ + self._metrics_collect( + self.cur_time, self.market_vol, self.market_price, amount, exec_vol + ) + ], ) if self.done(): if self.env is not None: - self.env.logger.add_any("history_steps", self.history_steps, loglevel=LogLevel.DEBUG) - self.env.logger.add_any("history_exec", self.history_exec, loglevel=LogLevel.DEBUG) + self.env.logger.add_any( + "history_steps", self.history_steps, loglevel=LogLevel.DEBUG + ) + self.env.logger.add_any( + "history_exec", self.history_exec, loglevel=LogLevel.DEBUG + ) self.metrics = self._metrics_collect( self.ticks_index[0], # start time @@ -257,7 +281,10 @@ def _next_time(self) -> pd.Timestamp: # as long as ticks_per_step is a multiple of something, each step won't cross morning and afternoon. next_loc = next_loc - next_loc % self.ticks_per_step - if next_loc < len(self.ticks_index) and self.ticks_index[next_loc] < self.order.end_time: + if ( + next_loc < len(self.ticks_index) + and self.ticks_index[next_loc] < self.order.end_time + ): return self.ticks_index[next_loc] else: return self.order.end_time @@ -274,16 +301,30 @@ def _split_exec_vol(self, exec_vol_sum: float) -> np.ndarray: next_time = self._next_time() # get the backtest data for next interval - self.market_vol = self.backtest_data.get_volume().loc[self.cur_time : next_time - EPS_T].to_numpy() - self.market_price = self.backtest_data.get_deal_price().loc[self.cur_time : next_time - EPS_T].to_numpy() + self.market_vol = ( + self.backtest_data.get_volume() + .loc[self.cur_time : next_time - EPS_T] + .to_numpy() + ) + self.market_price = ( + self.backtest_data.get_deal_price() + .loc[self.cur_time : next_time - EPS_T] + .to_numpy() + ) assert self.market_vol is not None and self.market_price is not None # split the volume equally into each minute - exec_vol = np.repeat(exec_vol_sum / len(self.market_price), len(self.market_price)) + exec_vol = np.repeat( + exec_vol_sum / len(self.market_price), len(self.market_price) + ) # apply the volume threshold - market_vol_limit = self.vol_threshold * self.market_vol if self.vol_threshold is not None else np.inf + market_vol_limit = ( + self.vol_threshold * self.market_vol + if self.vol_threshold is not None + else np.inf + ) exec_vol = np.minimum(exec_vol, market_vol_limit) # type: ignore # Complete all the order amount at the last moment. @@ -306,7 +347,9 @@ def _metrics_collect( if np.abs(np.sum(exec_vol)) < EPS: exec_avg_price = 0.0 else: - exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan + exec_avg_price = cast( + float, np.average(market_price, weights=exec_vol) + ) # could be nan if hasattr(exec_avg_price, "item"): # could be numpy scalar exec_avg_price = exec_avg_price.item() # type: ignore @@ -326,7 +369,9 @@ def _metrics_collect( pa=price_advantage(exec_avg_price, self.twap_price, self.order.direction), ) - def _get_ticks_slice(self, start: pd.Timestamp, end: pd.Timestamp, include_end: bool = False) -> pd.DatetimeIndex: + def _get_ticks_slice( + self, start: pd.Timestamp, end: pd.Timestamp, include_end: bool = False + ) -> pd.DatetimeIndex: if not include_end: end = end - EPS_T return self.ticks_index[self.ticks_index.slice_indexer(start, end)] diff --git a/qlib/rl/order_execution/strategy.py b/qlib/rl/order_execution/strategy.py index 7e66a1f085..76d0bfb500 100644 --- a/qlib/rl/order_execution/strategy.py +++ b/qlib/rl/order_execution/strategy.py @@ -15,7 +15,12 @@ from tianshou.policy import BasePolicy from qlib.backtest import CommonInfrastructure, Order -from qlib.backtest.decision import BaseTradeDecision, TradeDecisionWithDetails, TradeDecisionWO, TradeRange +from qlib.backtest.decision import ( + BaseTradeDecision, + TradeDecisionWithDetails, + TradeDecisionWO, + TradeRange, +) from qlib.backtest.exchange import Exchange from qlib.backtest.executor import BaseExecutor from qlib.backtest.utils import LevelInfrastructure, get_start_end_idx @@ -97,11 +102,15 @@ def __init__( self.executor = executor self.exchange = exchange self.backtest_data = backtest_data - self.start_idx, _ = get_start_end_idx(self.executor.trade_calendar, trade_decision) + self.start_idx, _ = get_start_end_idx( + self.executor.trade_calendar, trade_decision + ) self.twap_price = self.backtest_data.get_deal_price().mean() - metric_keys = list(SAOEMetrics.__annotations__.keys()) # pylint: disable=no-member + metric_keys = list( + SAOEMetrics.__annotations__.keys() + ) # pylint: disable=no-member self.history_exec = pd.DataFrame(columns=metric_keys).set_index("datetime") self.history_steps = pd.DataFrame(columns=metric_keys).set_index("datetime") self.metrics: Optional[SAOEMetrics] = None @@ -134,7 +143,9 @@ def update( exec_vol = np.zeros(last_step_size) for order, _, __, ___ in execute_result: - idx, _ = get_day_min_idx_range(order.start_time, order.end_time, f"{self.data_granularity}min", REG_CN) + idx, _ = get_day_min_idx_range( + order.start_time, order.end_time, f"{self.data_granularity}min", REG_CN + ) exec_vol[idx - last_step_range[0]] = order.deal_amount if exec_vol.sum() > self.position and exec_vol.sum() > 0.0: @@ -165,20 +176,29 @@ def update( direction=self.order.direction, ), ) - market_price = fill_missing_data(np.array(market_price, dtype=float).reshape(-1)) - market_volume = fill_missing_data(np.array(market_volume, dtype=float).reshape(-1)) + market_price = fill_missing_data( + np.array(market_price, dtype=float).reshape(-1) + ) + market_volume = fill_missing_data( + np.array(market_volume, dtype=float).reshape(-1) + ) assert market_price.shape == market_volume.shape == exec_vol.shape # Get data from the current level executor's indicator current_trade_account = self.executor.trade_account - current_df = current_trade_account.get_trade_indicator().generate_trade_indicators_dataframe() + current_df = ( + current_trade_account.get_trade_indicator().generate_trade_indicators_dataframe() + ) self.history_exec = dataframe_append( self.history_exec, self._collect_multi_order_metric( order=self.order, datetime=_get_all_timestamps( - start_time, end_time, include_end=True, granularity=ONE_MIN * self.data_granularity + start_time, + end_time, + include_end=True, + granularity=ONE_MIN * self.data_granularity, ), market_vol=market_volume, market_price=market_price, @@ -260,7 +280,9 @@ def _collect_single_order_metric( if np.abs(np.sum(exec_vol)) < EPS: exec_avg_price = 0.0 else: - exec_avg_price = cast(float, np.average(market_price, weights=exec_vol)) # could be nan + exec_avg_price = cast( + float, np.average(market_price, weights=exec_vol) + ) # could be nan if hasattr(exec_avg_price, "item"): # could be numpy scalar exec_avg_price = exec_avg_price.item() # type: ignore @@ -340,8 +362,12 @@ def _create_qlib_backtest_adapter( data_granularity=self._data_granularity, ) - def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None: - super(SAOEStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs) + def reset( + self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any + ) -> None: + super(SAOEStrategy, self).reset( + outer_trade_decision=outer_trade_decision, **kwargs + ) self.adapter_dict = {} self._last_step_range = (0, 0) @@ -353,8 +379,10 @@ def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: self.adapter_dict = {} for decision in outer_trade_decision.get_decision(): order = cast(Order, decision) - self.adapter_dict[order.key_by_day] = self._create_qlib_backtest_adapter( - order, outer_trade_decision, trade_range + self.adapter_dict[order.key_by_day] = ( + self._create_qlib_backtest_adapter( + order, outer_trade_decision, trade_range + ) ) def get_saoe_state_by_order(self, order: Order) -> SAOEState: @@ -418,9 +446,13 @@ def __init__( common_infra: CommonInfrastructure | None = None, **kwargs: Any, ) -> None: - super().__init__(None, outer_trade_decision, level_infra, common_infra, **kwargs) + super().__init__( + None, outer_trade_decision, level_infra, common_infra, **kwargs + ) - def _generate_trade_decision(self, execute_result: list | None = None) -> Generator[Any, Any, BaseTradeDecision]: + def _generate_trade_decision( + self, execute_result: list | None = None + ) -> Generator[Any, Any, BaseTradeDecision]: # Once the following line is executed, this ProxySAOEStrategy (self) will be yielded to the outside # of the entire executor, and the execution will be suspended. When the execution is resumed by `send()`, # the item will be captured by `exec_vol`. The outside policy could communicate with the inner @@ -432,7 +464,9 @@ def _generate_trade_decision(self, execute_result: list | None = None) -> Genera return TradeDecisionWO([order], self) - def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None: + def reset( + self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any + ) -> None: super().reset(outer_trade_decision=outer_trade_decision, **kwargs) assert isinstance(outer_trade_decision, TradeDecisionWO) @@ -502,14 +536,20 @@ def __init__( if self._policy is not None: self._policy.eval() - def reset(self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any) -> None: + def reset( + self, outer_trade_decision: BaseTradeDecision | None = None, **kwargs: Any + ) -> None: super().reset(outer_trade_decision=outer_trade_decision, **kwargs) - def _generate_trade_details(self, act: np.ndarray, exec_vols: List[float]) -> pd.DataFrame: + def _generate_trade_details( + self, act: np.ndarray, exec_vols: List[float] + ) -> pd.DataFrame: assert hasattr(self.outer_trade_decision, "order_list") trade_details = [] - for a, v, o in zip(act, exec_vols, getattr(self.outer_trade_decision, "order_list")): + for a, v, o in zip( + act, exec_vols, getattr(self.outer_trade_decision, "order_list") + ): trade_details.append( { "instrument": o.stock_id, @@ -522,7 +562,9 @@ def _generate_trade_details(self, act: np.ndarray, exec_vols: List[float]) -> pd trade_details[-1]["rl_action"] = a return pd.DataFrame.from_records(trade_details) - def _generate_trade_decision(self, execute_result: list | None = None) -> BaseTradeDecision: + def _generate_trade_decision( + self, execute_result: list | None = None + ) -> BaseTradeDecision: states = [] obs_batch = [] for decision in self.outer_trade_decision.get_decision(): @@ -534,12 +576,20 @@ def _generate_trade_decision(self, execute_result: list | None = None) -> BaseTr with torch.no_grad(): policy_out = self._policy(Batch(obs_batch)) - act = policy_out.act.numpy() if torch.is_tensor(policy_out.act) else policy_out.act - exec_vols = [self._action_interpreter.interpret(s, a) for s, a in zip(states, act)] + act = ( + policy_out.act.numpy() + if torch.is_tensor(policy_out.act) + else policy_out.act + ) + exec_vols = [ + self._action_interpreter.interpret(s, a) for s, a in zip(states, act) + ] oh = self.trade_exchange.get_order_helper() order_list = [] - for decision, exec_vol in zip(self.outer_trade_decision.get_decision(), exec_vols): + for decision, exec_vol in zip( + self.outer_trade_decision.get_decision(), exec_vols + ): if exec_vol != 0: order = cast(Order, decision) order_list.append(oh.create(order.stock_id, exec_vol, order.direction)) diff --git a/qlib/rl/strategy/single_order.py b/qlib/rl/strategy/single_order.py index 45db0d9c89..58310a1b07 100644 --- a/qlib/rl/strategy/single_order.py +++ b/qlib/rl/strategy/single_order.py @@ -21,7 +21,9 @@ def __init__( self._order = order self._trade_range = trade_range - def generate_trade_decision(self, execute_result: list | None = None) -> TradeDecisionWO: + def generate_trade_decision( + self, execute_result: list | None = None + ) -> TradeDecisionWO: oh: OrderHelper = self.common_infra.get("trade_exchange").get_order_helper() order_list = [ oh.create( diff --git a/qlib/rl/trainer/callbacks.py b/qlib/rl/trainer/callbacks.py index 9d1bf4ba28..5f0cadf590 100644 --- a/qlib/rl/trainer/callbacks.py +++ b/qlib/rl/trainer/callbacks.py @@ -122,7 +122,12 @@ def __init__( self.min_delta *= -1 def state_dict(self) -> dict: - return {"wait": self.wait, "best": self.best, "best_weights": self.best_weights, "best_iter": self.best_iter} + return { + "wait": self.wait, + "best": self.best, + "best_weights": self.best_weights, + "best_iter": self.best_iter, + } def load_state_dict(self, state_dict: dict) -> None: self.wait = state_dict["wait"] @@ -155,9 +160,7 @@ def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: if self.baseline is None or self._is_improvement(current, self.baseline): self.wait = 0 - msg = ( - f"#{trainer.current_iter} current reward: {current:.4f}, best reward: {self.best:.4f} in #{self.best_iter}" - ) + msg = f"#{trainer.current_iter} current reward: {current:.4f}, best reward: {self.best:.4f} in #{self.best_iter}" _logger.info(msg) # Only check after the first epoch. @@ -165,7 +168,10 @@ def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: trainer.should_stop = True _logger.info(f"On iteration %d: early stopping", trainer.current_iter + 1) if self.restore_best_weights and self.best_weights is not None: - _logger.info("Restoring model weights from the end of the best iteration: %d", self.best_iter + 1) + _logger.info( + "Restoring model weights from the end of the best iteration: %d", + self.best_iter + 1, + ) vessel.load_state_dict(self.best_weights) def get_monitor_value(self, trainer: Trainer) -> Any: @@ -192,12 +198,20 @@ def __init__(self, dirpath: Path) -> None: self.valid_records: List[dict] = [] def on_train_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: - self.train_records.append({k: v for k, v in trainer.metrics.items() if not k.startswith("val/")}) - pd.DataFrame.from_records(self.train_records).to_csv(self.dirpath / "train_result.csv", index=True) + self.train_records.append( + {k: v for k, v in trainer.metrics.items() if not k.startswith("val/")} + ) + pd.DataFrame.from_records(self.train_records).to_csv( + self.dirpath / "train_result.csv", index=True + ) def on_validate_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: - self.valid_records.append({k: v for k, v in trainer.metrics.items() if k.startswith("val/")}) - pd.DataFrame.from_records(self.valid_records).to_csv(self.dirpath / "validation_result.csv", index=True) + self.valid_records.append( + {k: v for k, v in trainer.metrics.items() if k.startswith("val/")} + ) + pd.DataFrame.from_records(self.valid_records).to_csv( + self.dirpath / "validation_result.csv", index=True + ) class Checkpoint(Callback): @@ -253,15 +267,21 @@ def __init__( self._last_checkpoint_time: float | None = None def on_fit_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: - if self.save_on_fit_end and (trainer.current_iter != self._last_checkpoint_iter): + if self.save_on_fit_end and ( + trainer.current_iter != self._last_checkpoint_iter + ): self._save_checkpoint(trainer) def on_iter_end(self, trainer: Trainer, vessel: TrainingVesselBase) -> None: should_save_ckpt = False - if self.every_n_iters is not None and (trainer.current_iter + 1) % self.every_n_iters == 0: + if ( + self.every_n_iters is not None + and (trainer.current_iter + 1) % self.every_n_iters == 0 + ): should_save_ckpt = True if self.time_interval is not None and ( - self._last_checkpoint_time is None or (time.time() - self._last_checkpoint_time) >= self.time_interval + self._last_checkpoint_time is None + or (time.time() - self._last_checkpoint_time) >= self.time_interval ): should_save_ckpt = True if should_save_ckpt: @@ -287,5 +307,7 @@ def _save_checkpoint(self, trainer: Trainer) -> None: def _new_checkpoint_name(self, trainer: Trainer) -> str: return self.filename.format( - iter=trainer.current_iter, time=datetime.now().strftime("%Y%m%d%H%M%S"), **trainer.metrics + iter=trainer.current_iter, + time=datetime.now().strftime("%Y%m%d%H%M%S"), + **trainer.metrics, ) diff --git a/qlib/rl/trainer/trainer.py b/qlib/rl/trainer/trainer.py index a1046e966e..192299e930 100644 --- a/qlib/rl/trainer/trainer.py +++ b/qlib/rl/trainer/trainer.py @@ -14,7 +14,15 @@ from qlib.log import get_module_logger from qlib.rl.simulator import InitialStateType -from qlib.rl.utils import EnvWrapper, FiniteEnvType, LogBuffer, LogCollector, LogLevel, LogWriter, vectorize_env +from qlib.rl.utils import ( + EnvWrapper, + FiniteEnvType, + LogBuffer, + LogCollector, + LogLevel, + LogWriter, + vectorize_env, +) from qlib.rl.utils.finite_env import FiniteVectorEnv from qlib.typehint import Literal @@ -109,7 +117,9 @@ def __init__( else: self.loggers = [] - self.loggers.append(LogBuffer(self._metrics_callback, loglevel=self._min_loglevel())) + self.loggers.append( + LogBuffer(self._metrics_callback, loglevel=self._min_loglevel()) + ) self.callbacks: List[Callback] = callbacks if callbacks is not None else [] self.finite_env_type = finite_env_type @@ -144,8 +154,14 @@ def state_dict(self) -> dict: """ return { "vessel": self.vessel.state_dict(), - "callbacks": {name: callback.state_dict() for name, callback in self.named_callbacks().items()}, - "loggers": {name: logger.state_dict() for name, logger in self.named_loggers().items()}, + "callbacks": { + name: callback.state_dict() + for name, callback in self.named_callbacks().items() + }, + "loggers": { + name: logger.state_dict() + for name, logger in self.named_loggers().items() + }, "should_stop": self.should_stop, "current_iter": self.current_iter, "current_episode": self.current_episode, @@ -226,7 +242,10 @@ def fit(self, vessel: TrainingVesselBase, ckpt_path: Path | None = None) -> None self._call_callback_hooks("on_train_end") - if self.val_every_n_iters is not None and (self.current_iter + 1) % self.val_every_n_iters == 0: + if ( + self.val_every_n_iters is not None + and (self.current_iter + 1) % self.val_every_n_iters == 0 + ): # Implementation of validation loop self.current_stage = "val" self._call_callback_hooks("on_validate_start") @@ -271,7 +290,9 @@ def test(self, vessel: TrainingVesselBase) -> None: del vector_env # FIXME: Explicitly delete this object to avoid memory leak. self._call_callback_hooks("on_test_end") - def venv_from_iterator(self, iterator: Iterable[InitialStateType]) -> FiniteVectorEnv: + def venv_from_iterator( + self, iterator: Iterable[InitialStateType] + ) -> FiniteVectorEnv: """Create a vectorized environment from iterator and the training vessel.""" def env_factory(): @@ -306,7 +327,9 @@ def env_factory(): self.loggers, ) - def _metrics_callback(self, on_episode: bool, on_collect: bool, log_buffer: LogBuffer) -> None: + def _metrics_callback( + self, on_episode: bool, on_collect: bool, log_buffer: LogBuffer + ) -> None: if on_episode: # Update the global counter. self.current_episode = log_buffer.global_episode @@ -349,7 +372,9 @@ def _named_collection(seq: Sequence[T]) -> Dict[str, T]: retry_cnt: collections.Counter = collections.Counter() for item in seq: typename = type(item).__name__.lower() - key = typename if retry_cnt[typename] == 0 else f"{typename}{retry_cnt[typename]}" + key = ( + typename if retry_cnt[typename] == 0 else f"{typename}{retry_cnt[typename]}" + ) retry_cnt[typename] += 1 res[key] = item return res diff --git a/qlib/rl/trainer/vessel.py b/qlib/rl/trainer/vessel.py index b7912b488b..9c9a38f554 100644 --- a/qlib/rl/trainer/vessel.py +++ b/qlib/rl/trainer/vessel.py @@ -4,7 +4,18 @@ from __future__ import annotations import weakref -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Dict, Generic, Iterable, Sequence, TypeVar, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ContextManager, + Dict, + Generic, + Iterable, + Sequence, + TypeVar, + cast, +) import numpy as np from tianshou.data import Collector, VectorReplayBuffer @@ -13,7 +24,14 @@ from qlib.constant import INF from qlib.log import get_module_logger -from qlib.rl.interpreter import ActionInterpreter, ActType, ObsType, PolicyActType, StateInterpreter, StateType +from qlib.rl.interpreter import ( + ActionInterpreter, + ActType, + ObsType, + PolicyActType, + StateInterpreter, + StateType, +) from qlib.rl.reward import Reward from qlib.rl.simulator import InitialStateType, Simulator from qlib.rl.utils import DataQueue @@ -31,7 +49,9 @@ class SeedIteratorNotAvailable(BaseException): pass -class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType]): +class TrainingVesselBase( + Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType] +): """A ship that contains simulator, interpreter, and policy, will be sent to trainer. This class controls algorithm-related parts of training, while trainer is responsible for runtime part. @@ -39,7 +59,9 @@ class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType, and (optionally) some callbacks to insert customized logics at specific events. """ - simulator_fn: Callable[[InitialStateType], Simulator[InitialStateType, StateType, ActType]] + simulator_fn: Callable[ + [InitialStateType], Simulator[InitialStateType, StateType, ActType] + ] state_interpreter: StateInterpreter[StateType, ObsType] action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType] policy: BasePolicy @@ -49,17 +71,23 @@ class TrainingVesselBase(Generic[InitialStateType, StateType, ActType, ObsType, def assign_trainer(self, trainer: Trainer) -> None: self.trainer = weakref.proxy(trainer) # type: ignore - def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: + def train_seed_iterator( + self, + ) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: """Override this to create a seed iterator for training. If the iterable is a context manager, the whole training will be invoked in the with-block, and the iterator will be automatically closed after the training is done.""" raise SeedIteratorNotAvailable("Seed iterator for training is not available.") - def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: + def val_seed_iterator( + self, + ) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: """Override this to create a seed iterator for validation.""" raise SeedIteratorNotAvailable("Seed iterator for validation is not available.") - def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: + def test_seed_iterator( + self, + ) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: """Override this to create a seed iterator for testing.""" raise SeedIteratorNotAvailable("Seed iterator for testing is not available.") @@ -115,7 +143,9 @@ class TrainingVessel(TrainingVesselBase): def __init__( self, *, - simulator_fn: Callable[[InitialStateType], Simulator[InitialStateType, StateType, ActType]], + simulator_fn: Callable[ + [InitialStateType], Simulator[InitialStateType, StateType, ActType] + ], state_interpreter: StateInterpreter[StateType, ObsType], action_interpreter: ActionInterpreter[StateType, PolicyActType, ActType], policy: BasePolicy, @@ -139,25 +169,46 @@ def __init__( self.episode_per_iter = episode_per_iter self.update_kwargs = update_kwargs or {} - def train_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: + def train_seed_iterator( + self, + ) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: if self.train_initial_states is not None: - _logger.info("Training initial states collection size: %d", len(self.train_initial_states)) + _logger.info( + "Training initial states collection size: %d", + len(self.train_initial_states), + ) # Implement fast_dev_run here. - train_initial_states = self._random_subset("train", self.train_initial_states, self.trainer.fast_dev_run) + train_initial_states = self._random_subset( + "train", self.train_initial_states, self.trainer.fast_dev_run + ) return DataQueue(train_initial_states, repeat=-1, shuffle=True) return super().train_seed_iterator() - def val_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: + def val_seed_iterator( + self, + ) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: if self.val_initial_states is not None: - _logger.info("Validation initial states collection size: %d", len(self.val_initial_states)) - val_initial_states = self._random_subset("val", self.val_initial_states, self.trainer.fast_dev_run) + _logger.info( + "Validation initial states collection size: %d", + len(self.val_initial_states), + ) + val_initial_states = self._random_subset( + "val", self.val_initial_states, self.trainer.fast_dev_run + ) return DataQueue(val_initial_states, repeat=1) return super().val_seed_iterator() - def test_seed_iterator(self) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: + def test_seed_iterator( + self, + ) -> ContextManager[Iterable[InitialStateType]] | Iterable[InitialStateType]: if self.test_initial_states is not None: - _logger.info("Testing initial states collection size: %d", len(self.test_initial_states)) - test_initial_states = self._random_subset("test", self.test_initial_states, self.trainer.fast_dev_run) + _logger.info( + "Testing initial states collection size: %d", + len(self.test_initial_states), + ) + test_initial_states = self._random_subset( + "test", self.test_initial_states, self.trainer.fast_dev_run + ) return DataQueue(test_initial_states, repeat=1) return super().test_seed_iterator() @@ -169,7 +220,10 @@ def train(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]: with vector_env.collector_guard(): collector = Collector( - self.policy, vector_env, VectorReplayBuffer(self.buffer_size, len(vector_env)), exploration_noise=True + self.policy, + vector_env, + VectorReplayBuffer(self.buffer_size, len(vector_env)), + exploration_noise=True, ) # Number of episodes collected in each training iteration can be overridden by fast dev run. @@ -179,7 +233,9 @@ def train(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]: episodes = self.episode_per_iter col_result = collector.collect(n_episode=episodes) - update_result = self.policy.update(sample_size=0, buffer=collector.buffer, **self.update_kwargs) + update_result = self.policy.update( + sample_size=0, buffer=collector.buffer, **self.update_kwargs + ) res = {**col_result, **update_result} self.log_dict(res) return res @@ -203,7 +259,9 @@ def test(self, vector_env: FiniteVectorEnv) -> Dict[str, Any]: return res @staticmethod - def _random_subset(name: str, collection: Sequence[T], size: int | None) -> Sequence[T]: + def _random_subset( + name: str, collection: Sequence[T], size: int | None + ) -> Sequence[T]: if size is None: # Size = None -> original collection return collection diff --git a/qlib/rl/utils/data_queue.py b/qlib/rl/utils/data_queue.py index 71c2dff65b..ca503b2051 100644 --- a/qlib/rl/utils/data_queue.py +++ b/qlib/rl/utils/data_queue.py @@ -67,7 +67,9 @@ def __init__( if queue_maxsize == 0: if os.cpu_count() is not None: queue_maxsize = cast(int, os.cpu_count()) - _logger.info(f"Automatically set data queue maxsize to {queue_maxsize} to avoid overwhelming.") + _logger.info( + f"Automatically set data queue maxsize to {queue_maxsize} to avoid overwhelming." + ) else: queue_maxsize = 1 _logger.warning(f"CPU count not available. Setting queue maxsize to 1.") @@ -78,7 +80,9 @@ def __init__( self.producer_num_workers: int = producer_num_workers self._activated: bool = False - self._queue: multiprocessing.Queue = multiprocessing.Queue(maxsize=queue_maxsize) + self._queue: multiprocessing.Queue = multiprocessing.Queue( + maxsize=queue_maxsize + ) # Mypy 0.981 brought '"SynchronizedBase[Any]" has no attribute "value" [attr-defined]' bug. # Therefore, add this type casting to pass Mypy checking. self._done = cast(Synchronized, multiprocessing.Value("i", 0)) @@ -95,7 +99,10 @@ def cleanup(self) -> None: self._done.value += 1 for repeat in range(500): if repeat >= 1: - warnings.warn(f"After {repeat} cleanup, the queue is still not empty.", category=RuntimeWarning) + warnings.warn( + f"After {repeat} cleanup, the queue is still not empty.", + category=RuntimeWarning, + ) while not self._queue.empty(): try: self._queue.get(block=False) @@ -108,7 +115,9 @@ def cleanup(self) -> None: time.sleep(1.0) if self._queue.empty(): break - _logger.debug(f"Remaining items in queue collection done. Empty: {self._queue.empty()}") + _logger.debug( + f"Remaining items in queue collection done. Empty: {self._queue.empty()}" + ) def get(self, block: bool = True) -> Any: if not hasattr(self, "_first_get"): @@ -166,7 +175,10 @@ def _consumer(self) -> Generator[Any, None, None]: def _producer(self) -> None: # pytorch dataloader is used here only because we need its sampler and multi-processing - from torch.utils.data import DataLoader, Dataset # pylint: disable=import-outside-toplevel + from torch.utils.data import ( + DataLoader, + Dataset, + ) # pylint: disable=import-outside-toplevel try: dataloader = DataLoader( diff --git a/qlib/rl/utils/env_wrapper.py b/qlib/rl/utils/env_wrapper.py index e863b709a1..a0e593bef4 100644 --- a/qlib/rl/utils/env_wrapper.py +++ b/qlib/rl/utils/env_wrapper.py @@ -4,13 +4,28 @@ from __future__ import annotations import weakref -from typing import Any, Callable, cast, Dict, Generic, Iterable, Iterator, Optional, Tuple - -import gym -from gym import Space +from typing import ( + Any, + Callable, + cast, + Dict, + Generic, + Iterable, + Iterator, + Optional, + Tuple, +) + +import gymnasium as gymnasium +from gymnasium import Space from qlib.rl.aux_info import AuxiliaryInfoCollector -from qlib.rl.interpreter import ActionInterpreter, ObsType, PolicyActType, StateInterpreter +from qlib.rl.interpreter import ( + ActionInterpreter, + ObsType, + PolicyActType, + StateInterpreter, +) from qlib.rl.reward import Reward from qlib.rl.simulator import ActType, InitialStateType, Simulator, StateType from qlib.typehint import TypedDict @@ -49,10 +64,10 @@ class EnvWrapperStatus(TypedDict): class EnvWrapper( - gym.Env[ObsType, PolicyActType], + gymnasium.Env[ObsType, PolicyActType], Generic[InitialStateType, StateType, ActType, ObsType, PolicyActType], ): - """Qlib-based RL environment, subclassing ``gym.Env``. + """Qlib-based RL environment, subclassing ``gymnasium.Env``. A wrapper of components, including simulator, state-interpreter, action-interpreter, reward. This is what the framework of simulator - interpreter - policy looks like in RL training. @@ -114,7 +129,12 @@ def __init__( # 3. Avoid circular reference. # 4. When the components get serialized, we can throw away the env without any burden. # (though this part is not implemented yet) - for obj in [state_interpreter, action_interpreter, reward_fn, aux_info_collector]: + for obj in [ + state_interpreter, + action_interpreter, + reward_fn, + aux_info_collector, + ]: if obj is not None: obj.env = weakref.proxy(self) # type: ignore @@ -151,7 +171,9 @@ def reset(self, **kwargs: Any) -> ObsType: try: if self.seed_iterator is None: - raise RuntimeError("You can trying to get a state from a dead environment wrapper.") + raise RuntimeError( + "You can trying to get a state from a dead environment wrapper." + ) # TODO: simulator/observation might need seed to prefetch something # as only seed has the ability to do the work beforehands @@ -166,7 +188,9 @@ def reset(self, **kwargs: Any) -> ObsType: initial_state = None self.simulator = cast(Callable[[], Simulator], self.simulator_fn)() else: - initial_state = next(cast(Iterator[InitialStateType], self.seed_iterator)) + initial_state = next( + cast(Iterator[InitialStateType], self.seed_iterator) + ) self.simulator = self.simulator_fn(initial_state) self.status = EnvWrapperStatus( @@ -192,14 +216,18 @@ def reset(self, **kwargs: Any) -> ObsType: self.seed_iterator = None return generate_nan_observation(self.observation_space) - def step(self, policy_action: PolicyActType, **kwargs: Any) -> Tuple[ObsType, float, bool, InfoDict]: + def step( + self, policy_action: PolicyActType, **kwargs: Any + ) -> Tuple[ObsType, float, bool, InfoDict]: """Environment step. See the code along with comments to get a sequence of things happening here. """ if self.seed_iterator is None: - raise RuntimeError("State queue is already exhausted, but the environment is still receiving action.") + raise RuntimeError( + "State queue is already exhausted, but the environment is still receiving action." + ) # Clear the logged information from last step self.logger.reset() diff --git a/qlib/rl/utils/finite_env.py b/qlib/rl/utils/finite_env.py index 87f0900e16..8cc9522787 100644 --- a/qlib/rl/utils/finite_env.py +++ b/qlib/rl/utils/finite_env.py @@ -11,9 +11,21 @@ import copy import warnings from contextlib import contextmanager -from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Type, Union, cast - -import gym +from typing import ( + Any, + Callable, + Dict, + Generator, + List, + Optional, + Set, + Tuple, + Type, + Union, + cast, +) + +import gymnasium as gymnasium import numpy as np from tianshou.env import BaseVectorEnv, DummyVectorEnv, ShmemVectorEnv, SubprocVectorEnv @@ -69,7 +81,7 @@ def is_invalid(arr: int | float | bool | T) -> bool: return True -def generate_nan_observation(obs_space: gym.Space) -> Any: +def generate_nan_observation(obs_space: gymnasium.Space) -> Any: """The NaN observation that indicates the environment receives no seed. We assume that obs is complex and there must be something like float. @@ -123,7 +135,14 @@ class FiniteVectorEnv(BaseVectorEnv): _logger: list[LogWriter] def __init__( - self, logger: LogWriter | list[LogWriter] | None, env_fns: list[Callable[..., gym.Env]], **kwargs: Any +<<<<<<< HEAD + self, logger: LogWriter | list[LogWriter] | None, env_fns: list[Callable[..., gymnasium.Env]], **kwargs: Any +======= + self, + logger: LogWriter | list[LogWriter] | None, + env_fns: list[Callable[..., gymnasium.Env]], + **kwargs: Any, +>>>>>>> f180e36a (fix: migrate from gym to gymnasium for NumPy 2.0+ compatibility) ) -> None: super().__init__(env_fns, **kwargs) @@ -311,7 +330,7 @@ class FiniteShmemVectorEnv(FiniteVectorEnv, ShmemVectorEnv): def vectorize_env( - env_factory: Callable[..., gym.Env], + env_factory: Callable[..., gymnasium.Env], env_type: FiniteEnvType, concurrency: int, logger: LogWriter | List[LogWriter], @@ -320,11 +339,11 @@ def vectorize_env( For example, once you wrote: :: - DummyVectorEnv([lambda: gym.make(task) for _ in range(env_num)]) + DummyVectorEnv([lambda: gymnasium.make(task) for _ in range(env_num)]) Now you can replace it with: :: - finite_env_factory(lambda: gym.make(task), "dummy", env_num, my_logger) + finite_env_factory(lambda: gymnasium.make(task), "dummy", env_num, my_logger) By doing such replacement, you have two additional features enabled (compared to normal VectorEnv): @@ -335,7 +354,7 @@ def vectorize_env( Parameters ---------- env_factory - Callable to instantiate one single ``gym.Env``. + Callable to instantiate one single ``gymnasium.Env``. All concurrent workers will have the same ``env_factory``. env_type dummy or subproc or shmem. Corresponding to diff --git a/qlib/rl/utils/log.py b/qlib/rl/utils/log.py index 75aab20688..65598c3c81 100644 --- a/qlib/rl/utils/log.py +++ b/qlib/rl/utils/log.py @@ -21,7 +21,18 @@ from collections import defaultdict from enum import IntEnum from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, List, Sequence, Set, Tuple, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generic, + List, + Sequence, + Set, + Tuple, + TypeVar, +) import numpy as np import pandas as pd @@ -32,7 +43,14 @@ from .env_wrapper import InfoDict -__all__ = ["LogCollector", "LogWriter", "LogLevel", "LogBuffer", "ConsoleWriter", "CsvWriter"] +__all__ = [ + "LogCollector", + "LogWriter", + "LogLevel", + "LogBuffer", + "ConsoleWriter", + "CsvWriter", +] ObsType = TypeVar("ObsType") ActType = TypeVar("ActType") @@ -77,10 +95,14 @@ def reset(self) -> None: def _add_metric(self, name: str, metric: Any, loglevel: int | LogLevel) -> None: if name in self._logged: - raise ValueError(f"A metric with {name} is already added. Please change a name or reset the log collector.") + raise ValueError( + f"A metric with {name} is already added. Please change a name or reset the log collector." + ) self._logged[name] = (int(loglevel), metric) - def add_string(self, name: str, string: str, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None: + def add_string( + self, name: str, string: str, loglevel: int | LogLevel = LogLevel.PERIODIC + ) -> None: """Add a string with name into logged contents.""" if loglevel < self._min_loglevel: return @@ -88,7 +110,9 @@ def add_string(self, name: str, string: str, loglevel: int | LogLevel = LogLevel raise TypeError(f"{string} is not a string.") self._add_metric(name, string, loglevel) - def add_scalar(self, name: str, scalar: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None: + def add_scalar( + self, name: str, scalar: Any, loglevel: int | LogLevel = LogLevel.PERIODIC + ) -> None: """Add a scalar with name into logged contents. Scalar will be converted into a float. """ @@ -99,7 +123,9 @@ def add_scalar(self, name: str, scalar: Any, loglevel: int | LogLevel = LogLevel # could be single-item number scalar = scalar.item() if not isinstance(scalar, (float, int)): - raise TypeError(f"{scalar} is not and can not be converted into float or integer.") + raise TypeError( + f"{scalar} is not and can not be converted into float or integer." + ) scalar = float(scalar) self._add_metric(name, scalar, loglevel) @@ -117,7 +143,9 @@ def add_array( raise TypeError(f"{array} is not one of ndarray, DataFrame and Series.") self._add_metric(name, array, loglevel) - def add_any(self, name: str, obj: Any, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None: + def add_any( + self, name: str, obj: Any, loglevel: int | LogLevel = LogLevel.PERIODIC + ) -> None: """Log something with any type. As it's an "any" object, the only LogWriter accepting it is pickle. @@ -131,7 +159,10 @@ def add_any(self, name: str, obj: Any, loglevel: int | LogLevel = LogLevel.PERIO self._add_metric(name, obj, loglevel) def logs(self) -> Dict[str, np.ndarray]: - return {key: np.asanyarray(value, dtype="object") for key, value in self._logged.items()} + return { + key: np.asanyarray(value, dtype="object") + for key, value in self._logged.items() + } class LogWriter(Generic[ObsType, ActType]): @@ -233,7 +264,9 @@ def aggregation(array: Sequence[Any], name: str | None = None) -> Any: else: return array[0] - def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None: + def log_episode( + self, length: int, rewards: List[float], contents: List[Dict[str, Any]] + ) -> None: """This is triggered at the end of each trajectory. Parameters @@ -257,7 +290,9 @@ def log_step(self, reward: float, contents: Dict[str, Any]) -> None: Logged contents for this step. """ - def on_env_step(self, env_id: int, obs: ObsType, rew: float, done: bool, info: InfoDict) -> None: + def on_env_step( + self, env_id: int, obs: ObsType, rew: float, done: bool, info: InfoDict + ) -> None: """Callback for finite env, on each step.""" # Update counter @@ -272,7 +307,9 @@ def on_env_step(self, env_id: int, obs: ObsType, rew: float, done: bool, info: I values: Dict[str, Any] = {} for key, (loglevel, value) in info["log"].items(): - if loglevel >= self.loglevel: # FIXME: this is actually incorrect (see last FIXME) + if ( + loglevel >= self.loglevel + ): # FIXME: this is actually incorrect (see last FIXME) values[key] = value self.episode_logs[env_id].append(values) @@ -283,7 +320,11 @@ def on_env_step(self, env_id: int, obs: ObsType, rew: float, done: bool, info: I self.global_episode += 1 self.episode_count += 1 - self.log_episode(self.episode_lengths[env_id], self.episode_rewards[env_id], self.episode_logs[env_id]) + self.log_episode( + self.episode_lengths[env_id], + self.episode_rewards[env_id], + self.episode_logs[env_id], + ) def on_env_reset(self, env_id: int, _: ObsType) -> None: """Callback for finite env. @@ -328,7 +369,11 @@ class LogBuffer(LogWriter): # FIXME: needs a metric count - def __init__(self, callback: Callable[[bool, bool, LogBuffer], None], loglevel: int | LogLevel = LogLevel.PERIODIC): + def __init__( + self, + callback: Callable[[bool, bool, LogBuffer], None], + loglevel: int | LogLevel = LogLevel.PERIODIC, + ): super().__init__(loglevel) self.callback = callback @@ -349,7 +394,9 @@ def clear(self): self._latest_metrics: dict[str, float] | None = None self._aggregated_metrics: dict[str, float] = defaultdict(float) - def log_episode(self, length: int, rewards: list[float], contents: list[dict[str, Any]]) -> None: + def log_episode( + self, length: int, rewards: list[float], contents: list[dict[str, Any]] + ) -> None: # FIXME Dup of ConsoleWriter episode_wise_contents: dict[str, list] = defaultdict(list) for step_contents in contents: @@ -379,7 +426,10 @@ def episode_metrics(self) -> dict[str, float]: def collect_metrics(self) -> dict[str, float]: """Retrieve the aggregated metrics of the latest collect.""" - return {name: value / self.episode_count for name, value in self._aggregated_metrics.items()} + return { + name: value / self.episode_count + for name, value in self._aggregated_metrics.items() + } class ConsoleWriter(LogWriter): @@ -422,7 +472,9 @@ def clear(self) -> None: self.metric_counts: Dict[str, int] = defaultdict(int) self.metric_sums: Dict[str, float] = defaultdict(float) - def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None: + def log_episode( + self, length: int, rewards: List[float], contents: List[Dict[str, Any]] + ) -> None: # Aggregate step-wise to episode-wise episode_wise_contents: Dict[str, list] = defaultdict(list) @@ -441,7 +493,10 @@ def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str self.metric_counts[name] += 1 self.metric_sums[name] += value - if self.episode_count % self.log_every_n_episode == 0 or self.episode_count == self.total_episodes: + if ( + self.episode_count % self.log_every_n_episode == 0 + or self.episode_count == self.total_episodes + ): # Only log periodically or at the end self.console_logger.info(self.generate_log_message(logs)) @@ -453,14 +508,20 @@ def generate_log_message(self, logs: Dict[str, float]) -> str: if self.total_episodes is None: msg_prefix += "[Step {" + self.counter_format + "}]" else: - msg_prefix += "[{" + self.counter_format + "}/" + str(self.total_episodes) + "]" + msg_prefix += ( + "[{" + self.counter_format + "}/" + str(self.total_episodes) + "]" + ) msg_prefix = msg_prefix.format(self.episode_count) msg = "" for name, value in logs.items(): # Double-space as delimiter - format_template = r" {} {" + self.float_format + "} ({" + self.float_format + "})" - msg += format_template.format(name, value, self.metric_sums[name] / self.metric_counts[name]) + format_template = ( + r" {} {" + self.float_format + "} ({" + self.float_format + "})" + ) + msg += format_template.format( + name, value, self.metric_sums[name] / self.metric_counts[name] + ) msg = msg_prefix + " " + msg @@ -479,7 +540,9 @@ class CsvWriter(LogWriter): # FIXME: save & reload - def __init__(self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC) -> None: + def __init__( + self, output_dir: Path, loglevel: int | LogLevel = LogLevel.PERIODIC + ) -> None: super().__init__(loglevel) self.output_dir = output_dir self.output_dir.mkdir(exist_ok=True) @@ -488,7 +551,9 @@ def clear(self) -> None: super().clear() self.all_records = [] - def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str, Any]]) -> None: + def log_episode( + self, length: int, rewards: List[float], contents: List[Dict[str, Any]] + ) -> None: # FIXME Same as ConsoleLogger, needs a refactor to eliminate code-dup episode_wise_contents: Dict[str, list] = defaultdict(list) @@ -505,7 +570,9 @@ def log_episode(self, length: int, rewards: List[float], contents: List[Dict[str def on_env_all_done(self) -> None: # FIXME: this is temporary - pd.DataFrame.from_records(self.all_records).to_csv(self.output_dir / "result.csv", index=False) + pd.DataFrame.from_records(self.all_records).to_csv( + self.output_dir / "result.csv", index=False + ) # The following are not implemented yet. diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index a9e138fdbb..c0b1a449d6 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -13,7 +13,11 @@ from typing import Tuple from ..backtest.decision import BaseTradeDecision -from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager +from ..backtest.utils import ( + CommonInfrastructure, + LevelInfrastructure, + TradeCalendarManager, +) from ..rl.interpreter import ActionInterpreter, StateInterpreter from ..utils import init_instance_by_config @@ -56,7 +60,11 @@ def __init__( - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended. """ - self._reset(level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision) + self._reset( + level_infra=level_infra, + common_infra=common_infra, + outer_trade_decision=outer_trade_decision, + ) self._trade_exchange = trade_exchange @property @@ -74,7 +82,9 @@ def trade_position(self) -> BasePosition: @property def trade_exchange(self) -> Exchange: """get trade exchange in a prioritized order""" - return getattr(self, "_trade_exchange", None) or self.common_infra.get("trade_exchange") + return getattr(self, "_trade_exchange", None) or self.common_infra.get( + "trade_exchange" + ) def reset_level_infra(self, level_infra: LevelInfrastructure) -> None: if not hasattr(self, "level_infra"): @@ -202,7 +212,9 @@ def update_trade_decision( # default to return None, which indicates that the trade decision is not changed return None - def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision: + def alter_outer_trade_decision( + self, outer_trade_decision: BaseTradeDecision + ) -> BaseTradeDecision: """ A method for updating the outer_trade_decision. The outer strategy may change its decision during updating. @@ -254,7 +266,9 @@ def __init__( policy : RL policy for generate action """ - super(RLStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs) + super(RLStrategy, self).__init__( + outer_trade_decision, level_infra, common_infra, **kwargs + ) self.policy = policy @@ -283,14 +297,22 @@ def __init__( end_time : Union[str, pd.Timestamp], optional end time of trading, by default None """ - super(RLIntStrategy, self).__init__(policy, outer_trade_decision, level_infra, common_infra, **kwargs) + super(RLIntStrategy, self).__init__( + policy, outer_trade_decision, level_infra, common_infra, **kwargs + ) self.policy = policy - self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter) - self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter) + self.state_interpreter = init_instance_by_config( + state_interpreter, accept_types=StateInterpreter + ) + self.action_interpreter = init_instance_by_config( + action_interpreter, accept_types=ActionInterpreter + ) def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision: - _interpret_state = self.state_interpreter.interpret(execute_result=execute_result) + _interpret_state = self.state_interpreter.interpret( + execute_result=execute_result + ) _action = self.policy.step(_interpret_state) _trade_decision = self.action_interpreter.interpret(action=_action) return _trade_decision diff --git a/qlib/tests/__init__.py b/qlib/tests/__init__.py index f9793cdabd..9030f4e752 100644 --- a/qlib/tests/__init__.py +++ b/qlib/tests/__init__.py @@ -10,7 +10,14 @@ from qlib.data.filter import NameDFilter from qlib.data import D from qlib.data.data import Cal, DatasetD -from qlib.data.storage import CalendarStorage, InstrumentStorage, FeatureStorage, CalVT, InstKT, InstVT +from qlib.data.storage import ( + CalendarStorage, + InstrumentStorage, + FeatureStorage, + CalVT, + InstKT, + InstVT, +) class TestAutoData(unittest.TestCase): @@ -206,14 +213,17 @@ def __len__(self) -> int: class MockFeatureStorage(MockStorageBase, FeatureStorage): def __init__(self, instrument: str, field: str, freq: str, db_region: str = None, **kwargs): # type: ignore - super().__init__(instrument=instrument, field=field, freq=freq, db_region=db_region, **kwargs) + super().__init__( + instrument=instrument, field=field, freq=freq, db_region=db_region, **kwargs + ) self.field = field calendar = sorted(self.df["datetime"].unique()) df_calendar = pd.DataFrame(calendar, columns=["datetime"]).set_index("datetime") df = self.df[self.df["symbol"] == instrument] data_dt_field = "datetime" cal_df = df_calendar[ - (df_calendar.index >= df[data_dt_field].min()) & (df_calendar.index <= df[data_dt_field].max()) + (df_calendar.index >= df[data_dt_field].min()) + & (df_calendar.index <= df[data_dt_field].max()) ] df = df.set_index(data_dt_field) df_data = df.reindex(cal_df.index) @@ -269,21 +279,36 @@ class TestMockData(unittest.TestCase): "calendar_provider": { "class": "LocalCalendarProvider", "module_path": "qlib.data.data", - "kwargs": {"backend": {"class": "MockCalendarStorage", "module_path": "qlib.tests"}}, + "kwargs": { + "backend": {"class": "MockCalendarStorage", "module_path": "qlib.tests"} + }, }, "instrument_provider": { "class": "LocalInstrumentProvider", "module_path": "qlib.data.data", - "kwargs": {"backend": {"class": "MockInstrumentStorage", "module_path": "qlib.tests"}}, + "kwargs": { + "backend": { + "class": "MockInstrumentStorage", + "module_path": "qlib.tests", + } + }, }, "feature_provider": { "class": "LocalFeatureProvider", "module_path": "qlib.data.data", - "kwargs": {"backend": {"class": "MockFeatureStorage", "module_path": "qlib.tests"}}, + "kwargs": { + "backend": {"class": "MockFeatureStorage", "module_path": "qlib.tests"} + }, }, } @classmethod def setUpClass(cls) -> None: provider_uri = "Not necessary." - init(region=REG_TW, provider_uri=provider_uri, expression_cache=None, dataset_cache=None, **cls._setup_kwargs) + init( + region=REG_TW, + provider_uri=provider_uri, + expression_cache=None, + dataset_cache=None, + **cls._setup_kwargs, + ) diff --git a/qlib/tests/config.py b/qlib/tests/config.py index ea1b236594..bfb4641c57 100644 --- a/qlib/tests/config.py +++ b/qlib/tests/config.py @@ -98,7 +98,9 @@ def get_gbdt_task(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKE } -def get_record_lgb_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}): +def get_record_lgb_config( + dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET} +): return { "model": { "class": "LGBModel", @@ -109,7 +111,9 @@ def get_record_lgb_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI3 } -def get_record_xgboost_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}): +def get_record_xgboost_config( + dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET} +): return { "model": { "class": "XGBModel", @@ -120,11 +124,17 @@ def get_record_xgboost_config(dataset_kwargs={}, handler_kwargs={"instruments": } -CSI300_DATASET_CONFIG = get_dataset_config(handler_kwargs={"instruments": CSI300_MARKET}) +CSI300_DATASET_CONFIG = get_dataset_config( + handler_kwargs={"instruments": CSI300_MARKET} +) CSI300_GBDT_TASK = get_gbdt_task(handler_kwargs={"instruments": CSI300_MARKET}) -CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(handler_kwargs={"instruments": CSI100_MARKET}) -CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(handler_kwargs={"instruments": CSI100_MARKET}) +CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config( + handler_kwargs={"instruments": CSI100_MARKET} +) +CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config( + handler_kwargs={"instruments": CSI100_MARKET} +) # use for rolling_online_managment.py ROLLING_HANDLER_CONFIG = { diff --git a/qlib/tests/data.py b/qlib/tests/data.py index 2fa76855b5..3a4dea9b01 100644 --- a/qlib/tests/data.py +++ b/qlib/tests/data.py @@ -39,7 +39,11 @@ def merge_remote_url(self, file_name: str): The file name can be accompanied by a version number, (e.g.: v2/qlib_data_simple_cn_1d_latest.zip), if no version number is attached, it will be downloaded from v0 by default. """ - return f"{self.REMOTE_URL}/{file_name}" if "/" in file_name else f"{self.REMOTE_URL}/v0/{file_name}" + return ( + f"{self.REMOTE_URL}/{file_name}" + if "/" in file_name + else f"{self.REMOTE_URL}/v0/{file_name}" + ) def download(self, url: str, target_path: [Path, str]): """ @@ -69,7 +73,9 @@ def download(self, url: str, target_path: [Path, str]): fp.write(chunk) p_bar.update(chunk_size) - def download_data(self, file_name: str, target_dir: [Path, str], delete_old: bool = True): + def download_data( + self, file_name: str, target_dir: [Path, str], delete_old: bool = True + ): """ Download the specified file to the target folder. @@ -98,7 +104,11 @@ def download_data(self, file_name: str, target_dir: [Path, str], delete_old: boo target_dir = Path(target_dir).expanduser() target_dir.mkdir(exist_ok=True, parents=True) # saved file name - _target_file_name = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + "_" + os.path.basename(file_name) + _target_file_name = ( + datetime.datetime.now().strftime("%Y%m%d%H%M%S") + + "_" + + os.path.basename(file_name) + ) target_path = target_dir.joinpath(_target_file_name) url = self.merge_remote_url(file_name) @@ -117,7 +127,9 @@ def check_dataset(self, file_name: str): return status @staticmethod - def _unzip(file_path: [Path, str], target_dir: [Path, str], delete_old: bool = True): + def _unzip( + file_path: [Path, str], target_dir: [Path, str], delete_old: bool = True + ): file_path = Path(file_path) target_dir = Path(target_dir) if delete_old: @@ -133,7 +145,13 @@ def _unzip(file_path: [Path, str], target_dir: [Path, str], delete_old: bool = T @staticmethod def _delete_qlib_data(file_dir: Path): rm_dirs = [] - for _name in ["features", "calendars", "instruments", "features_cache", "dataset_cache"]: + for _name in [ + "features", + "calendars", + "instruments", + "features_cache", + "dataset_cache", + ]: _p = file_dir.joinpath(_name) if _p.exists(): rm_dirs.append(str(_p.resolve())) diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 2a94ebd555..b38094c8c9 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -150,7 +150,9 @@ def read_period_data( # find the first index of linked revisions if last_period_index is None: with open(index_path, "rb") as fi: - (first_year,) = struct.unpack(PERIOD_DTYPE, fi.read(struct.calcsize(PERIOD_DTYPE))) + (first_year,) = struct.unpack( + PERIOD_DTYPE, fi.read(struct.calcsize(PERIOD_DTYPE)) + ) all_periods = np.fromfile(fi, dtype=INDEX_DTYPE) offset = get_period_offset(first_year, period, quarterly) _next = all_periods[offset] @@ -164,7 +166,9 @@ def read_period_data( with open(data_path, "rb") as fd: while _next != NAN_INDEX: fd.seek(_next) - date, period, value, new_next = struct.unpack(DATA_DTYPE, fd.read(struct.calcsize(DATA_DTYPE))) + date, period, value, new_next = struct.unpack( + DATA_DTYPE, fd.read(struct.calcsize(DATA_DTYPE)) + ) if date > cur_date_int: break prev_next = _next @@ -380,7 +384,9 @@ def is_tradable_date(cur_date): """ from ..data import D # pylint: disable=C0415 - return str(cur_date.date()) == str(D.calendar(start_time=cur_date, future=True)[0].date()) + return str(cur_date.date()) == str( + D.calendar(start_time=cur_date, future=True)[0].date() + ) def get_date_range(trading_date, left_shift=0, right_shift=0, future=False): @@ -444,7 +450,9 @@ def get_date_by_shift( if clip_shift: shift_index = np.clip(shift_index, 0, len(cal) - 1) else: - raise IndexError(f"The shift_index({shift_index}) of the trading day ({trading_date}) is out of range") + raise IndexError( + f"The shift_index({shift_index}) of the trading day ({trading_date}) is out of range" + ) return cal[shift_index] @@ -481,7 +489,11 @@ def transform_end_date(end_date=None, freq="day"): from ..data import D # pylint: disable=C0415 last_date = D.calendar(freq=freq)[-1] - if end_date is None or (str(end_date) == "-1") or (pd.Timestamp(last_date) < pd.Timestamp(end_date)): + if ( + end_date is None + or (str(end_date) == "-1") + or (pd.Timestamp(last_date) < pd.Timestamp(end_date)) + ): log.warning( "\nInfo: the end_date in the configuration file is {}, " "so the default last date {} is used.".format(end_date, last_date) @@ -595,7 +607,9 @@ def exists_qlib_data(qlib_dir): return False # check instruments - code_names = set(map(lambda x: fname_to_code(x.name.lower()), features_dir.iterdir())) + code_names = set( + map(lambda x: fname_to_code(x.name.lower()), features_dir.iterdir()) + ) _instrument = instruments_dir.joinpath("all.txt") # Removed two possible ticker names "NA" and "NULL" from the default na_values list for column 0 miss_code = set( @@ -865,7 +879,9 @@ def register(self, provider): self._provider = provider def __repr__(self): - return "{name}(provider={provider})".format(name=self.__class__.__name__, provider=self._provider) + return "{name}(provider={provider})".format( + name=self.__class__.__name__, provider=self._provider + ) def __getattr__(self, key): if self.__dict__.get("_provider", None) is None: diff --git a/qlib/utils/file.py b/qlib/utils/file.py index 1e17a574a9..8c8c8fc22a 100644 --- a/qlib/utils/file.py +++ b/qlib/utils/file.py @@ -25,7 +25,9 @@ def get_or_create_path(path: Optional[Text] = None, return_dir: bool = False): if path: if return_dir and not os.path.exists(path): os.makedirs(path) - elif not return_dir: # return a file, thus we need to create its parent directory + elif ( + not return_dir + ): # return a file, thus we need to create its parent directory xpath = os.path.abspath(os.path.join(path, "..")) if not os.path.exists(xpath): os.makedirs(xpath) @@ -74,7 +76,9 @@ def save_multiple_parts_file(filename, format="gztar"): # Create model dir if os.path.exists(file_path): - raise FileExistsError("ERROR: file exists: {}, cannot be create the directory.".format(file_path)) + raise FileExistsError( + "ERROR: file exists: {}, cannot be create the directory.".format(file_path) + ) os.makedirs(file_path) @@ -185,6 +189,8 @@ def get_io_object(file: Union[IO, str, Path], *args, **kwargs) -> IO: if isinstance(file, str): file = Path(file) if not isinstance(file, Path): - raise NotImplementedError(f"This type[{type(file)}] of input is not supported") + raise NotImplementedError( + f"This type[{type(file)}] of input is not supported" + ) with file.open(*args, **kwargs) as f: yield f diff --git a/qlib/utils/index_data.py b/qlib/utils/index_data.py index c707240d09..783b838c61 100644 --- a/qlib/utils/index_data.py +++ b/qlib/utils/index_data.py @@ -54,7 +54,9 @@ def concat(data_list: Union[SingleData], axis=0) -> MultiData: raise ValueError(f"axis must be 0 or 1") -def sum_by_index(data_list: Union[SingleData], new_index: list, fill_value=0) -> SingleData: +def sum_by_index( + data_list: Union[SingleData], new_index: list, fill_value=0 +) -> SingleData: """concat all SingleData by new index. Parameters @@ -98,7 +100,9 @@ class Index: """ def __init__(self, idx_list: Union[List, pd.Index, "Index", int]): - self.idx_list: np.ndarray = None # using array type for index list will make things easier + self.idx_list: np.ndarray = ( + None # using array type for index list will make things easier + ) if isinstance(idx_list, Index): # Fast read-only copy self.idx_list = idx_list.idx_list @@ -112,8 +116,12 @@ def __init__(self, idx_list: Union[List, pd.Index, "Index", int]): if not all(isinstance(x, type(idx_list[0])) for x in idx_list): raise TypeError("All elements in idx_list must be of the same type") # Check if all elements in idx_list are of the same datetime64 precision - if isinstance(idx_list[0], np.datetime64) and not all(x.dtype == idx_list[0].dtype for x in idx_list): - raise TypeError("All elements in idx_list must be of the same datetime64 precision") + if isinstance(idx_list[0], np.datetime64) and not all( + x.dtype == idx_list[0].dtype for x in idx_list + ): + raise TypeError( + "All elements in idx_list must be of the same datetime64 precision" + ) self.idx_list = np.array(idx_list) # NOTE: only the first appearance is indexed self.index_map = dict(zip(self.idx_list, range(len(self)))) @@ -212,14 +220,18 @@ class LocIndexer: Modifications will results in new Index. """ - def __init__(self, index_data: "IndexData", indices: List[Index], int_loc: bool = False): + def __init__( + self, index_data: "IndexData", indices: List[Index], int_loc: bool = False + ): self._indices: List[Index] = indices self._bind_id = index_data # bind index data self._int_loc = int_loc assert self._bind_id.data.ndim == len(self._indices) @staticmethod - def proc_idx_l(indices: List[Union[List, pd.Index, Index]], data_shape: Tuple = None) -> List[Index]: + def proc_idx_l( + indices: List[Union[List, pd.Index, Index]], data_shape: Tuple = None + ) -> List[Index]: """process the indices from user and output a list of `Index`""" res = [] for i, idx in enumerate(indices): @@ -243,8 +255,16 @@ def _slc_convert(self, index: Index, indexing: slice) -> slice: the integer based slicing """ if index.is_sorted(): - int_start = None if indexing.start is None else bisect.bisect_left(index, indexing.start) - int_stop = None if indexing.stop is None else bisect.bisect_right(index, indexing.stop) + int_start = ( + None + if indexing.start is None + else bisect.bisect_left(index, indexing.start) + ) + int_stop = ( + None + if indexing.stop is None + else bisect.bisect_right(index, indexing.stop) + ) else: int_start = None if indexing.start is None else index.index(indexing.start) int_stop = None if indexing.stop is None else index.index(indexing.stop) + 1 @@ -275,7 +295,9 @@ def __getitem__(self, indexing): for dim, index in enumerate(self._indices): if dim < len(indexing): _indexing = indexing[dim] - if not self._int_loc: # type converting is only necessary when it is not `iloc` + if ( + not self._int_loc + ): # type converting is only necessary when it is not `iloc` if isinstance(_indexing, slice): _indexing = self._slc_convert(index, _indexing) elif isinstance(_indexing, (IndexData, np.ndarray)): @@ -283,7 +305,9 @@ def __getitem__(self, indexing): _indexing = _indexing.data assert _indexing.ndim == 1 if _indexing.dtype != bool: - _indexing = np.array(list(index.index(i) for i in _indexing)) + _indexing = np.array( + list(index.index(i) for i in _indexing) + ) else: _indexing = index.index(_indexing) else: @@ -297,7 +321,9 @@ def __getitem__(self, indexing): if new_data.ndim == 0: return new_data # otherwise we go on to the index part - new_indices = [idx[indexing] for idx, indexing in zip(self._indices, int_indexing)] + new_indices = [ + idx[indexing] for idx, indexing in zip(self._indices, int_indexing) + ] # 3) squash dimensions new_indices = [ @@ -329,7 +355,9 @@ def __call__(self, other): return self.obj.__class__(self_data_method(other), *self.obj.indices) elif isinstance(other, self.obj.__class__): other_aligned = self.obj._align_indices(other) - return self.obj.__class__(self_data_method(other_aligned.data), *self.obj.indices) + return self.obj.__class__( + self_data_method(other_aligned.data), *self.obj.indices + ) else: return NotImplemented @@ -338,7 +366,16 @@ def index_data_ops_creator(*args, **kwargs): """ meta class for auto generating operations for index data. """ - for method_name in ["__add__", "__sub__", "__rsub__", "__mul__", "__truediv__", "__eq__", "__gt__", "__lt__"]: + for method_name in [ + "__add__", + "__sub__", + "__rsub__", + "__mul__", + "__truediv__", + "__eq__", + "__gt__", + "__lt__", + ]: args[2][method_name] = BinaryOps(method_name=method_name) return type(*args) @@ -472,7 +509,9 @@ def __len__(self): return len(self.data) def sum(self, axis=None, dtype=None, out=None): - assert out is None and dtype is None, "`out` is just for compatible with numpy's aggregating function" + assert ( + out is None and dtype is None + ), "`out` is just for compatible with numpy's aggregating function" # FIXME: weird logic and not general if axis is None: return np.nansum(self.data) @@ -486,7 +525,9 @@ def sum(self, axis=None, dtype=None, out=None): raise ValueError(f"axis must be None, 0 or 1") def mean(self, axis=None, dtype=None, out=None): - assert out is None and dtype is None, "`out` is just for compatible with numpy's aggregating function" + assert ( + out is None and dtype is None + ), "`out` is just for compatible with numpy's aggregating function" # FIXME: weird logic and not general if axis is None: return np.nanmean(self.data) @@ -528,7 +569,9 @@ def values(self): class SingleData(IndexData): def __init__( - self, data: Union[int, float, np.number, list, dict, pd.Series] = [], index: Union[List, pd.Index, Index] = [] + self, + data: Union[int, float, np.number, list, dict, pd.Series] = [], + index: Union[List, pd.Index, Index] = [], ): """A data structure of index and numpy data. It's used to replace pd.Series due to high-speed. @@ -651,4 +694,8 @@ def _align_indices(self, other): ) def __repr__(self) -> str: - return str(pd.DataFrame(self.data, index=self.index.tolist(), columns=self.columns.tolist())) + return str( + pd.DataFrame( + self.data, index=self.index.tolist(), columns=self.columns.tolist() + ) + ) diff --git a/qlib/utils/mod.py b/qlib/utils/mod.py index 12fbc58703..28652ec77d 100644 --- a/qlib/utils/mod.py +++ b/qlib/utils/mod.py @@ -36,8 +36,14 @@ def get_module_by_module_path(module_path: Union[str, ModuleType]): module = module_path else: if module_path.endswith(".py"): - module_name = re.sub("^[^a-zA-Z_]+", "", re.sub("[^0-9a-zA-Z_]", "", module_path[:-3].replace("/", "_"))) - module_spec = importlib.util.spec_from_file_location(module_name, module_path) + module_name = re.sub( + "^[^a-zA-Z_]+", + "", + re.sub("[^0-9a-zA-Z_]", "", module_path[:-3].replace("/", "_")), + ) + module_spec = importlib.util.spec_from_file_location( + module_name, module_path + ) module = importlib.util.module_from_spec(module_spec) sys.modules[module_name] = module module_spec.loader.exec_module(module) @@ -64,7 +70,9 @@ def split_module_path(module_path: str) -> Tuple[str, str]: return m_path, cls -def get_callable_kwargs(config: InstConf, default_module: Union[str, ModuleType] = None) -> (type, dict): +def get_callable_kwargs( + config: InstConf, default_module: Union[str, ModuleType] = None +) -> (type, dict): """ extract class/func and kwargs from config info @@ -116,7 +124,9 @@ def get_callable_kwargs(config: InstConf, default_module: Union[str, ModuleType] return _callable, kwargs -get_cls_kwargs = get_callable_kwargs # NOTE: this is for compatibility for the previous version +get_cls_kwargs = ( + get_callable_kwargs # NOTE: this is for compatibility for the previous version +) def init_instance_by_config( diff --git a/qlib/utils/paral.py b/qlib/utils/paral.py index a617783341..184c84fbd9 100644 --- a/qlib/utils/paral.py +++ b/qlib/utils/paral.py @@ -25,13 +25,22 @@ def __init__(self, *args, **kwargs): # 2025-05-04 joblib released version 1.5.0, in which _backend_args was removed and replaced by _backend_kwargs. # Ref: https://github.com/joblib/joblib/pull/1525/files#diff-e4dff8042ce45b443faf49605b75a58df35b8c195978d4a57f4afa695b406bdc if joblib.__version__ < "1.5.0": - self._backend_args["maxtasksperchild"] = maxtasksperchild # pylint: disable=E1101 + self._backend_args["maxtasksperchild"] = ( + maxtasksperchild # pylint: disable=E1101 + ) else: - self._backend_kwargs["maxtasksperchild"] = maxtasksperchild # pylint: disable=E1101 + self._backend_kwargs["maxtasksperchild"] = ( + maxtasksperchild # pylint: disable=E1101 + ) def datetime_groupby_apply( - df, apply_func: Union[Callable, Text], axis=0, level="datetime", resample_rule="ME", n_jobs=-1 + df, + apply_func: Union[Callable, Text], + axis=0, + level="datetime", + resample_rule="ME", + n_jobs=-1, ): """datetime_groupby_apply This function will apply the `apply_func` on the datetime level index. @@ -57,12 +66,15 @@ def datetime_groupby_apply( def _naive_group_apply(df): if isinstance(apply_func, str): - return getattr(df.groupby(axis=axis, level=level, group_keys=False), apply_func)() + return getattr( + df.groupby(axis=axis, level=level, group_keys=False), apply_func + )() return df.groupby(level=level, group_keys=False).apply(apply_func) if n_jobs != 1: dfs = ParallelExt(n_jobs=n_jobs)( - delayed(_naive_group_apply)(sub_df) for idx, sub_df in df.resample(resample_rule, level=level) + delayed(_naive_group_apply)(sub_df) + for idx, sub_df in df.resample(resample_rule, level=level) ) return pd.concat(dfs, axis=axis).sort_index() else: diff --git a/qlib/utils/resam.py b/qlib/utils/resam.py index 99aedfcd50..b5ac436330 100644 --- a/qlib/utils/resam.py +++ b/qlib/utils/resam.py @@ -10,7 +10,10 @@ def resam_calendar( - calendar_raw: np.ndarray, freq_raw: Union[str, Freq], freq_sam: Union[str, Freq], region: str = None + calendar_raw: np.ndarray, + freq_raw: Union[str, Freq], + freq_sam: Union[str, Freq], + region: str = None, ) -> np.ndarray: """ Resample the calendar with frequency freq_raw into the calendar with frequency freq_sam @@ -43,16 +46,27 @@ def resam_calendar( # if freq_sam is xminute, divide each trading day into several bars evenly if freq_sam.base == Freq.NORM_FREQ_MINUTE: if freq_raw.base != Freq.NORM_FREQ_MINUTE: - raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min") + raise ValueError( + "when sampling minute calendar, freq of raw calendar must be minute or min" + ) else: if freq_raw.count > freq_sam.count: raise ValueError("raw freq must be higher than sampling freq") - _calendar_minute = np.unique(list(map(lambda x: cal_sam_minute(x, freq_sam.count, region), calendar_raw))) + _calendar_minute = np.unique( + list(map(lambda x: cal_sam_minute(x, freq_sam.count, region), calendar_raw)) + ) return _calendar_minute # else, convert the raw calendar into day calendar, and divide the whole calendar into several bars evenly else: - _calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw))) + _calendar_day = np.unique( + list( + map( + lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), + calendar_raw, + ) + ) + ) if freq_sam.base == Freq.NORM_FREQ_DAY: return _calendar_day[:: freq_sam.count] @@ -69,7 +83,9 @@ def resam_calendar( raise ValueError("sampling freq must be xmin, xd, xw, xm") -def get_higher_eq_freq_feature(instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1): +def get_higher_eq_freq_feature( + instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1 +): """get the feature with higher or equal frequency than `freq`. Returns ------- @@ -80,19 +96,42 @@ def get_higher_eq_freq_feature(instruments, fields, start_time=None, end_time=No from ..data.data import D # pylint: disable=C0415 try: - _result = D.features(instruments, fields, start_time, end_time, freq=freq, disk_cache=disk_cache) + _result = D.features( + instruments, fields, start_time, end_time, freq=freq, disk_cache=disk_cache + ) _freq = freq except (ValueError, KeyError) as value_key_e: _, norm_freq = Freq.parse(freq) if norm_freq in [Freq.NORM_FREQ_MONTH, Freq.NORM_FREQ_WEEK, Freq.NORM_FREQ_DAY]: try: - _result = D.features(instruments, fields, start_time, end_time, freq="day", disk_cache=disk_cache) + _result = D.features( + instruments, + fields, + start_time, + end_time, + freq="day", + disk_cache=disk_cache, + ) _freq = "day" except (ValueError, KeyError): - _result = D.features(instruments, fields, start_time, end_time, freq="1min", disk_cache=disk_cache) + _result = D.features( + instruments, + fields, + start_time, + end_time, + freq="1min", + disk_cache=disk_cache, + ) _freq = "1min" elif norm_freq == Freq.NORM_FREQ_MINUTE: - _result = D.features(instruments, fields, start_time, end_time, freq="1min", disk_cache=disk_cache) + _result = D.features( + instruments, + fields, + start_time, + end_time, + freq="1min", + disk_cache=disk_cache, + ) _freq = "1min" else: raise ValueError(f"freq {freq} is not supported") from value_key_e @@ -194,9 +233,13 @@ def resam_ts_data( if isinstance(feature.index, pd.MultiIndex): if callable(method): method_func = method - return feature.groupby(level="instrument", group_keys=False).apply(method_func, **method_kwargs) + return feature.groupby(level="instrument", group_keys=False).apply( + method_func, **method_kwargs + ) elif isinstance(method, str): - return getattr(feature.groupby(level="instrument", group_keys=False), method)(**method_kwargs) + return getattr( + feature.groupby(level="instrument", group_keys=False), method + )(**method_kwargs) else: if callable(method): method_func = method @@ -232,7 +275,9 @@ def _ts_data_valid(ts_feature, last=False): elif isinstance(ts_feature, pd.Series): return get_valid_value(ts_feature, last=last) else: - raise TypeError(f"ts_feature should be pd.DataFrame/Series, not {type(ts_feature)}") + raise TypeError( + f"ts_feature should be pd.DataFrame/Series, not {type(ts_feature)}" + ) ts_data_last = partial(_ts_data_valid, last=True) diff --git a/qlib/utils/serial.py b/qlib/utils/serial.py index 720dbd7928..9b3879b9d4 100644 --- a/qlib/utils/serial.py +++ b/qlib/utils/serial.py @@ -33,7 +33,9 @@ class Serializable: def __init__(self): self._dump_all = self.default_dump_all - self._exclude = None # this attribute have higher priorities than `exclude_attr` + self._exclude = ( + None # this attribute have higher priorities than `exclude_attr` + ) def _is_kept(self, key): if key in self.config_attr: @@ -151,7 +153,9 @@ def load(cls, filepath): if isinstance(object, cls): return object else: - raise TypeError(f"The instance of {type(object)} is not a valid `{type(cls)}`!") + raise TypeError( + f"The instance of {type(object)} is not a valid `{type(cls)}`!" + ) @classmethod def get_backend(cls): diff --git a/qlib/utils/time.py b/qlib/utils/time.py index 238b3f0dd5..b7af3937c0 100644 --- a/qlib/utils/time.py +++ b/qlib/utils/time.py @@ -50,19 +50,23 @@ def get_min_cal(shift: int = 0, region: str = REG_CN) -> List[time]: if region == REG_CN: for ts in list( - pd.date_range(CN_TIME[0], CN_TIME[1] - timedelta(minutes=1), freq="1min") - pd.Timedelta(minutes=shift) + pd.date_range(CN_TIME[0], CN_TIME[1] - timedelta(minutes=1), freq="1min") + - pd.Timedelta(minutes=shift) ) + list( - pd.date_range(CN_TIME[2], CN_TIME[3] - timedelta(minutes=1), freq="1min") - pd.Timedelta(minutes=shift) + pd.date_range(CN_TIME[2], CN_TIME[3] - timedelta(minutes=1), freq="1min") + - pd.Timedelta(minutes=shift) ): cal.append(ts.time()) elif region == REG_TW: for ts in list( - pd.date_range(TW_TIME[0], TW_TIME[1] - timedelta(minutes=1), freq="1min") - pd.Timedelta(minutes=shift) + pd.date_range(TW_TIME[0], TW_TIME[1] - timedelta(minutes=1), freq="1min") + - pd.Timedelta(minutes=shift) ): cal.append(ts.time()) elif region == REG_US: for ts in list( - pd.date_range(US_TIME[0], US_TIME[1] - timedelta(minutes=1), freq="1min") - pd.Timedelta(minutes=shift) + pd.date_range(US_TIME[0], US_TIME[1] - timedelta(minutes=1), freq="1min") + - pd.Timedelta(minutes=shift) ): cal.append(ts.time()) else: @@ -108,15 +112,22 @@ def is_single_value(start_time, end_time, freq, region: str = REG_CN): return True return False else: - raise NotImplementedError(f"please implement the is_single_value func for {region}") + raise NotImplementedError( + f"please implement the is_single_value func for {region}" + ) class Freq: NORM_FREQ_MONTH = "month" NORM_FREQ_WEEK = "week" NORM_FREQ_DAY = "day" - NORM_FREQ_MINUTE = "min" # using min instead of minute for align with Qlib's data filename - SUPPORT_CAL_LIST = [NORM_FREQ_MINUTE, NORM_FREQ_DAY] # FIXME: this list should from data + NORM_FREQ_MINUTE = ( + "min" # using min instead of minute for align with Qlib's data filename + ) + SUPPORT_CAL_LIST = [ + NORM_FREQ_MINUTE, + NORM_FREQ_DAY, + ] # FIXME: this list should from data def __init__(self, freq: Union[str, "Freq"]) -> None: if isinstance(freq, str): @@ -132,7 +143,9 @@ def __eq__(self, freq): def __str__(self): # trying to align to the filename of Qlib: day, 30min, 5min, 1min... - return f"{self.count if self.count != 1 or self.base != 'day' else ''}{self.base}" + return ( + f"{self.count if self.count != 1 or self.base != 'day' else ''}{self.base}" + ) def __repr__(self) -> str: return f"{self.__class__.__name__}({str(self)})" @@ -226,7 +239,9 @@ def get_min_delta(left_frq: str, right_freq: str): return left_minutes - right_minutes @staticmethod - def get_recent_freq(base_freq: Union[str, "Freq"], freq_list: List[Union[str, "Freq"]]) -> Optional["Freq"]: + def get_recent_freq( + base_freq: Union[str, "Freq"], freq_list: List[Union[str, "Freq"]] + ) -> Optional["Freq"]: """Get the closest freq to base_freq from freq_list Parameters @@ -265,22 +280,30 @@ def time_to_day_index(time_obj: Union[str, datetime], region: str = REG_CN): elif CN_TIME[2] <= time_obj < CN_TIME[3]: return int((time_obj - CN_TIME[2]).total_seconds() / 60) + 120 else: - raise ValueError(f"{time_obj} is not the opening time of the {region} stock market") + raise ValueError( + f"{time_obj} is not the opening time of the {region} stock market" + ) elif region == REG_US: if US_TIME[0] <= time_obj < US_TIME[1]: return int((time_obj - US_TIME[0]).total_seconds() / 60) else: - raise ValueError(f"{time_obj} is not the opening time of the {region} stock market") + raise ValueError( + f"{time_obj} is not the opening time of the {region} stock market" + ) elif region == REG_TW: if TW_TIME[0] <= time_obj < TW_TIME[1]: return int((time_obj - TW_TIME[0]).total_seconds() / 60) else: - raise ValueError(f"{time_obj} is not the opening time of the {region} stock market") + raise ValueError( + f"{time_obj} is not the opening time of the {region} stock market" + ) else: raise ValueError(f"{region} is not supported") -def get_day_min_idx_range(start: str, end: str, freq: str, region: str) -> Tuple[int, int]: +def get_day_min_idx_range( + start: str, end: str, freq: str, region: str +) -> Tuple[int, int]: """ get the min-bar index in a day for a time range (both left and right is closed) given a fixed frequency Parameters @@ -320,7 +343,9 @@ def concat_date_time(date_obj: date, time_obj: time) -> pd.Timestamp: ) -def cal_sam_minute(x: pd.Timestamp, sam_minutes: int, region: str = REG_CN) -> pd.Timestamp: +def cal_sam_minute( + x: pd.Timestamp, sam_minutes: int, region: str = REG_CN +) -> pd.Timestamp: """ align the minute-level data to a down sampled calendar @@ -346,7 +371,9 @@ def cal_sam_minute(x: pd.Timestamp, sam_minutes: int, region: str = REG_CN) -> p return concat_date_time(_date, new_time) -def epsilon_change(date_time: pd.Timestamp, direction: str = "backward") -> pd.Timestamp: +def epsilon_change( + date_time: pd.Timestamp, direction: str = "backward" +) -> pd.Timestamp: """ change the time by infinitely small quantity. diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index a29e471c04..2a948bf421 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -32,7 +32,9 @@ def __init__(self, exp_manager: ExpManager): self.exp_manager: ExpManager = exp_manager def __repr__(self): - return "{name}(manager={manager})".format(name=self.__class__.__name__, manager=self.exp_manager) + return "{name}(manager={manager})".format( + name=self.__class__.__name__, manager=self.exp_manager + ) @contextmanager def start( @@ -90,7 +92,9 @@ def start( try: yield run except Exception as e: - self.end_exp(Recorder.STATUS_FA) # end the experiment if something went wrong + self.end_exp( + Recorder.STATUS_FA + ) # end the experiment if something went wrong raise e self.end_exp(Recorder.STATUS_FI) @@ -237,10 +241,17 @@ def list_recorders(self, experiment_id=None, experiment_name=None): ------- A dictionary (id -> recorder) of recorder information that being stored. """ - return self.get_exp(experiment_id=experiment_id, experiment_name=experiment_name).list_recorders() + return self.get_exp( + experiment_id=experiment_id, experiment_name=experiment_name + ).list_recorders() def get_exp( - self, *, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False + self, + *, + experiment_id=None, + experiment_name=None, + create: bool = True, + start: bool = False, ) -> Experiment: """ Method for retrieving an experiment with given id or name. Once the `create` argument is set to @@ -454,9 +465,9 @@ def get_recorder( ------- A recorder instance. """ - return self.get_exp(experiment_name=experiment_name, experiment_id=experiment_id, create=False).get_recorder( - recorder_id, recorder_name, create=False, start=False - ) + return self.get_exp( + experiment_name=experiment_name, experiment_id=experiment_id, create=False + ).get_recorder(recorder_id, recorder_name, create=False, start=False) def delete_recorder(self, recorder_id=None, recorder_name=None): """ @@ -478,7 +489,9 @@ def delete_recorder(self, recorder_id=None, recorder_name=None): """ self.get_exp().delete_recorder(recorder_id, recorder_name) - def save_objects(self, local_path=None, artifact_path=None, **kwargs: Dict[Text, Any]): + def save_objects( + self, local_path=None, artifact_path=None, **kwargs: Dict[Text, Any] + ): """ Method for saving objects as artifacts in the experiment to the uri. It supports either saving from a local file/directory, or directly saving objects. User can use valid python's keywords arguments @@ -531,7 +544,9 @@ def save_objects(self, local_path=None, artifact_path=None, **kwargs: Dict[Text, raise ValueError( "You can choose only one of `local_path`(save the files in a path) or `kwargs`(pass in the objects directly)" ) - self.get_exp().get_recorder(start=True).save_objects(local_path, artifact_path, **kwargs) + self.get_exp().get_recorder(start=True).save_objects( + local_path, artifact_path, **kwargs + ) def load_object(self, name: Text): """ @@ -603,7 +618,9 @@ def log_artifact(self, local_path: str, artifact_path: Optional[str] = None): artifact_path : Optional[str] If provided, the directory in ``artifact_uri`` to write to. """ - self.get_exp(start=True).get_recorder(start=True).log_artifact(local_path, artifact_path) + self.get_exp(start=True).get_recorder(start=True).log_artifact( + local_path, artifact_path + ) def download_artifact(self, path: str, dst_path: Optional[str] = None) -> str: """ @@ -625,7 +642,9 @@ def download_artifact(self, path: str, dst_path: Optional[str] = None) -> str: str Local path of desired artifact. """ - self.get_exp(start=True).get_recorder(start=True).download_artifact(path, dst_path) + self.get_exp(start=True).get_recorder(start=True).download_artifact( + path, dst_path + ) def set_tags(self, **kwargs): """ diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index ae165ef1f8..3e52c94aea 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -25,7 +25,9 @@ def __init__(self, id, name): self._default_rec_name = "abstract_recorder" def __repr__(self): - return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info) + return "{name}(id={id}, info={info})".format( + name=self.__class__.__name__, id=self.id, info=self.info + ) def __str__(self): return str(self.info) @@ -37,7 +39,9 @@ def info(self): output["class"] = "Experiment" output["id"] = self.id output["name"] = self.name - output["active_recorder"] = self.active_recorder.id if self.active_recorder is not None else None + output["active_recorder"] = ( + self.active_recorder.id if self.active_recorder is not None else None + ) output["recorders"] = list(recorders.keys()) return output @@ -111,7 +115,13 @@ def delete_recorder(self, recorder_id): """ raise NotImplementedError(f"Please implement the `delete_recorder` method.") - def get_recorder(self, recorder_id=None, recorder_name=None, create: bool = True, start: bool = False) -> Recorder: + def get_recorder( + self, + recorder_id=None, + recorder_name=None, + create: bool = True, + start: bool = False, + ) -> Recorder: """ Retrieve a Recorder for user. When user specify recorder id and name, the method will try to return the specific recorder. When user does not provide recorder id or name, the method will try to return the current @@ -163,10 +173,14 @@ def get_recorder(self, recorder_id=None, recorder_name=None, create: bool = True return self.active_recorder recorder_name = self._default_rec_name if create: - recorder, is_new = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name) + recorder, is_new = self._get_or_create_rec( + recorder_id=recorder_id, recorder_name=recorder_name + ) else: recorder, is_new = ( - self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), + self._get_recorder( + recorder_id=recorder_id, recorder_name=recorder_name + ), False, ) if is_new and start: @@ -175,7 +189,9 @@ def get_recorder(self, recorder_id=None, recorder_name=None, create: bool = True self.active_recorder.start_run() return recorder - def _get_or_create_rec(self, recorder_id=None, recorder_name=None) -> (object, bool): + def _get_or_create_rec( + self, recorder_id=None, recorder_name=None + ) -> (object, bool): """ Method for getting or creating a recorder. It will try to first get a valid recorder, if exception occurs, it will automatically create a new recorder based on the given id and name. @@ -184,13 +200,17 @@ def _get_or_create_rec(self, recorder_id=None, recorder_name=None) -> (object, b if recorder_id is None and recorder_name is None: recorder_name = self._default_rec_name return ( - self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), + self._get_recorder( + recorder_id=recorder_id, recorder_name=recorder_name + ), False, ) except ValueError: if recorder_name is None: recorder_name = self._default_rec_name - logger.info(f"No valid recorder found. Create a new recorder with name {recorder_name}.") + logger.info( + f"No valid recorder found. Create a new recorder with name {recorder_name}." + ) return self.create_recorder(recorder_name), True def _get_recorder(self, recorder_id=None, recorder_name=None): @@ -252,7 +272,9 @@ def __init__(self, id, name, uri): self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) def __repr__(self): - return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info) + return "{name}(id={id}, info={info})".format( + name=self.__class__.__name__, id=self.id, info=self.info + ) def start(self, *, recorder_id=None, recorder_name=None, resume=False): logger.info(f"Experiment {self.id} starts running ...") @@ -261,7 +283,9 @@ def start(self, *, recorder_id=None, recorder_name=None, resume=False): recorder_name = self._default_rec_name # resume the recorder if resume: - recorder, _ = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name) + recorder, _ = self._get_or_create_rec( + recorder_id=recorder_id, recorder_name=recorder_name + ) # create a new recorder else: recorder = self.create_recorder(recorder_name) @@ -312,15 +336,25 @@ def _get_recorder(self, recorder_id=None, recorder_name=None): for rid in recorders: if recorders[rid].name == recorder_name: return recorders[rid] - raise ValueError("No valid recorder has been found, please make sure the input recorder name is correct.") + raise ValueError( + "No valid recorder has been found, please make sure the input recorder name is correct." + ) def search_records(self, **kwargs): - filter_string = "" if kwargs.get("filter_string") is None else kwargs.get("filter_string") - run_view_type = 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type") - max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results") + filter_string = ( + "" if kwargs.get("filter_string") is None else kwargs.get("filter_string") + ) + run_view_type = ( + 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type") + ) + max_results = ( + 100000 if kwargs.get("max_results") is None else kwargs.get("max_results") + ) order_by = kwargs.get("order_by") - return self._client.search_runs([self.id], filter_string, run_view_type, max_results, order_by) + return self._client.search_runs( + [self.id], filter_string, run_view_type, max_results, order_by + ) def delete_recorder(self, recorder_id=None, recorder_name=None): assert ( @@ -361,7 +395,10 @@ def list_recorders( mlflow supported filter string like 'params."my_param"="a" and tags."my_tag"="b"', use this will help to reduce too much run number. """ runs = self._client.search_runs( - self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results, filter_string=filter_string + self.id, + run_view_type=ViewType.ACTIVE_ONLY, + max_results=max_results, + filter_string=filter_string, ) rids = [] recorders = [] diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 5047ccfb26..74504d631b 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -150,7 +150,14 @@ def search_records(self, experiment_ids=None, **kwargs): """ raise NotImplementedError(f"Please implement the `search_records` method.") - def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False): + def get_exp( + self, + *, + experiment_id=None, + experiment_name=None, + create: bool = True, + start: bool = False, + ): """ Retrieve an experiment. This method includes getting an active experiment, and get_or_create a specific experiment. @@ -206,42 +213,56 @@ def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = Tr experiment_name = self._default_exp_name if create: - exp, _ = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name) + exp, _ = self._get_or_create_exp( + experiment_id=experiment_id, experiment_name=experiment_name + ) else: - exp = self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name) + exp = self._get_exp( + experiment_id=experiment_id, experiment_name=experiment_name + ) if self.active_experiment is None and start: self.active_experiment = exp # start the recorder self.active_experiment.start() return exp - def _get_or_create_exp(self, experiment_id=None, experiment_name=None) -> (object, bool): + def _get_or_create_exp( + self, experiment_id=None, experiment_name=None + ) -> (object, bool): """ Method for getting or creating an experiment. It will try to first get a valid experiment, if exception occurs, it will automatically create a new experiment based on the given id and name. """ try: return ( - self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), + self._get_exp( + experiment_id=experiment_id, experiment_name=experiment_name + ), False, ) except ValueError: if experiment_name is None: experiment_name = self._default_exp_name - logger.warning(f"No valid experiment found. Create a new experiment with name {experiment_name}.") + logger.warning( + f"No valid experiment found. Create a new experiment with name {experiment_name}." + ) # NOTE: mlflow doesn't consider the lock for recording multiple runs # So we supported it in the interface wrapper pr = urlparse(self.uri) if pr.scheme == "file": - with FileLock(Path(os.path.join(pr.netloc, pr.path.lstrip("/"), "filelock"))): # pylint: disable=E0110 + with FileLock( + Path(os.path.join(pr.netloc, pr.path.lstrip("/"), "filelock")) + ): # pylint: disable=E0110 return self.create_exp(experiment_name), True # NOTE: for other schemes like http, we double check to avoid create exp conflicts try: return self.create_exp(experiment_name), True except ExpAlreadyExistError: return ( - self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), + self._get_exp( + experiment_id=experiment_id, experiment_name=experiment_name + ), False, ) @@ -338,11 +359,15 @@ def _start_exp( # Create experiment if experiment_name is None: experiment_name = self._default_exp_name - experiment, _ = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name) + experiment, _ = self._get_or_create_exp( + experiment_id=experiment_id, experiment_name=experiment_name + ) # Set up active experiment self.active_experiment = experiment # Start the experiment - self.active_experiment.start(recorder_id=recorder_id, recorder_name=recorder_name, resume=resume) + self.active_experiment.start( + recorder_id=recorder_id, recorder_name=recorder_name, resume=resume + ) return self.active_experiment @@ -389,7 +414,9 @@ def _get_exp(self, experiment_id=None, experiment_name=None): exp = self.client.get_experiment_by_name(experiment_name) if exp is None or exp.lifecycle_stage.upper() == "DELETED": raise MlflowException("No valid experiment has been found.") - experiment = MLflowExperiment(exp.experiment_id, experiment_name, self.uri) + experiment = MLflowExperiment( + exp.experiment_id, experiment_name, self.uri + ) return experiment except MlflowException as e: raise ValueError( @@ -397,11 +424,19 @@ def _get_exp(self, experiment_id=None, experiment_name=None): ) from e def search_records(self, experiment_ids=None, **kwargs): - filter_string = "" if kwargs.get("filter_string") is None else kwargs.get("filter_string") - run_view_type = 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type") - max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results") + filter_string = ( + "" if kwargs.get("filter_string") is None else kwargs.get("filter_string") + ) + run_view_type = ( + 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type") + ) + max_results = ( + 100000 if kwargs.get("max_results") is None else kwargs.get("max_results") + ) order_by = kwargs.get("order_by") - return self.client.search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by) + return self.client.search_runs( + experiment_ids, filter_string, run_view_type, max_results, order_by + ) def delete_exp(self, experiment_id=None, experiment_name=None): assert ( @@ -426,7 +461,9 @@ def list_experiments(self): if mlflow_version >= 2: exps = self.client.search_experiments(view_type=ViewType.ACTIVE_ONLY) else: - exps = self.client.list_experiments(view_type=ViewType.ACTIVE_ONLY) # pylint: disable=E1101 + exps = self.client.list_experiments( + view_type=ViewType.ACTIVE_ONLY + ) # pylint: disable=E1101 experiments = dict() for exp in exps: experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri) diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index 09e96d444f..27c2eac9bd 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -153,7 +153,9 @@ def _postpone_action(self): """ return self.status == self.STATUS_SIMULATING and self.trainer.is_delay() - def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = {}): + def first_train( + self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = {} + ): """ Get tasks from every strategy's first_tasks method and train them. If using DelayTrainer, it can finish training all together after every strategy's first_tasks. @@ -179,7 +181,9 @@ def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dic if not self._postpone_action(): for strategy, models in zip(strategies, models_list): - models = self.trainer.end_train(models, experiment_name=strategy.name_id) + models = self.trainer.end_train( + models, experiment_name=strategy.name_id + ) def routine( self, @@ -224,7 +228,9 @@ def routine( if not self._postpone_action(): for strategy, models in zip(self.strategies, models_list): - models = self.trainer.end_train(models, experiment_name=strategy.name_id) + models = self.trainer.end_train( + models, experiment_name=strategy.name_id + ) self.prepare_signals(**signal_kwargs) def get_collector(self, **kwargs) -> MergeCollector: @@ -255,7 +261,9 @@ def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]): self.first_train(strategies) self.strategies.extend(strategies) - def prepare_signals(self, prepare_func: Callable = AverageEnsemble(), over_write=False): + def prepare_signals( + self, prepare_func: Callable = AverageEnsemble(), over_write=False + ): """ After preparing the data of the last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for the next routine. @@ -300,7 +308,12 @@ def get_signals(self) -> Union[pd.Series, pd.DataFrame]: SIM_LOG_NAME = "SIMULATE_INFO" def simulate( - self, end_time=None, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={} + self, + end_time=None, + frequency="day", + task_kwargs={}, + model_kwargs={}, + signal_kwargs={}, ) -> Union[pd.Series, pd.DataFrame]: """ Starting from the current time, this method will simulate every routine in OnlineManager until the end time. @@ -329,7 +342,9 @@ def simulate( logging.addLevelName(simulate_level, self.SIM_LOG_NAME) for cur_time in cal: - self.logger.log(level=simulate_level, msg=f"Simulating at {str(cur_time)}......") + self.logger.log( + level=simulate_level, msg=f"Simulating at {str(cur_time)}......" + ) self.routine( cur_time, task_kwargs=task_kwargs, @@ -365,7 +380,9 @@ def delay_prepare(self, model_kwargs={}, signal_kwargs={}): for strategy, models in strategy_models.items(): # only new online models need to prepare if last_models.setdefault(strategy, set()) != set(models): - models = self.trainer.end_train(models, experiment_name=strategy.name_id, **model_kwargs) + models = self.trainer.end_train( + models, experiment_name=strategy.name_id, **model_kwargs + ) strategy.tool.reset_online_tag(models) need_prepare = True last_models[strategy] = set(models) diff --git a/qlib/workflow/online/strategy.py b/qlib/workflow/online/strategy.py index d545e4bc9a..8f62ef84df 100644 --- a/qlib/workflow/online/strategy.py +++ b/qlib/workflow/online/strategy.py @@ -116,11 +116,19 @@ def __init__( task_template = [task_template] self.task_template = task_template self.rg = rolling_gen - assert issubclass(self.rg.__class__, RollingGen), "The rolling strategy relies on the feature if RollingGen" + assert issubclass( + self.rg.__class__, RollingGen + ), "The rolling strategy relies on the feature if RollingGen" self.tool = OnlineToolR(self.exp_name) self.ta = TimeAdjuster() - def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None): + def get_collector( + self, + process_list=[RollingGroup()], + rec_key_func=None, + rec_filter_func=None, + artifacts_key=None, + ): """ Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results. The returned collector must distinguish results in different models. @@ -200,9 +208,15 @@ def _list_latest(self, rec_list: List[Recorder]): """ if len(rec_list) == 0: return rec_list, None - max_test = max(rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in rec_list) + max_test = max( + rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] + for rec in rec_list + ) latest_rec = [] for rec in rec_list: - if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test: + if ( + rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] + == max_test + ): latest_rec.append(rec) return latest_rec, max_test diff --git a/qlib/workflow/online/update.py b/qlib/workflow/online/update.py index 5047a1bd25..09ed96c2a1 100644 --- a/qlib/workflow/online/update.py +++ b/qlib/workflow/online/update.py @@ -27,7 +27,11 @@ def __init__(self, rec: Recorder): self.rec = rec def get_dataset( - self, start_time, end_time, segments=None, unprepared_dataset: Optional[DatasetH] = None + self, + start_time, + end_time, + segments=None, + unprepared_dataset: Optional[DatasetH] = None, ) -> DatasetH: """ Load, config and setup dataset. @@ -55,7 +59,10 @@ def get_dataset( dataset: DatasetH = self.rec.load_object("dataset") else: dataset = unprepared_dataset - dataset.config(handler_kwargs={"start_time": start_time, "end_time": end_time}, segments=segments) + dataset.config( + handler_kwargs={"start_time": start_time, "end_time": end_time}, + segments=segments, + ) dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS}) return dataset @@ -173,7 +180,9 @@ def __init__( if from_date is None: # dropna is for being compatible to some data with future information(e.g. label) # The recent label data should be updated together - self.last_end = self.old_data.dropna().index.get_level_values("datetime").max() + self.last_end = ( + self.old_data.dropna().index.get_level_values("datetime").max() + ) else: self.last_end = get_date_by_shift(from_date, -1, align="right") @@ -190,7 +199,11 @@ def prepare_data(self, unprepared_dataset: Optional[DatasetH] = None) -> Dataset """ # automatically getting the historical dependency if not specified if self.hist_ref is None: - dataset: DatasetH = self.record.load_object("dataset") if unprepared_dataset is None else unprepared_dataset + dataset: DatasetH = ( + self.record.load_object("dataset") + if unprepared_dataset is None + else unprepared_dataset + ) # Special treatment of historical dependencies if isinstance(dataset, TSDatasetH): hist_ref = dataset.step_len - 1 @@ -200,15 +213,23 @@ def prepare_data(self, unprepared_dataset: Optional[DatasetH] = None) -> Dataset hist_ref = self.hist_ref start_time_buffer = get_date_by_shift( - self.last_end, -hist_ref + 1, clip_shift=False, freq=self.freq # pylint: disable=E1130 + self.last_end, + -hist_ref + 1, + clip_shift=False, + freq=self.freq, # pylint: disable=E1130 ) start_time = get_date_by_shift(self.last_end, 1, freq=self.freq) seg = {"test": (start_time, self.to_date)} return self.rmdl.get_dataset( - start_time=start_time_buffer, end_time=self.to_date, segments=seg, unprepared_dataset=unprepared_dataset + start_time=start_time_buffer, + end_time=self.to_date, + segments=seg, + unprepared_dataset=unprepared_dataset, ) - def update(self, dataset: DatasetH = None, write: bool = True, ret_new: bool = False) -> Optional[object]: + def update( + self, dataset: DatasetH = None, write: bool = True, ret_new: bool = False + ) -> Optional[object]: """ Parameters ---------- @@ -277,7 +298,9 @@ def get_update_data(self, dataset: Dataset) -> pd.DataFrame: model = self.rmdl.get_model() new_pred: pd.Series = model.predict(dataset) data = _replace_range(self.old_data, new_pred.to_frame("score")) - self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}.") + self.logger.info( + f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}." + ) return data diff --git a/qlib/workflow/online/utils.py b/qlib/workflow/online/utils.py index c390ca0092..cb017e7655 100644 --- a/qlib/workflow/online/utils.py +++ b/qlib/workflow/online/utils.py @@ -154,7 +154,11 @@ def online_models(self, exp_name: str = None) -> list: list: a list of `online` models. """ exp_name = self._get_exp_name(exp_name) - return list(list_recorders(exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values()) + return list( + list_recorders( + exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG + ).values() + ) def update_online_pred(self, to_date=None, from_date=None, exp_name: str = None): """ @@ -171,11 +175,15 @@ def update_online_pred(self, to_date=None, from_date=None, exp_name: str = None) updater = PredUpdater(rec, to_date=to_date, from_date=from_date) except LoadObjectError as e: # skip the recorder without pred - self.logger.warn(f"An exception `{str(e)}` happened when load `pred.pkl`, skip it.") + self.logger.warn( + f"An exception `{str(e)}` happened when load `pred.pkl`, skip it." + ) continue updater.update() - self.logger.info(f"Finished updating {len(online_models)} online model predictions of {exp_name}.") + self.logger.info( + f"Finished updating {len(online_models)} online model predictions of {exp_name}." + ) def _get_exp_name(self, exp_name): if exp_name is None: diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 844914d469..987dcb8cef 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -199,7 +199,9 @@ def generate(self, **kwargs): f"Signal record 'pred.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" ) # print out results - pprint(f"The following are prediction results of the {type(self.model).__name__} model.") + pprint( + f"The following are prediction results of the {type(self.model).__name__} model." + ) pprint(pred.head(5)) if isinstance(self.dataset, DatasetH): @@ -260,7 +262,9 @@ def __init__(self, recorder, **kwargs): def generate(self): pred = self.load("pred.pkl") raw_label = self.load("label.pkl") - long_pre, short_pre = calc_long_short_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha=True) + long_pre, short_pre = calc_long_short_prec( + pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha=True + ) ic, ric = calc_ic(pred.iloc[:, 0], raw_label.iloc[:, 0]) metrics = { "IC": ic.mean(), @@ -272,7 +276,9 @@ def generate(self): } objects = {"ic.pkl": ic, "ric.pkl": ric} objects.update({"long_pre.pkl": long_pre, "short_pre.pkl": short_pre}) - long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], raw_label.iloc[:, 0]) + long_short_r, long_avg_r = calc_long_short_return( + pred.iloc[:, 0], raw_label.iloc[:, 0] + ) metrics.update( { "Long-Short Average Return": long_short_r.mean(), @@ -290,7 +296,14 @@ def generate(self): pprint(metrics) def list(self): - return ["ic.pkl", "ric.pkl", "long_pre.pkl", "short_pre.pkl", "long_short_r.pkl", "long_avg_r.pkl"] + return [ + "ic.pkl", + "ric.pkl", + "long_pre.pkl", + "short_pre.pkl", + "long_short_r.pkl", + "long_avg_r.pkl", + ] class SigAnaRecord(ACRecordTemp): @@ -302,7 +315,14 @@ class SigAnaRecord(ACRecordTemp): artifact_path = "sig_analysis" depend_cls = SignalRecord - def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0, skip_existing=False): + def __init__( + self, + recorder, + ana_long_short=False, + ann_scaler=252, + label_col=0, + skip_existing=False, + ): super().__init__(recorder=recorder, skip_existing=skip_existing) self.ana_long_short = ana_long_short self.ann_scaler = ann_scaler @@ -330,13 +350,19 @@ def _generate(self, label: Optional[pd.DataFrame] = None, **kwargs): } objects = {"ic.pkl": ic, "ric.pkl": ric} if self.ana_long_short: - long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, self.label_col]) + long_short_r, long_avg_r = calc_long_short_return( + pred.iloc[:, 0], label.iloc[:, self.label_col] + ) metrics.update( { "Long-Short Ann Return": long_short_r.mean() * self.ann_scaler, - "Long-Short Ann Sharpe": long_short_r.mean() / long_short_r.std() * self.ann_scaler**0.5, + "Long-Short Ann Sharpe": long_short_r.mean() + / long_short_r.std() + * self.ann_scaler**0.5, "Long-Avg Ann Return": long_avg_r.mean() * self.ann_scaler, - "Long-Avg Ann Sharpe": long_avg_r.mean() / long_avg_r.std() * self.ann_scaler**0.5, + "Long-Avg Ann Sharpe": long_avg_r.mean() + / long_avg_r.std() + * self.ann_scaler**0.5, } ) objects.update( @@ -447,10 +473,12 @@ def __init__( indicator_analysis_freq = [indicator_analysis_freq] self.risk_analysis_freq = [ - "{0}{1}".format(*Freq.parse(_analysis_freq)) for _analysis_freq in risk_analysis_freq + "{0}{1}".format(*Freq.parse(_analysis_freq)) + for _analysis_freq in risk_analysis_freq ] self.indicator_analysis_freq = [ - "{0}{1}".format(*Freq.parse(_analysis_freq)) for _analysis_freq in indicator_analysis_freq + "{0}{1}".format(*Freq.parse(_analysis_freq)) + for _analysis_freq in indicator_analysis_freq ] self.indicator_analysis_method = indicator_analysis_method @@ -460,7 +488,9 @@ def _get_report_freq(self, executor_config): _count, _freq = Freq.parse(executor_config["kwargs"]["time_per_step"]) ret_freq.append(f"{_count}{_freq}") if "inner_executor" in executor_config["kwargs"]: - ret_freq.extend(self._get_report_freq(executor_config["kwargs"]["inner_executor"])) + ret_freq.extend( + self._get_report_freq(executor_config["kwargs"]["inner_executor"]) + ) return ret_freq def _generate(self, **kwargs): @@ -481,15 +511,21 @@ def _generate(self, **kwargs): artifact_objects = {} # custom strategy and get backtest portfolio_metric_dict, indicator_dict = normal_backtest( - executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config + executor=self.executor_config, + strategy=self.strategy_config, + **self.backtest_config, ) for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items(): artifact_objects.update({f"report_normal_{_freq}.pkl": report_normal}) artifact_objects.update({f"positions_normal_{_freq}.pkl": positions_normal}) for _freq, indicators_normal in indicator_dict.items(): - artifact_objects.update({f"indicators_normal_{_freq}.pkl": indicators_normal[0]}) - artifact_objects.update({f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]}) + artifact_objects.update( + {f"indicators_normal_{_freq}.pkl": indicators_normal[0]} + ) + artifact_objects.update( + {f"indicators_normal_{_freq}_obj.pkl": indicators_normal[1]} + ) for _analysis_freq in self.risk_analysis_freq: if _analysis_freq not in portfolio_metric_dict: @@ -500,27 +536,41 @@ def _generate(self, **kwargs): report_normal, _ = portfolio_metric_dict.get(_analysis_freq) analysis = dict() analysis["excess_return_without_cost"] = risk_analysis( - report_normal["return"] - report_normal["bench"], freq=_analysis_freq + report_normal["return"] - report_normal["bench"], + freq=_analysis_freq, ) analysis["excess_return_with_cost"] = risk_analysis( - report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=_analysis_freq + report_normal["return"] + - report_normal["bench"] + - report_normal["cost"], + freq=_analysis_freq, ) analysis_df = pd.concat(analysis) # type: pd.DataFrame # log metrics analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict()) - self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()}) + self.recorder.log_metrics( + **{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()} + ) # save results - artifact_objects.update({f"port_analysis_{_analysis_freq}.pkl": analysis_df}) + artifact_objects.update( + {f"port_analysis_{_analysis_freq}.pkl": analysis_df} + ) logger.info( f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" ) # print out results - pprint(f"The following are analysis results of benchmark return({_analysis_freq}).") + pprint( + f"The following are analysis results of benchmark return({_analysis_freq})." + ) pprint(risk_analysis(report_normal["bench"], freq=_analysis_freq)) - pprint(f"The following are analysis results of the excess return without cost({_analysis_freq}).") + pprint( + f"The following are analysis results of the excess return without cost({_analysis_freq})." + ) pprint(analysis["excess_return_without_cost"]) - pprint(f"The following are analysis results of the excess return with cost({_analysis_freq}).") + pprint( + f"The following are analysis results of the excess return with cost({_analysis_freq})." + ) pprint(analysis["excess_return_with_cost"]) for _analysis_freq in self.indicator_analysis_freq: @@ -531,16 +581,24 @@ def _generate(self, **kwargs): if self.indicator_analysis_method is None: analysis_df = indicator_analysis(indicators_normal) else: - analysis_df = indicator_analysis(indicators_normal, method=self.indicator_analysis_method) + analysis_df = indicator_analysis( + indicators_normal, method=self.indicator_analysis_method + ) # log metrics analysis_dict = analysis_df["value"].to_dict() - self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()}) + self.recorder.log_metrics( + **{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()} + ) # save results - artifact_objects.update({f"indicator_analysis_{_analysis_freq}.pkl": analysis_df}) + artifact_objects.update( + {f"indicator_analysis_{_analysis_freq}.pkl": analysis_df} + ) logger.info( f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" ) - pprint(f"The following are analysis results of indicators({_analysis_freq}).") + pprint( + f"The following are analysis results of indicators({_analysis_freq})." + ) pprint(analysis_df) return artifact_objects @@ -605,9 +663,13 @@ def __init__(self, recorder, pass_num=10, shuffle_init_score=True, **kwargs): # Save original strategy so that pred df can be replaced in next generate self.original_strategy = deepcopy_basic_type(self.strategy_config) if not isinstance(self.original_strategy, dict): - raise QlibException("MultiPassPortAnaRecord require the passed in strategy to be a dict") + raise QlibException( + "MultiPassPortAnaRecord require the passed in strategy to be a dict" + ) if "signal" not in self.original_strategy.get("kwargs", {}): - raise QlibException("MultiPassPortAnaRecord require the passed in strategy to have signal as a parameter") + raise QlibException( + "MultiPassPortAnaRecord require the passed in strategy to have signal as a parameter" + ) def random_init(self): pred_df = self.load("pred.pkl") @@ -642,7 +704,9 @@ def _generate(self, **kwargs): risk_analysis_df_list = risk_analysis_df_map.get(_analysis_freq, []) risk_analysis_df_map[_analysis_freq] = risk_analysis_df_list - analysis_df = single_run_artifacts[f"port_analysis_{_analysis_freq}.pkl"] + analysis_df = single_run_artifacts[ + f"port_analysis_{_analysis_freq}.pkl" + ] analysis_df["run_id"] = i risk_analysis_df_list.append(analysis_df) @@ -652,9 +716,15 @@ def _generate(self, **kwargs): combined_df = pd.concat(risk_analysis_df_map[_analysis_freq]) # Calculate return and information ratio's mean, std and mean/std - multi_pass_port_analysis_df = combined_df.groupby(level=[0, 1], group_keys=False).apply( + multi_pass_port_analysis_df = combined_df.groupby( + level=[0, 1], group_keys=False + ).apply( lambda x: pd.Series( - {"mean": x["risk"].mean(), "std": x["risk"].std(), "mean_std": x["risk"].mean() / x["risk"].std()} + { + "mean": x["risk"].mean(), + "std": x["risk"].std(), + "mean_std": x["risk"].mean() / x["risk"].std(), + } ) ) @@ -665,14 +735,20 @@ def _generate(self, **kwargs): pprint(multi_pass_port_analysis_df) # Save new df - result_artifacts.update({f"multi_pass_port_analysis_{_analysis_freq}.pkl": multi_pass_port_analysis_df}) + result_artifacts.update( + { + f"multi_pass_port_analysis_{_analysis_freq}.pkl": multi_pass_port_analysis_df + } + ) # Log metrics metrics = flatten_dict( { "mean": multi_pass_port_analysis_df["mean"].unstack().T.to_dict(), "std": multi_pass_port_analysis_df["std"].unstack().T.to_dict(), - "mean_std": multi_pass_port_analysis_df["mean_std"].unstack().T.to_dict(), + "mean_std": multi_pass_port_analysis_df["mean_std"] + .unstack() + .T.to_dict(), } ) self.recorder.log_metrics(**metrics) diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 5fd99c0769..bc5cc88ec5 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -48,7 +48,9 @@ def __init__(self, experiment_id, name): self.status = Recorder.STATUS_S def __repr__(self): - return "{name}(info={info})".format(name=self.__class__.__name__, info=self.info) + return "{name}(info={info})".format( + name=self.__class__.__name__, info=self.info + ) def __str__(self): return str(self.info) @@ -267,17 +269,23 @@ def __init__(self, experiment_id, uri, name=None, mlflow_run=None): self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) # construct from mlflow run if mlflow_run is not None: - assert isinstance(mlflow_run, mlflow.entities.run.Run), "Please input with a MLflow Run object." + assert isinstance( + mlflow_run, mlflow.entities.run.Run + ), "Please input with a MLflow Run object." self.name = mlflow_run.data.tags["mlflow.runName"] self.id = mlflow_run.info.run_id self.status = mlflow_run.info.status self.start_time = ( - datetime.fromtimestamp(float(mlflow_run.info.start_time) / 1000.0).strftime("%Y-%m-%d %H:%M:%S") + datetime.fromtimestamp( + float(mlflow_run.info.start_time) / 1000.0 + ).strftime("%Y-%m-%d %H:%M:%S") if mlflow_run.info.start_time is not None else None ) self.end_time = ( - datetime.fromtimestamp(float(mlflow_run.info.end_time) / 1000.0).strftime("%Y-%m-%d %H:%M:%S") + datetime.fromtimestamp( + float(mlflow_run.info.end_time) / 1000.0 + ).strftime("%Y-%m-%d %H:%M:%S") if mlflow_run.info.end_time is not None else None ) @@ -318,14 +326,18 @@ def get_local_dir(self): """ if self.artifact_uri is not None: if platform.system() == "Windows": - local_dir_path = Path(self.artifact_uri.lstrip("file:").lstrip("/")).parent + local_dir_path = Path( + self.artifact_uri.lstrip("file:").lstrip("/") + ).parent else: local_dir_path = Path(self.artifact_uri.lstrip("file:")).parent local_dir_path = str(local_dir_path.resolve()) if os.path.isdir(local_dir_path): return local_dir_path else: - raise RuntimeError("This recorder is not saved in the local file system.") + raise RuntimeError( + "This recorder is not saved in the local file system." + ) else: raise ValueError( @@ -342,7 +354,9 @@ def start_run(self): self._artifact_uri = run.info.artifact_uri self.start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") self.status = Recorder.STATUS_R - logger.info(f"Recorder {self.id} starts running under Experiment {self.experiment_id} ...") + logger.info( + f"Recorder {self.id} starts running under Experiment {self.experiment_id} ..." + ) # NOTE: making logging async. # - This may cause delay when uploading results @@ -353,7 +367,9 @@ def start_run(self): # Maybe we can make this feature more general. self._log_uncommitted_code() - self.log_params(**{"cmd-sys.argv": " ".join(sys.argv)}) # log the command to produce current experiment + self.log_params( + **{"cmd-sys.argv": " ".join(sys.argv)} + ) # log the command to produce current experiment self.log_params( **{k: v for k, v in os.environ.items() if k.startswith("_QLIB_")} ) # Log necessary environment variables @@ -373,9 +389,13 @@ def _log_uncommitted_code(self): ]: try: out = subprocess.check_output(cmd, shell=True) - self.client.log_text(self.id, out.decode(), fname) # this behaves same as above + self.client.log_text( + self.id, out.decode(), fname + ) # this behaves same as above except subprocess.CalledProcessError: - logger.info(f"Fail to log the uncommitted code of $CWD({os.getcwd()}) when run {cmd}.") + logger.info( + f"Fail to log the uncommitted code of $CWD({os.getcwd()}) when run {cmd}." + ) def end_run(self, status: str = Recorder.STATUS_S): assert status in [ @@ -395,7 +415,9 @@ def end_run(self, status: str = Recorder.STATUS_S): mlflow.end_run(status) def save_objects(self, local_path=None, artifact_path=None, **kwargs): - assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly." + assert ( + self.uri is not None + ), "Please start the experiment and recorder first before using recorder directly." if local_path is not None: path = Path(local_path) if path.is_dir(): @@ -425,7 +447,9 @@ def load_object(self, name, unpickler=pickle.Unpickler): Returns: object: the saved object in mlflow. """ - assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly." + assert ( + self.uri is not None + ), "Please start the experiment and recorder first before using recorder directly." path = None try: @@ -453,7 +477,9 @@ def log_metrics(self, step=None, **kwargs): self.client.log_metric(self.id, name, data, step=step) def log_artifact(self, local_path, artifact_path: Optional[str] = None): - self.client.log_artifact(self.id, local_path=local_path, artifact_path=artifact_path) + self.client.log_artifact( + self.id, local_path=local_path, artifact_path=artifact_path + ) @AsyncCaller.async_dec(ac_attr="async_log") def set_tags(self, **kwargs): @@ -473,7 +499,9 @@ def get_artifact_uri(self): ) def list_artifacts(self, artifact_path=None): - assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly." + assert ( + self.uri is not None + ), "Please start the experiment and recorder first before using recorder directly." artifacts = self.client.list_artifacts(self.id, artifact_path) return [art.path for art in artifacts] diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index bedbd96d20..5152a45483 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -71,7 +71,9 @@ def process_collect(collected_dict, process_list=[], *args, **kwargs) -> dict: value = collected_dict[artifact] for process in process_list: if not callable(process): - raise NotImplementedError(f"{type(process)} is not supported in `process_collect`.") + raise NotImplementedError( + f"{type(process)} is not supported in `process_collect`." + ) value = process(value, *args, **kwargs) result[artifact] = value return result @@ -101,7 +103,12 @@ class MergeCollector(Collector): """ - def __init__(self, collector_dict: Dict[str, Collector], process_list: List[Callable] = [], merge_func=None): + def __init__( + self, + collector_dict: Dict[str, Collector], + process_list: List[Callable] = [], + merge_func=None, + ): """ Init MergeCollector. @@ -181,7 +188,9 @@ def rec_key_func(rec): self.list_kwargs = list_kwargs self.status = status - def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> dict: + def collect( + self, artifacts_key=None, rec_filter_func=None, only_exist=True + ) -> dict: """ Collect different artifacts based on recorder after filtering. @@ -215,7 +224,8 @@ def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> rec for rec in recs if ( - (self.status is None or rec.status in self.status) and (rec_filter_func is None or rec_filter_func(rec)) + (self.status is None or rec.status in self.status) + and (rec_filter_func is None or rec_filter_func(rec)) ) ] @@ -235,7 +245,9 @@ def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> except LoadObjectError as e: if only_exist: # only collect existing artifact - logger.warning(f"Fail to load {self.artifacts_path[key]} and it is ignored.") + logger.warning( + f"Fail to load {self.artifacts_path[key]} and it is ignored." + ) continue raise e # give user some warning if the values are overridden diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index bd98e501db..3e9cd580cb 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -124,7 +124,9 @@ def handler_mod(task: dict, rolling_gen): pass -def trunc_segments(ta: TimeAdjuster, segments: Dict[str, pd.Timestamp], days, test_key="test"): +def trunc_segments( + ta: TimeAdjuster, segments: Dict[str, pd.Timestamp], days, test_key="test" +): """ To avoid the leakage of future information, the segments should be truncated according to the test start_time @@ -222,7 +224,9 @@ def gen_following_tasks(self, task: dict, test_end: pd.Timestamp) -> List[dict]: break prev_seg = segments - t = self.task_copy_func(task) # deepcopy is necessary to avoid replace task inplace + t = self.task_copy_func( + task + ) # deepcopy is necessary to avoid replace task inplace self._update_task_segs(t, segments) yield t @@ -284,11 +288,16 @@ def generate(self, task: dict) -> List[dict]: # First rolling # 1) prepare the end point - segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"])) + segments: dict = copy.deepcopy( + self.ta.align_seg(t["dataset"]["kwargs"]["segments"]) + ) test_end = transform_end_date(segments[self.test_key][1]) # 2) and init test segments test_start_idx = self.ta.align_idx(segments[self.test_key][0]) - segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1)) + segments[self.test_key] = ( + self.ta.get(test_start_idx), + self.ta.get(test_start_idx + self.step - 1), + ) if self.trunc_days is not None: trunc_segments(self.ta, segments, self.trunc_days, self.test_key) @@ -345,7 +354,9 @@ def generate(self, task: dict): # adjust segment segments = self.ta.align_seg(t["dataset"]["kwargs"]["segments"]) - trunc_segments(self.ta, segments, days=hr + self.label_leak_n, test_key=self.test_key) + trunc_segments( + self.ta, segments, days=hr + self.label_leak_n, test_key=self.test_key + ) t["dataset"]["kwargs"]["segments"] = segments res.append(t) return res diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index 7fe9f58d66..63ea8d2c40 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -92,7 +92,9 @@ def __init__(self, task_pool: str): task_pool: str the name of Collection in MongoDB """ - self.task_pool: pymongo.collection.Collection = getattr(get_mongodb(), task_pool) + self.task_pool: pymongo.collection.Collection = getattr( + get_mongodb(), task_pool + ) self.logger = get_module_logger(self.__class__.__name__) self.logger.info(f"task_pool:{task_pool}") @@ -110,7 +112,9 @@ def _encode_task(self, task): for prefix in self.ENCODE_FIELDS_PREFIX: for k in list(task.keys()): if k.startswith(prefix): - task[k] = Binary(pickle.dumps(task[k], protocol=C.dump_protocol_version)) + task[k] = Binary( + pickle.dumps(task[k], protocol=C.dump_protocol_version) + ) return task def _decode_task(self, task): @@ -275,7 +279,9 @@ def fetch_task(self, query={}, status=STATUS_WAITING) -> dict: query = self._decode_query(query) query.update({"status": status}) task = self.task_pool.find_one_and_update( - query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)] + query, + {"$set": {"status": self.STATUS_RUNNING}}, + sort=[("priority", pymongo.DESCENDING)], ) # null will be at the top after sorting when using ASCENDING, so the larger the number higher, the higher the priority if task is None: @@ -300,10 +306,15 @@ def safe_fetch_task(self, query={}, status=STATUS_WAITING): task = self.fetch_task(query=query, status=status) try: yield task - except (Exception, KeyboardInterrupt): # KeyboardInterrupt is not a subclass of Exception + except ( + Exception, + KeyboardInterrupt, + ): # KeyboardInterrupt is not a subclass of Exception if task is not None: self.logger.info("Returning task before raising error") - self.return_task(task, status=status) # return task as the original status + self.return_task( + task, status=status + ) # return task as the original status self.logger.info("Task returned") raise @@ -363,7 +374,12 @@ def commit_task_res(self, task, res, status=STATUS_DONE): status = TaskManager.STATUS_DONE self.task_pool.update_one( {"_id": task["_id"]}, - {"$set": {"status": status, "res": Binary(pickle.dumps(res, protocol=C.dump_protocol_version))}}, + { + "$set": { + "status": status, + "res": Binary(pickle.dumps(res, protocol=C.dump_protocol_version)), + } + }, ) def return_task(self, task, status=STATUS_WAITING): @@ -466,7 +482,9 @@ def wait(self, query={}): last_undone_n = self._get_undone_n(task_stat) if last_undone_n == 0: return - self.logger.warning(f"Waiting for {last_undone_n} undone tasks. Please make sure they are running.") + self.logger.warning( + f"Waiting for {last_undone_n} undone tasks. Please make sure they are running." + ) with tqdm(total=total, initial=total - last_undone_n) as pbar: while True: time.sleep(10) @@ -537,7 +555,9 @@ def (task_def, \**kwargs) -> elif before_status == TaskManager.STATUS_PART_DONE: param = task["res"] else: - raise ValueError("The fetched task must be `STATUS_WAITING` or `STATUS_PART_DONE`!") + raise ValueError( + "The fetched task must be `STATUS_WAITING` or `STATUS_PART_DONE`!" + ) if force_release: with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: res = executor.submit(task_func, param, **kwargs).result() diff --git a/qlib/workflow/task/utils.py b/qlib/workflow/task/utils.py index 4b4a7c06b8..b3bb1a24f5 100644 --- a/qlib/workflow/task/utils.py +++ b/qlib/workflow/task/utils.py @@ -50,7 +50,9 @@ def get_mongodb() -> Database: try: cfg = C["mongo"] except KeyError: - get_module_logger("task").error("Please configure `C['mongo']` before using TaskManager") + get_module_logger("task").error( + "Please configure `C['mongo']` before using TaskManager" + ) raise get_module_logger("task").info(f"mongo config:{cfg}") client = MongoClient(cfg["task_url"]) @@ -196,7 +198,9 @@ def align_seg(self, segment: Union[dict, tuple]) -> Union[dict, tuple]: if isinstance(segment, dict): return {k: self.align_seg(seg) for k, seg in segment.items()} elif isinstance(segment, (tuple, list)): - return self.align_time(segment[0], tp_type="start"), self.align_time(segment[1], tp_type="end") + return self.align_time(segment[0], tp_type="start"), self.align_time( + segment[1], tp_type="end" + ) else: raise NotImplementedError(f"This type of input is not supported") @@ -265,7 +269,9 @@ def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple: shift will raise error if the index(both start and end) is out of self.cal """ if isinstance(seg, tuple): - start_idx, end_idx = self.align_idx(seg[0], tp_type="start"), self.align_idx(seg[1], tp_type="end") + start_idx, end_idx = self.align_idx( + seg[0], tp_type="start" + ), self.align_idx(seg[1], tp_type="end") if rtype == self.SHIFT_SD: start_idx = self._add_step(start_idx, step) end_idx = self._add_step(end_idx, step) @@ -280,7 +286,9 @@ def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple: raise NotImplementedError(f"This type of input is not supported") -def replace_task_handler_with_cache(task: dict, cache_dir: Union[str, Path] = ".") -> dict: +def replace_task_handler_with_cache( + task: dict, cache_dir: Union[str, Path] = "." +) -> dict: """ Replace the handler in task with a cache handler. It will automatically cache the file and save it in cache_dir. diff --git a/qlib/workflow/utils.py b/qlib/workflow/utils.py index 0f48c74f0b..e9eca19fb2 100644 --- a/qlib/workflow/utils.py +++ b/qlib/workflow/utils.py @@ -25,7 +25,9 @@ def experiment_exit_handler(): - If pdb is used in your program, excepthook will not be triggered when it ends. The status will be finished """ sys.excepthook = experiment_exception_hook # handle uncaught exception - atexit.register(R.end_exp, recorder_status=Recorder.STATUS_FI) # will not take effect if experiment ends + atexit.register( + R.end_exp, recorder_status=Recorder.STATUS_FI + ) # will not take effect if experiment ends def experiment_exception_hook(exc_type, value, tb): diff --git a/scripts/check_data_health.py b/scripts/check_data_health.py index c91ad3e715..bc3959bd93 100644 --- a/scripts/check_data_health.py +++ b/scripts/check_data_health.py @@ -28,7 +28,9 @@ def __init__( missing_data_num=0, ): assert csv_path or qlib_dir, "One of csv_path or qlib_dir should be provided." - assert not (csv_path and qlib_dir), "Only one of csv_path or qlib_dir should be provided." + assert not ( + csv_path and qlib_dir + ), "Only one of csv_path or qlib_dir should be provided." self.data = {} self.problems = {} @@ -50,7 +52,9 @@ def __init__( def load_qlib_data(self): instruments = D.instruments(market="all") - instrument_list = D.list_instruments(instruments=instruments, as_list=True, freq=self.freq) + instrument_list = D.list_instruments( + instruments=instruments, as_list=True, freq=self.freq + ) required_fields = ["$open", "$close", "$low", "$high", "$volume", "$factor"] for instrument in instrument_list: df = D.features([instrument], required_fields, freq=self.freq) @@ -79,7 +83,11 @@ def check_missing_data(self) -> Optional[pd.DataFrame]: "volume": [], } for filename, df in self.data.items(): - missing_data_columns = df.isnull().sum()[df.isnull().sum() > self.missing_data_num].index.tolist() + missing_data_columns = ( + df.isnull() + .sum()[df.isnull().sum() > self.missing_data_num] + .index.tolist() + ) if len(missing_data_columns) > 0: result_dict["instruments"].append(filename) result_dict["open"].append(df.isnull().sum()["open"]) @@ -108,12 +116,18 @@ def check_large_step_changes(self) -> Optional[pd.DataFrame]: for col in ["open", "high", "low", "close", "volume"]: if col in df.columns: pct_change = df[col].pct_change(fill_method=None).abs() - threshold = self.large_step_threshold_volume if col == "volume" else self.large_step_threshold_price + threshold = ( + self.large_step_threshold_volume + if col == "volume" + else self.large_step_threshold_price + ) if pct_change.max() > threshold: large_steps = pct_change[pct_change > threshold] result_dict["instruments"].append(filename) result_dict["col_name"].append(col) - result_dict["date"].append(large_steps.index.to_list()[0][1].strftime("%Y-%m-%d")) + result_dict["date"].append( + large_steps.index.to_list()[0][1].strftime("%Y-%m-%d") + ) result_dict["pct_change"].append(pct_change.max()) affected_columns.append(col) @@ -121,7 +135,9 @@ def check_large_step_changes(self) -> Optional[pd.DataFrame]: if not result_df.empty: return result_df else: - logger.info(f"✅ There are no large step changes in the OHLCV column above the threshold.") + logger.info( + f"✅ There are no large step changes in the OHLCV column above the threshold." + ) return None def check_required_columns(self) -> Optional[pd.DataFrame]: @@ -133,7 +149,9 @@ def check_required_columns(self) -> Optional[pd.DataFrame]: } for filename, df in self.data.items(): if not all(column in df.columns for column in required_columns): - missing_required_columns = [column for column in required_columns if column not in df.columns] + missing_required_columns = [ + column for column in required_columns if column not in df.columns + ] result_dict["instruments"].append(filename) result_dict["missing_col"] += missing_required_columns diff --git a/scripts/check_dump_bin.py b/scripts/check_dump_bin.py index 7ae8a26ab0..fa24d1966c 100644 --- a/scripts/check_dump_bin.py +++ b/scripts/check_dump_bin.py @@ -62,12 +62,20 @@ def __init__( redis_port=-1, ) csv_path = Path(csv_path).expanduser() - self.csv_files = sorted(csv_path.glob(f"*{file_suffix}") if csv_path.is_dir() else [csv_path]) + self.csv_files = sorted( + csv_path.glob(f"*{file_suffix}") if csv_path.is_dir() else [csv_path] + ) if check_fields is None: - check_fields = list(map(lambda x: x.name.split(".")[0], bin_path_list[0].glob(f"*.bin"))) + check_fields = list( + map(lambda x: x.name.split(".")[0], bin_path_list[0].glob(f"*.bin")) + ) else: - check_fields = check_fields.split(",") if isinstance(check_fields, str) else check_fields + check_fields = ( + check_fields.split(",") + if isinstance(check_fields, str) + else check_fields + ) self.check_fields = list(map(lambda x: x.strip(), check_fields)) self.qlib_fields = list(map(lambda x: f"${x}", self.check_fields)) self.max_workers = max_workers @@ -82,13 +90,19 @@ def _compare(self, file_path: Path): return self.NOT_IN_FEATURES # qlib data qlib_df = D.features([symbol], self.qlib_fields, freq=self.freq) - qlib_df.rename(columns={_c: _c.strip("$") for _c in qlib_df.columns}, inplace=True) + qlib_df.rename( + columns={_c: _c.strip("$") for _c in qlib_df.columns}, inplace=True + ) # csv data origin_df = pd.read_csv(file_path) - origin_df[self.date_field_name] = pd.to_datetime(origin_df[self.date_field_name]) + origin_df[self.date_field_name] = pd.to_datetime( + origin_df[self.date_field_name] + ) if self.symbol_field_name not in origin_df.columns: origin_df[self.symbol_field_name] = symbol - origin_df.set_index([self.symbol_field_name, self.date_field_name], inplace=True) + origin_df.set_index( + [self.symbol_field_name, self.date_field_name], inplace=True + ) origin_df.index.names = qlib_df.index.names origin_df = origin_df.reindex(qlib_df.index) try: @@ -116,7 +130,9 @@ def check(self): compare_false = [] with tqdm(total=len(self.csv_files)) as p_bar: with ProcessPoolExecutor(max_workers=self.max_workers) as executor: - for file_path, _check_res in zip(self.csv_files, executor.map(self._compare, self.csv_files)): + for file_path, _check_res in zip( + self.csv_files, executor.map(self._compare, self.csv_files) + ): symbol = file_path.name.strip(self.file_suffix) if _check_res == self.NOT_IN_FEATURES: not_in_features.append(symbol) diff --git a/scripts/collect_info.py b/scripts/collect_info.py index 9e7a6395ef..6cae858798 100644 --- a/scripts/collect_info.py +++ b/scripts/collect_info.py @@ -45,7 +45,7 @@ def qlib(self): "pymongo", "loguru", "lightgbm", - "gym", + "gymnasium", "cvxpy", "joblib", "matplotlib", diff --git a/scripts/data_collector/baostock_5min/collector.py b/scripts/data_collector/baostock_5min/collector.py index 0a69beefb8..36ca5adaa8 100644 --- a/scripts/data_collector/baostock_5min/collector.py +++ b/scripts/data_collector/baostock_5min/collector.py @@ -20,7 +20,10 @@ sys.path.append(str(CUR_DIR.parent.parent)) from data_collector.base import BaseCollector, BaseNormalize, BaseRun -from data_collector.utils import generate_minutes_calendar_from_daily, calc_adjusted_price +from data_collector.utils import ( + generate_minutes_calendar_from_daily, + calc_adjusted_price, +) class BaostockCollectorHS3005min(BaseCollector): @@ -87,17 +90,41 @@ def get_trade_calendar(self): @staticmethod def process_interval(interval: str): if interval == "1d": - return {"interval": "d", "fields": "date,code,open,high,low,close,volume,amount,adjustflag"} + return { + "interval": "d", + "fields": "date,code,open,high,low,close,volume,amount,adjustflag", + } if interval == "5min": - return {"interval": "5", "fields": "date,time,code,open,high,low,close,volume,amount,adjustflag"} + return { + "interval": "5", + "fields": "date,time,code,open,high,low,close,volume,amount,adjustflag", + } def get_data( - self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp + self, + symbol: str, + interval: str, + start_datetime: pd.Timestamp, + end_datetime: pd.Timestamp, ) -> pd.DataFrame: df = self.get_data_from_remote( - symbol=symbol, interval=interval, start_datetime=start_datetime, end_datetime=end_datetime + symbol=symbol, + interval=interval, + start_datetime=start_datetime, + end_datetime=end_datetime, ) - df.columns = ["date", "time", "symbol", "open", "high", "low", "close", "volume", "amount", "adjustflag"] + df.columns = [ + "date", + "time", + "symbol", + "open", + "high", + "low", + "close", + "volume", + "amount", + "adjustflag", + ] df["time"] = pd.to_datetime(df["time"], format="%Y%m%d%H%M%S%f") df["date"] = df["time"].dt.strftime("%Y-%m-%d %H:%M:%S") df["date"] = df["date"].map(lambda x: pd.Timestamp(x) - pd.Timedelta(minutes=5)) @@ -107,7 +134,10 @@ def get_data( @staticmethod def get_data_from_remote( - symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp + symbol: str, + interval: str, + start_datetime: pd.Timestamp, + end_datetime: pd.Timestamp, ) -> pd.DataFrame: df = pd.DataFrame() rs = bs.query_history_k_data_plus( @@ -115,7 +145,9 @@ def get_data_from_remote( BaostockCollectorHS3005min.process_interval(interval=interval)["fields"], start_date=str(start_datetime.strftime("%Y-%m-%d")), end_date=str(end_datetime.strftime("%Y-%m-%d")), - frequency=BaostockCollectorHS3005min.process_interval(interval=interval)["interval"], + frequency=BaostockCollectorHS3005min.process_interval(interval=interval)[ + "interval" + ], adjustflag="3", ) if rs.error_code == "0" and len(rs.data) > 0: @@ -151,7 +183,11 @@ class BaostockNormalizeHS3005min(BaseNormalize): PM_RANGE = ("13:00:00", "14:59:00") def __init__( - self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs + self, + qlib_data_1d_dir: [str, Path], + date_field_name: str = "date", + symbol_field_name: str = "symbol", + **kwargs, ): """ @@ -166,8 +202,14 @@ def __init__( """ bs.login() qlib.init(provider_uri=qlib_data_1d_dir) - self.all_1d_data = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day") - super(BaostockNormalizeHS3005min, self).__init__(date_field_name, symbol_field_name) + self.all_1d_data = D.features( + D.instruments("all"), + ["$paused", "$volume", "$factor", "$close"], + freq="day", + ) + super(BaostockNormalizeHS3005min, self).__init__( + date_field_name, symbol_field_name + ) @staticmethod def calc_change(df: pd.DataFrame, last_close: float) -> pd.Series: @@ -209,11 +251,19 @@ def normalize_baostock( if calendar_list is not None: df = df.reindex( pd.DataFrame(index=calendar_list) - .loc[pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date() + pd.Timedelta(days=1)] + .loc[ + pd.Timestamp(df.index.min()) + .date() : pd.Timestamp(df.index.max()) + .date() + + pd.Timedelta(days=1) + ] .index ) df.sort_index(inplace=True) - df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), list(set(df.columns) - {symbol_field_name})] = np.nan + df.loc[ + (df["volume"] <= 0) | np.isnan(df["volume"]), + list(set(df.columns) - {symbol_field_name}), + ] = np.nan df["change"] = BaostockNormalizeHS3005min.calc_change(df, last_close) @@ -244,14 +294,23 @@ def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]: def normalize(self, df: pd.DataFrame) -> pd.DataFrame: # normalize - df = self.normalize_baostock(df, self._calendar_list, self._date_field_name, self._symbol_field_name) + df = self.normalize_baostock( + df, self._calendar_list, self._date_field_name, self._symbol_field_name + ) # adjusted price df = self.adjusted_price(df) return df class Run(BaseRun): - def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="5min", region="HS300"): + def __init__( + self, + source_dir=None, + normalize_dir=None, + max_workers=1, + interval="5min", + region="HS300", + ): """ Changed the default value of: scripts.data_collector.base.BaseRun. """ @@ -291,7 +350,9 @@ def download_data( # get hs300 5min data $ python collector.py download_data --source_dir ~/.qlib/stock_data/source/hs300_5min_original --start 2022-01-01 --end 2022-01-30 --interval 5min --region HS300 """ - super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums) + super(Run, self).download_data( + max_collector_count, delay, start, end, check_data_length, limit_nums + ) def normalize_data( self, @@ -320,7 +381,10 @@ def normalize_data( "If normalize 5min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir , Reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance" ) super(Run, self).normalize_data( - date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir + date_field_name, + symbol_field_name, + end_date=end_date, + qlib_data_1d_dir=qlib_data_1d_dir, ) diff --git a/scripts/data_collector/base.py b/scripts/data_collector/base.py index 2efc2feadc..576d203145 100644 --- a/scripts/data_collector/base.py +++ b/scripts/data_collector/base.py @@ -22,8 +22,12 @@ class BaseCollector(abc.ABC): NORMAL_FLAG = "NORMAL" DEFAULT_START_DATETIME_1D = pd.Timestamp("2000-01-01") - DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6 - 1)).date() - DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1)).date() + DEFAULT_START_DATETIME_1MIN = pd.Timestamp( + datetime.datetime.now() - pd.Timedelta(days=5 * 6 - 1) + ).date() + DEFAULT_END_DATETIME_1D = pd.Timestamp( + datetime.datetime.now() + pd.Timedelta(days=1) + ).date() DEFAULT_END_DATETIME_1MIN = DEFAULT_END_DATETIME_1D INTERVAL_1min = "1min" @@ -72,7 +76,9 @@ def __init__( self.max_collector_count = max_collector_count self.mini_symbol_map = {} self.interval = interval - self.check_data_length = max(int(check_data_length) if check_data_length is not None else 0, 0) + self.check_data_length = max( + int(check_data_length) if check_data_length is not None else 0, 0 + ) self.start_datetime = self.normalize_start_datetime(start) self.end_datetime = self.normalize_end_datetime(end) @@ -83,7 +89,9 @@ def __init__( try: self.instrument_list = self.instrument_list[: int(limit_nums)] except Exception as e: - logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored") + logger.warning( + f"Cannot use limit_nums={limit_nums}, the parameter will be ignored" + ) def normalize_start_datetime(self, start_datetime: [str, pd.Timestamp] = None): return ( @@ -110,7 +118,11 @@ def normalize_symbol(self, symbol: str): @abc.abstractmethod def get_data( - self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp + self, + symbol: str, + interval: str, + start_datetime: pd.Timestamp, + end_datetime: pd.Timestamp, ) -> pd.DataFrame: """get data with symbol @@ -141,7 +153,9 @@ def _simple_collector(self, symbol: str): """ self.sleep() - df = self.get_data(symbol, self.interval, self.start_datetime, self.end_datetime) + df = self.get_data( + symbol, self.interval, self.start_datetime, self.end_datetime + ) _result = self.NORMAL_FLAG if self.check_data_length > 0: _result = self.cache_small_data(symbol, df) @@ -174,7 +188,9 @@ def save_instrument(self, symbol, df: pd.DataFrame): def cache_small_data(self, symbol, df): if len(df) < self.check_data_length: - logger.warning(f"the number of trading days of {symbol} is less than {self.check_data_length}!") + logger.warning( + f"the number of trading days of {symbol} is less than {self.check_data_length}!" + ) _temp = self.mini_symbol_map.setdefault(symbol, []) _temp.append(df.copy()) return self.CACHE_FLAG @@ -210,14 +226,22 @@ def collector_data(self): for _symbol, _df_list in self.mini_symbol_map.items(): _df = pd.concat(_df_list, sort=False) if not _df.empty: - self.save_instrument(_symbol, _df.drop_duplicates(["date"]).sort_values(["date"])) + self.save_instrument( + _symbol, _df.drop_duplicates(["date"]).sort_values(["date"]) + ) if self.mini_symbol_map: - logger.warning(f"less than {self.check_data_length} instrument list: {list(self.mini_symbol_map.keys())}") - logger.info(f"total {len(self.instrument_list)}, error: {len(set(instrument_list))}") + logger.warning( + f"less than {self.check_data_length} instrument list: {list(self.mini_symbol_map.keys())}" + ) + logger.info( + f"total {len(self.instrument_list)}, error: {len(set(instrument_list))}" + ) class BaseNormalize(abc.ABC): - def __init__(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs): + def __init__( + self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs + ): """ Parameters @@ -282,7 +306,9 @@ def __init__( self._max_workers = max_workers self._normalize_obj = normalize_class( - date_field_name=date_field_name, symbol_field_name=symbol_field_name, **kwargs + date_field_name=date_field_name, + symbol_field_name=symbol_field_name, + **kwargs, ) def _executor(self, file_path: Path): @@ -298,14 +324,19 @@ def _executor(self, file_path: Path): file_path, dtype={self._symbol_field_name: str}, keep_default_na=False, - na_values={col: symbol_na if col == self._symbol_field_name else default_na for col in columns}, + na_values={ + col: symbol_na if col == self._symbol_field_name else default_na + for col in columns + }, ) # NOTE: It has been reported that there may be some problems here, and the specific issues will be dealt with when they are identified. df = self._normalize_obj.normalize(df) if df is not None and not df.empty: if self._end_date is not None: - _mask = pd.to_datetime(df[self._date_field_name]) <= pd.Timestamp(self._end_date) + _mask = pd.to_datetime(df[self._date_field_name]) <= pd.Timestamp( + self._end_date + ) df = df[_mask] df.to_csv(self._target_dir.joinpath(file_path.name), index=False) @@ -320,7 +351,9 @@ def normalize(self): class BaseRun(abc.ABC): - def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d"): + def __init__( + self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d" + ): """ Parameters @@ -398,7 +431,9 @@ def download_data( $ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m """ - _class = getattr(self._cur_module, self.collector_class_name) # type: Type[BaseCollector] + _class = getattr( + self._cur_module, self.collector_class_name + ) # type: Type[BaseCollector] _class( self.source_dir, max_workers=self.max_workers, @@ -412,7 +447,9 @@ def download_data( **kwargs, ).collector_data() - def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs): + def normalize_data( + self, date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs + ): """normalize data Parameters diff --git a/scripts/data_collector/br_index/collector.py b/scripts/data_collector/br_index/collector.py index 04b2f96d9f..36ff8b7c82 100644 --- a/scripts/data_collector/br_index/collector.py +++ b/scripts/data_collector/br_index/collector.py @@ -32,7 +32,11 @@ def __init__( retry_sleep: int = 3, ): super(IBOVIndex, self).__init__( - index_name=index_name, qlib_dir=qlib_dir, freq=freq, request_retry=request_retry, retry_sleep=retry_sleep + index_name=index_name, + qlib_dir=qlib_dir, + freq=freq, + request_retry=request_retry, + retry_sleep=retry_sleep, ) self.today: datetime = datetime.date.today() @@ -97,13 +101,17 @@ def get_four_month_period(self): now = datetime.datetime.now() current_year = now.year current_month = now.month - for year in [item for item in range(init_year, current_year)]: # pylint: disable=R1721 + for year in [ + item for item in range(init_year, current_year) + ]: # pylint: disable=R1721 for el in four_months_period: self.years_4_month_periods.append(str(year) + "_" + el) # For current year the logic must be a little different current_4_month_period = self.get_current_4_month_period(current_month) for i in range(int(current_4_month_period[0])): - self.years_4_month_periods.append(str(current_year) + "_" + str(i + 1) + "Q") + self.years_4_month_periods.append( + str(current_year) + "_" + str(i + 1) + "Q" + ) return self.years_4_month_periods def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame: @@ -122,7 +130,9 @@ def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame: logger.info("Formatting Datetime") if self.freq != "day": inst_df[self.END_DATE_FIELD] = inst_df[self.END_DATE_FIELD].apply( - lambda x: (pd.Timestamp(x) + pd.Timedelta(hours=23, minutes=59)).strftime("%Y-%m-%d %H:%M:%S") + lambda x: ( + pd.Timestamp(x) + pd.Timedelta(hours=23, minutes=59) + ).strftime("%Y-%m-%d %H:%M:%S") ) else: inst_df[self.START_DATE_FIELD] = inst_df[self.START_DATE_FIELD].apply( @@ -187,10 +197,14 @@ def get_changes(self): df_changes_list = [] for i in tqdm(range(len(self.years_4_month_periods) - 1)): df = pd.read_csv( - self.ibov_index_composition.format(self.years_4_month_periods[i]), on_bad_lines="skip" + self.ibov_index_composition.format(self.years_4_month_periods[i]), + on_bad_lines="skip", )["symbol"] df_ = pd.read_csv( - self.ibov_index_composition.format(self.years_4_month_periods[i + 1]), on_bad_lines="skip" + self.ibov_index_composition.format( + self.years_4_month_periods[i + 1] + ), + on_bad_lines="skip", )["symbol"] ## Remove Dataframe @@ -216,7 +230,11 @@ def get_changes(self): ) list_add = list(df_[~df_.isin(df)]) df_added = pd.DataFrame( - {"date": len(list_add) * [add_date], "type": len(list_add) * ["add"], "symbol": list_add} + { + "date": len(list_add) * [add_date], + "type": len(list_add) * ["add"], + "symbol": list_add, + } ) df_changes_list.append(pd.concat([df_added, df_removed], sort=False)) @@ -226,7 +244,11 @@ def get_changes(self): return df except Exception as E: - logger.error("An error occured while downloading 2008 index composition - {}".format(E)) + logger.error( + "An error occured while downloading 2008 index composition - {}".format( + E + ) + ) def get_new_companies(self): """ @@ -257,17 +279,26 @@ def get_new_companies(self): ## Get index composition df_index = pd.read_csv( - self.ibov_index_composition.format(self.year + "_" + self.current_4_month_period), on_bad_lines="skip" + self.ibov_index_composition.format( + self.year + "_" + self.current_4_month_period + ), + on_bad_lines="skip", ) df_date_first_added = pd.read_csv( - self.ibov_index_composition.format("date_first_added_" + self.year + "_" + self.current_4_month_period), + self.ibov_index_composition.format( + "date_first_added_" + self.year + "_" + self.current_4_month_period + ), on_bad_lines="skip", ) - df = df_index.merge(df_date_first_added, on="symbol")[["symbol", "Date First Added"]] + df = df_index.merge(df_date_first_added, on="symbol")[ + ["symbol", "Date First Added"] + ] df[self.START_DATE_FIELD] = df["Date First Added"].map(self.format_quarter) # end_date will be our current quarter + 1, since the IBOV index updates itself every quarter - df[self.END_DATE_FIELD] = self.year + "-" + quarter_dict[self.current_4_month_period] + df[self.END_DATE_FIELD] = ( + self.year + "-" + quarter_dict[self.current_4_month_period] + ) df = df[["symbol", self.START_DATE_FIELD, self.END_DATE_FIELD]] df["symbol"] = df["symbol"].astype(str) + ".SA" diff --git a/scripts/data_collector/cn_index/collector.py b/scripts/data_collector/cn_index/collector.py index fb6914d24a..c52fea273e 100644 --- a/scripts/data_collector/cn_index/collector.py +++ b/scripts/data_collector/cn_index/collector.py @@ -19,13 +19,15 @@ sys.path.append(str(CUR_DIR.parent.parent)) from data_collector.index import IndexBase -from data_collector.utils import get_calendar_list, get_trading_date_by_shift, deco_retry +from data_collector.utils import ( + get_calendar_list, + get_trading_date_by_shift, + deco_retry, +) from data_collector.utils import get_instruments -NEW_COMPANIES_URL = ( - "https://oss-ch.csindex.com.cn/static/html/csindex/public/uploads/file/autofile/cons/{index_code}cons.xls" -) +NEW_COMPANIES_URL = "https://oss-ch.csindex.com.cn/static/html/csindex/public/uploads/file/autofile/cons/{index_code}cons.xls" INDEX_CHANGES_URL = "https://www.csindex.com.cn/csindex-home/search/search-content?lang=cn&searchInput=%E5%85%B3%E4%BA%8E%E8%B0%83%E6%95%B4%E6%B2%AA%E6%B7%B1300%E5%92%8C%E4%B8%AD%E8%AF%81%E9%A6%99%E6%B8%AF100%E7%AD%89%E6%8C%87%E6%95%B0%E6%A0%B7%E6%9C%AC&pageNum={page_num}&pageSize={page_size}&sortField=date&dateRange=all&contentType=announcement" @@ -114,10 +116,14 @@ def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame: """ if self.freq != "day": inst_df[self.START_DATE_FIELD] = inst_df[self.START_DATE_FIELD].apply( - lambda x: (pd.Timestamp(x) + pd.Timedelta(hours=9, minutes=30)).strftime("%Y-%m-%d %H:%M:%S") + lambda x: ( + pd.Timestamp(x) + pd.Timedelta(hours=9, minutes=30) + ).strftime("%Y-%m-%d %H:%M:%S") ) inst_df[self.END_DATE_FIELD] = inst_df[self.END_DATE_FIELD].apply( - lambda x: (pd.Timestamp(x) + pd.Timedelta(hours=15, minutes=0)).strftime("%Y-%m-%d %H:%M:%S") + lambda x: ( + pd.Timestamp(x) + pd.Timedelta(hours=15, minutes=0) + ).strftime("%Y-%m-%d %H:%M:%S") ) return inst_df @@ -158,9 +164,15 @@ def normalize_symbol(symbol: str) -> str: symbol """ symbol = f"{int(symbol):06}" - return f"SH{symbol}" if symbol.startswith("60") or symbol.startswith("688") else f"SZ{symbol}" - - def _parse_excel(self, excel_url: str, add_date: pd.Timestamp, remove_date: pd.Timestamp) -> pd.DataFrame: + return ( + f"SH{symbol}" + if symbol.startswith("60") or symbol.startswith("688") + else f"SZ{symbol}" + ) + + def _parse_excel( + self, excel_url: str, add_date: pd.Timestamp, remove_date: pd.Timestamp + ) -> pd.DataFrame: content = retry_request(excel_url, exclude_status=[404]).content _io = BytesIO(content) df_map = pd.read_excel(_io, sheet_name=None) @@ -169,7 +181,10 @@ def _parse_excel(self, excel_url: str, add_date: pd.Timestamp, remove_date: pd.T ).open("wb") as fp: fp.write(content) tmp = [] - for _s_name, _type, _date in [("调入", self.ADD, add_date), ("调出", self.REMOVE, remove_date)]: + for _s_name, _type, _date in [ + ("调入", self.ADD, add_date), + ("调出", self.REMOVE, remove_date), + ]: _df = df_map[_s_name] _df = _df.loc[_df["指数代码"] == self.index_code, ["证券代码"]] _df = _df.applymap(self.normalize_symbol) @@ -180,7 +195,9 @@ def _parse_excel(self, excel_url: str, add_date: pd.Timestamp, remove_date: pd.T df = pd.concat(tmp) return df - def _parse_table(self, content: str, add_date: pd.DataFrame, remove_date: pd.DataFrame) -> pd.DataFrame: + def _parse_table( + self, content: str, add_date: pd.DataFrame, remove_date: pd.DataFrame + ) -> pd.DataFrame: df = pd.DataFrame() _tmp_count = 0 for _df in pd.read_html(content): @@ -242,13 +259,17 @@ def _read_change_from_url(self, url: str) -> pd.DataFrame: if "沪深300" not in title: return pd.DataFrame() - logger.info(f"load index data from https://www.csindex.com.cn/#/about/newsDetail?id={url.split('id=')[-1]}") + logger.info( + f"load index data from https://www.csindex.com.cn/#/about/newsDetail?id={url.split('id=')[-1]}" + ) _text = resp["content"] date_list = re.findall(r"(\d{4}).*?年.*?(\d+).*?月.*?(\d+).*?日", _text) if len(date_list) >= 2: add_date = pd.Timestamp("-".join(date_list[0])) else: - _date = pd.Timestamp("-".join(re.findall(r"(\d{4}).*?年.*?(\d+).*?月", _text)[0])) + _date = pd.Timestamp( + "-".join(re.findall(r"(\d{4}).*?年.*?(\d+).*?月", _text)[0]) + ) add_date = get_trading_date_by_shift(self.calendar_list, _date, shift=0) if "盘后" in _text or "市后" in _text: add_date = get_trading_date_by_shift(self.calendar_list, add_date, shift=1) @@ -262,11 +283,15 @@ def _read_change_from_url(self, url: str) -> pd.DataFrame: if excel_url_list: excel_url = excel_url_list[0] if not excel_url.startswith("http"): - excel_url = excel_url if excel_url.startswith("/") else "/" + excel_url + excel_url = ( + excel_url if excel_url.startswith("/") else "/" + excel_url + ) excel_url = f"http://www.csindex.com.cn{excel_url}" if excel_url: try: - logger.info(f"get {add_date} changes from the excel, title={title}, excel_url={excel_url}") + logger.info( + f"get {add_date} changes from the excel, title={title}, excel_url={excel_url}" + ) df = self._parse_excel(excel_url, add_date, remove_date) except ValueError: logger.info( @@ -289,8 +314,12 @@ def _get_change_notices_url(self) -> Iterable[str]: """ page_num = 1 page_size = 5 - data = retry_request(self.changes_url.format(page_size=page_size, page_num=page_num)).json() - data = retry_request(self.changes_url.format(page_size=data["total"], page_num=page_num)).json() + data = retry_request( + self.changes_url.format(page_size=page_size, page_num=page_num) + ).json() + data = retry_request( + self.changes_url.format(page_size=data["total"], page_num=page_num) + ).json() for item in data["data"]: yield f"https://www.csindex.com.cn/csindex-home/announcement/queryAnnouncementById?id={item['id']}" @@ -319,7 +348,9 @@ def get_new_companies(self) -> pd.DataFrame: df = pd.read_excel(_io) df = df.iloc[:, [0, 4]] df.columns = [self.END_DATE_FIELD, self.SYMBOL_FIELD_NAME] - df[self.SYMBOL_FIELD_NAME] = df[self.SYMBOL_FIELD_NAME].map(self.normalize_symbol) + df[self.SYMBOL_FIELD_NAME] = df[self.SYMBOL_FIELD_NAME].map( + self.normalize_symbol + ) df[self.END_DATE_FIELD] = pd.to_datetime(df[self.END_DATE_FIELD].astype(str)) df[self.START_DATE_FIELD] = self.bench_start_date logger.info("end of get new companies.") @@ -396,7 +427,9 @@ def get_history_companies(self) -> pd.DataFrame: """ bs.login() today = pd.Timestamp.now() - date_range = pd.DataFrame(pd.date_range(start="2007-01-15", end=today, freq="7D"))[0].dt.date + date_range = pd.DataFrame( + pd.date_range(start="2007-01-15", end=today, freq="7D") + )[0].dt.date ret_list = [] for date in tqdm(date_range, desc="Download CSI500"): result = self.get_data_from_baostock(date) diff --git a/scripts/data_collector/contrib/fill_cn_1min_data/fill_cn_1min_data.py b/scripts/data_collector/contrib/fill_cn_1min_data/fill_cn_1min_data.py index 0a721298d3..8b475328aa 100644 --- a/scripts/data_collector/contrib/fill_cn_1min_data/fill_cn_1min_data.py +++ b/scripts/data_collector/contrib/fill_cn_1min_data/fill_cn_1min_data.py @@ -17,7 +17,9 @@ from data_collector.utils import generate_minutes_calendar_from_daily -def get_date_range(data_1min_dir: Path, max_workers: int = 16, date_field_name: str = "date"): +def get_date_range( + data_1min_dir: Path, max_workers: int = 16, date_field_name: str = "date" +): csv_files = list(data_1min_dir.glob("*.csv")) min_date = None max_date = None @@ -28,9 +30,13 @@ def get_date_range(data_1min_dir: Path, max_workers: int = 16, date_field_name: _dates = pd.to_datetime(_result[date_field_name]) _tmp_min = _dates.min() - min_date = min(min_date, _tmp_min) if min_date is not None else _tmp_min + min_date = ( + min(min_date, _tmp_min) if min_date is not None else _tmp_min + ) _tmp_max = _dates.max() - max_date = max(max_date, _tmp_max) if max_date is not None else _tmp_max + max_date = ( + max(max_date, _tmp_max) if max_date is not None else _tmp_max + ) p_bar.update() return min_date, max_date @@ -69,9 +75,13 @@ def fill_1min_using_1d( symbols_1min = get_symbols(data_1min_dir) qlib.init(provider_uri=str(qlib_data_1d_dir)) - data_1d = D.features(D.instruments("all"), ["$close"], min_date, max_date, freq="day") + data_1d = D.features( + D.instruments("all"), ["$close"], min_date, max_date, freq="day" + ) - miss_symbols = set(data_1d.index.get_level_values(level="instrument").unique()) - set(symbols_1min) + miss_symbols = set( + data_1d.index.get_level_values(level="instrument").unique() + ) - set(symbols_1min) if not miss_symbols: logger.warning("More symbols in 1min than 1d, no padding required") return diff --git a/scripts/data_collector/contrib/future_trading_date_collector/future_trading_date_collector.py b/scripts/data_collector/contrib/future_trading_date_collector/future_trading_date_collector.py index 939ba7f6ad..27f9bf40e4 100644 --- a/scripts/data_collector/contrib/future_trading_date_collector/future_trading_date_collector.py +++ b/scripts/data_collector/contrib/future_trading_date_collector/future_trading_date_collector.py @@ -40,7 +40,9 @@ def generate_qlib_calendar(date_list: List[str], freq: str) -> List[str]: return date_list elif freq == "1min": date_list = generate_minutes_calendar_from_daily(date_list, freq=freq).tolist() - return list(map(lambda x: pd.Timestamp(x).strftime("%Y-%m-%d %H:%M:%S"), date_list)) + return list( + map(lambda x: pd.Timestamp(x).strftime("%Y-%m-%d %H:%M:%S"), date_list) + ) else: raise ValueError(f"Unsupported freq: {freq}") @@ -70,7 +72,9 @@ def future_calendar_collector(qlib_dir: [str, Path], freq: str = "day"): start_year = pd.Timestamp.now().year else: start_year = pd.Timestamp(daily_calendar.iloc[-1, 0]).year - rs = bs.query_trade_dates(start_date=pd.Timestamp(f"{start_year}-01-01"), end_date=f"{end_year}-12-31") + rs = bs.query_trade_dates( + start_date=pd.Timestamp(f"{start_year}-01-01"), end_date=f"{end_year}-12-31" + ) data_list = [] while (rs.error_code == "0") & rs.next(): _row_data = rs.get_row_data() diff --git a/scripts/data_collector/crypto/collector.py b/scripts/data_collector/crypto/collector.py index 302b89e200..4e3bb12661 100644 --- a/scripts/data_collector/crypto/collector.py +++ b/scripts/data_collector/crypto/collector.py @@ -106,7 +106,9 @@ def __init__( def init_datetime(self): if self.interval == self.INTERVAL_1min: - self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN) + self.start_datetime = max( + self.start_datetime, self.DEFAULT_START_DATETIME_1MIN + ) elif self.interval == self.INTERVAL_1d: pass else: @@ -134,14 +136,22 @@ def get_data_from_remote(symbol, interval, start, end): error_msg = f"{symbol}-{interval}-{start}-{end}" try: cg = CoinGeckoAPI() - data = cg.get_coin_market_chart_by_id(id=symbol, vs_currency="usd", days="max") + data = cg.get_coin_market_chart_by_id( + id=symbol, vs_currency="usd", days="max" + ) _resp = pd.DataFrame(columns=["date"] + list(data.keys())) - _resp["date"] = [dt.fromtimestamp(mktime(time.localtime(x[0] / 1000))) for x in data["prices"]] + _resp["date"] = [ + dt.fromtimestamp(mktime(time.localtime(x[0] / 1000))) + for x in data["prices"] + ] for key in data.keys(): _resp[key] = [x[1] for x in data[key]] _resp["date"] = pd.to_datetime(_resp["date"]) _resp["date"] = [x.date() for x in _resp["date"]] - _resp = _resp[(_resp["date"] < pd.to_datetime(end).date()) & (_resp["date"] > pd.to_datetime(start).date())] + _resp = _resp[ + (_resp["date"] < pd.to_datetime(end).date()) + & (_resp["date"] > pd.to_datetime(start).date()) + ] if _resp.shape[0] != 0: _resp = _resp.reset_index() if isinstance(_resp, pd.DataFrame): @@ -150,7 +160,11 @@ def get_data_from_remote(symbol, interval, start, end): logger.warning(f"{error_msg}:{e}") def get_data( - self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp + self, + symbol: str, + interval: str, + start_datetime: pd.Timestamp, + end_datetime: pd.Timestamp, ) -> [pd.DataFrame]: def _get_simple(start_, end_): self.sleep() @@ -204,7 +218,9 @@ def normalize_crypto( df = df.reindex( pd.DataFrame(index=calendar_list) .loc[ - pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date() + pd.Timestamp(df.index.min()) + .date() : pd.Timestamp(df.index.max()) + .date() + pd.Timedelta(hours=23, minutes=59) ] .index @@ -215,7 +231,9 @@ def normalize_crypto( return df.reset_index() def normalize(self, df: pd.DataFrame) -> pd.DataFrame: - df = self.normalize_crypto(df, self._calendar_list, self._date_field_name, self._symbol_field_name) + df = self.normalize_crypto( + df, self._calendar_list, self._date_field_name, self._symbol_field_name + ) return df @@ -225,7 +243,9 @@ def _get_calendar_list(self): class Run(BaseRun): - def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d"): + def __init__( + self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d" + ): """ Parameters @@ -287,9 +307,13 @@ def download_data( $ python collector.py download_data --source_dir ~/.qlib/crypto_data/source/1d --start 2015-01-01 --end 2021-11-30 --delay 1 --interval 1d """ - super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums) + super(Run, self).download_data( + max_collector_count, delay, start, end, check_data_length, limit_nums + ) - def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): + def normalize_data( + self, date_field_name: str = "date", symbol_field_name: str = "symbol" + ): """normalize data Parameters diff --git a/scripts/data_collector/fund/collector.py b/scripts/data_collector/fund/collector.py index 937d3931db..028a5cd6b6 100644 --- a/scripts/data_collector/fund/collector.py +++ b/scripts/data_collector/fund/collector.py @@ -75,7 +75,9 @@ def __init__( def init_datetime(self): if self.interval == self.INTERVAL_1min: - self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN) + self.start_datetime = max( + self.start_datetime, self.DEFAULT_START_DATETIME_1MIN + ) elif self.interval == self.INTERVAL_1d: pass else: @@ -105,9 +107,16 @@ def get_data_from_remote(symbol, interval, start, end): try: # TODO: numberOfHistoricalDaysToCrawl should be bigger enough url = INDEX_BENCH_URL.format( - index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end + index_code=symbol, + numberOfHistoricalDaysToCrawl=10000, + startDate=start, + endDate=end, + ) + resp = requests.get( + url, + headers={"referer": "http://fund.eastmoney.com/110022.html"}, + timeout=None, ) - resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"}, timeout=None) if resp.status_code != 200: raise ValueError("request error") @@ -128,7 +137,11 @@ def get_data_from_remote(symbol, interval, start, end): logger.warning(f"{error_msg}:{e}") def get_data( - self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp + self, + symbol: str, + interval: str, + start_datetime: pd.Timestamp, + end_datetime: pd.Timestamp, ) -> [pd.DataFrame]: def _get_simple(start_, end_): self.sleep() @@ -186,7 +199,9 @@ def normalize_fund( df = df.reindex( pd.DataFrame(index=calendar_list) .loc[ - pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date() + pd.Timestamp(df.index.min()) + .date() : pd.Timestamp(df.index.max()) + .date() + pd.Timedelta(hours=23, minutes=59) ] .index @@ -198,7 +213,9 @@ def normalize_fund( def normalize(self, df: pd.DataFrame) -> pd.DataFrame: # normalize - df = self.normalize_fund(df, self._calendar_list, self._date_field_name, self._symbol_field_name) + df = self.normalize_fund( + df, self._calendar_list, self._date_field_name, self._symbol_field_name + ) return df @@ -216,7 +233,14 @@ class FundNormalizeCN1d(FundNormalizeCN, FundNormalize1d): class Run(BaseRun): - def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN): + def __init__( + self, + source_dir=None, + normalize_dir=None, + max_workers=4, + interval="1d", + region=REGION_CN, + ): """ Parameters @@ -281,9 +305,13 @@ def download_data( $ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_data --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d """ - super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums) + super(Run, self).download_data( + max_collector_count, delay, start, end, check_data_length, limit_nums + ) - def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"): + def normalize_data( + self, date_field_name: str = "date", symbol_field_name: str = "symbol" + ): """normalize data Parameters diff --git a/scripts/data_collector/future_calendar_collector.py b/scripts/data_collector/future_calendar_collector.py index 4dfd24e4b7..d04e91e3b5 100644 --- a/scripts/data_collector/future_calendar_collector.py +++ b/scripts/data_collector/future_calendar_collector.py @@ -18,7 +18,9 @@ class CollectorFutureCalendar: calendar_format = "%Y-%m-%d" - def __init__(self, qlib_dir: Union[str, Path], start_date: str = None, end_date: str = None): + def __init__( + self, qlib_dir: Union[str, Path], start_date: str = None, end_date: str = None + ): """ Parameters @@ -35,8 +37,14 @@ def __init__(self, qlib_dir: Union[str, Path], start_date: str = None, end_date: self.future_path = self.qlib_dir.joinpath("calendars/day_future.txt") self._calendar_list = self.calendar_list _latest_date = self._calendar_list[-1] - self.start_date = _latest_date if start_date is None else pd.Timestamp(start_date) - self.end_date = _latest_date + pd.Timedelta(days=365 * 2) if end_date is None else pd.Timestamp(end_date) + self.start_date = ( + _latest_date if start_date is None else pd.Timestamp(start_date) + ) + self.end_date = ( + _latest_date + pd.Timedelta(days=365 * 2) + if end_date is None + else pd.Timestamp(end_date) + ) @property def calendar_list(self) -> List[pd.Timestamp]: @@ -53,7 +61,9 @@ def _format_datetime(self, datetime_d: [str, pd.Timestamp]): return datetime_d.strftime(self.calendar_format) def write_calendar(self, calendar: Iterable): - calendars_list = [self._format_datetime(x) for x in sorted(set(self.calendar_list + calendar))] + calendars_list = [ + self._format_datetime(x) for x in sorted(set(self.calendar_list + calendar)) + ] np.savetxt(self.future_path, calendars_list, fmt="%s", encoding="utf-8") @abc.abstractmethod @@ -73,7 +83,8 @@ def collector(self) -> Iterable[pd.Timestamp]: if lg.error_code != "0": raise ValueError(f"login respond error_msg: {lg.error_msg}") rs = bs.query_trade_dates( - start_date=self._format_datetime(self.start_date), end_date=self._format_datetime(self.end_date) + start_date=self._format_datetime(self.start_date), + end_date=self._format_datetime(self.end_date), ) if rs.error_code != "0": raise ValueError(f"query_trade_dates respond error_msg: {rs.error_msg}") @@ -82,7 +93,9 @@ def collector(self) -> Iterable[pd.Timestamp]: data_list.append(rs.get_row_data()) calendar = pd.DataFrame(data_list, columns=rs.fields) calendar["is_trading_day"] = calendar["is_trading_day"].astype(int) - return pd.to_datetime(calendar[calendar["is_trading_day"] == 1]["calendar_date"]).to_list() + return pd.to_datetime( + calendar[calendar["is_trading_day"] == 1]["calendar_date"] + ).to_list() class CollectorFutureCalendarUS(CollectorFutureCalendar): @@ -91,7 +104,12 @@ def collector(self) -> Iterable[pd.Timestamp]: raise ValueError("Us calendar is not supported") -def run(qlib_dir: Union[str, Path], region: str = "cn", start_date: str = None, end_date: str = None): +def run( + qlib_dir: Union[str, Path], + region: str = "cn", + start_date: str = None, + end_date: str = None, +): """Collect future calendar(day) Parameters diff --git a/scripts/data_collector/index.py b/scripts/data_collector/index.py index caaaee53af..33ec027c2f 100644 --- a/scripts/data_collector/index.py +++ b/scripts/data_collector/index.py @@ -52,9 +52,13 @@ def __init__( self.index_name = index_name if qlib_dir is None: qlib_dir = Path(__file__).resolve().parent.joinpath("qlib_data") - self.instruments_dir = Path(qlib_dir).expanduser().resolve().joinpath("instruments") + self.instruments_dir = ( + Path(qlib_dir).expanduser().resolve().joinpath("instruments") + ) self.instruments_dir.mkdir(exist_ok=True, parents=True) - self.cache_dir = Path(f"~/.cache/qlib/index/{self.index_name}").expanduser().resolve() + self.cache_dir = ( + Path(f"~/.cache/qlib/index/{self.index_name}").expanduser().resolve() + ) self.cache_dir.mkdir(exist_ok=True, parents=True) self._request_retry = request_retry self._retry_sleep = retry_sleep @@ -143,10 +147,15 @@ def save_new_companies(self): raise ValueError(f"get new companies error: {self.index_name}") df = df.drop_duplicates([self.SYMBOL_FIELD_NAME]) df.loc[:, self.INSTRUMENTS_COLUMNS].to_csv( - self.instruments_dir.joinpath(f"{self.index_name.lower()}_only_new.txt"), sep="\t", index=False, header=None + self.instruments_dir.joinpath(f"{self.index_name.lower()}_only_new.txt"), + sep="\t", + index=False, + header=None, ) - def get_changes_with_history_companies(self, history_companies: pd.DataFrame) -> pd.DataFrame: + def get_changes_with_history_companies( + self, history_companies: pd.DataFrame + ) -> pd.DataFrame: """get changes with history companies Parameters @@ -174,25 +183,47 @@ def get_changes_with_history_companies(self, history_companies: pd.DataFrame) -> logger.info("parse changes from history companies......") last_code = [] result_df_list = [] - _columns = [self.DATE_FIELD_NAME, self.SYMBOL_FIELD_NAME, self.CHANGE_TYPE_FIELD] - for _trading_date in tqdm(sorted(history_companies[self.DATE_FIELD_NAME].unique(), reverse=True)): - _currenet_code = history_companies[history_companies[self.DATE_FIELD_NAME] == _trading_date][ - self.SYMBOL_FIELD_NAME - ].tolist() + _columns = [ + self.DATE_FIELD_NAME, + self.SYMBOL_FIELD_NAME, + self.CHANGE_TYPE_FIELD, + ] + for _trading_date in tqdm( + sorted(history_companies[self.DATE_FIELD_NAME].unique(), reverse=True) + ): + _currenet_code = history_companies[ + history_companies[self.DATE_FIELD_NAME] == _trading_date + ][self.SYMBOL_FIELD_NAME].tolist() if last_code: add_code = list(set(last_code) - set(_currenet_code)) remote_code = list(set(_currenet_code) - set(last_code)) for _code in add_code: result_df_list.append( pd.DataFrame( - [[get_trading_date_by_shift(self.calendar_list, _trading_date, 1), _code, self.ADD]], + [ + [ + get_trading_date_by_shift( + self.calendar_list, _trading_date, 1 + ), + _code, + self.ADD, + ] + ], columns=_columns, ) ) for _code in remote_code: result_df_list.append( pd.DataFrame( - [[get_trading_date_by_shift(self.calendar_list, _trading_date, 0), _code, self.REMOVE]], + [ + [ + get_trading_date_by_shift( + self.calendar_list, _trading_date, 0 + ), + _code, + self.REMOVE, + ] + ], columns=_columns, ) ) @@ -209,30 +240,49 @@ def parse_instruments(self): $ python collector.py parse_instruments --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data """ logger.info(f"start parse {self.index_name.lower()} companies.....") - instruments_columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD] + instruments_columns = [ + self.SYMBOL_FIELD_NAME, + self.START_DATE_FIELD, + self.END_DATE_FIELD, + ] changers_df = self.get_changes() new_df = self.get_new_companies() if new_df is None or new_df.empty: raise ValueError(f"get new companies error: {self.index_name}") new_df = new_df.copy() logger.info("parse history companies by changes......") - for _row in tqdm(changers_df.sort_values(self.DATE_FIELD_NAME, ascending=False).itertuples(index=False)): + for _row in tqdm( + changers_df.sort_values(self.DATE_FIELD_NAME, ascending=False).itertuples( + index=False + ) + ): if _row.type == self.ADD: - min_end_date = new_df.loc[new_df[self.SYMBOL_FIELD_NAME] == _row.symbol, self.END_DATE_FIELD].min() + min_end_date = new_df.loc[ + new_df[self.SYMBOL_FIELD_NAME] == _row.symbol, self.END_DATE_FIELD + ].min() new_df.loc[ - (new_df[self.END_DATE_FIELD] == min_end_date) & (new_df[self.SYMBOL_FIELD_NAME] == _row.symbol), + (new_df[self.END_DATE_FIELD] == min_end_date) + & (new_df[self.SYMBOL_FIELD_NAME] == _row.symbol), self.START_DATE_FIELD, ] = _row.date else: - _tmp_df = pd.DataFrame([[_row.symbol, self.bench_start_date, _row.date]], columns=instruments_columns) + _tmp_df = pd.DataFrame( + [[_row.symbol, self.bench_start_date, _row.date]], + columns=instruments_columns, + ) new_df = pd.concat([new_df, _tmp_df], sort=False) inst_df = new_df.loc[:, instruments_columns] _inst_prefix = self.INST_PREFIX.strip() if _inst_prefix: - inst_df["save_inst"] = inst_df[self.SYMBOL_FIELD_NAME].apply(lambda x: f"{_inst_prefix}{x}") + inst_df["save_inst"] = inst_df[self.SYMBOL_FIELD_NAME].apply( + lambda x: f"{_inst_prefix}{x}" + ) inst_df = self.format_datetime(inst_df) inst_df.to_csv( - self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), sep="\t", index=False, header=None + self.instruments_dir.joinpath(f"{self.index_name.lower()}.txt"), + sep="\t", + index=False, + header=None, ) logger.info(f"parse {self.index_name.lower()} companies finished.") diff --git a/scripts/data_collector/pit/collector.py b/scripts/data_collector/pit/collector.py index c34b31348d..0a8e7785fb 100644 --- a/scripts/data_collector/pit/collector.py +++ b/scripts/data_collector/pit/collector.py @@ -93,14 +93,18 @@ def normalize_symbol(self, symbol: str) -> str: return f"{exchange}{symbol}" @staticmethod - def get_performance_express_report_df(code: str, start_date: str, end_date: str) -> pd.DataFrame: + def get_performance_express_report_df( + code: str, start_date: str, end_date: str + ) -> pd.DataFrame: column_mapping = { "performanceExpPubDate": "date", "performanceExpStatDate": "period", "performanceExpressROEWa": "value", } - resp = bs.query_performance_express_report(code=code, start_date=start_date, end_date=end_date) + resp = bs.query_performance_express_report( + code=code, start_date=start_date, end_date=end_date + ) report_list = [] while (resp.error_code == "0") and resp.next(): report_list.append(resp.get_row_data()) @@ -121,7 +125,11 @@ def get_profit_df(code: str, start_date: str, end_date: str) -> pd.DataFrame: fields = bs.query_profit_data(code="sh.600519", year=2020, quarter=1).fields start_date = datetime.strptime(start_date, "%Y-%m-%d") end_date = datetime.strptime(end_date, "%Y-%m-%d") - args = [(year, quarter) for quarter in range(1, 5) for year in range(start_date.year - 1, end_date.year + 1)] + args = [ + (year, quarter) + for quarter in range(1, 5) + for year in range(start_date.year - 1, end_date.year + 1) + ] profit_list = [] for year, quarter in args: resp = bs.query_profit_data(code=code, year=year, quarter=quarter) @@ -143,23 +151,31 @@ def get_profit_df(code: str, start_date: str, end_date: str) -> pd.DataFrame: return profit_df @staticmethod - def get_forecast_report_df(code: str, start_date: str, end_date: str) -> pd.DataFrame: + def get_forecast_report_df( + code: str, start_date: str, end_date: str + ) -> pd.DataFrame: column_mapping = { "profitForcastExpPubDate": "date", "profitForcastExpStatDate": "period", "value": "value", } - resp = bs.query_forecast_report(code=code, start_date=start_date, end_date=end_date) + resp = bs.query_forecast_report( + code=code, start_date=start_date, end_date=end_date + ) forecast_list = [] while (resp.error_code == "0") and resp.next(): forecast_list.append(resp.get_row_data()) forecast_df = pd.DataFrame(forecast_list, columns=resp.fields) numeric_fields = ["profitForcastChgPctUp", "profitForcastChgPctDwn"] try: - forecast_df[numeric_fields] = forecast_df[numeric_fields].apply(pd.to_numeric, errors="ignore") + forecast_df[numeric_fields] = forecast_df[numeric_fields].apply( + pd.to_numeric, errors="ignore" + ) except KeyError: return pd.DataFrame() - forecast_df["value"] = (forecast_df["profitForcastChgPctUp"] + forecast_df["profitForcastChgPctDwn"]) / 200 + forecast_df["value"] = ( + forecast_df["profitForcastChgPctUp"] + forecast_df["profitForcastChgPctDwn"] + ) / 200 forecast_df = forecast_df[list(column_mapping.keys())] forecast_df.rename(columns=column_mapping, inplace=True) forecast_df["field"] = "YOYNI" @@ -171,7 +187,11 @@ def get_growth_df(code: str, start_date: str, end_date: str) -> pd.DataFrame: fields = bs.query_growth_data(code="sh.600519", year=2020, quarter=1).fields start_date = datetime.strptime(start_date, "%Y-%m-%d") end_date = datetime.strptime(end_date, "%Y-%m-%d") - args = [(year, quarter) for quarter in range(1, 5) for year in range(start_date.year - 1, end_date.year + 1)] + args = [ + (year, quarter) + for quarter in range(1, 5) + for year in range(start_date.year - 1, end_date.year + 1) + ] growth_list = [] for year, quarter in args: resp = bs.query_growth_data(code=code, year=year, quarter=quarter) @@ -207,7 +227,9 @@ def get_data( start_date = start_datetime.strftime("%Y-%m-%d") end_date = end_datetime.strftime("%Y-%m-%d") - performance_express_report_df = self.get_performance_express_report_df(code, start_date, end_date) + performance_express_report_df = self.get_performance_express_report_df( + code, start_date, end_date + ) profit_df = self.get_profit_df(code, start_date, end_date) forecast_report_df = self.get_forecast_report_df(code, start_date, end_date) growth_df = self.get_growth_df(code, start_date, end_date) @@ -227,14 +249,23 @@ def __init__(self, interval: str = "quarterly", *args, **kwargs): def normalize(self, df: pd.DataFrame) -> pd.DataFrame: dt = df["period"].apply( lambda x: ( - pd.to_datetime(x) + pd.DateOffset(days=(45 if self.interval == PitCollector.INTERVAL_QUARTERLY else 90)) + pd.to_datetime(x) + + pd.DateOffset( + days=( + 45 if self.interval == PitCollector.INTERVAL_QUARTERLY else 90 + ) + ) ).date() ) df["date"] = df["date"].fillna(dt.astype(str)) df["period"] = pd.to_datetime(df["period"]) df["period"] = df["period"].apply( - lambda x: x.year if self.interval == PitCollector.INTERVAL_ANNUAL else x.year * 100 + (x.month - 1) // 3 + 1 + lambda x: ( + x.year + if self.interval == PitCollector.INTERVAL_ANNUAL + else x.year * 100 + (x.month - 1) // 3 + 1 + ) ) return df diff --git a/scripts/data_collector/us_index/collector.py b/scripts/data_collector/us_index/collector.py index 50278d11ee..a4d53ffc2a 100644 --- a/scripts/data_collector/us_index/collector.py +++ b/scripts/data_collector/us_index/collector.py @@ -19,7 +19,11 @@ sys.path.append(str(CUR_DIR.parent.parent)) from data_collector.index import IndexBase -from data_collector.utils import deco_retry, get_calendar_list, get_trading_date_by_shift +from data_collector.utils import ( + deco_retry, + get_calendar_list, + get_trading_date_by_shift, +) from data_collector.utils import get_instruments @@ -47,7 +51,11 @@ def __init__( retry_sleep: int = 3, ): super(WIKIIndex, self).__init__( - index_name=index_name, qlib_dir=qlib_dir, freq=freq, request_retry=request_retry, retry_sleep=retry_sleep + index_name=index_name, + qlib_dir=qlib_dir, + freq=freq, + request_retry=request_retry, + retry_sleep=retry_sleep, ) self._target_url = f"{WIKI_URL}/{WIKI_INDEX_NAME_MAP[self.index_name.upper()]}" @@ -93,7 +101,9 @@ def format_datetime(self, inst_df: pd.DataFrame) -> pd.DataFrame: """ if self.freq != "day": inst_df[self.END_DATE_FIELD] = inst_df[self.END_DATE_FIELD].apply( - lambda x: (pd.Timestamp(x) + pd.Timedelta(hours=23, minutes=59)).strftime("%Y-%m-%d %H:%M:%S") + lambda x: ( + pd.Timestamp(x) + pd.Timedelta(hours=23, minutes=59) + ).strftime("%Y-%m-%d %H:%M:%S") ) return inst_df @@ -107,7 +117,11 @@ def calendar_list(self) -> List[pd.Timestamp]: """ _calendar_list = getattr(self, "_calendar_list", None) if _calendar_list is None: - _calendar_list = list(filter(lambda x: x >= self.bench_start_date, get_calendar_list("US_ALL"))) + _calendar_list = list( + filter( + lambda x: x >= self.bench_start_date, get_calendar_list("US_ALL") + ) + ) setattr(self, "_calendar_list", _calendar_list) return _calendar_list @@ -127,7 +141,9 @@ def set_default_date_range(self, df: pd.DataFrame) -> pd.DataFrame: def get_new_companies(self): logger.info(f"get new companies {self.index_name} ......") - _data = deco_retry(retry=self._request_retry, retry_sleep=self._retry_sleep)(self._request_new_companies)() + _data = deco_retry(retry=self._request_retry, retry_sleep=self._retry_sleep)( + self._request_new_companies + )() df_list = pd.read_html(_data.text) for _df in df_list: _df = self.filter_df(_df) @@ -142,9 +158,7 @@ def filter_df(self, df: pd.DataFrame) -> pd.DataFrame: class NASDAQ100Index(WIKIIndex): - HISTORY_COMPANIES_URL = ( - "https://indexes.nasdaqomx.com/Index/WeightingData?id=NDX&tradeDate={trade_date}T00%3A00%3A00.000&timeOfDay=SOD" - ) + HISTORY_COMPANIES_URL = "https://indexes.nasdaqomx.com/Index/WeightingData?id=NDX&tradeDate={trade_date}T00%3A00%3A00.000&timeOfDay=SOD" MAX_WORKERS = 16 def filter_df(self, df: pd.DataFrame) -> pd.DataFrame: @@ -156,7 +170,9 @@ def bench_start_date(self) -> pd.Timestamp: return pd.Timestamp("2003-01-02") @deco_retry - def _request_history_companies(self, trade_date: pd.Timestamp, use_cache: bool = True) -> pd.DataFrame: + def _request_history_companies( + self, trade_date: pd.Timestamp, use_cache: bool = True + ) -> pd.DataFrame: trade_date = trade_date.strftime("%Y-%m-%d") cache_path = self.cache_dir.joinpath(f"{trade_date}_history_companies.pkl") if cache_path.exists() and use_cache: @@ -168,7 +184,9 @@ def _request_history_companies(self, trade_date: pd.Timestamp, use_cache: bool = raise ValueError(f"request error: {url}") df = pd.DataFrame(resp.json()["aaData"]) df[self.DATE_FIELD_NAME] = trade_date - df.rename(columns={"Name": "name", "Symbol": self.SYMBOL_FIELD_NAME}, inplace=True) + df.rename( + columns={"Name": "name", "Symbol": self.SYMBOL_FIELD_NAME}, inplace=True + ) if not df.empty: df.to_pickle(cache_path) return df @@ -180,7 +198,8 @@ def get_history_companies(self): with tqdm(total=len(self.calendar_list)) as p_bar: with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor: for _trading_date, _df in zip( - self.calendar_list, executor.map(self._request_history_companies, self.calendar_list) + self.calendar_list, + executor.map(self._request_history_companies, self.calendar_list), ): if _df.empty: error_list.append(_trading_date) @@ -229,7 +248,9 @@ def get_changes(self) -> pd.DataFrame: changes_df = pd.read_html(self.WIKISP500_CHANGES_URL)[-1] changes_df = changes_df.iloc[:, [0, 1, 3]] changes_df.columns = [self.DATE_FIELD_NAME, self.ADD, self.REMOVE] - changes_df[self.DATE_FIELD_NAME] = pd.to_datetime(changes_df[self.DATE_FIELD_NAME]) + changes_df[self.DATE_FIELD_NAME] = pd.to_datetime( + changes_df[self.DATE_FIELD_NAME] + ) _result = [] for _type in [self.ADD, self.REMOVE]: _df = changes_df.copy() @@ -244,7 +265,15 @@ def get_changes(self) -> pd.DataFrame: _df[self.DATE_FIELD_NAME] = _df[self.DATE_FIELD_NAME].apply( lambda x: get_trading_date_by_shift(self.calendar_list, x, -1) ) - _result.append(_df[[self.DATE_FIELD_NAME, self.CHANGE_TYPE_FIELD, self.SYMBOL_FIELD_NAME]]) + _result.append( + _df[ + [ + self.DATE_FIELD_NAME, + self.CHANGE_TYPE_FIELD, + self.SYMBOL_FIELD_NAME, + ] + ] + ) logger.info(f"end of get sp500 history changes.") return pd.concat(_result, sort=False) diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index f25b1ec7a2..4b1342b776 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -73,11 +73,26 @@ def _get_calendar(url): calendar = _CALENDAR_MAP.get(bench_code, None) if calendar is None: - if bench_code.startswith("US_") or bench_code.startswith("IN_") or bench_code.startswith("BR_"): + if ( + bench_code.startswith("US_") + or bench_code.startswith("IN_") + or bench_code.startswith("BR_") + ): print(Ticker(CALENDAR_BENCH_URL_MAP[bench_code])) - print(Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max")) - df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history(interval="1d", period="max") - calendar = df.index.get_level_values(level="date").map(pd.Timestamp).unique().tolist() + print( + Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history( + interval="1d", period="max" + ) + ) + df = Ticker(CALENDAR_BENCH_URL_MAP[bench_code]).history( + interval="1d", period="max" + ) + calendar = ( + df.index.get_level_values(level="date") + .map(pd.Timestamp) + .unique() + .tolist() + ) else: if bench_code.upper() == "ALL": @@ -86,7 +101,8 @@ def _get_calendar(month): _cal = [] try: resp = requests.get( - SZSE_CALENDAR_URL.format(month=month, random=random.random), timeout=None + SZSE_CALENDAR_URL.format(month=month, random=random.random), + timeout=None, ).json() for _r in resp["data"]: if int(_r["jybz"]): @@ -95,7 +111,11 @@ def _get_calendar(month): raise ValueError(f"{month}-->{e}") from e return _cal - month_range = pd.date_range(start="2000-01", end=pd.Timestamp.now() + pd.Timedelta(days=31), freq="M") + month_range = pd.date_range( + start="2000-01", + end=pd.Timestamp.now() + pd.Timedelta(days=31), + freq="M", + ) calendar = [] for _m in month_range: cal = _get_calendar(_m.strftime("%Y-%m")) @@ -165,7 +185,9 @@ def get_calendar_list_by_ratio( p_bar.update() logger.info(f"count how many funds have founded in this day......") - _dict_count_founding = {date: _number_all_funds for date in _dict_count_trade} # dict{date:count} + _dict_count_founding = { + date: _number_all_funds for date in _dict_count_trade + } # dict{date:count} with tqdm(total=_number_all_funds) as p_bar: for oldest_date in all_oldest_list: for date in _dict_count_founding.keys(): @@ -173,7 +195,9 @@ def get_calendar_list_by_ratio( _dict_count_founding[date] -= 1 calendar = [ - date for date, count in _dict_count_trade.items() if count >= max(int(count * threshold), minimum_count) + date + for date, count in _dict_count_trade.items() + if count >= max(int(count * threshold), minimum_count) ] return calendar @@ -225,14 +249,21 @@ def _get_symbol(): data = resp.json() # Check if response contains valid data - if not data or "data" not in data or not data["data"] or "diff" not in data["data"]: + if ( + not data + or "data" not in data + or not data["data"] + or "diff" not in data["data"] + ): logger.warning(f"Invalid response structure on page {page}") break # fetch the current page data current_symbols = [_v["f12"] for _v in data["data"]["diff"]] - if not current_symbols: # It's the last page if there is no data in current page + if ( + not current_symbols + ): # It's the last page if there is no data in current page logger.info(f"Last page reached: {page - 1}") break @@ -253,7 +284,9 @@ def _get_symbol(): f"Request to {base_url} failed with status code {resp.status_code}" ) from e except Exception as e: - logger.warning("An error occurred while extracting data from the response.") + logger.warning( + "An error occurred while extracting data from the response." + ) raise if len(_symbols) < 3900: @@ -261,7 +294,11 @@ def _get_symbol(): # Add suffix after the stock code to conform to yahooquery standard, otherwise the data will not be fetched. _symbols = [ - _symbol + ".ss" if _symbol.startswith("6") else _symbol + ".sz" if _symbol.startswith(("0", "3")) else None + ( + _symbol + ".ss" + if _symbol.startswith("6") + else _symbol + ".sz" if _symbol.startswith(("0", "3")) else None + ) for _symbol in _symbols ] _symbols = [_symbol for _symbol in _symbols if _symbol is not None] @@ -307,7 +344,10 @@ def _get_eastmoney(): raise ValueError("request error") try: - _symbols = [_v["f12"].replace("_", "-P") for _v in resp.json()["data"]["diff"].values()] + _symbols = [ + _v["f12"].replace("_", "-P") + for _v in resp.json()["data"]["diff"].values() + ] except Exception as e: logger.warning(f"request error: {e}") raise @@ -372,7 +412,14 @@ def _format(s_): s_ = s_.strip("*") return s_ - _US_SYMBOLS = sorted(set(map(_format, filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols)))) + _US_SYMBOLS = sorted( + set( + map( + _format, + filter(lambda x: len(x) < 8 and not x.endswith("WS"), _all_symbols), + ) + ) + ) return _US_SYMBOLS @@ -442,7 +489,9 @@ def _get_ibovespa(): children = tbody.findChildren("a", recursive=True) for child in children: - _symbols.append(str(child).rsplit('"', maxsplit=1)[-1].split(">")[1].split("<")[0]) + _symbols.append( + str(child).rsplit('"', maxsplit=1)[-1].split(">")[1].split("<")[0] + ) return _symbols @@ -486,7 +535,10 @@ def _get_eastmoney(): raise ValueError("request error") try: _symbols = [] - for sub_data in re.findall(r"[\[](.*?)[\]]", resp.content.decode().split("= [")[-1].replace("];", "")): + for sub_data in re.findall( + r"[\[](.*?)[\]]", + resp.content.decode().split("= [")[-1].replace("];", ""), + ): data = sub_data.replace('"', "").replace("'", "") # TODO: do we need other information, like fund_name from ['000001', 'HXCZHH', '华夏成长混合', '混合型', 'HUAXIACHENGZHANGHUNHE'] _symbols.append(data.split(",")[0]) @@ -567,7 +619,9 @@ def wrapper(*args, **kwargs): return deco_func(retry) if callable(retry) else deco_func -def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, shift: int = 1): +def get_trading_date_by_shift( + trading_list: list, trading_date: pd.Timestamp, shift: int = 1 +): """get trading date by shift Parameters @@ -665,17 +719,28 @@ def get_instruments( $ python collector.py --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies """ - _cur_module = importlib.import_module("data_collector.{}.collector".format(market_index)) + _cur_module = importlib.import_module( + "data_collector.{}.collector".format(market_index) + ) obj = getattr(_cur_module, f"{index_name.upper()}Index")( - qlib_dir=qlib_dir, index_name=index_name, freq=freq, request_retry=request_retry, retry_sleep=retry_sleep + qlib_dir=qlib_dir, + index_name=index_name, + freq=freq, + request_retry=request_retry, + retry_sleep=retry_sleep, ) getattr(obj, method)() -def _get_all_1d_data(_date_field_name: str, _symbol_field_name: str, _1d_data_all: pd.DataFrame): +def _get_all_1d_data( + _date_field_name: str, _symbol_field_name: str, _1d_data_all: pd.DataFrame +): df = copy.deepcopy(_1d_data_all) df.reset_index(inplace=True) - df.rename(columns={"datetime": _date_field_name, "instrument": _symbol_field_name}, inplace=True) + df.rename( + columns={"datetime": _date_field_name, "instrument": _symbol_field_name}, + inplace=True, + ) df.columns = list(map(lambda x: x[1:] if x.startswith("$") else x, df.columns)) return df @@ -738,8 +803,12 @@ def calc_adjusted_price( df[_date_field_name] = pd.to_datetime(df[_date_field_name]) # get 1d data from qlib _start = pd.Timestamp(df[_date_field_name].min()).strftime("%Y-%m-%d") - _end = (pd.Timestamp(df[_date_field_name].max()) + pd.Timedelta(days=1)).strftime("%Y-%m-%d") - data_1d: pd.DataFrame = get_1d_data(_date_field_name, _symbol_field_name, symbol, _start, _end, _1d_data_all) + _end = (pd.Timestamp(df[_date_field_name].max()) + pd.Timedelta(days=1)).strftime( + "%Y-%m-%d" + ) + data_1d: pd.DataFrame = get_1d_data( + _date_field_name, _symbol_field_name, symbol, _start, _end, _1d_data_all + ) data_1d = data_1d.copy() if data_1d is None or data_1d.empty: df["factor"] = 1 / df.loc[df["close"].first_valid_index()]["close"] @@ -759,27 +828,38 @@ def calc_adjusted_price( # - data_1d.close: `data_1d.adjclose / (close for the first trading day that is not np.nan)` def _calc_factor(df_1d: pd.DataFrame): try: - _date = pd.Timestamp(pd.Timestamp(df_1d[_date_field_name].iloc[0]).date()) - df_1d["factor"] = data_1d.loc[_date]["close"] / df_1d.loc[df_1d["close"].last_valid_index()]["close"] + _date = pd.Timestamp( + pd.Timestamp(df_1d[_date_field_name].iloc[0]).date() + ) + df_1d["factor"] = ( + data_1d.loc[_date]["close"] + / df_1d.loc[df_1d["close"].last_valid_index()]["close"] + ) df_1d["paused"] = data_1d.loc[_date]["paused"] except Exception: df_1d["factor"] = np.nan df_1d["paused"] = np.nan return df_1d - df = df.groupby([df[_date_field_name].dt.date], group_keys=False).apply(_calc_factor) + df = df.groupby([df[_date_field_name].dt.date], group_keys=False).apply( + _calc_factor + ) if consistent_1d: # the date sequence is consistent with 1d df.set_index(_date_field_name, inplace=True) df = df.reindex( generate_minutes_calendar_from_daily( - calendars=pd.to_datetime(data_1d.reset_index()[_date_field_name].drop_duplicates()), + calendars=pd.to_datetime( + data_1d.reset_index()[_date_field_name].drop_duplicates() + ), freq=frequence, am_range=("09:30:00", "11:29:00"), pm_range=("13:00:00", "14:59:00"), ) ) - df[_symbol_field_name] = df.loc[df[_symbol_field_name].first_valid_index()][_symbol_field_name] + df[_symbol_field_name] = df.loc[df[_symbol_field_name].first_valid_index()][ + _symbol_field_name + ] df.index.names = [_date_field_name] df.reset_index(inplace=True) for _col in ["open", "close", "high", "low", "volume"]: @@ -821,7 +901,10 @@ def calc_paused_num(df: pd.DataFrame, _date_field_name, _symbol_field_name): _date_field_name, _symbol_field_name, } - if _df.loc[:, list(check_fields)].isna().values.all() or (_df["volume"] == 0).all(): + if ( + _df.loc[:, list(check_fields)].isna().values.all() + or (_df["volume"] == 0).all() + ): all_nan_nums += 1 not_nan_nums = 0 _df["paused"] = 1 diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index a1b4d64f65..bfd2aa8298 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -99,7 +99,9 @@ def __init__( def init_datetime(self): if self.interval == self.INTERVAL_1min: - self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN) + self.start_datetime = max( + self.start_datetime, self.DEFAULT_START_DATETIME_1MIN + ) elif self.interval == self.INTERVAL_1d: pass else: @@ -123,7 +125,9 @@ def _timezone(self): raise NotImplementedError("rewrite get_timezone") @staticmethod - def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False): + def get_data_from_remote( + symbol, interval, start, end, show_1min_logging: bool = False + ): error_msg = f"{symbol}-{interval}-{start}-{end}" def _show_logging_func(): @@ -132,13 +136,16 @@ def _show_logging_func(): interval = "1m" if interval in ["1m", "1min"] else interval try: - _resp = Ticker(symbol, asynchronous=False).history(interval=interval, start=start, end=end) + _resp = Ticker(symbol, asynchronous=False).history( + interval=interval, start=start, end=end + ) if isinstance(_resp, pd.DataFrame): return _resp.reset_index() elif isinstance(_resp, dict): _temp_data = _resp.get(symbol, {}) if isinstance(_temp_data, str) or ( - isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None + isinstance(_resp, dict) + and _temp_data.get("indicators", {}).get("quote", None) is None ): _show_logging_func() else: @@ -150,7 +157,11 @@ def _show_logging_func(): ) def get_data( - self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp + self, + symbol: str, + interval: str, + start_datetime: pd.Timestamp, + end_datetime: pd.Timestamp, ) -> pd.DataFrame: @deco_retry(retry_sleep=self.delay, retry=self.retry) def _get_simple(start_, end_): @@ -164,7 +175,8 @@ def _get_simple(start_, end_): ) if resp is None or resp.empty: raise ValueError( - f"get data error: {symbol}--{start_}--{end_}" + "The stock may be delisted, please check" + f"get data error: {symbol}--{start_}--{end_}" + + "The stock may be delisted, please check" ) return resp @@ -225,21 +237,37 @@ def download_index_data(self): _format = "%Y%m%d" _begin = self.start_datetime.strftime(_format) _end = self.end_datetime.strftime(_format) - for _index_name, _index_code in {"csi300": "000300", "csi100": "000903", "csi500": "000905"}.items(): + for _index_name, _index_code in { + "csi300": "000300", + "csi100": "000903", + "csi500": "000905", + }.items(): logger.info(f"get bench data: {_index_name}({_index_code})......") try: df = pd.DataFrame( map( lambda x: x.split(","), requests.get( - INDEX_BENCH_URL.format(index_code=_index_code, begin=_begin, end=_end), timeout=None + INDEX_BENCH_URL.format( + index_code=_index_code, begin=_begin, end=_end + ), + timeout=None, ).json()["data"]["klines"], ) ) except Exception as e: logger.warning(f"get {_index_name} error: {e}") continue - df.columns = ["date", "open", "close", "high", "low", "volume", "money", "change"] + df.columns = [ + "date", + "open", + "close", + "high", + "low", + "volume", + "money", + "change", + ] df["date"] = pd.to_datetime(df["date"]) df = df.astype(float, errors="ignore") df["adjclose"] = df["close"] @@ -399,13 +427,18 @@ def normalize_yahoo( df = df.reindex( pd.DataFrame(index=calendar_list) .loc[ - pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date() + pd.Timestamp(df.index.min()) + .date() : pd.Timestamp(df.index.max()) + .date() + pd.Timedelta(hours=23, minutes=59) ] .index ) df.sort_index(inplace=True) - df.loc[(df["volume"] <= 0) | np.isnan(df["volume"]), list(set(df.columns) - {symbol_field_name})] = np.nan + df.loc[ + (df["volume"] <= 0) | np.isnan(df["volume"]), + list(set(df.columns) - {symbol_field_name}), + ] = np.nan change_series = YahooNormalize.calc_change(df, last_close) # NOTE: The data obtained by Yahoo finance sometimes has exceptions @@ -438,7 +471,9 @@ def normalize_yahoo( def normalize(self, df: pd.DataFrame) -> pd.DataFrame: # normalize - df = self.normalize_yahoo(df, self._calendar_list, self._date_field_name, self._symbol_field_name) + df = self.normalize_yahoo( + df, self._calendar_list, self._date_field_name, self._symbol_field_name + ) # adjusted price df = self.adjusted_price(df) return df @@ -509,7 +544,11 @@ def _manual_adj_data(self, df: pd.DataFrame) -> pd.DataFrame: class YahooNormalize1dExtend(YahooNormalize1d): def __init__( - self, old_qlib_data_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs + self, + old_qlib_data_dir: [str, Path], + date_field_name: str = "date", + symbol_field_name: str = "symbol", + **kwargs, ): """ @@ -523,7 +562,15 @@ def __init__( symbol field name, default is symbol """ super(YahooNormalize1dExtend, self).__init__(date_field_name, symbol_field_name) - self.column_list = ["open", "high", "low", "close", "volume", "factor", "change"] + self.column_list = [ + "open", + "high", + "low", + "close", + "volume", + "factor", + "change", + ] self.old_qlib_data = self._get_old_data(old_qlib_data_dir) def _get_old_data(self, qlib_data_dir: [str, Path]): @@ -537,7 +584,9 @@ def normalize(self, df: pd.DataFrame) -> pd.DataFrame: df = super(YahooNormalize1dExtend, self).normalize(df) df.set_index(self._date_field_name, inplace=True) symbol_name = df[self._symbol_field_name].iloc[0] - old_symbol_list = self.old_qlib_data.index.get_level_values("instrument").unique().to_list() + old_symbol_list = ( + self.old_qlib_data.index.get_level_values("instrument").unique().to_list() + ) if str(symbol_name).upper() not in old_symbol_list: return df.reset_index() old_df = self.old_qlib_data.loc[str(symbol_name).upper()] @@ -564,7 +613,11 @@ class YahooNormalize1min(YahooNormalize, ABC): CALC_PAUSED_NUM = True def __init__( - self, qlib_data_1d_dir: [str, Path], date_field_name: str = "date", symbol_field_name: str = "symbol", **kwargs + self, + qlib_data_1d_dir: [str, Path], + date_field_name: str = "date", + symbol_field_name: str = "symbol", + **kwargs, ): """ @@ -579,7 +632,11 @@ def __init__( """ super(YahooNormalize1min, self).__init__(date_field_name, symbol_field_name) qlib.init(provider_uri=qlib_data_1d_dir) - self.all_1d_data = D.features(D.instruments("all"), ["$paused", "$volume", "$factor", "$close"], freq="day") + self.all_1d_data = D.features( + D.instruments("all"), + ["$paused", "$volume", "$factor", "$close"], + freq="day", + ) def _get_1d_calendar_list(self) -> Iterable[pd.Timestamp]: return list(D.calendar(freq="day")) @@ -693,7 +750,11 @@ def _get_calendar_list(self) -> Iterable[pd.Timestamp]: def symbol_to_yahoo(self, symbol): if "." not in symbol: _exchange = symbol[:2] - _exchange = ("ss" if _exchange.islower() else "SS") if _exchange.lower() == "sh" else _exchange + _exchange = ( + ("ss" if _exchange.islower() else "SS") + if _exchange.lower() == "sh" + else _exchange + ) symbol = symbol[2:] + "." + _exchange return symbol @@ -725,7 +786,14 @@ def symbol_to_yahoo(self, symbol): class Run(BaseRun): - def __init__(self, source_dir=None, normalize_dir=None, max_workers=1, interval="1d", region=REGION_CN): + def __init__( + self, + source_dir=None, + normalize_dir=None, + max_workers=1, + interval="1d", + region=REGION_CN, + ): """ Parameters @@ -796,10 +864,14 @@ def download_data( # get 1m data $ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m """ - if self.interval == "1d" and pd.Timestamp(end) > pd.Timestamp(datetime.datetime.now().strftime("%Y-%m-%d")): + if self.interval == "1d" and pd.Timestamp(end) > pd.Timestamp( + datetime.datetime.now().strftime("%Y-%m-%d") + ): raise ValueError(f"end_date: {end} is greater than the current date.") - super(Run, self).download_data(max_collector_count, delay, start, end, check_data_length, limit_nums) + super(Run, self).download_data( + max_collector_count, delay, start, end, check_data_length, limit_nums + ) def normalize_data( self, @@ -833,16 +905,25 @@ def normalize_data( $ python collector.py normalize_data --qlib_data_1d_dir ~/.qlib/qlib_data/cn_data --source_dir ~/.qlib/stock_data/source_cn_1min --normalize_dir ~/.qlib/stock_data/normalize_cn_1min --region CN --interval 1min """ if self.interval.lower() == "1min": - if qlib_data_1d_dir is None or not Path(qlib_data_1d_dir).expanduser().exists(): + if ( + qlib_data_1d_dir is None + or not Path(qlib_data_1d_dir).expanduser().exists() + ): raise ValueError( "If normalize 1min, the qlib_data_1d_dir parameter must be set: --qlib_data_1d_dir , Reference: https://github.com/microsoft/qlib/tree/main/scripts/data_collector/yahoo#automatic-update-of-daily-frequency-datafrom-yahoo-finance" ) super(Run, self).normalize_data( - date_field_name, symbol_field_name, end_date=end_date, qlib_data_1d_dir=qlib_data_1d_dir + date_field_name, + symbol_field_name, + end_date=end_date, + qlib_data_1d_dir=qlib_data_1d_dir, ) def normalize_data_1d_extend( - self, old_qlib_data_dir, date_field_name: str = "date", symbol_field_name: str = "symbol" + self, + old_qlib_data_dir, + date_field_name: str = "date", + symbol_field_name: str = "symbol", ): """normalize data extend; extending yahoo qlib data(from: https://github.com/microsoft/qlib/tree/main/scripts#download-cn-data) @@ -973,19 +1054,31 @@ def update_data_to_bin( qlib_data_1d_dir = str(Path(qlib_data_1d_dir).expanduser().resolve()) if not exists_qlib_data(qlib_data_1d_dir): GetData().qlib_data( - target_dir=qlib_data_1d_dir, interval=self.interval, region=self.region, exists_skip=exists_skip + target_dir=qlib_data_1d_dir, + interval=self.interval, + region=self.region, + exists_skip=exists_skip, ) # start/end date calendar_df = pd.read_csv(Path(qlib_data_1d_dir).joinpath("calendars/day.txt")) - trading_date = (pd.Timestamp(calendar_df.iloc[-1, 0]) - pd.Timedelta(days=1)).strftime("%Y-%m-%d") + trading_date = ( + pd.Timestamp(calendar_df.iloc[-1, 0]) - pd.Timedelta(days=1) + ).strftime("%Y-%m-%d") if end_date is None: - end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime("%Y-%m-%d") + end_date = (pd.Timestamp(trading_date) + pd.Timedelta(days=1)).strftime( + "%Y-%m-%d" + ) # download data from yahoo # NOTE: when downloading data from YahooFinance, max_workers is recommended to be 1 - self.download_data(delay=delay, start=trading_date, end=end_date, check_data_length=check_data_length) + self.download_data( + delay=delay, + start=trading_date, + end=end_date, + check_data_length=check_data_length, + ) # NOTE: a larger max_workers setting here would be faster self.max_workers = ( max(multiprocessing.cpu_count() - 2, 1) @@ -1007,14 +1100,23 @@ def update_data_to_bin( # parse index _region = self.region.lower() if _region not in ["cn", "us"]: - logger.warning(f"Unsupported region: region={_region}, component downloads will be ignored") + logger.warning( + f"Unsupported region: region={_region}, component downloads will be ignored" + ) return - index_list = ["CSI100", "CSI300"] if _region == "cn" else ["SP500", "NASDAQ100", "DJIA", "SP400"] + index_list = ( + ["CSI100", "CSI300"] + if _region == "cn" + else ["SP500", "NASDAQ100", "DJIA", "SP400"] + ) get_instruments = getattr( - importlib.import_module(f"data_collector.{_region}_index.collector"), "get_instruments" + importlib.import_module(f"data_collector.{_region}_index.collector"), + "get_instruments", ) for _index in index_list: - get_instruments(str(qlib_data_1d_dir), _index, market_index=f"{_region}_index") + get_instruments( + str(qlib_data_1d_dir), _index, market_index=f"{_region}_index" + ) if __name__ == "__main__": diff --git a/scripts/dump_bin.py b/scripts/dump_bin.py index cb8ed72dab..df28843f60 100644 --- a/scripts/dump_bin.py +++ b/scripts/dump_bin.py @@ -111,20 +111,32 @@ def __init__( exclude_fields = exclude_fields.split(",") if isinstance(include_fields, str): include_fields = include_fields.split(",") - self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields))) - self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields))) + self._exclude_fields = tuple( + filter(lambda x: len(x) > 0, map(str.strip, exclude_fields)) + ) + self._include_fields = tuple( + filter(lambda x: len(x) > 0, map(str.strip, include_fields)) + ) self.file_suffix = file_suffix self.symbol_field_name = symbol_field_name - self.df_files = sorted(data_path.glob(f"*{self.file_suffix}") if data_path.is_dir() else [data_path]) + self.df_files = sorted( + data_path.glob(f"*{self.file_suffix}") + if data_path.is_dir() + else [data_path] + ) if limit_nums is not None: self.df_files = self.df_files[: int(limit_nums)] self.qlib_dir = Path(qlib_dir).expanduser() - self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser() + self.backup_dir = ( + backup_dir if backup_dir is None else Path(backup_dir).expanduser() + ) if backup_dir is not None: self._backup_qlib_dir(Path(backup_dir).expanduser()) self.freq = freq - self.calendar_format = self.DAILY_FORMAT if self.freq == "day" else self.HIGH_FREQ_FORMAT + self.calendar_format = ( + self.DAILY_FORMAT if self.freq == "day" else self.HIGH_FREQ_FORMAT + ) self.works = max_workers self.date_field_name = date_field_name @@ -146,7 +158,11 @@ def _format_datetime(self, datetime_d: [str, pd.Timestamp]): return datetime_d.strftime(self.calendar_format) def _get_date( - self, file_or_df: [Path, pd.DataFrame], *, is_begin_end: bool = False, as_set: bool = False + self, + file_or_df: [Path, pd.DataFrame], + *, + is_begin_end: bool = False, + as_set: bool = False, ) -> Iterable[pd.Timestamp]: if not isinstance(file_or_df, pd.DataFrame): df = self._get_source_data(file_or_df) @@ -180,7 +196,11 @@ def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]: return ( self._include_fields if self._include_fields - else set(df_columns) - set(self._exclude_fields) if self._exclude_fields else df_columns + else ( + set(df_columns) - set(self._exclude_fields) + if self._exclude_fields + else df_columns + ) ) @staticmethod @@ -207,27 +227,41 @@ def _read_instruments(self, instrument_path: Path) -> pd.DataFrame: def save_calendars(self, calendars_data: list): self._calendars_dir.mkdir(parents=True, exist_ok=True) - calendars_path = str(self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve()) + calendars_path = str( + self._calendars_dir.joinpath(f"{self.freq}.txt").expanduser().resolve() + ) result_calendars_list = [self._format_datetime(x) for x in calendars_data] np.savetxt(calendars_path, result_calendars_list, fmt="%s", encoding="utf-8") def save_instruments(self, instruments_data: Union[list, pd.DataFrame]): self._instruments_dir.mkdir(parents=True, exist_ok=True) - instruments_path = str(self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME).resolve()) + instruments_path = str( + self._instruments_dir.joinpath(self.INSTRUMENTS_FILE_NAME).resolve() + ) if isinstance(instruments_data, pd.DataFrame): - _df_fields = [self.symbol_field_name, self.INSTRUMENTS_START_FIELD, self.INSTRUMENTS_END_FIELD] + _df_fields = [ + self.symbol_field_name, + self.INSTRUMENTS_START_FIELD, + self.INSTRUMENTS_END_FIELD, + ] instruments_data = instruments_data.loc[:, _df_fields] - instruments_data[self.symbol_field_name] = instruments_data[self.symbol_field_name].apply( - lambda x: fname_to_code(x.lower()).upper() + instruments_data[self.symbol_field_name] = instruments_data[ + self.symbol_field_name + ].apply(lambda x: fname_to_code(x.lower()).upper()) + instruments_data.to_csv( + instruments_path, header=False, sep=self.INSTRUMENTS_SEP, index=False ) - instruments_data.to_csv(instruments_path, header=False, sep=self.INSTRUMENTS_SEP, index=False) else: np.savetxt(instruments_path, instruments_data, fmt="%s", encoding="utf-8") - def data_merge_calendar(self, df: pd.DataFrame, calendars_list: List[pd.Timestamp]) -> pd.DataFrame: + def data_merge_calendar( + self, df: pd.DataFrame, calendars_list: List[pd.Timestamp] + ) -> pd.DataFrame: # calendars calendars_df = pd.DataFrame(data=calendars_list, columns=[self.date_field_name]) - calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype("datetime64[ns]") + calendars_df[self.date_field_name] = calendars_df[self.date_field_name].astype( + "datetime64[ns]" + ) cal_df = calendars_df[ (calendars_df[self.date_field_name] >= df[self.date_field_name].min()) & (calendars_df[self.date_field_name] <= df[self.date_field_name].max()) @@ -242,7 +276,9 @@ def data_merge_calendar(self, df: pd.DataFrame, calendars_list: List[pd.Timestam def get_datetime_index(df: pd.DataFrame, calendar_list: List[pd.Timestamp]) -> int: return calendar_list.index(df.index.min()) - def _data_to_bin(self, df: pd.DataFrame, calendar_list: List[pd.Timestamp], features_dir: Path): + def _data_to_bin( + self, df: pd.DataFrame, calendar_list: List[pd.Timestamp], features_dir: Path + ): if df.empty: logger.warning(f"{features_dir.name} data is None or empty") return @@ -257,7 +293,9 @@ def _data_to_bin(self, df: pd.DataFrame, calendar_list: List[pd.Timestamp], feat # used when creating a bin file date_index = self.get_datetime_index(_df, calendar_list) for field in self.get_dump_fields(_df.columns): - bin_path = features_dir.joinpath(f"{field.lower()}.{self.freq}{self.DUMP_FILE_SUFFIX}") + bin_path = features_dir.joinpath( + f"{field.lower()}.{self.freq}{self.DUMP_FILE_SUFFIX}" + ) if field not in _df.columns: continue if bin_path.exists() and self._mode == self.UPDATE_MODE: @@ -266,16 +304,22 @@ def _data_to_bin(self, df: pd.DataFrame, calendar_list: List[pd.Timestamp], feat np.array(_df[field]).astype(" self._old_calendar_list[-1], self._all_data[self.date_field_name].unique()) + filter( + lambda x: x > self._old_calendar_list[-1], + self._all_data[self.date_field_name].unique(), + ) ) def _load_all_source_data(self): @@ -495,29 +565,44 @@ def _dump_features(self): error_code = {} with ProcessPoolExecutor(max_workers=self.works) as executor: futures = {} - for _code, _df in self._all_data.groupby(self.symbol_field_name, group_keys=False): + for _code, _df in self._all_data.groupby( + self.symbol_field_name, group_keys=False + ): _code = fname_to_code(str(_code).lower()).upper() _start, _end = self._get_date(_df, is_begin_end=True) - if not (isinstance(_start, pd.Timestamp) and isinstance(_end, pd.Timestamp)): + if not ( + isinstance(_start, pd.Timestamp) and isinstance(_end, pd.Timestamp) + ): continue if _code in self._update_instruments: # exists stock, will append data _update_calendars = ( - _df[_df[self.date_field_name] > self._update_instruments[_code][self.INSTRUMENTS_END_FIELD]][ - self.date_field_name - ] + _df[ + _df[self.date_field_name] + > self._update_instruments[_code][ + self.INSTRUMENTS_END_FIELD + ] + ][self.date_field_name] .sort_values() .to_list() ) if _update_calendars: - self._update_instruments[_code][self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end) - futures[executor.submit(self._dump_bin, _df, _update_calendars)] = _code + self._update_instruments[_code][self.INSTRUMENTS_END_FIELD] = ( + self._format_datetime(_end) + ) + futures[ + executor.submit(self._dump_bin, _df, _update_calendars) + ] = _code else: # new stock _dt_range = self._update_instruments.setdefault(_code, dict()) - _dt_range[self.INSTRUMENTS_START_FIELD] = self._format_datetime(_start) + _dt_range[self.INSTRUMENTS_START_FIELD] = self._format_datetime( + _start + ) _dt_range[self.INSTRUMENTS_END_FIELD] = self._format_datetime(_end) - futures[executor.submit(self._dump_bin, _df, self._new_calendar_list)] = _code + futures[ + executor.submit(self._dump_bin, _df, self._new_calendar_list) + ] = _code with tqdm(total=len(futures)) as p_bar: for _future in as_completed(futures): @@ -539,4 +624,10 @@ def dump(self): if __name__ == "__main__": - fire.Fire({"dump_all": DumpDataAll, "dump_fix": DumpDataFix, "dump_update": DumpDataUpdate}) + fire.Fire( + { + "dump_all": DumpDataAll, + "dump_fix": DumpDataFix, + "dump_update": DumpDataUpdate, + } + ) diff --git a/scripts/dump_pit.py b/scripts/dump_pit.py index 806bbd0cc9..d02d6996ca 100644 --- a/scripts/dump_pit.py +++ b/scripts/dump_pit.py @@ -96,14 +96,22 @@ def __init__( exclude_fields = exclude_fields.split(",") if isinstance(include_fields, str): include_fields = include_fields.split(",") - self._exclude_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, exclude_fields))) - self._include_fields = tuple(filter(lambda x: len(x) > 0, map(str.strip, include_fields))) + self._exclude_fields = tuple( + filter(lambda x: len(x) > 0, map(str.strip, exclude_fields)) + ) + self._include_fields = tuple( + filter(lambda x: len(x) > 0, map(str.strip, include_fields)) + ) self.file_suffix = file_suffix - self.csv_files = sorted(csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path]) + self.csv_files = sorted( + csv_path.glob(f"*{self.file_suffix}") if csv_path.is_dir() else [csv_path] + ) if limit_nums is not None: self.csv_files = self.csv_files[: int(limit_nums)] self.qlib_dir = Path(qlib_dir).expanduser() - self.backup_dir = backup_dir if backup_dir is None else Path(backup_dir).expanduser() + self.backup_dir = ( + backup_dir if backup_dir is None else Path(backup_dir).expanduser() + ) if backup_dir is not None: self._backup_qlib_dir(Path(backup_dir).expanduser()) @@ -121,7 +129,9 @@ def _backup_qlib_dir(self, target_dir: Path): def get_source_data(self, file_path: Path) -> pd.DataFrame: df = pd.read_csv(str(file_path.resolve()), low_memory=False) df[self.value_column_name] = df[self.value_column_name].astype("float32") - df[self.date_column_name] = df[self.date_column_name].str.replace("-", "").astype("int32") + df[self.date_column_name] = ( + df[self.date_column_name].str.replace("-", "").astype("int32") + ) # df.drop_duplicates([self.date_field_name], inplace=True) return df @@ -183,7 +193,9 @@ def _dump_pit( logger.warning(f"{symbol} file is empty") return for field in self.get_dump_fields(df): - df_sub = df.query(f'{self.field_column_name}=="{field}"').sort_values(self.date_column_name) + df_sub = df.query(f'{self.field_column_name}=="{field}"').sort_values( + self.date_column_name + ) if df_sub.empty: logger.warning(f"field {field} of {symbol} is empty") continue @@ -199,7 +211,9 @@ def _dump_pit( # adjust `first_year` if existing data found if not overwrite and index_file.exists(): with open(index_file, "rb") as fi: - (first_year,) = struct.unpack(self.PERIOD_DTYPE, fi.read(self.PERIOD_DTYPE_SIZE)) + (first_year,) = struct.unpack( + self.PERIOD_DTYPE, fi.read(self.PERIOD_DTYPE_SIZE) + ) n_years = len(fi.read()) // self.INDEX_DTYPE_SIZE if interval == self.INTERVAL_quarterly: n_years //= 4 @@ -211,14 +225,18 @@ def _dump_pit( # if data already exists, continue to the next field if start_year > end_year: - logger.warning(f"{symbol}-{field} data already exists, continue to the next field") + logger.warning( + f"{symbol}-{field} data already exists, continue to the next field" + ) continue # dump index filled with NA with open(index_file, "ab") as fi: for year in range(start_year, end_year + 1): if interval == self.INTERVAL_quarterly: - fi.write(struct.pack(self.INDEX_DTYPE * 4, *[self.NA_INDEX] * 4)) + fi.write( + struct.pack(self.INDEX_DTYPE * 4, *[self.NA_INDEX] * 4) + ) else: fi.write(struct.pack(self.INDEX_DTYPE, self.NA_INDEX)) @@ -239,10 +257,14 @@ def _dump_pit( # update index if needed for i, row in df_sub.iterrows(): # get index - offset = get_period_offset(first_year, row.period, interval == self.INTERVAL_quarterly) + offset = get_period_offset( + first_year, row.period, interval == self.INTERVAL_quarterly + ) fi.seek(self.PERIOD_DTYPE_SIZE + self.INDEX_DTYPE_SIZE * offset) - (cur_index,) = struct.unpack(self.INDEX_DTYPE, fi.read(self.INDEX_DTYPE_SIZE)) + (cur_index,) = struct.unpack( + self.INDEX_DTYPE, fi.read(self.INDEX_DTYPE_SIZE) + ) # Case I: new data => update `_next` with current index if cur_index == self.NA_INDEX: @@ -252,16 +274,34 @@ def _dump_pit( else: _cur_fd = fd.tell() prev_index = self.NA_INDEX - while cur_index != self.NA_INDEX: # NOTE: first iter always != NA_INDEX - fd.seek(cur_index + self.DATA_DTYPE_SIZE - self.INDEX_DTYPE_SIZE) + while ( + cur_index != self.NA_INDEX + ): # NOTE: first iter always != NA_INDEX + fd.seek( + cur_index + self.DATA_DTYPE_SIZE - self.INDEX_DTYPE_SIZE + ) prev_index = cur_index - (cur_index,) = struct.unpack(self.INDEX_DTYPE, fd.read(self.INDEX_DTYPE_SIZE)) - fd.seek(prev_index + self.DATA_DTYPE_SIZE - self.INDEX_DTYPE_SIZE) - fd.write(struct.pack(self.INDEX_DTYPE, _cur_fd)) # NOTE: add _next pointer + (cur_index,) = struct.unpack( + self.INDEX_DTYPE, fd.read(self.INDEX_DTYPE_SIZE) + ) + fd.seek( + prev_index + self.DATA_DTYPE_SIZE - self.INDEX_DTYPE_SIZE + ) + fd.write( + struct.pack(self.INDEX_DTYPE, _cur_fd) + ) # NOTE: add _next pointer fd.seek(_cur_fd) # dump data - fd.write(struct.pack(self.DATA_DTYPE, row.date, row.period, row.value, self.NA_INDEX)) + fd.write( + struct.pack( + self.DATA_DTYPE, + row.date, + row.period, + row.value, + self.NA_INDEX, + ) + ) def dump(self, interval="quarterly", overwrite=False): logger.info("start dump pit data......") diff --git a/setup.py b/setup.py index 326dac8ed0..ba76046224 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,62 @@ def read(rel_path: str) -> str: VERSION = get_version(root=".", relative_to=__file__) +# Define base requirements +install_requires = [ + "numpy>=1.12.0", +<<<<<<< HEAD + "pandas>=0.25.1", +======= + "pandas>=0.25.1", +>>>>>>> f180e36a (fix: migrate from gym to gymnasium for NumPy 2.0+ compatibility) + "scipy>=1.0.0", + "scikit-learn>=0.22.0", + "matplotlib>=3.0.0", + "seaborn>=0.9.0", + "tqdm", + "joblib>=0.17.0", + "ruamel.yaml>=0.16.0", + "fire>=0.3.0", + "cloudpickle", + "lxml", + "jinja2", + "statsmodels", + "plotly>=4.12.0", + "redis>=3.0.1", + "python-socketio", + "pymongo>=3.7.0", + "influxdb", + "pyarrow>=6.0.0", +] + +<<<<<<< HEAD +# Define RL-specific optional requirements +======= +# Define RL-specific optional requirements +>>>>>>> f180e36a (fix: migrate from gym to gymnasium for NumPy 2.0+ compatibility) +extras_require = { + "rl": [ + "gymnasium>=0.28.0", # gymnasium + "stable-baselines3>=1.2.0", + "tensorboard>=2.0.0", + ], + "dev": [ + "black", +<<<<<<< HEAD + "flake8", + "pytest", + "pytest-cov", + "sphinx", + ] +======= + "flake8", + "pytest", + "pytest-cov", + "sphinx", + ], +>>>>>>> f180e36a (fix: migrate from gym to gymnasium for NumPy 2.0+ compatibility) +} + setup( version=VERSION, ext_modules=[ diff --git a/tests/backtest/test_file_strategy.py b/tests/backtest/test_file_strategy.py index 2e30f1a3cb..0f2967b9fe 100644 --- a/tests/backtest/test_file_strategy.py +++ b/tests/backtest/test_file_strategy.py @@ -40,14 +40,21 @@ def _gen_orders(self, dealt_num_for_1000) -> pd.DataFrame: # test selling all stocks ["20200110", self.TEST_INST, str(dealt_num_for_1000), "sell"], ] - return pd.DataFrame(orders, columns=headers).set_index(["datetime", "instrument"]) + return pd.DataFrame(orders, columns=headers).set_index( + ["datetime", "instrument"] + ) def test_file_str(self): # 0) basic settings account_money = 150000 # 1) get information - df = D.features([self.TEST_INST], ["$close", "$factor"], start_time="20200103", end_time="20200103") + df = D.features( + [self.TEST_INST], + ["$close", "$factor"], + start_time="20200103", + end_time="20200103", + ) price = df["$close"].item() factor = df["$factor"].item() price_unit = price / factor * 100 diff --git a/tests/backtest/test_high_freq_trading.py b/tests/backtest/test_high_freq_trading.py index a538464db4..1e7a164f05 100644 --- a/tests/backtest/test_high_freq_trading.py +++ b/tests/backtest/test_high_freq_trading.py @@ -8,7 +8,9 @@ import pandas as pd -@unittest.skip("This test takes a lot of time due to the large size of high-frequency data") +@unittest.skip( + "This test takes a lot of time due to the large size of high-frequency data" +) class TestHFBacktest(TestAutoData): @classmethod def setUpClass(cls) -> None: @@ -118,7 +120,12 @@ def test_trading(self): ret_val = {} decisions = list( - collect_data(executor=executor_config, strategy=strategy_config, **backtest_config, return_value=ret_val) + collect_data( + executor=executor_config, + strategy=strategy_config, + **backtest_config, + return_value=ret_val, + ) ) report, indicator = ret_val["report"], ret_val["indicator"] # NOTE: please refer to the docs of format_decisions diff --git a/tests/data_mid_layer_tests/test_dataloader.py b/tests/data_mid_layer_tests/test_dataloader.py index 8646e78587..8537802d86 100644 --- a/tests/data_mid_layer_tests/test_dataloader.py +++ b/tests/data_mid_layer_tests/test_dataloader.py @@ -25,13 +25,22 @@ def test_nested_data_loader(self): }, { "class": "qlib.contrib.data.loader.Alpha360DL", - "kwargs": {"config": {"label": (["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])}}, + "kwargs": { + "config": { + "label": ( + ["Ref($close, -2)/Ref($close, -1) - 1"], + ["LABEL0"], + ) + } + }, }, ] ) # Of course you can use StaticDataLoader - dataset = nd.load(instruments="csi300", start_time="2020-01-01", end_time="2020-01-31") + dataset = nd.load( + instruments="csi300", start_time="2020-01-01", end_time="2020-01-31" + ) assert dataset is not None diff --git a/tests/data_mid_layer_tests/test_dataset.py b/tests/data_mid_layer_tests/test_dataset.py index 9eb2083aa7..5596aae2a6 100755 --- a/tests/data_mid_layer_tests/test_dataset.py +++ b/tests/data_mid_layer_tests/test_dataset.py @@ -26,13 +26,25 @@ def testTSDataset(self): "fit_end_time": "2017-12-31", "instruments": "csi300", "infer_processors": [ - {"class": "FilterCol", "kwargs": {"col_list": ["RESI5", "WVMA5", "RSQR5"]}}, - {"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier": "true"}}, + { + "class": "FilterCol", + "kwargs": {"col_list": ["RESI5", "WVMA5", "RSQR5"]}, + }, + { + "class": "RobustZScoreNorm", + "kwargs": { + "fields_group": "feature", + "clip_outlier": "true", + }, + }, {"class": "Fillna", "kwargs": {"fields_group": "feature"}}, ], "learn_processors": [ "DropnaLabel", - {"class": "CSRankNorm", "kwargs": {"fields_group": "label"}}, # CSRankNorm + { + "class": "CSRankNorm", + "kwargs": {"fields_group": "label"}, + }, # CSRankNorm ], }, }, @@ -42,7 +54,9 @@ def testTSDataset(self): "test": ("2019-01-01", "2020-08-01"), }, ) - tsds_train = tsdh.prepare("train", data_key=DataHandlerLP.DK_L) # Test the correctness + tsds_train = tsdh.prepare( + "train", data_key=DataHandlerLP.DK_L + ) # Test the correctness tsds = tsdh.prepare("valid", data_key=DataHandlerLP.DK_L) t = time.time() @@ -104,14 +118,23 @@ def test_TSDataSampler(self): """ Test TSDataSampler for issue #1716 """ - datetime_list = ["2000-01-31", "2000-02-29", "2000-03-31", "2000-04-30", "2000-05-31"] + datetime_list = [ + "2000-01-31", + "2000-02-29", + "2000-03-31", + "2000-04-30", + "2000-05-31", + ] instruments = ["000001", "000002", "000003", "000004", "000005"] index = pd.MultiIndex.from_product( - [pd.to_datetime(datetime_list), instruments], names=["datetime", "instrument"] + [pd.to_datetime(datetime_list), instruments], + names=["datetime", "instrument"], ) data = np.random.randn(len(datetime_list) * len(instruments)) test_df = pd.DataFrame(data=data, index=index, columns=["factor"]) - dataset = TSDataSampler(test_df, datetime_list[0], datetime_list[-1], step_len=2) + dataset = TSDataSampler( + test_df, datetime_list[0], datetime_list[-1], step_len=2 + ) print() print("--------------dataset[0]--------------") print(dataset[0]) @@ -127,14 +150,23 @@ def test_TSDataSampler2(self): """ Extra test TSDataSampler to prevent incorrect filling of nan for the values at the front """ - datetime_list = ["2000-01-31", "2000-02-29", "2000-03-31", "2000-04-30", "2000-05-31"] + datetime_list = [ + "2000-01-31", + "2000-02-29", + "2000-03-31", + "2000-04-30", + "2000-05-31", + ] instruments = ["000001", "000002", "000003", "000004", "000005"] index = pd.MultiIndex.from_product( - [pd.to_datetime(datetime_list), instruments], names=["datetime", "instrument"] + [pd.to_datetime(datetime_list), instruments], + names=["datetime", "instrument"], ) data = np.random.randn(len(datetime_list) * len(instruments)) test_df = pd.DataFrame(data=data, index=index, columns=["factor"]) - dataset = TSDataSampler(test_df, datetime_list[2], datetime_list[-1], step_len=3) + dataset = TSDataSampler( + test_df, datetime_list[2], datetime_list[-1], step_len=3 + ) print() print("--------------dataset[0]--------------") print(dataset[0]) diff --git a/tests/data_mid_layer_tests/test_handler.py b/tests/data_mid_layer_tests/test_handler.py index 3ac813f5bf..d092516e71 100644 --- a/tests/data_mid_layer_tests/test_handler.py +++ b/tests/data_mid_layer_tests/test_handler.py @@ -12,7 +12,9 @@ def to_str(self, obj): return "".join(str(obj).split()) def test_handler_df(self): - df = D.features(["sh600519"], start_time="20190101", end_time="20190201", fields=["$close"]) + df = D.features( + ["sh600519"], start_time="20190101", end_time="20190201", fields=["$close"] + ) dh = DataHandlerLP.from_df(df) print(dh.fetch()) self.assertTrue(dh._data.equals(df)) diff --git a/tests/data_mid_layer_tests/test_handler_storage.py b/tests/data_mid_layer_tests/test_handler_storage.py index a8bb730f7b..7a39335bc2 100644 --- a/tests/data_mid_layer_tests/test_handler_storage.py +++ b/tests/data_mid_layer_tests/test_handler_storage.py @@ -21,8 +21,12 @@ def __init__( fit_end_time=None, drop_raw=True, ): - infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) - learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) + infer_processors = check_transform_proc( + infer_processors, fit_start_time, fit_end_time + ) + learn_processors = check_transform_proc( + learn_processors, fit_start_time, fit_end_time + ) data_loader = { "class": "QlibDataLoader", @@ -44,7 +48,14 @@ def __init__( ) def get_feature_config(self): - fields = ["Ref($open, 1)", "Ref($close, 1)", "Ref($volume, 1)", "$open", "$close", "$volume"] + fields = [ + "Ref($open, 1)", + "Ref($close, 1)", + "Ref($volume, 1)", + "$open", + "$close", + "$volume", + ] names = ["open_0", "close_0", "volume_0", "open_1", "close_1", "volume_1"] return fields, names @@ -70,13 +81,18 @@ def test_handler_storage(self): data_handler = TestHandler(**self.data_handler_kwargs) # init data handler with hasing storage - data_handler_hs = TestHandler(**self.data_handler_kwargs, infer_processors=["HashStockFormat"]) + data_handler_hs = TestHandler( + **self.data_handler_kwargs, infer_processors=["HashStockFormat"] + ) fetch_start_time = "2019-01-01" fetch_end_time = "2019-12-31" instruments = D.instruments(market=self.market) instruments = D.list_instruments( - instruments=instruments, start_time=fetch_start_time, end_time=fetch_end_time, as_list=True + instruments=instruments, + start_time=fetch_start_time, + end_time=fetch_end_time, + as_list=True, ) with TimeInspector.logt("random fetch with DataFrame Storage"): @@ -84,26 +100,38 @@ def test_handler_storage(self): for i in range(100): random_index = np.random.randint(len(instruments), size=1)[0] fetch_stock = instruments[random_index] - data_handler.fetch(selector=(fetch_stock, slice(fetch_start_time, fetch_end_time)), level=None) + data_handler.fetch( + selector=(fetch_stock, slice(fetch_start_time, fetch_end_time)), + level=None, + ) # multi stocks for i in range(100): random_indexs = np.random.randint(len(instruments), size=5) fetch_stocks = [instruments[_index] for _index in random_indexs] - data_handler.fetch(selector=(fetch_stocks, slice(fetch_start_time, fetch_end_time)), level=None) + data_handler.fetch( + selector=(fetch_stocks, slice(fetch_start_time, fetch_end_time)), + level=None, + ) with TimeInspector.logt("random fetch with HashingStock Storage"): # single stock for i in range(100): random_index = np.random.randint(len(instruments), size=1)[0] fetch_stock = instruments[random_index] - data_handler_hs.fetch(selector=(fetch_stock, slice(fetch_start_time, fetch_end_time)), level=None) + data_handler_hs.fetch( + selector=(fetch_stock, slice(fetch_start_time, fetch_end_time)), + level=None, + ) # multi stocks for i in range(100): random_indexs = np.random.randint(len(instruments), size=5) fetch_stocks = [instruments[_index] for _index in random_indexs] - data_handler_hs.fetch(selector=(fetch_stocks, slice(fetch_start_time, fetch_end_time)), level=None) + data_handler_hs.fetch( + selector=(fetch_stocks, slice(fetch_start_time, fetch_end_time)), + level=None, + ) if __name__ == "__main__": diff --git a/tests/data_mid_layer_tests/test_processor.py b/tests/data_mid_layer_tests/test_processor.py index 46453b3162..ee9000d3ea 100644 --- a/tests/data_mid_layer_tests/test_processor.py +++ b/tests/data_mid_layer_tests/test_processor.py @@ -23,10 +23,14 @@ def normalize(df): df.loc(axis=1)[df.columns] = (df.values - min_val) / (max_val - min_val) return df - origin_df = D.features([self.TEST_INST], ["$high", "$open", "$low", "$close"]).tail(10) + origin_df = D.features( + [self.TEST_INST], ["$high", "$open", "$low", "$close"] + ).tail(10) origin_df["test"] = 0 df = origin_df.copy() - mmn = MinMaxNorm(fields_group=None, fit_start_time="2021-05-31", fit_end_time="2021-06-11") + mmn = MinMaxNorm( + fields_group=None, fit_start_time="2021-05-31", fit_end_time="2021-06-11" + ) mmn.fit(df) mmn.__call__(df) origin_df = normalize(origin_df) @@ -44,31 +48,50 @@ def normalize(df): df.loc(axis=1)[df.columns] = (df.values - mean_train) / std_train return df - origin_df = D.features([self.TEST_INST], ["$high", "$open", "$low", "$close"]).tail(10) + origin_df = D.features( + [self.TEST_INST], ["$high", "$open", "$low", "$close"] + ).tail(10) origin_df["test"] = 0 df = origin_df.copy() - zsn = ZScoreNorm(fields_group=None, fit_start_time="2021-05-31", fit_end_time="2021-06-11") + zsn = ZScoreNorm( + fields_group=None, fit_start_time="2021-05-31", fit_end_time="2021-06-11" + ) zsn.fit(df) zsn.__call__(df) origin_df = normalize(origin_df) assert (df == origin_df).all().all() def test_CSZFillna(self): - origin_df = D.features(D.instruments(market="csi300"), fields=["$high", "$open", "$low", "$close"]) - origin_df = origin_df.groupby("datetime", group_keys=False).apply(lambda x: x[97:99])[228:238] + origin_df = D.features( + D.instruments(market="csi300"), fields=["$high", "$open", "$low", "$close"] + ) + origin_df = origin_df.groupby("datetime", group_keys=False).apply( + lambda x: x[97:99] + )[228:238] df = origin_df.copy() CSZFillna(fields_group=None).__call__(df) assert ~df[1:2].isna().all().all() and origin_df[1:2].isna().all().all() def test_CSZScoreNorm(self): - origin_df = D.features(D.instruments(market="csi300"), fields=["$high", "$open", "$low", "$close"]) - origin_df = origin_df.groupby("datetime", group_keys=False).apply(lambda x: x[10:12])[50:60] + origin_df = D.features( + D.instruments(market="csi300"), fields=["$high", "$open", "$low", "$close"] + ) + origin_df = origin_df.groupby("datetime", group_keys=False).apply( + lambda x: x[10:12] + )[50:60] df = origin_df.copy() CSZScoreNorm(fields_group=None).__call__(df) # If we use the formula directly on the original data, we cannot get the correct result, # because the original data is processed by `groupby`, so we use the method of slicing, # taking the 2nd group of data from the original data, to calculate and compare. - assert (df[2:4] == ((origin_df[2:4] - origin_df[2:4].mean()).div(origin_df[2:4].std()))).all().all() + assert ( + ( + df[2:4] + == ((origin_df[2:4] - origin_df[2:4].mean()).div(origin_df[2:4].std())) + ) + .all() + .all() + ) if __name__ == "__main__": diff --git a/tests/dataset_tests/test_datalayer.py b/tests/dataset_tests/test_datalayer.py index 6509a77dec..8e0c2e7631 100644 --- a/tests/dataset_tests/test_datalayer.py +++ b/tests/dataset_tests/test_datalayer.py @@ -15,10 +15,18 @@ def testCSI300(self): print(size_desc) print(cnt_desc) - self.assertLessEqual(size_desc.loc["max"], 305, "Excessive number of CSI300 constituent stocks") - self.assertGreaterEqual(size_desc.loc["80%"], 290, "Insufficient number of CSI300 constituent stocks") - - self.assertLessEqual(cnt_desc.loc["max"], 305, "Excessive number of CSI300 constituent stocks") + self.assertLessEqual( + size_desc.loc["max"], 305, "Excessive number of CSI300 constituent stocks" + ) + self.assertGreaterEqual( + size_desc.loc["80%"], + 290, + "Insufficient number of CSI300 constituent stocks", + ) + + self.assertLessEqual( + cnt_desc.loc["max"], 305, "Excessive number of CSI300 constituent stocks" + ) # FIXME: Due to the low quality of data. Hard to make sure there are enough data # self.assertEqual(cnt_desc.loc["80%"], 300, "Insufficient number of CSI300 constituent stocks") @@ -26,8 +34,12 @@ def testClose(self): close_p = D.features(D.instruments("csi300"), ["Ref($close, 1)/$close - 1"]) close_desc = close_p.describe(percentiles=np.arange(0.1, 1.0, 0.1)) print(close_desc) - self.assertLessEqual(abs(close_desc.loc["90%"][0]), 0.1, "Close value is abnormal") - self.assertLessEqual(abs(close_desc.loc["10%"][0]), 0.1, "Close value is abnormal") + self.assertLessEqual( + abs(close_desc.loc["90%"][0]), 0.1, "Close value is abnormal" + ) + self.assertLessEqual( + abs(close_desc.loc["10%"][0]), 0.1, "Close value is abnormal" + ) # FIXME: The yahoo data is not perfect. We have to # self.assertLessEqual(abs(close_desc.loc["max"][0]), 0.2, "Close value is abnormal") # self.assertGreaterEqual(close_desc.loc["min"][0], -0.2, "Close value is abnormal") diff --git a/tests/misc/test_get_multi_proc.py b/tests/misc/test_get_multi_proc.py index 7e27781b6e..7489cca5b5 100644 --- a/tests/misc/test_get_multi_proc.py +++ b/tests/misc/test_get_multi_proc.py @@ -10,7 +10,12 @@ def get_features(fields): - qlib.init(provider_uri=TestAutoData.provider_uri, expression_cache=None, dataset_cache=None, joblib_backend="loky") + qlib.init( + provider_uri=TestAutoData.provider_uri, + expression_cache=None, + dataset_cache=None, + joblib_backend="loky", + ) return D.features(D.instruments("csi300"), fields) diff --git a/tests/misc/test_index_data.py b/tests/misc/test_index_data.py index 89fccb4d91..2989cad30d 100644 --- a/tests/misc/test_index_data.py +++ b/tests/misc/test_index_data.py @@ -45,7 +45,9 @@ def test_index_multi_data(self): idd.MultiData(range(10), index=["foo", "bar"], columns=["f", "g"]) # test indexing - sd = idd.MultiData(np.arange(4).reshape(2, 2), index=["foo", "bar"], columns=["f", "g"]) + sd = idd.MultiData( + np.arange(4).reshape(2, 2), index=["foo", "bar"], columns=["f", "g"] + ) print(sd) print(sd.iloc[1]) # get second row @@ -62,7 +64,9 @@ def test_index_multi_data(self): print(sd.loc[:, "g":]) def test_sorting(self): - sd = idd.MultiData(np.arange(4).reshape(2, 2), index=["foo", "bar"], columns=["f", "g"]) + sd = idd.MultiData( + np.arange(4).reshape(2, 2), index=["foo", "bar"], columns=["f", "g"] + ) print(sd) sd.sort_index() @@ -70,7 +74,9 @@ def test_sorting(self): print(sd.loc[:"c"]) def test_corner_cases(self): - sd = idd.MultiData([[1, 2], [3, np.nan]], index=["foo", "bar"], columns=["f", "g"]) + sd = idd.MultiData( + [[1, 2], [3, np.nan]], index=["foo", "bar"], columns=["f", "g"] + ) print(sd) self.assertTrue(np.isnan(sd.loc["bar", "g"])) @@ -124,7 +130,9 @@ def test_ops(self): self.assertTrue(np.isnan((sd1 + sd2).iloc[3])) self.assertTrue(sd1.add(sd2).sum() == 13) - self.assertTrue(idd.sum_by_index([sd1, sd2], sd1.index, fill_value=0.0).sum() == 13) + self.assertTrue( + idd.sum_by_index([sd1, sd2], sd1.index, fill_value=0.0).sum() == 13 + ) def test_todo(self): pass diff --git a/tests/misc/test_utils.py b/tests/misc/test_utils.py index db5b072488..b2e342a4f9 100644 --- a/tests/misc/test_utils.py +++ b/tests/misc/test_utils.py @@ -8,7 +8,13 @@ from qlib.config import C from qlib.log import TimeInspector from qlib.constant import REG_CN, REG_US, REG_TW -from qlib.utils.time import cal_sam_minute as cal_sam_minute_new, get_min_cal, CN_TIME, US_TIME, TW_TIME +from qlib.utils.time import ( + cal_sam_minute as cal_sam_minute_new, + get_min_cal, + CN_TIME, + US_TIME, + TW_TIME, +) from qlib.utils.data import guess_horizon REG_MAP = {REG_CN: CN_TIME, REG_US: US_TIME, REG_TW: TW_TIME} @@ -66,7 +72,9 @@ def cal_sam_minute(x: pd.Timestamp, sam_minutes: int, region: str): elif 120 <= minute_index < 240: return mid_open_time + (minute_index - 120) * pd.Timedelta(minutes=1) else: - raise ValueError("calendar minute_index error, check `min_data_shift` in qlib.config.C") + raise ValueError( + "calendar minute_index error, check `min_data_shift` in qlib.config.C" + ) class TimeUtils(TestCase): @@ -99,7 +107,9 @@ def gen_args(cal: List): for region in regions: cal_time = get_min_cal(region=region) for args in gen_args(cal_time): - assert cal_sam_minute(*args, region) == cal_sam_minute_new(*args, region=region) + assert cal_sam_minute(*args, region) == cal_sam_minute_new( + *args, region=region + ) # test the performance of the code args_l = list(gen_args(cal_time)) diff --git a/tests/ops/test_elem_operator.py b/tests/ops/test_elem_operator.py index 8349157ff4..f2c29b23fc 100644 --- a/tests/ops/test_elem_operator.py +++ b/tests/ops/test_elem_operator.py @@ -18,7 +18,9 @@ def setUp(self) -> None: def test_Abs(self): field = "Abs($close-Ref($close, 1))" - result = ExpressionD.expression(self.instrument, field, self.start_time, self.end_time, self.freq) + result = ExpressionD.expression( + self.instrument, field, self.start_time, self.end_time, self.freq + ) self.assertGreaterEqual(result.min(), 0) result = result.to_numpy() prev_close = self.mock_df["close"].shift(1) @@ -29,7 +31,9 @@ def test_Abs(self): def test_Sign(self): field = "Sign($close-Ref($close, 1))" - result = ExpressionD.expression(self.instrument, field, self.start_time, self.end_time, self.freq) + result = ExpressionD.expression( + self.instrument, field, self.start_time, self.end_time, self.freq + ) result = result.to_numpy() prev_close = self.mock_df["close"].shift(1) close = self.mock_df["close"] @@ -55,7 +59,14 @@ def setUp(self) -> None: ] columns = ["change", "abs"] self.data = DatasetProvider.inst_calculator( - self.inst, self.start_time, self.end_time, freq, expressions, self.spans, C, [] + self.inst, + self.start_time, + self.end_time, + freq, + expressions, + self.spans, + C, + [], ) self.data.columns = columns diff --git a/tests/ops/test_special_ops.py b/tests/ops/test_special_ops.py index 6c4a4ec499..60d75b69c1 100644 --- a/tests/ops/test_special_ops.py +++ b/tests/ops/test_special_ops.py @@ -12,11 +12,21 @@ def test_setting(self): df = D.features(["SH600519"], ["ChangeInstrument('SH000300', $close)"]) # get market return for "SH600519" - df = D.features(["SH600519"], ["ChangeInstrument('SH000300', Feature('close')/Ref(Feature('close'),1) -1)"]) - df = D.features(["SH600519"], ["ChangeInstrument('SH000300', $close/Ref($close,1) -1)"]) + df = D.features( + ["SH600519"], + [ + "ChangeInstrument('SH000300', Feature('close')/Ref(Feature('close'),1) -1)" + ], + ) + df = D.features( + ["SH600519"], ["ChangeInstrument('SH000300', $close/Ref($close,1) -1)"] + ) # excess return df = D.features( - ["SH600519"], ["($close/Ref($close,1) -1) - ChangeInstrument('SH000300', $close/Ref($close,1) -1)"] + ["SH600519"], + [ + "($close/Ref($close,1) -1) - ChangeInstrument('SH000300', $close/Ref($close,1) -1)" + ], ) print(df) @@ -29,10 +39,16 @@ def test_case(instruments, queries, note=None): print(df) return df - test_case(["SH600519"], ["ChangeInstrument('SH000300', $close)"], "get market index close") test_case( ["SH600519"], - ["ChangeInstrument('SH000300', Feature('close')/Ref(Feature('close'),1) -1)"], + ["ChangeInstrument('SH000300', $close)"], + "get market index close", + ) + test_case( + ["SH600519"], + [ + "ChangeInstrument('SH000300', Feature('close')/Ref(Feature('close'),1) -1)" + ], "get market index return with Feature", ) test_case( @@ -42,7 +58,9 @@ def test_case(instruments, queries, note=None): ) test_case( ["SH600519"], - ["($close/Ref($close,1) -1) - ChangeInstrument('SH000300', $close/Ref($close,1) -1)"], + [ + "($close/Ref($close,1) -1) - ChangeInstrument('SH000300', $close/Ref($close,1) -1)" + ], "get excess return with expression with beta=1", ) @@ -61,13 +79,19 @@ def test_case(instruments, queries, note=None): beta, excess_return, ] - test_case(["SH600519"], fields[5:], "get market beta and excess_return with estimated beta") + test_case( + ["SH600519"], + fields[5:], + "get market beta and excess_return with estimated beta", + ) instrument = "sh600519" ret = Feature("close") / Ref(Feature("close"), 1) - 1 benchmark = "sh000300" n_period = 252 - marketRet = ChangeInstrument(benchmark, Feature("close") / Ref(Feature("close"), 1) - 1) + marketRet = ChangeInstrument( + benchmark, Feature("close") / Ref(Feature("close"), 1) - 1 + ) marketVar = ChangeInstrument(benchmark, Var(marketRet, n_period)) beta = Cov(ret, marketRet, n_period) / marketVar fields = [ @@ -78,7 +102,14 @@ def test_case(instruments, queries, note=None): beta, ret - beta * marketRet, ] - names = ["close", "marketClose", "ret", "marketRet", f"beta_{n_period}", "excess_return"] + names = [ + "close", + "marketClose", + "ret", + "marketRet", + f"beta_{n_period}", + "excess_return", + ] data_loader_config = {"feature": (fields, names)} data_loader = QlibDataLoader(config=data_loader_config) df = data_loader.load(instruments=[instrument]) # , start_time=start_time) diff --git a/tests/rl/test_data_queue.py b/tests/rl/test_data_queue.py index 0b0c61280d..e052913cb5 100644 --- a/tests/rl/test_data_queue.py +++ b/tests/rl/test_data_queue.py @@ -17,7 +17,9 @@ def __init__(self, length): def __getitem__(self, index): assert 0 <= index < self.length - return pd.DataFrame(np.random.randint(0, 100, size=(index + 1, 4)), columns=list("ABCD")) + return pd.DataFrame( + np.random.randint(0, 100, size=(index + 1, 4)), columns=list("ABCD") + ) def __len__(self): return self.length @@ -50,7 +52,9 @@ def test_multiprocess_shared_dataloader(): queue = multiprocessing.Queue() processes = [] for _ in range(3): - processes.append(multiprocessing.Process(target=_worker, args=(data_queue, queue))) + processes.append( + multiprocessing.Process(target=_worker, args=(data_queue, queue)) + ) processes[-1].start() for p in processes: p.join() diff --git a/tests/rl/test_finite_env.py b/tests/rl/test_finite_env.py index d6f2a2ec95..c1dc237a72 100644 --- a/tests/rl/test_finite_env.py +++ b/tests/rl/test_finite_env.py @@ -25,7 +25,10 @@ "position": gym.spaces.Box(low=-100, high=100, shape=(3,)), "velocity": gym.spaces.Box(low=-1, high=1, shape=(3,)), "front_cam": gym.spaces.Tuple( - (gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)), gym.spaces.Box(low=0, high=1, shape=(10, 10, 3))) + ( + gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)), + gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)), + ) ), "rear_cam": gym.spaces.Box(low=0, high=1, shape=(10, 10, 3)), } @@ -52,7 +55,11 @@ def __init__(self, dataset, num_replicas, rank): self.dataset = dataset self.num_replicas = num_replicas self.rank = rank - self.loader = DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None) + self.loader = DataLoader( + dataset, + sampler=DistributedSampler(dataset, num_replicas, rank), + batch_size=None, + ) self.iterator = None self.observation_space = gym.spaces.Discrete(255) self.action_space = gym.spaces.Discrete(2) @@ -84,7 +91,11 @@ def __init__(self, dataset, num_replicas, rank): self.dataset = dataset self.num_replicas = num_replicas self.rank = rank - self.loader = DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas, rank), batch_size=None) + self.loader = DataLoader( + dataset, + sampler=DistributedSampler(dataset, num_replicas, rank), + batch_size=None, + ) self.iterator = None self.observation_space = gym.spaces.Discrete(255) self.action_space = gym.spaces.Discrete(2) @@ -167,7 +178,9 @@ def on_env_step(self, *args, **kwargs): def test_finite_dummy_vector_env(): length = 100 dataset = DummyDataset(length) - envs = FiniteDummyVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)]) + envs = FiniteDummyVectorEnv( + MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)] + ) envs._collector_guarded = True policy = AnyPolicy() test_collector = Collector(policy, envs, exploration_noise=True) @@ -183,7 +196,9 @@ def test_finite_dummy_vector_env(): def test_finite_shmem_vector_env(): length = 100 dataset = DummyDataset(length) - envs = FiniteShmemVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)]) + envs = FiniteShmemVectorEnv( + MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)] + ) envs._collector_guarded = True policy = AnyPolicy() test_collector = Collector(policy, envs, exploration_noise=True) @@ -199,7 +214,9 @@ def test_finite_shmem_vector_env(): def test_finite_subproc_vector_env(): length = 100 dataset = DummyDataset(length) - envs = FiniteSubprocVectorEnv(MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)]) + envs = FiniteSubprocVectorEnv( + MetricTracker(length), [_finite_env_factory(dataset, 5, i) for i in range(5)] + ) envs._collector_guarded = True policy = AnyPolicy() test_collector = Collector(policy, envs, exploration_noise=True) @@ -221,7 +238,8 @@ def test_finite_dummy_vector_env_complex(): length = 100 dataset = DummyDataset(length) envs = FiniteDummyVectorEnv( - DoNothingTracker(), [_finite_env_factory(dataset, 5, i, complex=True) for i in range(5)] + DoNothingTracker(), + [_finite_env_factory(dataset, 5, i, complex=True) for i in range(5)], ) envs._collector_guarded = True policy = AnyPolicy() @@ -237,7 +255,8 @@ def test_finite_shmem_vector_env_complex(): length = 100 dataset = DummyDataset(length) envs = FiniteShmemVectorEnv( - DoNothingTracker(), [_finite_env_factory(dataset, 5, i, complex=True) for i in range(5)] + DoNothingTracker(), + [_finite_env_factory(dataset, 5, i, complex=True) for i in range(5)], ) envs._collector_guarded = True policy = AnyPolicy() diff --git a/tests/rl/test_logger.py b/tests/rl/test_logger.py index e100e5046b..797901cc8b 100644 --- a/tests/rl/test_logger.py +++ b/tests/rl/test_logger.py @@ -76,7 +76,9 @@ def test_simple_env_logger(caplog): for venv_cls_name in ["dummy", "shmem", "subproc"]: writer = ConsoleWriter() csv_writer = CsvWriter(Path(__file__).parent / ".output") - venv = vectorize_env(lambda: SimpleEnv(), venv_cls_name, 4, [writer, csv_writer]) + venv = vectorize_env( + lambda: SimpleEnv(), venv_cls_name, 4, [writer, csv_writer] + ) with venv.collector_guard(): collector = Collector(AnyPolicy(), venv) collector.collect(n_episode=30) @@ -89,7 +91,10 @@ def test_simple_env_logger(caplog): line = line.strip() if line: line_counter += 1 - assert re.match(r".*reward .* {2}a .* \(([456])\.\d+\) {2}c .* \((14|15|16)\.\d+\)", line) + assert re.match( + r".*reward .* {2}a .* \(([456])\.\d+\) {2}c .* \((14|15|16)\.\d+\)", + line, + ) assert line_counter >= 3 @@ -152,7 +157,9 @@ def env_wrapper_factory(): # loglevel can be debugged here because metrics can all dump into csv # otherwise, csv writer might crash - csv_writer = CsvWriter(Path(__file__).parent / ".output", loglevel=LogLevel.DEBUG) + csv_writer = CsvWriter( + Path(__file__).parent / ".output", loglevel=LogLevel.DEBUG + ) venv = vectorize_env(env_wrapper_factory, "shmem", 4, csv_writer) with venv.collector_guard(): collector = Collector(RandomFivePolicy(), venv) @@ -161,7 +168,9 @@ def env_wrapper_factory(): output_df = pd.read_csv(Path(__file__).parent / ".output" / "result.csv") assert len(output_df) == 20 # obs has an increasing trend - assert output_df["obs"].to_numpy()[:10].sum() < output_df["obs"].to_numpy()[10:].sum() + assert ( + output_df["obs"].to_numpy()[:10].sum() < output_df["obs"].to_numpy()[10:].sum() + ) assert (output_df["test_a"] == 233).all() assert (output_df["test_b"] == 200).all() assert "steps_per_episode" in output_df and "reward" in output_df diff --git a/tests/rl/test_qlib_simulator.py b/tests/rl/test_qlib_simulator.py index 382609e5e1..3fdc56dc1f 100644 --- a/tests/rl/test_qlib_simulator.py +++ b/tests/rl/test_qlib_simulator.py @@ -15,7 +15,9 @@ TOTAL_POSITION = 2100.0 -python_version_request = pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") +python_version_request = pytest.mark.skipif( + sys.version_info < (3, 8), reason="requires python3.8 or higher" +) def is_close(a: float, b: float, epsilon: float = 1e-4) -> bool: @@ -38,7 +40,10 @@ def get_configs(order: Order) -> Tuple[dict, dict]: "module_path": "qlib.backtest.executor", "kwargs": { "time_per_step": "1day", - "inner_strategy": {"class": "ProxySAOEStrategy", "module_path": "qlib.rl.order_execution.strategy"}, + "inner_strategy": { + "class": "ProxySAOEStrategy", + "module_path": "qlib.rl.order_execution.strategy", + }, "track_data": True, "inner_executor": { "class": "NestedExecutor", @@ -138,7 +143,9 @@ def test_simulator_first_step(): assert (state.history_exec["deal_amount"] == AMOUNT / 30).all() assert is_close(state.history_exec["trade_price"].iloc[0], 149.566483) assert is_close(state.history_exec["trade_value"].iloc[0], 1495.664825) - assert is_close(state.history_exec["position"].iloc[0], TOTAL_POSITION - AMOUNT / 30) + assert is_close( + state.history_exec["position"].iloc[0], TOTAL_POSITION - AMOUNT / 30 + ) assert is_close(state.history_exec["ffr"].iloc[0], AMOUNT / TOTAL_POSITION / 30) assert is_close(state.history_steps["market_volume"].iloc[0], 1254848.5756835938) @@ -146,7 +153,8 @@ def test_simulator_first_step(): assert state.history_steps["deal_amount"].iloc[0] == AMOUNT assert state.history_steps["ffr"].iloc[0] == AMOUNT / TOTAL_POSITION assert is_close( - state.history_steps["pa"].iloc[0] * (1.0 if order.direction == OrderDir.SELL else -1.0), + state.history_steps["pa"].iloc[0] + * (1.0 if order.direction == OrderDir.SELL else -1.0), (state.history_steps["trade_price"].iloc[0] / simulator.twap_price - 1) * 10000, ) @@ -163,14 +171,23 @@ def test_simulator_stop_twap() -> None: state = simulator.get_state() assert len(state.history_exec) == HISTORY_STEP_LENGTH - assert (state.history_exec["deal_amount"] == TOTAL_POSITION / HISTORY_STEP_LENGTH).all() - assert is_close(state.history_steps["position"].iloc[0], TOTAL_POSITION * (NUM_STEPS - 1) / NUM_STEPS) + assert ( + state.history_exec["deal_amount"] == TOTAL_POSITION / HISTORY_STEP_LENGTH + ).all() + assert is_close( + state.history_steps["position"].iloc[0], + TOTAL_POSITION * (NUM_STEPS - 1) / NUM_STEPS, + ) assert is_close(state.history_steps["position"].iloc[-1], 0.0) assert is_close(state.position, 0.0) assert is_close(state.metrics["ffr"], 1.0) - assert is_close(state.metrics["market_price"], state.backtest_data.get_deal_price().mean()) - assert is_close(state.metrics["market_volume"], state.backtest_data.get_volume().sum()) + assert is_close( + state.metrics["market_price"], state.backtest_data.get_deal_price().mean() + ) + assert is_close( + state.metrics["market_volume"], state.backtest_data.get_volume().sum() + ) assert is_close(state.metrics["trade_price"], state.metrics["market_price"]) assert is_close(state.metrics["pa"], 0.0) @@ -192,4 +209,6 @@ def test_interpreter() -> None: state = simulator.get_state() position_history.append(state.position) - assert position_history[-1] == max(TOTAL_POSITION - TOTAL_POSITION / NUM_EXECUTION * (i + 1), 0.0) + assert position_history[-1] == max( + TOTAL_POSITION - TOTAL_POSITION / NUM_EXECUTION * (i + 1), 0.0 + ) diff --git a/tests/rl/test_saoe_simple.py b/tests/rl/test_saoe_simple.py index d1711bb289..bf114fecb6 100644 --- a/tests/rl/test_saoe_simple.py +++ b/tests/rl/test_saoe_simple.py @@ -21,7 +21,9 @@ from qlib.rl.trainer import backtest, train from qlib.rl.utils import ConsoleWriter, CsvWriter, EnvWrapperStatus -pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8") +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8" +) DATA_ROOT_DIR = Path(__file__).parent.parent / ".data" / "rl" / "intraday_saoe" @@ -37,7 +39,9 @@ def test_pickle_data_inspect(): - data = pickle_styled.load_simple_intraday_backtest_data(BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0) + data = pickle_styled.load_simple_intraday_backtest_data( + BACKTEST_DATA_DIR, "AAL", "2013-12-11", "close", 0 + ) assert len(data) == 390 provider = PickleProcessedDataProvider(DATA_DIR / "processed") @@ -46,7 +50,13 @@ def test_pickle_data_inspect(): def test_simulator_first_step(): - order = Order("AAL", 30.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59")) + order = Order( + "AAL", + 30.0, + 0, + pd.Timestamp("2013-12-11 00:00:00"), + pd.Timestamp("2013-12-11 23:59:59"), + ) simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR) state = simulator.get_state() @@ -72,7 +82,8 @@ def test_simulator_first_step(): assert state.history_steps["ffr"].iloc[0] == 0.5 assert ( state.history_steps["pa"].iloc[0] - == (state.history_steps["trade_price"].iloc[0] / simulator.twap_price - 1) * 10000 + == (state.history_steps["trade_price"].iloc[0] / simulator.twap_price - 1) + * 10000 ) assert state.position == 15.0 @@ -80,7 +91,13 @@ def test_simulator_first_step(): def test_simulator_stop_twap(): - order = Order("AAL", 13.0, 0, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59")) + order = Order( + "AAL", + 13.0, + 0, + pd.Timestamp("2013-12-11 00:00:00"), + pd.Timestamp("2013-12-11 23:59:59"), + ) simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR) for _ in range(13): @@ -89,11 +106,19 @@ def test_simulator_stop_twap(): state = simulator.get_state() assert len(state.history_exec) == 390 assert (state.history_exec["deal_amount"] == 13 / 390).all() - assert state.history_steps["position"].iloc[0] == 12 and state.history_steps["position"].iloc[-1] == 0 + assert ( + state.history_steps["position"].iloc[0] == 12 + and state.history_steps["position"].iloc[-1] == 0 + ) assert (state.metrics["ffr"] - 1) < 1e-3 - assert abs(state.metrics["market_price"] - state.backtest_data.get_deal_price().mean()) < 1e-4 - assert np.isclose(state.metrics["market_volume"], state.backtest_data.get_volume().sum()) + assert ( + abs(state.metrics["market_price"] - state.backtest_data.get_deal_price().mean()) + < 1e-4 + ) + assert np.isclose( + state.metrics["market_volume"], state.backtest_data.get_volume().sum() + ) assert state.position == 0.0 assert abs(state.metrics["trade_price"] - state.metrics["market_price"]) < 1e-4 assert abs(state.metrics["pa"]) < 1e-2 @@ -102,7 +127,13 @@ def test_simulator_stop_twap(): def test_simulator_stop_early(): - order = Order("AAL", 1.0, 1, pd.Timestamp("2013-12-11 00:00:00"), pd.Timestamp("2013-12-11 23:59:59")) + order = Order( + "AAL", + 1.0, + 1, + pd.Timestamp("2013-12-11 00:00:00"), + pd.Timestamp("2013-12-11 23:59:59"), + ) with pytest.raises(ValueError): simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR) @@ -116,7 +147,13 @@ def test_simulator_stop_early(): def test_simulator_start_middle(): - order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59")) + order = Order( + "AAL", + 15.0, + 1, + pd.Timestamp("2013-12-11 10:15:00"), + pd.Timestamp("2013-12-11 15:44:59"), + ) simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR) assert len(simulator.ticks_for_order) == 330 @@ -135,7 +172,13 @@ def test_simulator_start_middle(): def test_interpreter(): - order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 10:15:00"), pd.Timestamp("2013-12-11 15:44:59")) + order = Order( + "AAL", + 15.0, + 1, + pd.Timestamp("2013-12-11 10:15:00"), + pd.Timestamp("2013-12-11 15:44:59"), + ) simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR) assert len(simulator.ticks_for_order) == 330 @@ -145,15 +188,21 @@ def test_interpreter(): class EmulateEnvWrapper(NamedTuple): status: EnvWrapperStatus - interpreter = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR)) + interpreter = FullHistoryStateInterpreter( + 13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR) + ) interpreter_step = CurrentStepStateInterpreter(13) interpreter_action = CategoricalActionInterpreter(20) interpreter_action_twap = TwapRelativeActionInterpreter() - wrapper_status_kwargs = dict(initial_state=order, obs_history=[], action_history=[], reward_history=[]) + wrapper_status_kwargs = dict( + initial_state=order, obs_history=[], action_history=[], reward_history=[] + ) # first step - interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=0, done=False, **wrapper_status_kwargs)) + interpreter.env = EmulateEnvWrapper( + status=EnvWrapperStatus(cur_step=0, done=False, **wrapper_status_kwargs) + ) obs = interpreter(simulator.get_state()) assert obs["cur_tick"] == 45 @@ -165,7 +214,9 @@ class EmulateEnvWrapper(NamedTuple): assert obs["data_processed_prev"].shape == (390, 5) # first step: second interpreter - interpreter_step.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=0, done=False, **wrapper_status_kwargs)) + interpreter_step.env = EmulateEnvWrapper( + status=EnvWrapperStatus(cur_step=0, done=False, **wrapper_status_kwargs) + ) obs = interpreter_step(simulator.get_state()) assert obs["acquiring"] == 1 @@ -173,7 +224,9 @@ class EmulateEnvWrapper(NamedTuple): # second step simulator.step(5.0) - interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=1, done=False, **wrapper_status_kwargs)) + interpreter.env = EmulateEnvWrapper( + status=EnvWrapperStatus(cur_step=1, done=False, **wrapper_status_kwargs) + ) obs = interpreter(simulator.get_state()) assert obs["cur_tick"] == 60 @@ -200,7 +253,9 @@ class EmulateEnvWrapper(NamedTuple): # last step simulator.step(5.0) interpreter.env = EmulateEnvWrapper( - status=EnvWrapperStatus(cur_step=12, done=simulator.done(), **wrapper_status_kwargs) + status=EnvWrapperStatus( + cur_step=12, done=simulator.done(), **wrapper_status_kwargs + ) ) assert interpreter.env.status["done"] @@ -216,7 +271,13 @@ class EmulateEnvWrapper(NamedTuple): def test_network_sanity(): # we won't check the correctness of networks here - order = Order("AAL", 15.0, 1, pd.Timestamp("2013-12-11 9:30:00"), pd.Timestamp("2013-12-11 15:59:59")) + order = Order( + "AAL", + 15.0, + 1, + pd.Timestamp("2013-12-11 9:30:00"), + pd.Timestamp("2013-12-11 15:59:59"), + ) simulator = SingleAssetOrderExecutionSimple(order, DATA_DIR) assert len(simulator.ticks_for_order) == 390 @@ -224,16 +285,24 @@ def test_network_sanity(): class EmulateEnvWrapper(NamedTuple): status: EnvWrapperStatus - interpreter = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR)) + interpreter = FullHistoryStateInterpreter( + 13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR) + ) action_interp = CategoricalActionInterpreter(13) - wrapper_status_kwargs = dict(initial_state=order, obs_history=[], action_history=[], reward_history=[]) + wrapper_status_kwargs = dict( + initial_state=order, obs_history=[], action_history=[], reward_history=[] + ) network = Recurrent(interpreter.observation_space) - policy = PPO(network, interpreter.observation_space, action_interp.action_space, 1e-3) + policy = PPO( + network, interpreter.observation_space, action_interp.action_space, 1e-3 + ) for i in range(14): - interpreter.env = EmulateEnvWrapper(status=EnvWrapperStatus(cur_step=i, done=False, **wrapper_status_kwargs)) + interpreter.env = EmulateEnvWrapper( + status=EnvWrapperStatus(cur_step=i, done=False, **wrapper_status_kwargs) + ) obs = interpreter(simulator.get_state()) batch = Batch(obs=[obs]) output = policy(batch) @@ -252,7 +321,9 @@ def test_twap_strategy(finite_env_type): orders = pickle_styled.load_orders(ORDER_DIR) assert len(orders) == 248 - state_interp = FullHistoryStateInterpreter(13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR)) + state_interp = FullHistoryStateInterpreter( + 13, 390, 5, PickleProcessedDataProvider(FEATURE_DATA_DIR) + ) action_interp = TwapRelativeActionInterpreter() policy = AllOne(state_interp.observation_space, action_interp.action_space) csv_writer = CsvWriter(Path(__file__).parent / ".output") @@ -278,18 +349,30 @@ def test_twap_strategy(finite_env_type): def test_cn_ppo_strategy(): set_log_with_config(C.logging_config) # The data starts with 9:31 and ends with 15:00 - orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58")) + orders = pickle_styled.load_orders( + CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58") + ) assert len(orders) == 40 - state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR)) + state_interp = FullHistoryStateInterpreter( + 8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR) + ) action_interp = CategoricalActionInterpreter(4) network = Recurrent(state_interp.observation_space) - policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4) - policy.load_state_dict(torch.load(CN_POLICY_WEIGHTS_DIR / "ppo_recurrent_30min.pth", map_location="cpu")) + policy = PPO( + network, state_interp.observation_space, action_interp.action_space, 1e-4 + ) + policy.load_state_dict( + torch.load( + CN_POLICY_WEIGHTS_DIR / "ppo_recurrent_30min.pth", map_location="cpu" + ) + ) csv_writer = CsvWriter(Path(__file__).parent / ".output") backtest( - partial(SingleAssetOrderExecutionSimple, data_dir=CN_DATA_DIR, ticks_per_step=30), + partial( + SingleAssetOrderExecutionSimple, data_dir=CN_DATA_DIR, ticks_per_step=30 + ), state_interp, action_interp, orders, @@ -309,21 +392,32 @@ def test_cn_ppo_strategy(): def test_ppo_train(): set_log_with_config(C.logging_config) # The data starts with 9:31 and ends with 15:00 - orders = pickle_styled.load_orders(CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58")) + orders = pickle_styled.load_orders( + CN_ORDER_DIR, start_time=pd.Timestamp("9:31"), end_time=pd.Timestamp("14:58") + ) assert len(orders) == 40 - state_interp = FullHistoryStateInterpreter(8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR)) + state_interp = FullHistoryStateInterpreter( + 8, 240, 6, PickleProcessedDataProvider(CN_FEATURE_DATA_DIR) + ) action_interp = CategoricalActionInterpreter(4) network = Recurrent(state_interp.observation_space) - policy = PPO(network, state_interp.observation_space, action_interp.action_space, 1e-4) + policy = PPO( + network, state_interp.observation_space, action_interp.action_space, 1e-4 + ) train( - partial(SingleAssetOrderExecutionSimple, data_dir=CN_DATA_DIR, ticks_per_step=30), + partial( + SingleAssetOrderExecutionSimple, data_dir=CN_DATA_DIR, ticks_per_step=30 + ), state_interp, action_interp, orders, policy, PAPenaltyReward(), - vessel_kwargs={"episode_per_iter": 100, "update_kwargs": {"batch_size": 64, "repeat": 5}}, + vessel_kwargs={ + "episode_per_iter": 100, + "update_kwargs": {"batch_size": 64, "repeat": 5}, + }, trainer_kwargs={"max_iters": 2, "loggers": ConsoleWriter(total_episodes=100)}, ) diff --git a/tests/rl/test_trainer.py b/tests/rl/test_trainer.py index f842d9781b..d35f41b7e3 100644 --- a/tests/rl/test_trainer.py +++ b/tests/rl/test_trainer.py @@ -17,7 +17,9 @@ from qlib.rl.reward import Reward from qlib.rl.trainer import Trainer, TrainingVessel, EarlyStopping, Checkpoint -pytestmark = pytest.mark.skipif(sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8") +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 8), reason="Pickle styled data only supports Python >= 3.8" +) class ZeroSimulator(Simulator): @@ -173,7 +175,11 @@ def test_trainer_earlystop(): def test_trainer_checkpoint(): set_log_with_config(C.logging_config) output_dir = Path(__file__).parent / ".output" - trainer = Trainer(max_iters=2, finite_env_type="dummy", callbacks=[Checkpoint(output_dir, every_n_iters=1)]) + trainer = Trainer( + max_iters=2, + finite_env_type="dummy", + callbacks=[Checkpoint(output_dir, every_n_iters=1)], + ) policy = _ppo_policy() vessel = TrainingVessel( diff --git a/tests/rolling_tests/test_update_pred.py b/tests/rolling_tests/test_update_pred.py index b3ca2e0368..495c3ec554 100644 --- a/tests/rolling_tests/test_update_pred.py +++ b/tests/rolling_tests/test_update_pred.py @@ -33,7 +33,10 @@ def test_update_pred(self): train_end = latest_date - pd.Timedelta(days=41) task["dataset"]["kwargs"]["segments"] = { "train": (train_start, train_end), - "valid": (latest_date - pd.Timedelta(days=40), latest_date - pd.Timedelta(days=21)), + "valid": ( + latest_date - pd.Timedelta(days=40), + latest_date - pd.Timedelta(days=21), + ), "test": (latest_date - pd.Timedelta(days=20), latest_date), } @@ -56,8 +59,12 @@ def test_update_pred(self): good_pred = rec.load_object("pred.pkl") - mod_range = slice(latest_date - pd.Timedelta(days=20), latest_date - pd.Timedelta(days=10)) - mod_range2 = slice(latest_date - pd.Timedelta(days=9), latest_date - pd.Timedelta(days=2)) + mod_range = slice( + latest_date - pd.Timedelta(days=20), latest_date - pd.Timedelta(days=10) + ) + mod_range2 = slice( + latest_date - pd.Timedelta(days=9), latest_date - pd.Timedelta(days=2) + ) mod_pred = good_pred.copy() mod_pred.loc[mod_range] = -1 @@ -65,13 +72,16 @@ def test_update_pred(self): rec.save_objects(**{"pred.pkl": mod_pred}) online_tool.update_online_pred( - to_date=latest_date - pd.Timedelta(days=10), from_date=latest_date - pd.Timedelta(days=20) + to_date=latest_date - pd.Timedelta(days=10), + from_date=latest_date - pd.Timedelta(days=20), ) updated_pred = rec.load_object("pred.pkl") # this range is not fixed - self.assertTrue((updated_pred.loc[mod_range] == good_pred.loc[mod_range]).all().item()) + self.assertTrue( + (updated_pred.loc[mod_range] == good_pred.loc[mod_range]).all().item() + ) # this range is fixed now self.assertTrue((updated_pred.loc[mod_range2] == -2).all().item()) @@ -95,7 +105,10 @@ def test_update_label(self): train_end = latest_date - pd.Timedelta(days=41) task["dataset"]["kwargs"]["segments"] = { "train": (train_start, train_end), - "valid": (latest_date - pd.Timedelta(days=40), latest_date - pd.Timedelta(days=21)), + "valid": ( + latest_date - pd.Timedelta(days=40), + latest_date - pd.Timedelta(days=21), + ), "test": (latest_date - pd.Timedelta(days=20), latest_date), } @@ -128,7 +141,9 @@ def test_update_label(self): lu.update() new_label = rec.load_object("label.pkl") new_label_date = new_label.index.get_level_values("datetime").max() - self.assertTrue(new_label_date == pred_date) # make sure the label is updated now + self.assertTrue( + new_label_date == pred_date + ) # make sure the label is updated now if __name__ == "__main__": diff --git a/tests/storage_tests/test_storage.py b/tests/storage_tests/test_storage.py index 92fed34ecd..f6e88d4c73 100644 --- a/tests/storage_tests/test_storage.py +++ b/tests/storage_tests/test_storage.py @@ -22,9 +22,15 @@ class TestStorage(TestAutoData): def test_calendar_storage(self): - calendar = CalendarStorage(freq="day", future=False, provider_uri=self.provider_uri) - assert isinstance(calendar[:], Iterable), f"{calendar.__class__.__name__}.__getitem__(s: slice) is not Iterable" - assert isinstance(calendar.data, Iterable), f"{calendar.__class__.__name__}.data is not Iterable" + calendar = CalendarStorage( + freq="day", future=False, provider_uri=self.provider_uri + ) + assert isinstance( + calendar[:], Iterable + ), f"{calendar.__class__.__name__}.__getitem__(s: slice) is not Iterable" + assert isinstance( + calendar.data, Iterable + ), f"{calendar.__class__.__name__}.data is not Iterable" print(f"calendar[1: 5]: {calendar[1:5]}") print(f"calendar[0]: {calendar[0]}") @@ -74,7 +80,9 @@ def test_instrument_storage(self): """ - instrument = InstrumentStorage(market="csi300", provider_uri=self.provider_uri, freq="day") + instrument = InstrumentStorage( + market="csi300", provider_uri=self.provider_uri, freq="day" + ) for inst, spans in instrument.data.items(): assert isinstance(inst, str) and isinstance( @@ -87,7 +95,9 @@ def test_instrument_storage(self): print(f"instrument['SH600000']: {instrument['SH600000']}") - instrument = InstrumentStorage(market="csi300", provider_uri="not_found", freq="day") + instrument = InstrumentStorage( + market="csi300", provider_uri="not_found", freq="day" + ) with self.assertRaises(ValueError): print(instrument.data) @@ -148,19 +158,28 @@ def test_feature_storage(self): """ - feature = FeatureStorage(instrument="SZ300677", field="close", freq="day", provider_uri=self.provider_uri) + feature = FeatureStorage( + instrument="SZ300677", + field="close", + freq="day", + provider_uri=self.provider_uri, + ) with self.assertRaises(IndexError): print(feature[0]) assert isinstance( feature[3049][1], (float, np.float32) ), f"{feature.__class__.__name__}.__getitem__(i: int) error" - assert len(feature[3049:3052]) == 3, f"{feature.__class__.__name__}.__getitem__(s: slice) error" + assert ( + len(feature[3049:3052]) == 3 + ), f"{feature.__class__.__name__}.__getitem__(s: slice) error" print(f"feature[3049: 3052]: \n{feature[3049: 3052]}") print(f"feature[:].tail(): \n{feature[:].tail()}") - feature = FeatureStorage(instrument="SH600004", field="close", freq="day", provider_uri="not_fount") + feature = FeatureStorage( + instrument="SH600004", field="close", freq="day", provider_uri="not_fount" + ) with self.assertRaises(ValueError): print(feature[0]) diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index 7bbdaefe3c..332f2f2369 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -78,7 +78,11 @@ def fake_experiment(): current_uri_to_check = R.get_uri() default_uri_to_check = R.get_uri() - return default_uri == default_uri_to_check, current_uri == current_uri_to_check, current_uri + return ( + default_uri == default_uri_to_check, + current_uri == current_uri_to_check, + current_uri, + ) def backtest_analysis(pred, rid, uri_path: str = None): @@ -148,7 +152,9 @@ class TestAllFlow(TestAutoData): REPORT_NORMAL = None POSITIONS = None RID = None - URI_PATH = "file:" + str(Path(__file__).parent.joinpath("test_all_flow_mlruns").resolve()) + URI_PATH = "file:" + str( + Path(__file__).parent.joinpath("test_all_flow_mlruns").resolve() + ) @classmethod def tearDownClass(cls) -> None: @@ -162,9 +168,13 @@ def test_0_train(self): @pytest.mark.slow def test_1_backtest(self): - analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID, self.URI_PATH) + analyze_df = backtest_analysis( + TestAllFlow.PRED_SCORE, TestAllFlow.RID, self.URI_PATH + ) self.assertGreaterEqual( - analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0], + analyze_df.loc(axis=0)[ + "excess_return_with_cost", "annualized_return" + ].values[0], 0.05, "backtest failed", ) diff --git a/tests/test_contrib_model.py b/tests/test_contrib_model.py index a82a3042ec..c29a073179 100644 --- a/tests/test_contrib_model.py +++ b/tests/test_contrib_model.py @@ -13,7 +13,11 @@ def test_0_initialize(self): if model_class is not None: model = model_class() num += 1 - print("There are {:}/{:} valid models in total.".format(num, len(all_model_classes))) + print( + "There are {:}/{:} valid models in total.".format( + num, len(all_model_classes) + ) + ) def suite(): diff --git a/tests/test_contrib_workflow.py b/tests/test_contrib_workflow.py index c556472c0d..b5d6775555 100644 --- a/tests/test_contrib_workflow.py +++ b/tests/test_contrib_workflow.py @@ -59,7 +59,9 @@ def train_mse(uri_path: str = None): class TestAllFlow(TestAutoData): - URI_PATH = "file:" + str(Path(__file__).parent.joinpath("test_contrib_mlruns").resolve()) + URI_PATH = "file:" + str( + Path(__file__).parent.joinpath("test_contrib_mlruns").resolve() + ) @classmethod def tearDownClass(cls) -> None: diff --git a/tests/test_dump_data.py b/tests/test_dump_data.py index e24e3c759a..83539777c5 100644 --- a/tests/test_dump_data.py +++ b/tests/test_dump_data.py @@ -36,8 +36,12 @@ class TestDumpData(unittest.TestCase): @classmethod def setUpClass(cls) -> None: GetData().download_data(file_name="csv_data_cn.zip", target_dir=SOURCE_DIR) - TestDumpData.DUMP_DATA = DumpDataAll(data_path=SOURCE_DIR, qlib_dir=QLIB_DIR, include_fields=cls.FIELDS) - TestDumpData.STOCK_NAMES = list(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv"))) + TestDumpData.DUMP_DATA = DumpDataAll( + data_path=SOURCE_DIR, qlib_dir=QLIB_DIR, include_fields=cls.FIELDS + ) + TestDumpData.STOCK_NAMES = list( + map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv")) + ) provider_uri = str(QLIB_DIR.resolve()) qlib.init( provider_uri=provider_uri, @@ -56,34 +60,51 @@ def test_1_dump_calendars(self): ori_calendars = set( map( pd.Timestamp, - pd.read_csv(QLIB_DIR.joinpath("calendars", "day.txt"), header=None).loc[:, 0].values, + pd.read_csv(QLIB_DIR.joinpath("calendars", "day.txt"), header=None) + .loc[:, 0] + .values, ) ) res_calendars = set(D.calendar()) - assert len(ori_calendars - res_calendars) == len(res_calendars - ori_calendars) == 0, "dump calendars failed" + assert ( + len(ori_calendars - res_calendars) + == len(res_calendars - ori_calendars) + == 0 + ), "dump calendars failed" def test_2_dump_instruments(self): ori_ins = set(map(lambda x: x.name[:-4].upper(), SOURCE_DIR.glob("*.csv"))) res_ins = set(D.list_instruments(D.instruments("all"), as_list=True)) - assert len(ori_ins - res_ins) == len(ori_ins - res_ins) == 0, "dump instruments failed" + assert ( + len(ori_ins - res_ins) == len(ori_ins - res_ins) == 0 + ), "dump instruments failed" def test_3_dump_features(self): df = D.features(self.STOCK_NAMES, self.QLIB_FIELDS) TestDumpData.SIMPLE_DATA = df.loc(axis=0)[self.STOCK_NAMES[0], :] self.assertFalse(df.dropna().empty, "features data failed") - self.assertListEqual(list(df.columns), self.QLIB_FIELDS, "features columns failed") + self.assertListEqual( + list(df.columns), self.QLIB_FIELDS, "features columns failed" + ) def test_4_dump_features_simple(self): stock = self.STOCK_NAMES[0] dump_data = DumpDataFix( - data_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), qlib_dir=QLIB_DIR, include_fields=self.FIELDS + data_path=SOURCE_DIR.joinpath(f"{stock.lower()}.csv"), + qlib_dir=QLIB_DIR, + include_fields=self.FIELDS, ) dump_data.dump() df = D.features([stock], self.QLIB_FIELDS) - self.assertEqual(len(df), len(TestDumpData.SIMPLE_DATA), "dump features simple failed") - self.assertTrue(np.isclose(df.dropna(), self.SIMPLE_DATA.dropna()).all(), "dump features simple failed") + self.assertEqual( + len(df), len(TestDumpData.SIMPLE_DATA), "dump features simple failed" + ) + self.assertTrue( + np.isclose(df.dropna(), self.SIMPLE_DATA.dropna()).all(), + "dump features simple failed", + ) if __name__ == "__main__": diff --git a/tests/test_get_data.py b/tests/test_get_data.py index 125b9203e6..2cf0f91c3f 100644 --- a/tests/test_get_data.py +++ b/tests/test_get_data.py @@ -34,7 +34,12 @@ def tearDownClass(cls) -> None: def test_0_qlib_data(self): GetData().qlib_data( - name="qlib_data_simple", target_dir=QLIB_DIR, region="cn", interval="1d", delete_old=False, exists_skip=True + name="qlib_data_simple", + target_dir=QLIB_DIR, + region="cn", + interval="1d", + delete_old=False, + exists_skip=True, ) df = D.features(D.instruments("csi300"), self.FIELDS) self.assertListEqual(list(df.columns), self.FIELDS, "get qlib data failed") diff --git a/tests/test_pit.py b/tests/test_pit.py index 548f91baaa..eaafc651a3 100644 --- a/tests/test_pit.py +++ b/tests/test_pit.py @@ -16,7 +16,9 @@ sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts"))) from dump_pit import DumpPitData -sys.path.append(str(Path(__file__).resolve().parent.parent.joinpath("scripts/data_collector/pit"))) +sys.path.append( + str(Path(__file__).resolve().parent.parent.joinpath("scripts/data_collector/pit")) +) from collector import Run @@ -41,9 +43,19 @@ def setUpClass(cls) -> None: pit_dir = str(SOURCE_DIR.joinpath("pit").resolve()) pit_normalized_dir = str(SOURCE_DIR.joinpath("pit_normalized").resolve()) GetData().qlib_data( - name="qlib_data_simple", target_dir=cn_data_dir, region="cn", delete_old=False, exists_skip=True + name="qlib_data_simple", + target_dir=cn_data_dir, + region="cn", + delete_old=False, + exists_skip=True, + ) + GetData().qlib_data( + name="qlib_data", + target_dir=pit_dir, + region="pit", + delete_old=False, + exists_skip=True, ) - GetData().qlib_data(name="qlib_data", target_dir=pit_dir, region="pit", delete_old=False, exists_skip=True) # NOTE: This code does the same thing as line 43, but since baostock is not stable in downloading data, we have chosen to download offline data. # bs.login() @@ -79,7 +91,13 @@ def test_query(self): fields = ["P($$roewa_q)", "P($$yoyni_q)"] # Mao Tai published 2019Q2 report at 2019-07-13 & 2019-07-18 # - http://www.cninfo.com.cn/new/commonUrl/pageOfSearch?url=disclosure/list/search&lastPage=index - data = D.features(instruments, fields, start_time="2019-01-01", end_time="2019-07-19", freq="day") + data = D.features( + instruments, + fields, + start_time="2019-01-01", + end_time="2019-07-19", + freq="day", + ) res = """ P($$roewa_q) P($$yoyni_q) count 133.000000 133.000000 @@ -106,7 +124,13 @@ def test_query(self): def test_no_exist_data(self): fields = ["P($$roewa_q)", "P($$yoyni_q)", "$close"] - data = D.features(["sh600519", "sh601988"], fields, start_time="2019-01-01", end_time="2019-07-19", freq="day") + data = D.features( + ["sh600519", "sh601988"], + fields, + start_time="2019-01-01", + end_time="2019-07-19", + freq="day", + ) data["$close"] = 1 # in case of different dataset gives different values expect = """ P($$roewa_q) P($$yoyni_q) $close @@ -137,7 +161,13 @@ def test_expr(self): "P((Ref($$roewa_q, 1) +$$roewa_q) / 2)", ] instruments = ["sh600519"] - data = D.features(instruments, fields, start_time="2019-01-01", end_time="2019-07-19", freq="day") + data = D.features( + instruments, + fields, + start_time="2019-01-01", + end_time="2019-07-19", + freq="day", + ) expect = """ P(Mean($$roewa_q, 1)) P($$roewa_q) P(Mean($$roewa_q, 2)) P(Ref($$roewa_q, 1)) P((Ref($$roewa_q, 1) +$$roewa_q) / 2) instrument datetime @@ -164,7 +194,9 @@ def test_unlimit(self): fields = ["P($$roewa_q)"] instruments = ["sh600519"] _ = D.features(instruments, fields, freq="day") # this should not raise error - data = D.features(instruments, fields, end_time="2020-01-01", freq="day") # this should not raise error + data = D.features( + instruments, fields, end_time="2020-01-01", freq="day" + ) # this should not raise error s = data.iloc[:, 0] # You can check the expected value based on the content in `docs/advanced/PIT.rst` expect = """ @@ -233,7 +265,13 @@ def test_expr2(self): fields += ["P(($$roewa_q / $$yoyni_q) / Ref($$roewa_q / $$yoyni_q, 1) - 1)"] fields += ["P(Sum($$yoyni_q, 4))"] fields += ["$close", "P($$roewa_q) * $close"] - data = D.features(instruments, fields, start_time="2019-01-01", end_time="2020-01-01", freq="day") + data = D.features( + instruments, + fields, + start_time="2019-01-01", + end_time="2020-01-01", + freq="day", + ) except_data = """ P($$roewa_q) P($$yoyni_q) P(($$roewa_q / $$yoyni_q) / Ref($$roewa_q / $$yoyni_q, 1) - 1) P(Sum($$yoyni_q, 4)) $close P($$roewa_q) * $close instrument datetime @@ -261,7 +299,13 @@ def test_pref_operator(self): "P($$roewa_q)", "P($$roewa_q) / PRef($$roewa_q, 201801)", ] - data = D.features(instruments, fields, start_time="2018-04-28", end_time="2019-07-19", freq="day") + data = D.features( + instruments, + fields, + start_time="2018-04-28", + end_time="2019-07-19", + freq="day", + ) except_data = """ PRef($$roewa_q, 201902) PRef($$yoyni_q, 201801) P($$roewa_q) P($$roewa_q) / PRef($$roewa_q, 201801) instrument datetime diff --git a/tests/test_register_ops.py b/tests/test_register_ops.py index ac86be59ce..4137f2fd21 100644 --- a/tests/test_register_ops.py +++ b/tests/test_register_ops.py @@ -57,7 +57,15 @@ def setUpClass(cls) -> None: def test_regiter_custom_ops(self): instruments = ["SH600000"] fields = ["Diff($close)", "Distance($close, Ref($close, 1))"] - print(D.features(instruments, fields, start_time="2010-01-01", end_time="2017-12-31", freq="day")) + print( + D.features( + instruments, + fields, + start_time="2010-01-01", + end_time="2017-12-31", + freq="day", + ) + ) if __name__ == "__main__": diff --git a/tests/test_structured_cov_estimator.py b/tests/test_structured_cov_estimator.py index 494962cc33..022ea834b3 100644 --- a/tests/test_structured_cov_estimator.py +++ b/tests/test_structured_cov_estimator.py @@ -20,7 +20,9 @@ def test_random_covariance(self): X = np.random.rand(NUM_OBSERVATION, NUM_VARIABLE) est_cov = estimator.predict(X, is_price=False) - np_cov = np.cov(X.T) # While numpy assume row means variable, qlib assume the other wise. + np_cov = np.cov( + X.T + ) # While numpy assume row means variable, qlib assume the other wise. delta = abs(est_cov - np_cov) if_identical = (delta < EPS).all() @@ -33,12 +35,16 @@ def test_nan_option_covariance(self): NUM_OBSERVATION = 200 EPS = 1e-6 - estimator = StructuredCovEstimator(scale_return=False, assume_centered=True, nan_option="fill") + estimator = StructuredCovEstimator( + scale_return=False, assume_centered=True, nan_option="fill" + ) X = np.random.rand(NUM_OBSERVATION, NUM_VARIABLE) est_cov = estimator.predict(X, is_price=False) - np_cov = np.cov(X.T) # While numpy assume row means variable, qlib assume the other wise. + np_cov = np.cov( + X.T + ) # While numpy assume row means variable, qlib assume the other wise. delta = abs(est_cov - np_cov) if_identical = (delta < EPS).all() @@ -50,11 +56,15 @@ def test_decompose_covariance(self): NUM_VARIABLE = 10 NUM_OBSERVATION = 200 - estimator = StructuredCovEstimator(scale_return=False, assume_centered=True, nan_option="fill") + estimator = StructuredCovEstimator( + scale_return=False, assume_centered=True, nan_option="fill" + ) X = np.random.rand(NUM_OBSERVATION, NUM_VARIABLE) - F, cov_b, var_u = estimator.predict(X, is_price=False, return_decomposed_components=True) + F, cov_b, var_u = estimator.predict( + X, is_price=False, return_decomposed_components=True + ) self.assertTrue(F is not None and cov_b is not None and var_u is not None) @@ -65,7 +75,9 @@ def test_constructed_covariance(self): NUM_OBSERVATION = 500 EPS = 0.1 - estimator = StructuredCovEstimator(scale_return=False, assume_centered=True, num_factors=NUM_VARIABLE - 1) + estimator = StructuredCovEstimator( + scale_return=False, assume_centered=True, num_factors=NUM_VARIABLE - 1 + ) sqrt_cov = None while sqrt_cov is None or (np.iscomplex(sqrt_cov)).any(): @@ -76,7 +88,9 @@ def test_constructed_covariance(self): X = np.random.rand(NUM_OBSERVATION, NUM_VARIABLE) @ sqrt_cov est_cov = estimator.predict(X, is_price=False) - np_cov = np.cov(X.T) # While numpy assume row means variable, qlib assume the other wise. + np_cov = np.cov( + X.T + ) # While numpy assume row means variable, qlib assume the other wise. delta = abs(est_cov - np_cov) if_identical = (delta < EPS).all() @@ -91,7 +105,9 @@ def test_decomposition(self): NUM_FACTOR = 10 EPS = 0.1 - estimator = StructuredCovEstimator(scale_return=False, assume_centered=True, num_factors=NUM_FACTOR) + estimator = StructuredCovEstimator( + scale_return=False, assume_centered=True, num_factors=NUM_FACTOR + ) F = np.random.rand(NUM_VARIABLE, NUM_FACTOR) B = np.random.rand(NUM_FACTOR, NUM_OBSERVATION) @@ -99,7 +115,9 @@ def test_decomposition(self): X = (F @ B).T + U est_cov = estimator.predict(X, is_price=False) - np_cov = np.cov(X.T) # While numpy assume row means variable, qlib assume the other wise. + np_cov = np.cov( + X.T + ) # While numpy assume row means variable, qlib assume the other wise. delta = abs(est_cov - np_cov) if_identical = (delta < EPS).all()