diff --git a/acceptance/bundle/deploy/registered_models/basic/databricks.yml.tmpl b/acceptance/bundle/deploy/registered_models/basic/databricks.yml.tmpl new file mode 100644 index 0000000000..9c10c6bbc4 --- /dev/null +++ b/acceptance/bundle/deploy/registered_models/basic/databricks.yml.tmpl @@ -0,0 +1,10 @@ +bundle: + name: deploy-registered-models-basic-$UNIQUE_NAME + +resources: + registered_models: + my_registered_model: + name: $NAME + comment: $COMMENT + catalog_name: $CATALOG_NAME + schema_name: $SCHEMA_NAME diff --git a/acceptance/bundle/deploy/registered_models/basic/out.test.toml b/acceptance/bundle/deploy/registered_models/basic/out.test.toml new file mode 100644 index 0000000000..c969c92e84 --- /dev/null +++ b/acceptance/bundle/deploy/registered_models/basic/out.test.toml @@ -0,0 +1,6 @@ +Local = true +Cloud = true +RequiresUnityCatalog = true + +[EnvMatrix] + DATABRICKS_BUNDLE_ENGINE = ["terraform", "direct-exp"] diff --git a/acceptance/bundle/deploy/registered_models/basic/output.txt b/acceptance/bundle/deploy/registered_models/basic/output.txt new file mode 100644 index 0000000000..f03508fe3a --- /dev/null +++ b/acceptance/bundle/deploy/registered_models/basic/output.txt @@ -0,0 +1,132 @@ + +>>> export NAME=my-registered-model-[UNIQUE_NAME] + +>>> export COMMENT=original comment + +>>> export CATALOG_NAME=main + +>>> export SCHEMA_NAME=default + +=== create catalog and schema to test diff functionality +>>> [CLI] catalogs create mycatalog-[UNIQUE_NAME] +{ + "full_name": "mycatalog-[UNIQUE_NAME]" +} + +>>> [CLI] schemas create myschema-[UNIQUE_NAME] mycatalog-[UNIQUE_NAME] +{ + "full_name": "mycatalog-[UNIQUE_NAME].myschema-[UNIQUE_NAME]" +} + +=== create the registered model +>>> [CLI] bundle plan +create registered_models.my_registered_model + +Plan: 1 to add, 0 to change, 0 to delete, 0 unchanged + +>>> [CLI] bundle deploy +Uploading bundle files to /Workspace/Users/[USERNAME]/.bundle/deploy-registered-models-basic-[UNIQUE_NAME]/default/files... +Deploying resources... +Updating deployment state... +Deployment complete! + +>>> [CLI] registered-models get main.default.my-registered-model-[UNIQUE_NAME] +{ + "name": "my-registered-model-[UNIQUE_NAME]", + "comment": "original comment", + "catalog_name": "main", + "schema_name": "default" +} + +=== update the comment, this should not recreate +>>> [CLI] bundle plan +update registered_models.my_registered_model + +Plan: 0 to add, 1 to change, 0 to delete, 0 unchanged + +>>> [CLI] bundle deploy +Uploading bundle files to /Workspace/Users/[USERNAME]/.bundle/deploy-registered-models-basic-[UNIQUE_NAME]/default/files... +Deploying resources... +Updating deployment state... +Deployment complete! + +>>> [CLI] registered-models get main.default.my-registered-model-[UNIQUE_NAME] +{ + "name": "my-registered-model-[UNIQUE_NAME]", + "comment": "updated comment", + "catalog_name": "main", + "schema_name": "default" +} + +=== update the name, this should recreate +>>> [CLI] bundle plan +recreate registered_models.my_registered_model + +Plan: 1 to add, 0 to change, 1 to delete, 0 unchanged + +>>> [CLI] bundle deploy +Uploading bundle files to /Workspace/Users/[USERNAME]/.bundle/deploy-registered-models-basic-[UNIQUE_NAME]/default/files... +Deploying resources... +Updating deployment state... +Deployment complete! + +>>> [CLI] registered-models get main.default.my-registered-model-updated-[UNIQUE_NAME] +{ + "name": "my-registered-model-updated-[UNIQUE_NAME]", + "comment": "updated comment", + "catalog_name": "main", + "schema_name": "default" +} + +=== update the catalog name, this should recreate +>>> [CLI] bundle plan +recreate registered_models.my_registered_model + +Plan: 1 to add, 0 to change, 1 to delete, 0 unchanged + +>>> [CLI] bundle deploy +Uploading bundle files to /Workspace/Users/[USERNAME]/.bundle/deploy-registered-models-basic-[UNIQUE_NAME]/default/files... +Deploying resources... +Updating deployment state... +Deployment complete! + +>>> [CLI] registered-models get mycatalog-[UNIQUE_NAME].default.my-registered-model-updated-[UNIQUE_NAME] +{ + "name": "my-registered-model-updated-[UNIQUE_NAME]", + "comment": "updated comment", + "catalog_name": "mycatalog-[UNIQUE_NAME]", + "schema_name": "default" +} + +=== update the schema name, this should recreate +>>> [CLI] bundle plan +recreate registered_models.my_registered_model + +Plan: 1 to add, 0 to change, 1 to delete, 0 unchanged + +>>> [CLI] bundle deploy +Uploading bundle files to /Workspace/Users/[USERNAME]/.bundle/deploy-registered-models-basic-[UNIQUE_NAME]/default/files... +Deploying resources... +Updating deployment state... +Deployment complete! + +>>> [CLI] registered-models get mycatalog-[UNIQUE_NAME].myschema-[UNIQUE_NAME].my-registered-model-updated-[UNIQUE_NAME] +{ + "name": "my-registered-model-updated-[UNIQUE_NAME]", + "comment": "updated comment", + "catalog_name": "mycatalog-[UNIQUE_NAME]", + "schema_name": "myschema-[UNIQUE_NAME]" +} + +>>> [CLI] bundle destroy --auto-approve +The following resources will be deleted: + delete registered_model my_registered_model + +All files and directories at the following location will be deleted: /Workspace/Users/[USERNAME]/.bundle/deploy-registered-models-basic-[UNIQUE_NAME]/default + +Deleting files... +Destroy complete! + +>>> [CLI] schemas delete mycatalog-[UNIQUE_NAME].myschema-[UNIQUE_NAME] --force + +>>> [CLI] catalogs delete mycatalog-[UNIQUE_NAME] --force diff --git a/acceptance/bundle/deploy/registered_models/basic/script b/acceptance/bundle/deploy/registered_models/basic/script new file mode 100644 index 0000000000..42313057e0 --- /dev/null +++ b/acceptance/bundle/deploy/registered_models/basic/script @@ -0,0 +1,50 @@ +trace export NAME="my-registered-model-$UNIQUE_NAME" +trace export COMMENT="original comment" +trace export CATALOG_NAME="main" +trace export SCHEMA_NAME="default" +envsubst < databricks.yml.tmpl > databricks.yml + +title "create catalog and schema to test diff functionality" +catalog_name="mycatalog-${UNIQUE_NAME}" +schema_name="myschema-${UNIQUE_NAME}" +trace $CLI catalogs create ${catalog_name} | jq '{full_name}' +trace $CLI schemas create ${schema_name} ${catalog_name} | jq '{full_name}' + +cleanup() { + trace $CLI bundle destroy --auto-approve + trace $CLI schemas delete ${catalog_name}.${schema_name} --force + trace $CLI catalogs delete ${catalog_name} --force +} +trap cleanup EXIT + +deploy_registered_model() { + trace $CLI bundle plan + trace $CLI bundle deploy + registered_model_id=$($CLI bundle summary --output json | jq -r '.resources.registered_models.my_registered_model.id') + trace $CLI registered-models get "${registered_model_id}" | jq '{name, comment, catalog_name, schema_name}' +} + +title "create the registered model" +deploy_registered_model + +export COMMENT="updated comment" +envsubst < databricks.yml.tmpl > databricks.yml + +title "update the comment, this should not recreate" +deploy_registered_model + +export NAME="my-registered-model-updated-$UNIQUE_NAME" +envsubst < databricks.yml.tmpl > databricks.yml + +title "update the name, this should recreate" +deploy_registered_model + +title "update the catalog name, this should recreate" +export CATALOG_NAME="${catalog_name}" +envsubst < databricks.yml.tmpl > databricks.yml +deploy_registered_model + +title "update the schema name, this should recreate" +export SCHEMA_NAME="${schema_name}" +envsubst < databricks.yml.tmpl > databricks.yml +deploy_registered_model diff --git a/acceptance/bundle/deploy/registered_models/basic/test.toml b/acceptance/bundle/deploy/registered_models/basic/test.toml new file mode 100644 index 0000000000..80d5c3424e --- /dev/null +++ b/acceptance/bundle/deploy/registered_models/basic/test.toml @@ -0,0 +1,3 @@ +Cloud = true +Local = true +RequiresUnityCatalog = true diff --git a/acceptance/bundle/refschema/out.fields.txt b/acceptance/bundle/refschema/out.fields.txt index 91ebbdc1ad..92b66ab7ef 100644 --- a/acceptance/bundle/refschema/out.fields.txt +++ b/acceptance/bundle/refschema/out.fields.txt @@ -2571,6 +2571,33 @@ resources.pipelines.*.trigger.cron.quartz_cron_schedule string INPUT STATE resources.pipelines.*.trigger.cron.timezone_id string INPUT STATE resources.pipelines.*.trigger.manual *pipelines.ManualTrigger INPUT STATE resources.pipelines.*.url string INPUT +resources.registered_models.*.aliases []catalog.RegisteredModelAlias REMOTE +resources.registered_models.*.aliases[*] catalog.RegisteredModelAlias REMOTE +resources.registered_models.*.aliases[*].alias_name string REMOTE +resources.registered_models.*.aliases[*].version_num int REMOTE +resources.registered_models.*.browse_only bool REMOTE +resources.registered_models.*.catalog_name string ALL +resources.registered_models.*.comment string ALL +resources.registered_models.*.created_at int64 REMOTE +resources.registered_models.*.created_by string REMOTE +resources.registered_models.*.full_name string REMOTE +resources.registered_models.*.grants []resources.Grant INPUT +resources.registered_models.*.grants[*] resources.Grant INPUT +resources.registered_models.*.grants[*].principal string INPUT +resources.registered_models.*.grants[*].privileges []string INPUT +resources.registered_models.*.grants[*].privileges[*] string INPUT +resources.registered_models.*.id string INPUT +resources.registered_models.*.lifecycle resources.Lifecycle INPUT +resources.registered_models.*.lifecycle.prevent_destroy bool INPUT +resources.registered_models.*.metastore_id string REMOTE +resources.registered_models.*.modified_status string INPUT +resources.registered_models.*.name string ALL +resources.registered_models.*.owner string REMOTE +resources.registered_models.*.schema_name string ALL +resources.registered_models.*.storage_location string ALL +resources.registered_models.*.updated_at int64 REMOTE +resources.registered_models.*.updated_by string REMOTE +resources.registered_models.*.url string INPUT resources.schemas.*.browse_only bool REMOTE resources.schemas.*.catalog_name string ALL resources.schemas.*.catalog_type catalog.CatalogType REMOTE diff --git a/bundle/direct/dresources/all.go b/bundle/direct/dresources/all.go index fcf68e3836..dfb79a6343 100644 --- a/bundle/direct/dresources/all.go +++ b/bundle/direct/dresources/all.go @@ -19,6 +19,7 @@ var SupportedResources = map[string]any{ "database_catalogs": (*ResourceDatabaseCatalog)(nil), "synced_database_tables": (*ResourceSyncedDatabaseTable)(nil), "alerts": (*ResourceAlert)(nil), + "registered_models": (*ResourceRegisteredModel)(nil), } func InitAll(client *databricks.WorkspaceClient) (map[string]*Adapter, error) { diff --git a/bundle/direct/dresources/all_test.go b/bundle/direct/dresources/all_test.go index 18116b5be1..455d94bdaa 100644 --- a/bundle/direct/dresources/all_test.go +++ b/bundle/direct/dresources/all_test.go @@ -61,6 +61,17 @@ var testConfig map[string]any = map[string]any{ Name: "main.myschema.my_synced_table", }, }, + + "registered_models": &resources.RegisteredModel{ + CreateRegisteredModelRequest: catalog.CreateRegisteredModelRequest{ + Name: "my_registered_model", + Comment: "Test registered model", + CatalogName: "main", + SchemaName: "default", + StorageLocation: "s3://my-bucket/my-path", + }, + }, + "experiments": &resources.MlflowExperiment{ CreateExperiment: ml.CreateExperiment{ Name: "my-experiment", @@ -73,6 +84,7 @@ var testConfig map[string]any = map[string]any{ ArtifactLocation: "s3://my-bucket/my-experiment", }, }, + "models": &resources.MlflowModel{ CreateModelRequest: ml.CreateModelRequest{ Name: "my_mlflow_model", @@ -162,22 +174,26 @@ func testCRUD(t *testing.T, group string, adapter *Adapter, client *databricks.W require.Equal(t, remote, remoteStateFromWaitCreate) } + remappedState, err := adapter.RemapState(remote) + require.NoError(t, err) + require.NotNil(t, remappedState) + remoteStateFromUpdate, err := adapter.DoUpdate(ctx, createdID, newState) require.NoError(t, err, "DoUpdate failed") if remoteStateFromUpdate != nil { - require.Equal(t, remote, remoteStateFromUpdate) + remappedStateFromUpdate, err := adapter.RemapState(remoteStateFromUpdate) + require.NoError(t, err) + require.Equal(t, remappedState, remappedStateFromUpdate) } remoteStateFromWaitUpdate, err := adapter.WaitAfterUpdate(ctx, newState) require.NoError(t, err) if remoteStateFromWaitUpdate != nil { - require.Equal(t, remote, remoteStateFromWaitUpdate) + remappedStateFromWaitUpdate, err := adapter.RemapState(remoteStateFromWaitUpdate) + require.NoError(t, err) + require.Equal(t, remappedState, remappedStateFromWaitUpdate) } - remappedState, err := adapter.RemapState(remote) - require.NoError(t, err) - require.NotNil(t, remappedState) - require.NoError(t, structwalk.Walk(newState, func(path *structpath.PathNode, val any, field *reflect.StructField) { remoteValue, err := structaccess.Get(remappedState, path) if err != nil { diff --git a/bundle/direct/dresources/registered_model.go b/bundle/direct/dresources/registered_model.go new file mode 100644 index 0000000000..b01b90b319 --- /dev/null +++ b/bundle/direct/dresources/registered_model.go @@ -0,0 +1,93 @@ +package dresources + +import ( + "context" + + "github.com/databricks/cli/bundle/config/resources" + "github.com/databricks/cli/bundle/deployplan" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/catalog" +) + +type ResourceRegisteredModel struct { + client *databricks.WorkspaceClient +} + +func (*ResourceRegisteredModel) New(client *databricks.WorkspaceClient) *ResourceRegisteredModel { + return &ResourceRegisteredModel{ + client: client, + } +} + +func (*ResourceRegisteredModel) PrepareState(input *resources.RegisteredModel) *catalog.CreateRegisteredModelRequest { + return &input.CreateRegisteredModelRequest +} + +func (*ResourceRegisteredModel) RemapState(model *catalog.RegisteredModelInfo) *catalog.CreateRegisteredModelRequest { + return &catalog.CreateRegisteredModelRequest{ + CatalogName: model.CatalogName, + Comment: model.Comment, + Name: model.Name, + SchemaName: model.SchemaName, + StorageLocation: model.StorageLocation, + ForceSendFields: filterFields[catalog.CreateRegisteredModelRequest](model.ForceSendFields), + } +} + +func (r *ResourceRegisteredModel) DoRefresh(ctx context.Context, id string) (*catalog.RegisteredModelInfo, error) { + return r.client.RegisteredModels.Get(ctx, catalog.GetRegisteredModelRequest{ + FullName: id, + IncludeAliases: false, + IncludeBrowse: false, + ForceSendFields: nil, + }) +} + +func (r *ResourceRegisteredModel) DoCreate(ctx context.Context, config *catalog.CreateRegisteredModelRequest) (string, *catalog.RegisteredModelInfo, error) { + response, err := r.client.RegisteredModels.Create(ctx, *config) + if err != nil { + return "", nil, err + } + + return response.FullName, response, nil +} + +func (r *ResourceRegisteredModel) DoUpdate(ctx context.Context, id string, config *catalog.CreateRegisteredModelRequest) (*catalog.RegisteredModelInfo, error) { + updateRequest := catalog.UpdateRegisteredModelRequest{ + FullName: id, + Comment: config.Comment, + ForceSendFields: filterFields[catalog.UpdateRegisteredModelRequest](config.ForceSendFields, "Owner", "NewName"), + + // Owner is not part of the configuration tree + Owner: "", + + // Name updates are not supported yet without recreating. Can be added as a follow-up. + // Note: TF also does not support changing name without a recreate so the current behavior matches TF. + NewName: "", + } + + response, err := r.client.RegisteredModels.Update(ctx, updateRequest) + if err != nil { + return nil, err + } + + return response, nil +} + +func (r *ResourceRegisteredModel) DoDelete(ctx context.Context, id string) error { + return r.client.RegisteredModels.Delete(ctx, catalog.DeleteRegisteredModelRequest{ + FullName: id, + }) +} + +func (*ResourceRegisteredModel) FieldTriggers() map[string]deployplan.ActionType { + return map[string]deployplan.ActionType{ + // The name can technically be updated without recreated. We recreate for now though + // to match TF implementation. + "name": deployplan.ActionTypeRecreate, + + "catalog_name": deployplan.ActionTypeRecreate, + "schema_name": deployplan.ActionTypeRecreate, + "storage_location": deployplan.ActionTypeRecreate, + } +} diff --git a/libs/testserver/catalogs.go b/libs/testserver/catalogs.go new file mode 100644 index 0000000000..25d7cbd5f7 --- /dev/null +++ b/libs/testserver/catalogs.go @@ -0,0 +1,90 @@ +package testserver + +import ( + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/databricks/databricks-sdk-go/service/catalog" + "github.com/google/uuid" +) + +func (s *FakeWorkspace) CatalogsCreate(req Request) Response { + defer s.LockUnlock()() + + var createRequest catalog.CreateCatalog + if err := json.Unmarshal(req.Body, &createRequest); err != nil { + return Response{ + Body: fmt.Sprintf("internal error: %s", err), + StatusCode: http.StatusInternalServerError, + } + } + + catalogInfo := catalog.CatalogInfo{ + Name: createRequest.Name, + Comment: createRequest.Comment, + StorageRoot: createRequest.StorageRoot, + ProviderName: createRequest.ProviderName, + ShareName: createRequest.ShareName, + Options: createRequest.Options, + Properties: createRequest.Properties, + FullName: createRequest.Name, + CreatedAt: time.Now().UnixMilli(), + CreatedBy: s.CurrentUser().UserName, + UpdatedAt: time.Now().UnixMilli(), + UpdatedBy: s.CurrentUser().UserName, + MetastoreId: uuid.New().String(), + Owner: s.CurrentUser().UserName, + CatalogType: catalog.CatalogTypeManagedCatalog, + } + + s.Catalogs[createRequest.Name] = catalogInfo + return Response{ + Body: catalogInfo, + } +} + +func (s *FakeWorkspace) CatalogsUpdate(req Request, name string) Response { + defer s.LockUnlock()() + + existing, ok := s.Catalogs[name] + if !ok { + return Response{ + StatusCode: http.StatusNotFound, + Body: fmt.Sprintf("catalog %s not found", name), + } + } + + var updateRequest catalog.UpdateCatalog + if err := json.Unmarshal(req.Body, &updateRequest); err != nil { + return Response{ + Body: fmt.Sprintf("internal error: %s", err), + StatusCode: http.StatusInternalServerError, + } + } + + // Update only the fields that can be updated + if updateRequest.Comment != "" { + existing.Comment = updateRequest.Comment + } + if updateRequest.Owner != "" { + existing.Owner = updateRequest.Owner + } + if updateRequest.NewName != "" { + existing.Name = updateRequest.NewName + existing.FullName = updateRequest.NewName + + // Delete the old entry and create with new name + delete(s.Catalogs, name) + name = updateRequest.NewName + } + + existing.UpdatedAt = time.Now().UnixMilli() + existing.UpdatedBy = s.CurrentUser().UserName + + s.Catalogs[name] = existing + return Response{ + Body: existing, + } +} diff --git a/libs/testserver/fake_workspace.go b/libs/testserver/fake_workspace.go index ceeb1a5a26..a02c0ae0c5 100644 --- a/libs/testserver/fake_workspace.go +++ b/libs/testserver/fake_workspace.go @@ -78,6 +78,8 @@ type FakeWorkspace struct { Alerts map[string]sql.AlertV2 Experiments map[string]ml.GetExperimentResponse ModelRegistryModels map[string]ml.Model + Catalogs map[string]catalog.CatalogInfo + RegisteredModels map[string]catalog.RegisteredModelInfo Acls map[string][]workspace.AclItem @@ -165,7 +167,9 @@ func NewFakeWorkspace(url, token string) *FakeWorkspace { PipelineUpdates: map[string]bool{}, Monitors: map[string]catalog.MonitorInfo{}, Apps: map[string]apps.App{}, + Catalogs: map[string]catalog.CatalogInfo{}, Schemas: map[string]catalog.SchemaInfo{}, + RegisteredModels: map[string]catalog.RegisteredModelInfo{}, Volumes: map[string]catalog.VolumeInfo{}, Dashboards: map[string]dashboards.Dashboard{}, SqlWarehouses: map[string]sql.GetWarehouseResponse{}, diff --git a/libs/testserver/handlers.go b/libs/testserver/handlers.go index ef31dde2f2..2c6343100d 100644 --- a/libs/testserver/handlers.go +++ b/libs/testserver/handlers.go @@ -357,6 +357,42 @@ func AddDefaultHandlers(server *Server) { return req.Workspace.SchemasGetGrants(req, req.Vars["full_name"]) }) + // Catalogs: + + server.Handle("GET", "/api/2.1/unity-catalog/catalogs/{name}", func(req Request) any { + return MapGet(req.Workspace, req.Workspace.Catalogs, req.Vars["name"]) + }) + + server.Handle("POST", "/api/2.1/unity-catalog/catalogs", func(req Request) any { + return req.Workspace.CatalogsCreate(req) + }) + + server.Handle("PATCH", "/api/2.1/unity-catalog/catalogs/{name}", func(req Request) any { + return req.Workspace.CatalogsUpdate(req, req.Vars["name"]) + }) + + server.Handle("DELETE", "/api/2.1/unity-catalog/catalogs/{name}", func(req Request) any { + return MapDelete(req.Workspace, req.Workspace.Catalogs, req.Vars["name"]) + }) + + // Registered Models: + + server.Handle("GET", "/api/2.1/unity-catalog/models/{full_name}", func(req Request) any { + return MapGet(req.Workspace, req.Workspace.RegisteredModels, req.Vars["full_name"]) + }) + + server.Handle("POST", "/api/2.1/unity-catalog/models", func(req Request) any { + return req.Workspace.RegisteredModelsCreate(req) + }) + + server.Handle("PATCH", "/api/2.1/unity-catalog/models/{full_name}", func(req Request) any { + return req.Workspace.RegisteredModelsUpdate(req, req.Vars["full_name"]) + }) + + server.Handle("DELETE", "/api/2.1/unity-catalog/models/{full_name}", func(req Request) any { + return MapDelete(req.Workspace, req.Workspace.RegisteredModels, req.Vars["full_name"]) + }) + // Volumes: server.Handle("GET", "/api/2.1/unity-catalog/volumes/{full_name}", func(req Request) any { diff --git a/libs/testserver/registered_models.go b/libs/testserver/registered_models.go new file mode 100644 index 0000000000..865e6c3b5a --- /dev/null +++ b/libs/testserver/registered_models.go @@ -0,0 +1,87 @@ +package testserver + +import ( + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/databricks/databricks-sdk-go/service/catalog" + "github.com/google/uuid" +) + +func (s *FakeWorkspace) RegisteredModelsCreate(req Request) Response { + defer s.LockUnlock()() + + var createRequest catalog.CreateRegisteredModelRequest + if err := json.Unmarshal(req.Body, &createRequest); err != nil { + return Response{ + Body: fmt.Sprintf("internal error: %s", err), + StatusCode: http.StatusInternalServerError, + } + } + + // Build full name from catalog.schema.name + fullName := createRequest.CatalogName + "." + createRequest.SchemaName + "." + createRequest.Name + + registeredModel := catalog.RegisteredModelInfo{ + CatalogName: createRequest.CatalogName, + Comment: createRequest.Comment, + Name: createRequest.Name, + SchemaName: createRequest.SchemaName, + StorageLocation: createRequest.StorageLocation, + FullName: fullName, + CreatedAt: time.Now().UnixMilli(), + CreatedBy: s.CurrentUser().UserName, + UpdatedAt: time.Now().UnixMilli(), + UpdatedBy: s.CurrentUser().UserName, + MetastoreId: uuid.New().String(), + Owner: s.CurrentUser().UserName, + } + + s.RegisteredModels[fullName] = registeredModel + return Response{ + Body: registeredModel, + } +} + +func (s *FakeWorkspace) RegisteredModelsUpdate(req Request, fullName string) Response { + defer s.LockUnlock()() + + existing, ok := s.RegisteredModels[fullName] + if !ok { + return Response{ + StatusCode: http.StatusNotFound, + Body: fmt.Sprintf("registered model %s not found", fullName), + } + } + + var updateRequest catalog.UpdateRegisteredModelRequest + if err := json.Unmarshal(req.Body, &updateRequest); err != nil { + return Response{ + Body: fmt.Sprintf("internal error: %s", err), + StatusCode: http.StatusInternalServerError, + } + } + + // Update only the fields that can be updated + if updateRequest.Comment != "" { + existing.Comment = updateRequest.Comment + } + if updateRequest.Owner != "" { + existing.Owner = updateRequest.Owner + } + if updateRequest.NewName != "" { + existing.Name = updateRequest.NewName + + // Delete the old entry and set full name to the new name + delete(s.RegisteredModels, fullName) + fullName = existing.CatalogName + "." + existing.SchemaName + "." + updateRequest.NewName + } + + existing.UpdatedAt = time.Now().UnixMilli() + s.RegisteredModels[fullName] = existing + return Response{ + Body: existing, + } +}