diff --git a/src/platform/assets/components/AssetBrowserModal.vue b/src/platform/assets/components/AssetBrowserModal.vue index 01e03782b..322c12945 100644 --- a/src/platform/assets/components/AssetBrowserModal.vue +++ b/src/platform/assets/components/AssetBrowserModal.vue @@ -1,6 +1,6 @@ - - {{ modelType }} + + ) const baseModels = ref(getAssetBaseModels(asset)) const additionalTags = ref(getAssetAdditionalTags(asset)) +const selectedModelType = ref( + getAssetModelType(asset) ?? undefined +) watch( () => asset, () => { baseModels.value = getAssetBaseModels(asset) additionalTags.value = getAssetAdditionalTags(asset) + selectedModelType.value = getAssetModelType(asset) ?? undefined } ) const description = computed(() => getAssetDescription(asset)) const triggerPhrases = computed(() => getAssetTriggerPhrases(asset)) const isImmutable = computed(() => asset.is_immutable ?? true) -const modelType = computed(() => { - const typeTag = asset.tags.find((tag) => tag !== 'models') - if (!typeTag) return null - return typeTag.includes('/') ? typeTag.split('/').pop() : typeTag -}) +const { modelTypes } = useModelTypes() const assetsStore = useAssetsStore() @@ -200,6 +211,19 @@ async function saveMetadata() { ) } +async function saveModelType(newModelType: string | undefined) { + if (isImmutable.value || !newModelType) return + + const currentModelType = getAssetModelType(asset) + if (currentModelType === newModelType) return + + const newTags = asset.tags + .filter((tag) => tag !== currentModelType) + .concat(newModelType) + await assetsStore.updateAssetTags(asset.id, newTags, cacheKey) +} + watchDebounced(baseModels, saveMetadata, { debounce: 500 }) watchDebounced(additionalTags, saveMetadata, { debounce: 500 }) +watchDebounced(selectedModelType, saveModelType, { debounce: 500 }) diff --git a/src/platform/assets/composables/useModelTypes.ts b/src/platform/assets/composables/useModelTypes.ts index a60d94813..3f10f6c07 100644 --- a/src/platform/assets/composables/useModelTypes.ts +++ b/src/platform/assets/composables/useModelTypes.ts @@ -46,9 +46,10 @@ const DISALLOWED_MODEL_TYPES = ['nlf'] as const export const useModelTypes = createSharedComposable(() => { const { state: modelTypes, + isReady, isLoading, error, - execute: fetchModelTypes + execute } = useAsyncState( async (): Promise => { const response = await api.getModelFolders() @@ -74,6 +75,11 @@ export const useModelTypes = createSharedComposable(() => { } ) + function fetchModelTypes() { + if (isReady.value || isLoading.value) return + return execute() + } + return { modelTypes, isLoading, diff --git a/src/platform/assets/utils/assetMetadataUtils.ts b/src/platform/assets/utils/assetMetadataUtils.ts index 47712cef3..e69aa79ce 100644 --- a/src/platform/assets/utils/assetMetadataUtils.ts +++ b/src/platform/assets/utils/assetMetadataUtils.ts @@ -110,3 +110,14 @@ export function getSourceName(url: string): string { if (url.includes('huggingface.co')) return 'Hugging Face' return 'Source' } + +/** + * Extracts the model type from asset tags + * @param asset - The asset to extract model type from + * @returns The model type string or null if not present + */ +export function getAssetModelType(asset: AssetItem): string | null { + const typeTag = asset.tags?.find((tag) => tag !== 'models') + if (!typeTag) return null + return typeTag.includes('/') ? (typeTag.split('/').pop() ?? null) : typeTag +} diff --git a/src/stores/assetsStore.ts b/src/stores/assetsStore.ts index 369b84266..0cee838c6 100644 --- a/src/stores/assetsStore.ts +++ b/src/stores/assetsStore.ts @@ -384,13 +384,29 @@ export const useAssetsStore = defineStore('assets', () => { await assetService.updateAsset(assetId, { user_metadata: userMetadata }) } + /** + * Update asset tags with optimistic cache update + * @param assetId The asset ID to update + * @param tags The tags array to save + * @param cacheKey Optional cache key to target for optimistic update + */ + async function updateAssetTags( + assetId: string, + tags: string[], + cacheKey?: string + ) { + updateAssetInCache(assetId, { tags }, cacheKey) + await assetService.updateAsset(assetId, { tags }) + } + return { modelAssetsByNodeType, modelLoadingByNodeType, modelErrorByNodeType, updateModelsForNodeType, updateModelsForTag, - updateAssetMetadata + updateAssetMetadata, + updateAssetTags } } @@ -400,7 +416,8 @@ export const useAssetsStore = defineStore('assets', () => { modelErrorByNodeType: shallowReactive(new Map()), updateModelsForNodeType: async () => [], updateModelsForTag: async () => [], - updateAssetMetadata: async () => {} + updateAssetMetadata: async () => {}, + updateAssetTags: async () => {} } } @@ -410,7 +427,8 @@ export const useAssetsStore = defineStore('assets', () => { modelErrorByNodeType, updateModelsForNodeType, updateModelsForTag, - updateAssetMetadata + updateAssetMetadata, + updateAssetTags } = getModelState() // Watch for completed downloads and refresh model caches @@ -476,6 +494,7 @@ export const useAssetsStore = defineStore('assets', () => { modelErrorByNodeType, updateModelsForNodeType, updateModelsForTag, - updateAssetMetadata + updateAssetMetadata, + updateAssetTags } })